Browse Source

Main+Eval: Implement tail call optimization (TCO)

master
Riyyi 1 year ago
parent
commit
7c62d65d72
  1. 4
      CMakeLists.txt
  2. 238
      src/eval.cpp
  3. 3
      src/eval.h
  4. 2
      src/step4_if_fn_do.cpp
  5. 136
      src/step5_tco.cpp

4
CMakeLists.txt

@ -99,3 +99,7 @@ add_dependencies(test3 ${PROJECT})
add_custom_target(test4 add_custom_target(test4
COMMAND env STEP=step_env MAL_IMPL=js ../vendor/mal/runtest.py --deferrable --optional ../vendor/mal/tests/step4_if_fn_do.mal -- ./${PROJECT}) COMMAND env STEP=step_env MAL_IMPL=js ../vendor/mal/runtest.py --deferrable --optional ../vendor/mal/tests/step4_if_fn_do.mal -- ./${PROJECT})
add_dependencies(test4 ${PROJECT}) add_dependencies(test4 ${PROJECT})
add_custom_target(test5
COMMAND env STEP=step_env MAL_IMPL=js ../vendor/mal/runtest.py --deferrable --optional ../vendor/mal/tests/step5_tco.mal -- ./${PROJECT})
add_dependencies(test5 ${PROJECT})

238
src/eval.cpp

