Browse Source

Eval: Implement tail call optimization (TCO) via stack iteration

master
Riyyi 1 year ago
parent
commit
4d3c2a4ca2
  1. 65
      src/eval.cpp
  2. 6
      src/eval.h

65
src/eval.cpp

@ -66,14 +66,16 @@ Eval::Eval(ASTNodePtr ast, EnvironmentPtr env)
} \ } \
\ \
std::string key = std::static_pointer_cast<Symbol>(*it)->symbol(); \ std::string key = std::static_pointer_cast<Symbol>(*it)->symbol(); \
ASTNodePtr value = evalImpl(*std::next(it), let_env); \ m_ast_stack.push(*std::next(it)); \
m_env_stack.push(let_env); \
ASTNodePtr value = evalImpl(); \
let_env->set(key, value); \ let_env->set(key, value); \
} \ } \
\ \
/* TODO: Remove limitation of 3 arguments */ \ /* TODO: Remove limitation of 3 arguments */ \
/* Eval all values in this new env, return last sexp of the result */ \ /* Eval all values in this new env, return last sexp of the result */ \
ast = second_argument; \ m_ast_stack.push(second_argument); \
env = let_env; \ m_env_stack.push(let_env); \
continue; /* TCO */ \ continue; /* TCO */ \
} }
@ -86,11 +88,14 @@ Eval::Eval(ASTNodePtr ast, EnvironmentPtr env)
\ \
/* Evaluate all nodes except the last */ \ /* Evaluate all nodes except the last */ \
for (auto it = nodes.begin(); it != std::prev(nodes.end(), 1); ++it) { \ for (auto it = nodes.begin(); it != std::prev(nodes.end(), 1); ++it) { \
evalImpl(*it, env); \ m_ast_stack.push(*it); \
m_env_stack.push(env); \
evalImpl(); \
} \ } \
\ \
/* Eval last node */ \ /* Eval last node */ \
ast = nodes.back(); \ m_ast_stack.push(nodes.back()); \
m_env_stack.push(env); \
continue; /* TCO */ \ continue; /* TCO */ \
} }
@ -105,34 +110,52 @@ Eval::Eval(ASTNodePtr ast, EnvironmentPtr env)
auto second_argument = *std::next(nodes.begin()); \ auto second_argument = *std::next(nodes.begin()); \
auto third_argument = (nodes.size() == 3) ? *std::next(std::next(nodes.begin())) : makePtr<Value>(Value::Nil); \ auto third_argument = (nodes.size() == 3) ? *std::next(std::next(nodes.begin())) : makePtr<Value>(Value::Nil); \
\ \
auto first_evaluated = evalImpl(first_argument, env); \ m_ast_stack.push(first_argument); \
m_env_stack.push(env); \
auto first_evaluated = evalImpl(); \
if (!is<Value>(first_evaluated.get()) \ if (!is<Value>(first_evaluated.get()) \
|| std::static_pointer_cast<Value>(first_evaluated)->state() == Value::True) { \ || std::static_pointer_cast<Value>(first_evaluated)->state() == Value::True) { \
ast = second_argument; \ m_ast_stack.push(second_argument); \
m_env_stack.push(env); \
continue; /* TCO */ \ continue; /* TCO */ \
} \ } \
else { \ else { \
ast = third_argument; \ m_ast_stack.push(third_argument); \
m_env_stack.push(env); \
continue; /* TCO */ \ continue; /* TCO */ \
} \ } \
} }
void Eval::eval() void Eval::eval()
{ {
m_ast = evalImpl(m_ast, m_env); m_ast_stack = std::stack<ASTNodePtr>();
m_env_stack = std::stack<EnvironmentPtr>();
m_ast_stack.push(m_ast);
m_env_stack.push(m_env);
m_ast = evalImpl();
} }
ASTNodePtr Eval::evalImpl(ASTNodePtr ast, EnvironmentPtr env) ASTNodePtr Eval::evalImpl()
{ {
ASTNodePtr ast = nullptr;
EnvironmentPtr env = nullptr;
while (true) { while (true) {
if (ast == nullptr) { if (m_ast_stack.size() == 0) {
return nullptr; return nullptr;
} }
if (env == nullptr) { if (m_env_stack.size() == 0) {
env = m_env; m_env_stack.push(m_env);
} }
ast = m_ast_stack.top();
env = m_env_stack.top();
m_ast_stack.pop();
m_env_stack.pop();
if (!is<List>(ast.get())) { if (!is<List>(ast.get())) {
return evalAst(ast, env); return evalAst(ast, env);
} }
@ -180,8 +203,8 @@ ASTNodePtr Eval::evalImpl(ASTNodePtr ast, EnvironmentPtr env)
// cdr // cdr
evaluated_nodes.pop_front(); evaluated_nodes.pop_front();
ast = lambda->body(); m_ast_stack.push(lambda->body());
env = Environment::create(lambda, evaluated_nodes); m_env_stack.push(Environment::create(lambda, evaluated_nodes));
continue; // TCO continue; // TCO
} }
@ -210,7 +233,9 @@ ASTNodePtr Eval::evalAst(ASTNodePtr ast, EnvironmentPtr env)
(is<List>(ast_raw_ptr)) ? result = makePtr<List>() : result = makePtr<Vector>(); (is<List>(ast_raw_ptr)) ? result = makePtr<List>() : result = makePtr<Vector>();
auto nodes = std::static_pointer_cast<Collection>(ast)->nodes(); auto nodes = std::static_pointer_cast<Collection>(ast)->nodes();
for (auto node : nodes) { for (auto node : nodes) {
ASTNodePtr eval_node = evalImpl(node, env); m_ast_stack.push(node);
m_env_stack.push(env);
ASTNodePtr eval_node = evalImpl();
if (eval_node == nullptr) { if (eval_node == nullptr) {
return nullptr; return nullptr;
} }
@ -222,7 +247,9 @@ ASTNodePtr Eval::evalAst(ASTNodePtr ast, EnvironmentPtr env)
auto result = makePtr<HashMap>(); auto result = makePtr<HashMap>();
auto elements = std::static_pointer_cast<HashMap>(ast)->elements(); auto elements = std::static_pointer_cast<HashMap>(ast)->elements();
for (auto& element : elements) { for (auto& element : elements) {
ASTNodePtr element_node = evalImpl(element.second, env); m_ast_stack.push(element.second);
m_env_stack.push(env);
ASTNodePtr element_node = evalImpl();
if (element_node == nullptr) { if (element_node == nullptr) {
return nullptr; return nullptr;
} }
@ -251,7 +278,9 @@ ASTNodePtr Eval::evalDef(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env)
} }
std::string symbol = std::static_pointer_cast<Symbol>(first_argument)->symbol(); std::string symbol = std::static_pointer_cast<Symbol>(first_argument)->symbol();
ASTNodePtr value = evalImpl(second_argument, env); m_ast_stack.push(second_argument);
m_env_stack.push(env);
ASTNodePtr value = evalImpl();
// Dont overwrite symbols after an error // Dont overwrite symbols after an error
if (Error::the().hasAnyError()) { if (Error::the().hasAnyError()) {

6
src/eval.h

@ -7,6 +7,7 @@
#pragma once #pragma once
#include <list> #include <list>
#include <stack>
#include "environment.h" #include "environment.h"
#include "forward.h" #include "forward.h"
@ -25,7 +26,7 @@ public:
ASTNodePtr ast() const { return m_ast; } ASTNodePtr ast() const { return m_ast; }
private: private:
ASTNodePtr evalImpl(ASTNodePtr ast, EnvironmentPtr env); ASTNodePtr evalImpl();
ASTNodePtr evalAst(ASTNodePtr ast, EnvironmentPtr env); ASTNodePtr evalAst(ASTNodePtr ast, EnvironmentPtr env);
ASTNodePtr evalDef(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env); ASTNodePtr evalDef(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env);
ASTNodePtr evalFn(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env); ASTNodePtr evalFn(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env);
@ -33,6 +34,9 @@ private:
ASTNodePtr m_ast; ASTNodePtr m_ast;
EnvironmentPtr m_env; EnvironmentPtr m_env;
std::stack<ASTNodePtr> m_ast_stack;
std::stack<EnvironmentPtr> m_env_stack;
}; };
} // namespace blaze } // namespace blaze

Loading…
Cancel
Save