Browse Source

Eval: Implement tail call optimization (TCO) via stack iteration

master
Riyyi 1 year ago
parent
commit
4d3c2a4ca2
  1. 65
      src/eval.cpp
  2. 6
      src/eval.h

65
src/eval.cpp

@ -66,14 +66,16 @@ Eval::Eval(ASTNodePtr ast, EnvironmentPtr env)
} \
\
std::string key = std::static_pointer_cast<Symbol>(*it)->symbol(); \
ASTNodePtr value = evalImpl(*std::next(it), let_env); \
m_ast_stack.push(*std::next(it)); \
m_env_stack.push(let_env); \
ASTNodePtr value = evalImpl(); \
let_env->set(key, value); \
} \
\
/* TODO: Remove limitation of 3 arguments */ \
/* Eval all values in this new env, return last sexp of the result */ \
ast = second_argument; \
env = let_env; \
m_ast_stack.push(second_argument); \
m_env_stack.push(let_env); \
continue; /* TCO */ \
}
@ -86,11 +88,14 @@ Eval::Eval(ASTNodePtr ast, EnvironmentPtr env)
\
/* Evaluate all nodes except the last */ \
for (auto it = nodes.begin(); it != std::prev(nodes.end(), 1); ++it) { \
evalImpl(*it, env); \
m_ast_stack.push(*it); \
m_env_stack.push(env); \
evalImpl(); \
} \
\
/* Eval last node */ \
ast = nodes.back(); \
m_ast_stack.push(nodes.back()); \
m_env_stack.push(env); \
continue; /* TCO */ \
}
@ -105,34 +110,52 @@ Eval::Eval(ASTNodePtr ast, EnvironmentPtr env)
auto second_argument = *std::next(nodes.begin()); \
auto third_argument = (nodes.size() == 3) ? *std::next(std::next(nodes.begin())) : makePtr<Value>(Value::Nil); \
\
auto first_evaluated = evalImpl(first_argument, env); \
m_ast_stack.push(first_argument); \
m_env_stack.push(env); \
auto first_evaluated = evalImpl(); \
if (!is<Value>(first_evaluated.get()) \
|| std::static_pointer_cast<Value>(first_evaluated)->state() == Value::True) { \
ast = second_argument; \
m_ast_stack.push(second_argument); \
m_env_stack.push(env); \
continue; /* TCO */ \
} \
else { \
ast = third_argument; \
m_ast_stack.push(third_argument); \
m_env_stack.push(env); \
continue; /* TCO */ \
} \
}
void Eval::eval()
{
m_ast = evalImpl(m_ast, m_env);
m_ast_stack = std::stack<ASTNodePtr>();
m_env_stack = std::stack<EnvironmentPtr>();
m_ast_stack.push(m_ast);
m_env_stack.push(m_env);
m_ast = evalImpl();
}
ASTNodePtr Eval::evalImpl(ASTNodePtr ast, EnvironmentPtr env)
ASTNodePtr Eval::evalImpl()
{
ASTNodePtr ast = nullptr;
EnvironmentPtr env = nullptr;
while (true) {
if (ast == nullptr) {
if (m_ast_stack.size() == 0) {
return nullptr;
}
if (env == nullptr) {
env = m_env;
if (m_env_stack.size() == 0) {
m_env_stack.push(m_env);
}
ast = m_ast_stack.top();
env = m_env_stack.top();
m_ast_stack.pop();
m_env_stack.pop();
if (!is<List>(ast.get())) {
return evalAst(ast, env);
}
@ -180,8 +203,8 @@ ASTNodePtr Eval::evalImpl(ASTNodePtr ast, EnvironmentPtr env)
// cdr
evaluated_nodes.pop_front();
ast = lambda->body();
env = Environment::create(lambda, evaluated_nodes);
m_ast_stack.push(lambda->body());
m_env_stack.push(Environment::create(lambda, evaluated_nodes));
continue; // TCO
}
@ -210,7 +233,9 @@ ASTNodePtr Eval::evalAst(ASTNodePtr ast, EnvironmentPtr env)
(is<List>(ast_raw_ptr)) ? result = makePtr<List>() : result = makePtr<Vector>();
auto nodes = std::static_pointer_cast<Collection>(ast)->nodes();
for (auto node : nodes) {
ASTNodePtr eval_node = evalImpl(node, env);
m_ast_stack.push(node);
m_env_stack.push(env);
ASTNodePtr eval_node = evalImpl();
if (eval_node == nullptr) {
return nullptr;
}
@ -222,7 +247,9 @@ ASTNodePtr Eval::evalAst(ASTNodePtr ast, EnvironmentPtr env)
auto result = makePtr<HashMap>();
auto elements = std::static_pointer_cast<HashMap>(ast)->elements();
for (auto& element : elements) {
ASTNodePtr element_node = evalImpl(element.second, env);
m_ast_stack.push(element.second);
m_env_stack.push(env);
ASTNodePtr element_node = evalImpl();
if (element_node == nullptr) {
return nullptr;
}
@ -251,7 +278,9 @@ ASTNodePtr Eval::evalDef(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env)
}
std::string symbol = std::static_pointer_cast<Symbol>(first_argument)->symbol();
ASTNodePtr value = evalImpl(second_argument, env);
m_ast_stack.push(second_argument);
m_env_stack.push(env);
ASTNodePtr value = evalImpl();
// Dont overwrite symbols after an error
if (Error::the().hasAnyError()) {

6
src/eval.h

@ -7,6 +7,7 @@
#pragma once
#include <list>
#include <stack>
#include "environment.h"
#include "forward.h"
@ -25,7 +26,7 @@ public:
ASTNodePtr ast() const { return m_ast; }
private:
ASTNodePtr evalImpl(ASTNodePtr ast, EnvironmentPtr env);
ASTNodePtr evalImpl();
ASTNodePtr evalAst(ASTNodePtr ast, EnvironmentPtr env);
ASTNodePtr evalDef(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env);
ASTNodePtr evalFn(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env);
@ -33,6 +34,9 @@ private:
ASTNodePtr m_ast;
EnvironmentPtr m_env;
std::stack<ASTNodePtr> m_ast_stack;
std::stack<EnvironmentPtr> m_env_stack;
};
} // namespace blaze

Loading…
Cancel
Save