@ -25,6 +25,98 @@ Eval::Eval(ASTNodePtr ast, EnvironmentPtr env)
{ {
} }
// -----------------------------------------
#define EVAL_LET(ast, nodes, env) \
{ \
if (nodes.size() != 2) { \
Error::the().add(format("wrong number of arguments: let*, {}", nodes.size())); \
return nullptr; \
} \
\
auto first_argument = *nodes.begin(); \
auto second_argument = *std::next(nodes.begin()); \
\
/* First argument needs to be a List or Vector */ \
if (!is<Collection>(first_argument.get())) { \
Error::the().add(format("wrong argument type: collection, '{}'", first_argument)); \
return nullptr; \
} \
\
/* Get the nodes out of the List or Vector */ \
std::list<ASTNodePtr> binding_nodes; \
auto bindings = std::static_pointer_cast<Collection>(first_argument); \
binding_nodes = bindings->nodes(); \
\
/* List or Vector needs to have an even number of elements */ \
size_t count = binding_nodes.size(); \
if (count % 2 != 0) { \
Error::the().add(format("wrong number of arguments: {}, {}", "let* bindings", count)); \
return nullptr; \
} \
\
/* Create new environment */ \
auto let_env = Environment::create(env); \
\
for (auto it = binding_nodes.begin(); it != binding_nodes.end(); std::advance(it, 2)) { \
/* First element needs to be a Symbol */ \
if (!is<Symbol>(*it->get())) { \
Error::the().add(format("wrong argument type: symbol, '{}'", *it)); \
return nullptr; \
} \
\
std::string key = std::static_pointer_cast<Symbol>(*it)->symbol(); \
ASTNodePtr value = evalImpl(*std::next(it), let_env); \
let_env->set(key, value); \
} \
\
/* TODO: Remove limitation of 3 arguments */ \
/* Eval all values in this new env, return last sexp of the result */ \
ast = second_argument; \
env = let_env; \
continue; /* TCO */ \
}
#define EVAL_DO(ast, nodes, env) \
{ \
if (nodes.size() == 0) { \
Error::the().add(format("wrong number of arguments: do, {}", nodes.size())); \
return nullptr; \
} \
\
/* Evaluate all nodes except the last */ \
for (auto it = nodes.begin(); it != std::prev(nodes.end(), 1); ++it) { \
evalImpl(*it, env); \
} \
\
/* Eval last node */ \
ast = nodes.back(); \
continue; /* TCO */ \
}
#define EVAL_IF(ast, nodes, env) \
{ \
if (nodes.size() != 2 && nodes.size() != 3) { \
Error::the().add(format("wrong number of arguments: if, {}", nodes.size())); \
return nullptr; \
} \
\
auto first_argument = *nodes.begin(); \
auto second_argument = *std::next(nodes.begin()); \
auto third_argument = (nodes.size() == 3) ? *std::next(std::next(nodes.begin())) : makePtr<Value>(Value::Nil); \
\
auto first_evaluated = evalImpl(first_argument, env); \
if (!is<Value>(first_evaluated.get()) \
|| std::static_pointer_cast<Value>(first_evaluated)->state() == Value::True) { \
ast = second_argument; \
continue; /* TCO */ \
} \
else { \
ast = third_argument; \
continue; /* TCO */ \
} \
}
void Eval::eval() void Eval::eval()
{ {
m_ast = evalImpl(m_ast, m_env); m_ast = evalImpl(m_ast, m_env);
@ -32,10 +124,15 @@ void Eval::eval()
ASTNodePtr Eval::evalImpl(ASTNodePtr ast, EnvironmentPtr env) ASTNodePtr Eval::evalImpl(ASTNodePtr ast, EnvironmentPtr env)
{ {
if (ast == nullptr || env == nullptr) { while (true) {
if (ast == nullptr) {
return nullptr; return nullptr;
} }
if (env == nullptr) {
env = m_env;
}
if (!is<List>(ast.get())) { if (!is<List>(ast.get())) {
return evalAst(ast, env); return evalAst(ast, env);
} }
@ -46,7 +143,7 @@ ASTNodePtr Eval::evalImpl(ASTNodePtr ast, EnvironmentPtr env)
return ast; return ast;
} }
// Environment // Special forms
auto nodes = list->nodes(); auto nodes = list->nodes();
if (is<Symbol>(nodes.front().get())) { if (is<Symbol>(nodes.front().get())) {
auto symbol = std::static_pointer_cast<Symbol>(nodes.front())->symbol(); auto symbol = std::static_pointer_cast<Symbol>(nodes.front())->symbol();
@ -55,21 +152,42 @@ ASTNodePtr Eval::evalImpl(ASTNodePtr ast, EnvironmentPtr env)
return evalDef(nodes, env); return evalDef(nodes, env);
} }
if (symbol == "let*") { if (symbol == "let*") {
return evalLet(nodes, env); EVAL_LET(ast, nodes, env);
} }
if (symbol == "do") { if (symbol == "do") {
return evalDo(nodes, env); EVAL_DO(ast, nodes, env);
} }
if (symbol == "if") { if (symbol == "if") {
return evalIf(nodes, env); EVAL_IF(ast, nodes, env);
} }
if (symbol == "fn*") { if (symbol == "fn*") {
return evalFn(nodes, env); return evalFn(nodes, env);
} }
} }
auto evaluated_list = std::static_pointer_cast<List>(evalAst(ast, env));
if (evaluated_list == nullptr) {
return nullptr;
}
// Regular list
if (is<Lambda>(evaluated_list->nodes().front().get())) {
auto evaluated_nodes = evaluated_list->nodes();
// car
auto lambda = std::static_pointer_cast<Lambda>(evaluated_nodes.front());
// cdr
evaluated_nodes.pop_front();
ast = lambda->body();
env = Environment::create(lambda, evaluated_nodes);
continue; // TCO
}
// Function call // Function call
return apply(std::static_pointer_cast<List>(evalAst(ast, env))); return apply(evaluated_list);
}
} }
ASTNodePtr Eval::evalAst(ASTNodePtr ast, EnvironmentPtr env) ASTNodePtr Eval::evalAst(ASTNodePtr ast, EnvironmentPtr env)
@ -144,93 +262,8 @@ ASTNodePtr Eval::evalDef(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env)
return env->set(symbol, value); return env->set(symbol, value);
} }
ASTNodePtr Eval::evalLet(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env) #define ARG_COUNT_CHECK(name, comparison, size) \
{ if (comparison) { \
if (nodes.size() != 2) {
Error::the().add(format("wrong number of arguments: let*, {}", nodes.size()));
return nullptr;
}
auto first_argument = *nodes.begin();
auto second_argument = *std::next(nodes.begin());
// First argument needs to be a List or Vector
if (!is<Collection>(first_argument.get())) {
Error::the().add(format("wrong argument type: collection, '{}'", first_argument));
return nullptr;
}
// Get the nodes out of the List or Vector
std::list<ASTNodePtr> binding_nodes;
auto bindings = std::static_pointer_cast<Collection>(first_argument);
binding_nodes = bindings->nodes();
// List or Vector needs to have an even number of elements
size_t count = binding_nodes.size();
if (count % 2 != 0) {
Error::the().add(format("wrong number of arguments: {}, {}", "let* bindings", count));
return nullptr;
}
// Create new environment
auto let_env = Environment::create(env);
for (auto it = binding_nodes.begin(); it != binding_nodes.end(); std::advance(it, 2)) {
// First element needs to be a Symbol
if (!is<Symbol>(*it->get())) {
Error::the().add(format("wrong argument type: symbol, '{}'", *it));
return nullptr;
}
std::string key = std::static_pointer_cast<Symbol>(*it)->symbol();
ASTNodePtr value = evalImpl(*std::next(it), let_env);
let_env->set(key, value);
}
// TODO: Remove limitation of 3 arguments
// Eval all values in this new env, return last sexp of the result
return evalImpl(second_argument, let_env);
}
ASTNodePtr Eval::evalDo(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env)
{
if (nodes.size() == 0) {
Error::the().add(format("wrong number of arguments: do, {}", nodes.size()));
return nullptr;
}
// Evaluate all nodes except the last
for (auto it = nodes.begin(); it != std::prev(nodes.end(), 1); ++it) {
evalImpl(*it, env);
}
// Eval and return last node
return evalImpl(nodes.back(), env);
}
ASTNodePtr Eval::evalIf(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env)
{
if (nodes.size() != 2 && nodes.size() != 3) {
Error::the().add(format("wrong number of arguments: if, {}", nodes.size()));
return nullptr;
}
auto first_argument = *nodes.begin();
auto second_argument = *std::next(nodes.begin());
auto third_argument = (nodes.size() == 3) ? *std::next(std::next(nodes.begin())) : makePtr<Value>(Value::Nil);
auto first_evaluated = evalImpl(first_argument, env);
if (!is<Value>(first_evaluated.get())
|| std::static_pointer_cast<Value>(first_evaluated)->state() == Value::True) {
return evalImpl(second_argument, env);
}
else {
return evalImpl(third_argument, env);
}
}
#define ARG_COUNT_CHECK(name, size, comparison) \
if (size comparison) { \
Error::the().add(format("wrong number of arguments: {}, {}", name, size)); \ Error::the().add(format("wrong number of arguments: {}, {}", name, size)); \
return nullptr; \ return nullptr; \
} }
@ -247,7 +280,7 @@ ASTNodePtr Eval::evalIf(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env)
ASTNodePtr Eval::evalFn(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env) ASTNodePtr Eval::evalFn(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env)
{ {
ARG_COUNT_CHECK("fn*", nodes.size(), != 2); ARG_COUNT_CHECK("fn*", nodes.size() != 2, nodes.size());
auto first_argument = *nodes.begin(); auto first_argument = *nodes.begin();
auto second_argument = *std::next(nodes.begin()); auto second_argument = *std::next(nodes.begin());
@ -273,14 +306,11 @@ ASTNodePtr Eval::apply(std::shared_ptr<List> evaluated_list)
auto nodes = evaluated_list->nodes(); auto nodes = evaluated_list->nodes();
if (!is<Function>(nodes.front().get()) && !is<Lambda>(nodes.front().get())) { if (!is<Function>(nodes.front().get())) {
Error::the().add(format("invalid function: {}", nodes.front())); Error::the().add(format("invalid function: {}", nodes.front()));
return nullptr; return nullptr;
} }
// Function
if (is<Function>(nodes.front().get())) {
// car // car
auto function = std::static_pointer_cast<Function>(nodes.front())->function(); auto function = std::static_pointer_cast<Function>(nodes.front())->function();
// cdr // cdr
@ -289,16 +319,4 @@ ASTNodePtr Eval::apply(std::shared_ptr<List> evaluated_list)
return function(nodes); return function(nodes);
} }
// Lambda
// car
auto lambda = std::static_pointer_cast<Lambda>(nodes.front());
// cdr
nodes.pop_front();
auto lambda_env = Environment::create(lambda, nodes);
return evalImpl(lambda->body(), lambda_env);
}
} // namespace blaze } // namespace blaze

