From 4d3c2a4ca2805e1528065a959c7354996380d882 Mon Sep 17 00:00:00 2001 From: Riyyi Date: Mon, 3 Apr 2023 23:02:04 +0200 Subject: [PATCH] Eval: Implement tail call optimization (TCO) via stack iteration --- src/eval.cpp | 65 +++++++++++++++++++++++++++++++++++++--------------- src/eval.h | 6 ++++- 2 files changed, 52 insertions(+), 19 deletions(-) diff --git a/src/eval.cpp b/src/eval.cpp index cdcadc2..1d8de65 100644 --- a/src/eval.cpp +++ b/src/eval.cpp @@ -66,14 +66,16 @@ Eval::Eval(ASTNodePtr ast, EnvironmentPtr env) } \ \ std::string key = std::static_pointer_cast(*it)->symbol(); \ - ASTNodePtr value = evalImpl(*std::next(it), let_env); \ + m_ast_stack.push(*std::next(it)); \ + m_env_stack.push(let_env); \ + ASTNodePtr value = evalImpl(); \ let_env->set(key, value); \ } \ \ /* TODO: Remove limitation of 3 arguments */ \ /* Eval all values in this new env, return last sexp of the result */ \ - ast = second_argument; \ - env = let_env; \ + m_ast_stack.push(second_argument); \ + m_env_stack.push(let_env); \ continue; /* TCO */ \ } @@ -86,11 +88,14 @@ Eval::Eval(ASTNodePtr ast, EnvironmentPtr env) \ /* Evaluate all nodes except the last */ \ for (auto it = nodes.begin(); it != std::prev(nodes.end(), 1); ++it) { \ - evalImpl(*it, env); \ + m_ast_stack.push(*it); \ + m_env_stack.push(env); \ + evalImpl(); \ } \ \ /* Eval last node */ \ - ast = nodes.back(); \ + m_ast_stack.push(nodes.back()); \ + m_env_stack.push(env); \ continue; /* TCO */ \ } @@ -105,34 +110,52 @@ Eval::Eval(ASTNodePtr ast, EnvironmentPtr env) auto second_argument = *std::next(nodes.begin()); \ auto third_argument = (nodes.size() == 3) ? *std::next(std::next(nodes.begin())) : makePtr(Value::Nil); \ \ - auto first_evaluated = evalImpl(first_argument, env); \ + m_ast_stack.push(first_argument); \ + m_env_stack.push(env); \ + auto first_evaluated = evalImpl(); \ if (!is(first_evaluated.get()) \ || std::static_pointer_cast(first_evaluated)->state() == Value::True) { \ - ast = second_argument; \ + m_ast_stack.push(second_argument); \ + m_env_stack.push(env); \ continue; /* TCO */ \ } \ else { \ - ast = third_argument; \ + m_ast_stack.push(third_argument); \ + m_env_stack.push(env); \ continue; /* TCO */ \ } \ } void Eval::eval() { - m_ast = evalImpl(m_ast, m_env); + m_ast_stack = std::stack(); + m_env_stack = std::stack(); + m_ast_stack.push(m_ast); + m_env_stack.push(m_env); + + m_ast = evalImpl(); } -ASTNodePtr Eval::evalImpl(ASTNodePtr ast, EnvironmentPtr env) +ASTNodePtr Eval::evalImpl() { + ASTNodePtr ast = nullptr; + EnvironmentPtr env = nullptr; + while (true) { - if (ast == nullptr) { + if (m_ast_stack.size() == 0) { return nullptr; } - if (env == nullptr) { - env = m_env; + if (m_env_stack.size() == 0) { + m_env_stack.push(m_env); } + ast = m_ast_stack.top(); + env = m_env_stack.top(); + + m_ast_stack.pop(); + m_env_stack.pop(); + if (!is(ast.get())) { return evalAst(ast, env); } @@ -180,8 +203,8 @@ ASTNodePtr Eval::evalImpl(ASTNodePtr ast, EnvironmentPtr env) // cdr evaluated_nodes.pop_front(); - ast = lambda->body(); - env = Environment::create(lambda, evaluated_nodes); + m_ast_stack.push(lambda->body()); + m_env_stack.push(Environment::create(lambda, evaluated_nodes)); continue; // TCO } @@ -210,7 +233,9 @@ ASTNodePtr Eval::evalAst(ASTNodePtr ast, EnvironmentPtr env) (is(ast_raw_ptr)) ? result = makePtr() : result = makePtr(); auto nodes = std::static_pointer_cast(ast)->nodes(); for (auto node : nodes) { - ASTNodePtr eval_node = evalImpl(node, env); + m_ast_stack.push(node); + m_env_stack.push(env); + ASTNodePtr eval_node = evalImpl(); if (eval_node == nullptr) { return nullptr; } @@ -222,7 +247,9 @@ ASTNodePtr Eval::evalAst(ASTNodePtr ast, EnvironmentPtr env) auto result = makePtr(); auto elements = std::static_pointer_cast(ast)->elements(); for (auto& element : elements) { - ASTNodePtr element_node = evalImpl(element.second, env); + m_ast_stack.push(element.second); + m_env_stack.push(env); + ASTNodePtr element_node = evalImpl(); if (element_node == nullptr) { return nullptr; } @@ -251,7 +278,9 @@ ASTNodePtr Eval::evalDef(const std::list& nodes, EnvironmentPtr env) } std::string symbol = std::static_pointer_cast(first_argument)->symbol(); - ASTNodePtr value = evalImpl(second_argument, env); + m_ast_stack.push(second_argument); + m_env_stack.push(env); + ASTNodePtr value = evalImpl(); // Dont overwrite symbols after an error if (Error::the().hasAnyError()) { diff --git a/src/eval.h b/src/eval.h index b160ca3..873107b 100644 --- a/src/eval.h +++ b/src/eval.h @@ -7,6 +7,7 @@ #pragma once #include +#include #include "environment.h" #include "forward.h" @@ -25,7 +26,7 @@ public: ASTNodePtr ast() const { return m_ast; } private: - ASTNodePtr evalImpl(ASTNodePtr ast, EnvironmentPtr env); + ASTNodePtr evalImpl(); ASTNodePtr evalAst(ASTNodePtr ast, EnvironmentPtr env); ASTNodePtr evalDef(const std::list& nodes, EnvironmentPtr env); ASTNodePtr evalFn(const std::list& nodes, EnvironmentPtr env); @@ -33,6 +34,9 @@ private: ASTNodePtr m_ast; EnvironmentPtr m_env; + + std::stack m_ast_stack; + std::stack m_env_stack; }; } // namespace blaze