From 7c62d65d72a2cf14b684aeb031e885e22366c6a7 Mon Sep 17 00:00:00 2001 From: Riyyi Date: Mon, 3 Apr 2023 13:54:32 +0200 Subject: [PATCH] Main+Eval: Implement tail call optimization (TCO) --- CMakeLists.txt | 4 + src/eval.cpp | 288 ++++++++++++++++++++++------------------- src/eval.h | 3 - src/step4_if_fn_do.cpp | 2 +- src/step5_tco.cpp | 136 +++++++++++++++++++ 5 files changed, 294 insertions(+), 139 deletions(-) create mode 100644 src/step5_tco.cpp diff --git a/CMakeLists.txt b/CMakeLists.txt index cf5cc68..a2848cb 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -99,3 +99,7 @@ add_dependencies(test3 ${PROJECT}) add_custom_target(test4 COMMAND env STEP=step_env MAL_IMPL=js ../vendor/mal/runtest.py --deferrable --optional ../vendor/mal/tests/step4_if_fn_do.mal -- ./${PROJECT}) add_dependencies(test4 ${PROJECT}) + +add_custom_target(test5 + COMMAND env STEP=step_env MAL_IMPL=js ../vendor/mal/runtest.py --deferrable --optional ../vendor/mal/tests/step5_tco.mal -- ./${PROJECT}) +add_dependencies(test5 ${PROJECT}) diff --git a/src/eval.cpp b/src/eval.cpp index ec12c7c..cdcadc2 100644 --- a/src/eval.cpp +++ b/src/eval.cpp @@ -25,6 +25,98 @@ Eval::Eval(ASTNodePtr ast, EnvironmentPtr env) { } +// ----------------------------------------- + +#define EVAL_LET(ast, nodes, env) \ + { \ + if (nodes.size() != 2) { \ + Error::the().add(format("wrong number of arguments: let*, {}", nodes.size())); \ + return nullptr; \ + } \ + \ + auto first_argument = *nodes.begin(); \ + auto second_argument = *std::next(nodes.begin()); \ + \ + /* First argument needs to be a List or Vector */ \ + if (!is(first_argument.get())) { \ + Error::the().add(format("wrong argument type: collection, '{}'", first_argument)); \ + return nullptr; \ + } \ + \ + /* Get the nodes out of the List or Vector */ \ + std::list binding_nodes; \ + auto bindings = std::static_pointer_cast(first_argument); \ + binding_nodes = bindings->nodes(); \ + \ + /* List or Vector needs to have an even number of elements */ \ + size_t count = binding_nodes.size(); \ + if (count % 2 != 0) { \ + Error::the().add(format("wrong number of arguments: {}, {}", "let* bindings", count)); \ + return nullptr; \ + } \ + \ + /* Create new environment */ \ + auto let_env = Environment::create(env); \ + \ + for (auto it = binding_nodes.begin(); it != binding_nodes.end(); std::advance(it, 2)) { \ + /* First element needs to be a Symbol */ \ + if (!is(*it->get())) { \ + Error::the().add(format("wrong argument type: symbol, '{}'", *it)); \ + return nullptr; \ + } \ + \ + std::string key = std::static_pointer_cast(*it)->symbol(); \ + ASTNodePtr value = evalImpl(*std::next(it), let_env); \ + let_env->set(key, value); \ + } \ + \ + /* TODO: Remove limitation of 3 arguments */ \ + /* Eval all values in this new env, return last sexp of the result */ \ + ast = second_argument; \ + env = let_env; \ + continue; /* TCO */ \ + } + +#define EVAL_DO(ast, nodes, env) \ + { \ + if (nodes.size() == 0) { \ + Error::the().add(format("wrong number of arguments: do, {}", nodes.size())); \ + return nullptr; \ + } \ + \ + /* Evaluate all nodes except the last */ \ + for (auto it = nodes.begin(); it != std::prev(nodes.end(), 1); ++it) { \ + evalImpl(*it, env); \ + } \ + \ + /* Eval last node */ \ + ast = nodes.back(); \ + continue; /* TCO */ \ + } + +#define EVAL_IF(ast, nodes, env) \ + { \ + if (nodes.size() != 2 && nodes.size() != 3) { \ + Error::the().add(format("wrong number of arguments: if, {}", nodes.size())); \ + return nullptr; \ + } \ + \ + auto first_argument = *nodes.begin(); \ + auto second_argument = *std::next(nodes.begin()); \ + auto third_argument = (nodes.size() == 3) ? *std::next(std::next(nodes.begin())) : makePtr(Value::Nil); \ + \ + auto first_evaluated = evalImpl(first_argument, env); \ + if (!is(first_evaluated.get()) \ + || std::static_pointer_cast(first_evaluated)->state() == Value::True) { \ + ast = second_argument; \ + continue; /* TCO */ \ + } \ + else { \ + ast = third_argument; \ + continue; /* TCO */ \ + } \ + } + void Eval::eval() { m_ast = evalImpl(m_ast, m_env); @@ -32,44 +124,70 @@ void Eval::eval() ASTNodePtr Eval::evalImpl(ASTNodePtr ast, EnvironmentPtr env) { - if (ast == nullptr || env == nullptr) { - return nullptr; - } + while (true) { + if (ast == nullptr) { + return nullptr; + } - if (!is(ast.get())) { - return evalAst(ast, env); - } + if (env == nullptr) { + env = m_env; + } - auto list = std::static_pointer_cast(ast); + if (!is(ast.get())) { + return evalAst(ast, env); + } - if (list->empty()) { - return ast; - } + auto list = std::static_pointer_cast(ast); - // Environment - auto nodes = list->nodes(); - if (is(nodes.front().get())) { - auto symbol = std::static_pointer_cast(nodes.front())->symbol(); - nodes.pop_front(); - if (symbol == "def!") { - return evalDef(nodes, env); + if (list->empty()) { + return ast; } - if (symbol == "let*") { - return evalLet(nodes, env); - } - if (symbol == "do") { - return evalDo(nodes, env); + + // Special forms + auto nodes = list->nodes(); + if (is(nodes.front().get())) { + auto symbol = std::static_pointer_cast(nodes.front())->symbol(); + nodes.pop_front(); + if (symbol == "def!") { + return evalDef(nodes, env); + } + if (symbol == "let*") { + EVAL_LET(ast, nodes, env); + } + if (symbol == "do") { + EVAL_DO(ast, nodes, env); + } + if (symbol == "if") { + EVAL_IF(ast, nodes, env); + } + if (symbol == "fn*") { + return evalFn(nodes, env); + } } - if (symbol == "if") { - return evalIf(nodes, env); + + auto evaluated_list = std::static_pointer_cast(evalAst(ast, env)); + + if (evaluated_list == nullptr) { + return nullptr; } - if (symbol == "fn*") { - return evalFn(nodes, env); + + // Regular list + if (is(evaluated_list->nodes().front().get())) { + auto evaluated_nodes = evaluated_list->nodes(); + + // car + auto lambda = std::static_pointer_cast(evaluated_nodes.front()); + // cdr + evaluated_nodes.pop_front(); + + ast = lambda->body(); + env = Environment::create(lambda, evaluated_nodes); + continue; // TCO } - } - // Function call - return apply(std::static_pointer_cast(evalAst(ast, env))); + // Function call + return apply(evaluated_list); + } } ASTNodePtr Eval::evalAst(ASTNodePtr ast, EnvironmentPtr env) @@ -144,93 +262,8 @@ ASTNodePtr Eval::evalDef(const std::list& nodes, EnvironmentPtr env) return env->set(symbol, value); } -ASTNodePtr Eval::evalLet(const std::list& nodes, EnvironmentPtr env) -{ - if (nodes.size() != 2) { - Error::the().add(format("wrong number of arguments: let*, {}", nodes.size())); - return nullptr; - } - - auto first_argument = *nodes.begin(); - auto second_argument = *std::next(nodes.begin()); - - // First argument needs to be a List or Vector - if (!is(first_argument.get())) { - Error::the().add(format("wrong argument type: collection, '{}'", first_argument)); - return nullptr; - } - - // Get the nodes out of the List or Vector - std::list binding_nodes; - auto bindings = std::static_pointer_cast(first_argument); - binding_nodes = bindings->nodes(); - - // List or Vector needs to have an even number of elements - size_t count = binding_nodes.size(); - if (count % 2 != 0) { - Error::the().add(format("wrong number of arguments: {}, {}", "let* bindings", count)); - return nullptr; - } - - // Create new environment - auto let_env = Environment::create(env); - - for (auto it = binding_nodes.begin(); it != binding_nodes.end(); std::advance(it, 2)) { - // First element needs to be a Symbol - if (!is(*it->get())) { - Error::the().add(format("wrong argument type: symbol, '{}'", *it)); - return nullptr; - } - - std::string key = std::static_pointer_cast(*it)->symbol(); - ASTNodePtr value = evalImpl(*std::next(it), let_env); - let_env->set(key, value); - } - - // TODO: Remove limitation of 3 arguments - // Eval all values in this new env, return last sexp of the result - return evalImpl(second_argument, let_env); -} - -ASTNodePtr Eval::evalDo(const std::list& nodes, EnvironmentPtr env) -{ - if (nodes.size() == 0) { - Error::the().add(format("wrong number of arguments: do, {}", nodes.size())); - return nullptr; - } - - // Evaluate all nodes except the last - for (auto it = nodes.begin(); it != std::prev(nodes.end(), 1); ++it) { - evalImpl(*it, env); - } - - // Eval and return last node - return evalImpl(nodes.back(), env); -} - -ASTNodePtr Eval::evalIf(const std::list& nodes, EnvironmentPtr env) -{ - if (nodes.size() != 2 && nodes.size() != 3) { - Error::the().add(format("wrong number of arguments: if, {}", nodes.size())); - return nullptr; - } - - auto first_argument = *nodes.begin(); - auto second_argument = *std::next(nodes.begin()); - auto third_argument = (nodes.size() == 3) ? *std::next(std::next(nodes.begin())) : makePtr(Value::Nil); - - auto first_evaluated = evalImpl(first_argument, env); - if (!is(first_evaluated.get()) - || std::static_pointer_cast(first_evaluated)->state() == Value::True) { - return evalImpl(second_argument, env); - } - else { - return evalImpl(third_argument, env); - } -} - -#define ARG_COUNT_CHECK(name, size, comparison) \ - if (size comparison) { \ +#define ARG_COUNT_CHECK(name, comparison, size) \ + if (comparison) { \ Error::the().add(format("wrong number of arguments: {}, {}", name, size)); \ return nullptr; \ } @@ -247,7 +280,7 @@ ASTNodePtr Eval::evalIf(const std::list& nodes, EnvironmentPtr env) ASTNodePtr Eval::evalFn(const std::list& nodes, EnvironmentPtr env) { - ARG_COUNT_CHECK("fn*", nodes.size(), != 2); + ARG_COUNT_CHECK("fn*", nodes.size() != 2, nodes.size()); auto first_argument = *nodes.begin(); auto second_argument = *std::next(nodes.begin()); @@ -273,32 +306,17 @@ ASTNodePtr Eval::apply(std::shared_ptr evaluated_list) auto nodes = evaluated_list->nodes(); - if (!is(nodes.front().get()) && !is(nodes.front().get())) { + if (!is(nodes.front().get())) { Error::the().add(format("invalid function: {}", nodes.front())); return nullptr; } - // Function - - if (is(nodes.front().get())) { - // car - auto function = std::static_pointer_cast(nodes.front())->function(); - // cdr - nodes.pop_front(); - - return function(nodes); - } - - // Lambda - // car - auto lambda = std::static_pointer_cast(nodes.front()); + auto function = std::static_pointer_cast(nodes.front())->function(); // cdr nodes.pop_front(); - auto lambda_env = Environment::create(lambda, nodes); - - return evalImpl(lambda->body(), lambda_env); + return function(nodes); } } // namespace blaze diff --git a/src/eval.h b/src/eval.h index f40a91a..b160ca3 100644 --- a/src/eval.h +++ b/src/eval.h @@ -28,9 +28,6 @@ private: ASTNodePtr evalImpl(ASTNodePtr ast, EnvironmentPtr env); ASTNodePtr evalAst(ASTNodePtr ast, EnvironmentPtr env); ASTNodePtr evalDef(const std::list& nodes, EnvironmentPtr env); - ASTNodePtr evalLet(const std::list& nodes, EnvironmentPtr env); - ASTNodePtr evalDo(const std::list& nodes, EnvironmentPtr env); - ASTNodePtr evalIf(const std::list& nodes, EnvironmentPtr env); ASTNodePtr evalFn(const std::list& nodes, EnvironmentPtr env); ASTNodePtr apply(std::shared_ptr evaluated_list); diff --git a/src/step4_if_fn_do.cpp b/src/step4_if_fn_do.cpp index 85b548a..2dfeaca 100644 --- a/src/step4_if_fn_do.cpp +++ b/src/step4_if_fn_do.cpp @@ -17,7 +17,7 @@ #include "readline.h" #include "settings.h" -#if 1 +#if 0 static blaze::EnvironmentPtr s_outer_env = blaze::Environment::create(); static auto cleanup(int signal) -> void; diff --git a/src/step5_tco.cpp b/src/step5_tco.cpp new file mode 100644 index 0000000..492eaf0 --- /dev/null +++ b/src/step5_tco.cpp @@ -0,0 +1,136 @@ +/* + * Copyright (C) 2023 Riyyi + * + * SPDX-License-Identifier: MIT + */ + +#include // std::signal +#include // std::exit +#include +#include + +#include "ruc/argparser.h" +#include "ruc/format/color.h" + +#include "ast.h" +#include "environment.h" +#include "error.h" +#include "eval.h" +#include "forward.h" +#include "lexer.h" +#include "printer.h" +#include "reader.h" +#include "readline.h" +#include "settings.h" + +#if 1 +static blaze::EnvironmentPtr s_outer_env = blaze::Environment::create(); + +static auto cleanup(int signal) -> void; +static auto installLambdas(blaze::EnvironmentPtr env) -> void; +static auto rep(std::string_view input, blaze::EnvironmentPtr env) -> std::string; +static auto read(std::string_view input) -> blaze::ASTNodePtr; +static auto eval(blaze::ASTNodePtr ast, blaze::EnvironmentPtr env) -> blaze::ASTNodePtr; +static auto print(blaze::ASTNodePtr exp) -> std::string; + +auto main(int argc, char* argv[]) -> int +{ + bool dump_lexer = false; + bool dump_reader = false; + bool pretty_print = false; + std::string_view history_path = "~/.mal-history"; + + // CLI arguments + ruc::ArgParser arg_parser; + arg_parser.addOption(dump_lexer, 'l', "dump-lexer", nullptr, nullptr); + arg_parser.addOption(dump_reader, 'r', "dump-reader", nullptr, nullptr); + arg_parser.addOption(pretty_print, 'c', "color", nullptr, nullptr); + arg_parser.addOption(history_path, 'h', "history", nullptr, nullptr, nullptr, ruc::ArgParser::Required::Yes); + arg_parser.parse(argc, argv); + + // Set settings + blaze::Settings::the().set("dump-lexer", dump_lexer ? "1" : "0"); + blaze::Settings::the().set("dump-reader", dump_reader ? "1" : "0"); + blaze::Settings::the().set("pretty-print", pretty_print ? "1" : "0"); + + // Signal callbacks + std::signal(SIGINT, cleanup); + std::signal(SIGTERM, cleanup); + + installFunctions(s_outer_env); + installLambdas(s_outer_env); + + blaze::Readline readline(pretty_print, history_path); + + std::string input; + while (readline.get(input)) { + if (pretty_print) { + print("\033[0m"); + } + print("{}\n", rep(input, s_outer_env)); + } + + if (pretty_print) { + print("\033[0m"); + } + + return 0; +} + +static auto cleanup(int signal) -> void +{ + print("\033[0m\n"); + std::exit(signal); +} + +static std::string_view lambdaTable[] = { + "(def! not (fn* (cond) (if cond false true)))", +}; + +static auto installLambdas(blaze::EnvironmentPtr env) -> void +{ + for (auto function : lambdaTable) { + rep(function, env); + } +} + +static auto rep(std::string_view input, blaze::EnvironmentPtr env) -> std::string +{ + blaze::Error::the().clearErrors(); + blaze::Error::the().setInput(input); + + return print(eval(read(input), env)); +} + +static auto read(std::string_view input) -> blaze::ASTNodePtr +{ + blaze::Lexer lexer(input); + lexer.tokenize(); + if (blaze::Settings::the().get("dump-lexer") == "1") { + lexer.dump(); + } + + blaze::Reader reader(std::move(lexer.tokens())); + reader.read(); + if (blaze::Settings::the().get("dump-reader") == "1") { + reader.dump(); + } + + return reader.node(); +} + +static auto eval(blaze::ASTNodePtr ast, blaze::EnvironmentPtr env) -> blaze::ASTNodePtr +{ + blaze::Eval eval(ast, env); + eval.eval(); + + return eval.ast(); +} + +static auto print(blaze::ASTNodePtr exp) -> std::string +{ + blaze::Printer printer; + + return printer.print(exp, true); +} +#endif