3
src/eval.h

@ -28,9 +28,6 @@ private:
ASTNodePtr evalImpl(ASTNodePtr ast, EnvironmentPtr env); ASTNodePtr evalImpl(ASTNodePtr ast, EnvironmentPtr env);
ASTNodePtr evalAst(ASTNodePtr ast, EnvironmentPtr env); ASTNodePtr evalAst(ASTNodePtr ast, EnvironmentPtr env);
ASTNodePtr evalDef(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env); ASTNodePtr evalDef(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env);
ASTNodePtr evalLet(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env);
ASTNodePtr evalDo(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env);
ASTNodePtr evalIf(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env);
ASTNodePtr evalFn(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env); ASTNodePtr evalFn(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env);
ASTNodePtr apply(std::shared_ptr<List> evaluated_list); ASTNodePtr apply(std::shared_ptr<List> evaluated_list);

2
src/step4_if_fn_do.cpp

@ -17,7 +17,7 @@
#include "readline.h" #include "readline.h"
#include "settings.h" #include "settings.h"
#if 1 #if 0
static blaze::EnvironmentPtr s_outer_env = blaze::Environment::create(); static blaze::EnvironmentPtr s_outer_env = blaze::Environment::create();
static auto cleanup(int signal) -> void; static auto cleanup(int signal) -> void;

