Browse Source

AST+Eval: Prevent copying lists where unneeded

master
Riyyi 2 years ago
parent
commit
d34ab1efab
  1. 9
      src/ast.cpp
  2. 1
      src/ast.h
  3. 6
      src/eval-special-form.cpp
  4. 8
      src/eval.cpp
  5. 48
      src/functions.cpp

9
src/ast.cpp

@ -31,6 +31,15 @@ void Collection::add(ValuePtr node)
m_nodes.push_back(node); m_nodes.push_back(node);
} }
void Collection::addFront(ValuePtr node)
{
if (node == nullptr) {
return;
}
m_nodes.push_front(node);
}
// ----------------------------------------- // -----------------------------------------
List::List(const std::list<ValuePtr>& nodes) List::List(const std::list<ValuePtr>& nodes)

1
src/ast.h

@ -61,6 +61,7 @@ public:
virtual ~Collection() = default; virtual ~Collection() = default;
void add(ValuePtr node); void add(ValuePtr node);
void addFront(ValuePtr node);
size_t size() const { return m_nodes.size(); } size_t size() const { return m_nodes.size(); }
bool empty() const { return m_nodes.size() == 0; } bool empty() const { return m_nodes.size() == 0; }

6
src/eval-special-form.cpp

@ -243,11 +243,11 @@ static ValuePtr evalQuasiQuoteImpl(ValuePtr ast)
ValuePtr result = makePtr<List>(); ValuePtr result = makePtr<List>();
auto nodes = std::static_pointer_cast<Collection>(ast)->nodes(); const auto& nodes = std::static_pointer_cast<Collection>(ast)->nodes();
// `() or `(1 ~2 3) or `(1 ~@(list 2 2 2) 3) // `() or `(1 ~2 3) or `(1 ~@(list 2 2 2) 3)
for (auto it = nodes.rbegin(); it != nodes.rend(); ++it) { for (auto it = nodes.crbegin(); it != nodes.crend(); ++it) {
const auto elt = *it; const auto& elt = *it;
const auto splice_unquote = startsWith(elt, "splice-unquote"); // (list 2 2 2) const auto splice_unquote = startsWith(elt, "splice-unquote"); // (list 2 2 2)
if (splice_unquote) { if (splice_unquote) {

8
src/eval.cpp

@ -167,8 +167,8 @@ ValuePtr Eval::evalAst(ValuePtr ast, EnvironmentPtr env)
else if (is<Collection>(ast_raw_ptr)) { else if (is<Collection>(ast_raw_ptr)) {
std::shared_ptr<Collection> result = nullptr; std::shared_ptr<Collection> result = nullptr;
(is<List>(ast_raw_ptr)) ? result = makePtr<List>() : result = makePtr<Vector>(); (is<List>(ast_raw_ptr)) ? result = makePtr<List>() : result = makePtr<Vector>();
auto nodes = std::static_pointer_cast<Collection>(ast)->nodes(); const auto& nodes = std::static_pointer_cast<Collection>(ast)->nodes();
for (auto node : nodes) { for (const auto& node : nodes) {
m_ast_stack.push(node); m_ast_stack.push(node);
m_env_stack.push(env); m_env_stack.push(env);
ValuePtr eval_node = evalImpl(); ValuePtr eval_node = evalImpl();
@ -181,8 +181,8 @@ ValuePtr Eval::evalAst(ValuePtr ast, EnvironmentPtr env)
} }
else if (is<HashMap>(ast_raw_ptr)) { else if (is<HashMap>(ast_raw_ptr)) {
auto result = makePtr<HashMap>(); auto result = makePtr<HashMap>();
auto elements = std::static_pointer_cast<HashMap>(ast)->elements(); const auto& elements = std::static_pointer_cast<HashMap>(ast)->elements();
for (auto& element : elements) { for (const auto& element : elements) {
m_ast_stack.push(element.second); m_ast_stack.push(element.second);
m_env_stack.push(env); m_env_stack.push(env);
ValuePtr element_node = evalImpl(); ValuePtr element_node = evalImpl();

48
src/functions.cpp

@ -237,15 +237,15 @@ ADD_FUNCTION(
[&equal](ValuePtr lhs, ValuePtr rhs) -> bool { [&equal](ValuePtr lhs, ValuePtr rhs) -> bool {
if ((is<List>(lhs.get()) || is<Vector>(lhs.get())) if ((is<List>(lhs.get()) || is<Vector>(lhs.get()))
&& (is<List>(rhs.get()) || is<Vector>(rhs.get()))) { && (is<List>(rhs.get()) || is<Vector>(rhs.get()))) {
auto lhs_nodes = std::static_pointer_cast<Collection>(lhs)->nodes(); const auto& lhs_nodes = std::static_pointer_cast<Collection>(lhs)->nodes();
auto rhs_nodes = std::static_pointer_cast<Collection>(rhs)->nodes(); const auto& rhs_nodes = std::static_pointer_cast<Collection>(rhs)->nodes();
if (lhs_nodes.size() != rhs_nodes.size()) { if (lhs_nodes.size() != rhs_nodes.size()) {
return false; return false;
} }
auto lhs_it = lhs_nodes.begin(); auto lhs_it = lhs_nodes.cbegin();
auto rhs_it = rhs_nodes.begin(); auto rhs_it = rhs_nodes.cbegin();
for (; lhs_it != lhs_nodes.end(); ++lhs_it, ++rhs_it) { for (; lhs_it != lhs_nodes.end(); ++lhs_it, ++rhs_it) {
if (!equal(*lhs_it, *rhs_it)) { if (!equal(*lhs_it, *rhs_it)) {
return false; return false;
@ -256,8 +256,8 @@ ADD_FUNCTION(
} }
if (is<HashMap>(lhs.get()) && is<HashMap>(rhs.get())) { if (is<HashMap>(lhs.get()) && is<HashMap>(rhs.get())) {
auto lhs_nodes = std::static_pointer_cast<HashMap>(lhs)->elements(); const auto& lhs_nodes = std::static_pointer_cast<HashMap>(lhs)->elements();
auto rhs_nodes = std::static_pointer_cast<HashMap>(rhs)->elements(); const auto& rhs_nodes = std::static_pointer_cast<HashMap>(rhs)->elements();
if (lhs_nodes.size() != rhs_nodes.size()) { if (lhs_nodes.size() != rhs_nodes.size()) {
return false; return false;
@ -265,7 +265,7 @@ ADD_FUNCTION(
for (const auto& [key, value] : lhs_nodes) { for (const auto& [key, value] : lhs_nodes) {
auto it = rhs_nodes.find(key); auto it = rhs_nodes.find(key);
if (it == rhs_nodes.end() || !equal(value, it->second)) { if (it == rhs_nodes.cend() || !equal(value, it->second)) {
return false; return false;
} }
} }
@ -384,8 +384,7 @@ ADD_FUNCTION(
VALUE_CAST(atom, Atom, nodes.front()); VALUE_CAST(atom, Atom, nodes.front());
auto callable = *std::next(nodes.begin()); VALUE_CAST(callable, Callable, (*std::next(nodes.begin())));
IS_VALUE(Callable, callable);
// Remove atom and function from the argument list, add atom value // Remove atom and function from the argument list, add atom value
nodes.pop_front(); nodes.pop_front();
@ -405,7 +404,7 @@ ADD_FUNCTION(
return atom->reset(value); return atom->reset(value);
}); });
// (cons 1 (list 2 3)) // (cons 1 (list 2 3)) -> (1 2 3)
ADD_FUNCTION( ADD_FUNCTION(
"cons", "cons",
{ {
@ -413,13 +412,13 @@ ADD_FUNCTION(
VALUE_CAST(collection, Collection, (*std::next(nodes.begin()))); VALUE_CAST(collection, Collection, (*std::next(nodes.begin())));
auto result_nodes = collection->nodes(); auto result = makePtr<List>(collection->nodes());
result_nodes.push_front(nodes.front()); result->addFront(nodes.front());
return makePtr<List>(result_nodes); return result;
}); });
// (concat (list 1) (list 2 3)) // (concat (list 1) (list 2 3)) -> (1 2 3)
ADD_FUNCTION( ADD_FUNCTION(
"concat", "concat",
{ {
@ -434,7 +433,7 @@ ADD_FUNCTION(
return makePtr<List>(result_nodes); return makePtr<List>(result_nodes);
}); });
// (vec (list 1 2 3)) // (vec (list 1 2 3)) -> [1 2 3]
ADD_FUNCTION( ADD_FUNCTION(
"vec", "vec",
{ {
@ -449,7 +448,7 @@ ADD_FUNCTION(
return makePtr<Vector>(collection->nodes()); return makePtr<Vector>(collection->nodes());
}); });
// (nth (list 1 2 3) 0) // (nth (list 1 2 3) 0) -> 1
ADD_FUNCTION( ADD_FUNCTION(
"nth", "nth",
{ {
@ -457,7 +456,7 @@ ADD_FUNCTION(
VALUE_CAST(collection, Collection, nodes.front()); VALUE_CAST(collection, Collection, nodes.front());
VALUE_CAST(number_node, Number, (*std::next(nodes.begin()))); VALUE_CAST(number_node, Number, (*std::next(nodes.begin())));
auto collection_nodes = collection->nodes(); const auto& collection_nodes = collection->nodes();
auto index = (size_t)number_node->number(); auto index = (size_t)number_node->number();
if (number_node->number() < 0 || index >= collection_nodes.size()) { if (number_node->number() < 0 || index >= collection_nodes.size()) {
@ -465,7 +464,7 @@ ADD_FUNCTION(
return nullptr; return nullptr;
} }
auto result = collection_nodes.begin(); auto result = collection_nodes.cbegin();
for (size_t i = 0; i < index; ++i) { for (size_t i = 0; i < index; ++i) {
result++; result++;
} }
@ -473,7 +472,7 @@ ADD_FUNCTION(
return *result; return *result;
}); });
// (first (list 1 2 3)) // (first (list 1 2 3)) -> 1
ADD_FUNCTION( ADD_FUNCTION(
"first", "first",
{ {
@ -485,12 +484,11 @@ ADD_FUNCTION(
} }
VALUE_CAST(collection, Collection, nodes.front()); VALUE_CAST(collection, Collection, nodes.front());
auto collection_nodes = collection->nodes();
return (collection_nodes.empty()) ? makePtr<Constant>() : collection_nodes.front(); return (collection->empty()) ? makePtr<Constant>() : collection->front();
}); });
// (rest (list 1 2 3)) // (rest (list 1 2 3)) -> (2 3)
ADD_FUNCTION( ADD_FUNCTION(
"rest", "rest",
{ {
@ -550,19 +548,19 @@ ADD_FUNCTION(
VALUE_CAST(callable, Callable, nodes.front()); VALUE_CAST(callable, Callable, nodes.front());
VALUE_CAST(collection, Collection, nodes.back()); VALUE_CAST(collection, Collection, nodes.back());
auto collection_nodes = collection->nodes(); const auto& collection_nodes = collection->nodes();
auto result = makePtr<List>(); auto result = makePtr<List>();
if (is<Function>(callable.get())) { if (is<Function>(callable.get())) {
auto function = std::static_pointer_cast<Function>(callable)->function(); auto function = std::static_pointer_cast<Function>(callable)->function();
for (auto node : collection_nodes) { for (const auto& node : collection_nodes) {
result->add(function({ node })); result->add(function({ node }));
} }
} }
else { else {
auto lambda = std::static_pointer_cast<Lambda>(callable); auto lambda = std::static_pointer_cast<Lambda>(callable);
for (auto node : collection_nodes) { for (const auto& node : collection_nodes) {
result->add(eval(lambda->body(), Environment::create(lambda, { node }))); result->add(eval(lambda->body(), Environment::create(lambda, { node })));
} }
} }

Loading…
Cancel
Save