From d34ab1efab6680b706072bc66be3766564ff2d2a Mon Sep 17 00:00:00 2001 From: Riyyi Date: Sun, 23 Apr 2023 14:21:11 +0200 Subject: [PATCH] AST+Eval: Prevent copying lists where unneeded --- src/ast.cpp | 9 ++++++++ src/ast.h | 1 + src/eval-special-form.cpp | 6 ++--- src/eval.cpp | 8 +++---- src/functions.cpp | 48 +++++++++++++++++++-------------------- 5 files changed, 40 insertions(+), 32 deletions(-) diff --git a/src/ast.cpp b/src/ast.cpp index 56547ba..c365374 100644 --- a/src/ast.cpp +++ b/src/ast.cpp @@ -31,6 +31,15 @@ void Collection::add(ValuePtr node) m_nodes.push_back(node); } +void Collection::addFront(ValuePtr node) +{ + if (node == nullptr) { + return; + } + + m_nodes.push_front(node); +} + // ----------------------------------------- List::List(const std::list& nodes) diff --git a/src/ast.h b/src/ast.h index 5357287..26d6222 100644 --- a/src/ast.h +++ b/src/ast.h @@ -61,6 +61,7 @@ public: virtual ~Collection() = default; void add(ValuePtr node); + void addFront(ValuePtr node); size_t size() const { return m_nodes.size(); } bool empty() const { return m_nodes.size() == 0; } diff --git a/src/eval-special-form.cpp b/src/eval-special-form.cpp index cbe6d0c..0c8bb1f 100644 --- a/src/eval-special-form.cpp +++ b/src/eval-special-form.cpp @@ -243,11 +243,11 @@ static ValuePtr evalQuasiQuoteImpl(ValuePtr ast) ValuePtr result = makePtr(); - auto nodes = std::static_pointer_cast(ast)->nodes(); + const auto& nodes = std::static_pointer_cast(ast)->nodes(); // `() or `(1 ~2 3) or `(1 ~@(list 2 2 2) 3) - for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) { - const auto elt = *it; + for (auto it = nodes.crbegin(); it != nodes.crend(); ++it) { + const auto& elt = *it; const auto splice_unquote = startsWith(elt, "splice-unquote"); // (list 2 2 2) if (splice_unquote) { diff --git a/src/eval.cpp b/src/eval.cpp index 25d5d5c..c8002c3 100644 --- a/src/eval.cpp +++ b/src/eval.cpp @@ -167,8 +167,8 @@ ValuePtr Eval::evalAst(ValuePtr ast, EnvironmentPtr env) else if (is(ast_raw_ptr)) { std::shared_ptr result = nullptr; (is(ast_raw_ptr)) ? result = makePtr() : result = makePtr(); - auto nodes = std::static_pointer_cast(ast)->nodes(); - for (auto node : nodes) { + const auto& nodes = std::static_pointer_cast(ast)->nodes(); + for (const auto& node : nodes) { m_ast_stack.push(node); m_env_stack.push(env); ValuePtr eval_node = evalImpl(); @@ -181,8 +181,8 @@ ValuePtr Eval::evalAst(ValuePtr ast, EnvironmentPtr env) } else if (is(ast_raw_ptr)) { auto result = makePtr(); - auto elements = std::static_pointer_cast(ast)->elements(); - for (auto& element : elements) { + const auto& elements = std::static_pointer_cast(ast)->elements(); + for (const auto& element : elements) { m_ast_stack.push(element.second); m_env_stack.push(env); ValuePtr element_node = evalImpl(); diff --git a/src/functions.cpp b/src/functions.cpp index 69b0312..4b6a221 100644 --- a/src/functions.cpp +++ b/src/functions.cpp @@ -237,15 +237,15 @@ ADD_FUNCTION( [&equal](ValuePtr lhs, ValuePtr rhs) -> bool { if ((is(lhs.get()) || is(lhs.get())) && (is(rhs.get()) || is(rhs.get()))) { - auto lhs_nodes = std::static_pointer_cast(lhs)->nodes(); - auto rhs_nodes = std::static_pointer_cast(rhs)->nodes(); + const auto& lhs_nodes = std::static_pointer_cast(lhs)->nodes(); + const auto& rhs_nodes = std::static_pointer_cast(rhs)->nodes(); if (lhs_nodes.size() != rhs_nodes.size()) { return false; } - auto lhs_it = lhs_nodes.begin(); - auto rhs_it = rhs_nodes.begin(); + auto lhs_it = lhs_nodes.cbegin(); + auto rhs_it = rhs_nodes.cbegin(); for (; lhs_it != lhs_nodes.end(); ++lhs_it, ++rhs_it) { if (!equal(*lhs_it, *rhs_it)) { return false; @@ -256,8 +256,8 @@ ADD_FUNCTION( } if (is(lhs.get()) && is(rhs.get())) { - auto lhs_nodes = std::static_pointer_cast(lhs)->elements(); - auto rhs_nodes = std::static_pointer_cast(rhs)->elements(); + const auto& lhs_nodes = std::static_pointer_cast(lhs)->elements(); + const auto& rhs_nodes = std::static_pointer_cast(rhs)->elements(); if (lhs_nodes.size() != rhs_nodes.size()) { return false; @@ -265,7 +265,7 @@ ADD_FUNCTION( for (const auto& [key, value] : lhs_nodes) { auto it = rhs_nodes.find(key); - if (it == rhs_nodes.end() || !equal(value, it->second)) { + if (it == rhs_nodes.cend() || !equal(value, it->second)) { return false; } } @@ -384,8 +384,7 @@ ADD_FUNCTION( VALUE_CAST(atom, Atom, nodes.front()); - auto callable = *std::next(nodes.begin()); - IS_VALUE(Callable, callable); + VALUE_CAST(callable, Callable, (*std::next(nodes.begin()))); // Remove atom and function from the argument list, add atom value nodes.pop_front(); @@ -405,7 +404,7 @@ ADD_FUNCTION( return atom->reset(value); }); -// (cons 1 (list 2 3)) +// (cons 1 (list 2 3)) -> (1 2 3) ADD_FUNCTION( "cons", { @@ -413,13 +412,13 @@ ADD_FUNCTION( VALUE_CAST(collection, Collection, (*std::next(nodes.begin()))); - auto result_nodes = collection->nodes(); - result_nodes.push_front(nodes.front()); + auto result = makePtr(collection->nodes()); + result->addFront(nodes.front()); - return makePtr(result_nodes); + return result; }); -// (concat (list 1) (list 2 3)) +// (concat (list 1) (list 2 3)) -> (1 2 3) ADD_FUNCTION( "concat", { @@ -434,7 +433,7 @@ ADD_FUNCTION( return makePtr(result_nodes); }); -// (vec (list 1 2 3)) +// (vec (list 1 2 3)) -> [1 2 3] ADD_FUNCTION( "vec", { @@ -449,7 +448,7 @@ ADD_FUNCTION( return makePtr(collection->nodes()); }); -// (nth (list 1 2 3) 0) +// (nth (list 1 2 3) 0) -> 1 ADD_FUNCTION( "nth", { @@ -457,7 +456,7 @@ ADD_FUNCTION( VALUE_CAST(collection, Collection, nodes.front()); VALUE_CAST(number_node, Number, (*std::next(nodes.begin()))); - auto collection_nodes = collection->nodes(); + const auto& collection_nodes = collection->nodes(); auto index = (size_t)number_node->number(); if (number_node->number() < 0 || index >= collection_nodes.size()) { @@ -465,7 +464,7 @@ ADD_FUNCTION( return nullptr; } - auto result = collection_nodes.begin(); + auto result = collection_nodes.cbegin(); for (size_t i = 0; i < index; ++i) { result++; } @@ -473,7 +472,7 @@ ADD_FUNCTION( return *result; }); -// (first (list 1 2 3)) +// (first (list 1 2 3)) -> 1 ADD_FUNCTION( "first", { @@ -485,12 +484,11 @@ ADD_FUNCTION( } VALUE_CAST(collection, Collection, nodes.front()); - auto collection_nodes = collection->nodes(); - return (collection_nodes.empty()) ? makePtr() : collection_nodes.front(); + return (collection->empty()) ? makePtr() : collection->front(); }); -// (rest (list 1 2 3)) +// (rest (list 1 2 3)) -> (2 3) ADD_FUNCTION( "rest", { @@ -550,19 +548,19 @@ ADD_FUNCTION( VALUE_CAST(callable, Callable, nodes.front()); VALUE_CAST(collection, Collection, nodes.back()); - auto collection_nodes = collection->nodes(); + const auto& collection_nodes = collection->nodes(); auto result = makePtr(); if (is(callable.get())) { auto function = std::static_pointer_cast(callable)->function(); - for (auto node : collection_nodes) { + for (const auto& node : collection_nodes) { result->add(function({ node })); } } else { auto lambda = std::static_pointer_cast(callable); - for (auto node : collection_nodes) { + for (const auto& node : collection_nodes) { result->add(eval(lambda->body(), Environment::create(lambda, { node }))); } }