136
src/step5_tco.cpp

@ -0,0 +1,136 @@
/*
* Copyright (C) 2023 Riyyi
*
* SPDX-License-Identifier: MIT
*/
#include <csignal> // std::signal
#include <cstdlib> // std::exit
#include <string>
#include <string_view>
#include "ruc/argparser.h"
#include "ruc/format/color.h"
#include "ast.h"
#include "environment.h"
#include "error.h"
#include "eval.h"
#include "forward.h"
#include "lexer.h"
#include "printer.h"
#include "reader.h"
#include "readline.h"
#include "settings.h"
#if 1
static blaze::EnvironmentPtr s_outer_env = blaze::Environment::create();
static auto cleanup(int signal) -> void;
static auto installLambdas(blaze::EnvironmentPtr env) -> void;
static auto rep(std::string_view input, blaze::EnvironmentPtr env) -> std::string;
static auto read(std::string_view input) -> blaze::ASTNodePtr;
static auto eval(blaze::ASTNodePtr ast, blaze::EnvironmentPtr env) -> blaze::ASTNodePtr;
static auto print(blaze::ASTNodePtr exp) -> std::string;
auto main(int argc, char* argv[]) -> int
{
bool dump_lexer = false;
bool dump_reader = false;
bool pretty_print = false;
std::string_view history_path = "~/.mal-history";
// CLI arguments
ruc::ArgParser arg_parser;
arg_parser.addOption(dump_lexer, 'l', "dump-lexer", nullptr, nullptr);
arg_parser.addOption(dump_reader, 'r', "dump-reader", nullptr, nullptr);
arg_parser.addOption(pretty_print, 'c', "color", nullptr, nullptr);
arg_parser.addOption(history_path, 'h', "history", nullptr, nullptr, nullptr, ruc::ArgParser::Required::Yes);
arg_parser.parse(argc, argv);
// Set settings
blaze::Settings::the().set("dump-lexer", dump_lexer ? "1" : "0");
blaze::Settings::the().set("dump-reader", dump_reader ? "1" : "0");
blaze::Settings::the().set("pretty-print", pretty_print ? "1" : "0");
// Signal callbacks
std::signal(SIGINT, cleanup);
std::signal(SIGTERM, cleanup);
installFunctions(s_outer_env);
installLambdas(s_outer_env);
blaze::Readline readline(pretty_print, history_path);
std::string input;
while (readline.get(input)) {
if (pretty_print) {
print("\033[0m");
}
print("{}\n", rep(input, s_outer_env));
}
if (pretty_print) {
print("\033[0m");
}
return 0;
}
static auto cleanup(int signal) -> void
{
print("\033[0m\n");
std::exit(signal);
}
static std::string_view lambdaTable[] = {
"(def! not (fn* (cond) (if cond false true)))",
};
static auto installLambdas(blaze::EnvironmentPtr env) -> void
{
for (auto function : lambdaTable) {
rep(function, env);
}
}
static auto rep(std::string_view input, blaze::EnvironmentPtr env) -> std::string
{
blaze::Error::the().clearErrors();
blaze::Error::the().setInput(input);
return print(eval(read(input), env));
}
static auto read(std::string_view input) -> blaze::ASTNodePtr
{
blaze::Lexer lexer(input);
lexer.tokenize();
if (blaze::Settings::the().get("dump-lexer") == "1") {
lexer.dump();
}
blaze::Reader reader(std::move(lexer.tokens()));
reader.read();
if (blaze::Settings::the().get("dump-reader") == "1") {
reader.dump();
}
return reader.node();
}
static auto eval(blaze::ASTNodePtr ast, blaze::EnvironmentPtr env) -> blaze::ASTNodePtr
{
blaze::Eval eval(ast, env);
eval.eval();
return eval.ast();
}
static auto print(blaze::ASTNodePtr exp) -> std::string
{
blaze::Printer printer;
return printer.print(exp, true);
}
#endif
Loading…
Cancel
Save