diff --git a/src/ast.cpp b/src/ast.cpp index c8f05d3..2431b10 100644 --- a/src/ast.cpp +++ b/src/ast.cpp @@ -9,6 +9,8 @@ #include #include "ast.h" +#include "environment.h" +#include "forward.h" #include "printer.h" #include "types.h" @@ -63,8 +65,17 @@ Value::Value(State state) // ----------------------------------------- -Function::Function(Lambda lambda) - : m_lambda(lambda) +Function::Function(FunctionType function) + : m_function(function) +{ +} + +// ----------------------------------------- + +Lambda::Lambda(std::vector bindings, ASTNodePtr body, EnvironmentPtr env) + : m_bindings(bindings) + , m_body(body) + , m_env(env) { } @@ -74,7 +85,6 @@ Function::Function(Lambda lambda) void Formatter::format(Builder& builder, blaze::ASTNodePtr value) const { - // printf("ASDJASJKDASJKDNAJK\n"); blaze::Printer printer; return Formatter::format(builder, printer.printNoErrorCheck(value)); } diff --git a/src/ast.h b/src/ast.h index da35f40..bfe85c8 100644 --- a/src/ast.h +++ b/src/ast.h @@ -15,13 +15,13 @@ #include #include // typeid #include +#include #include "ruc/format/formatter.h" -namespace blaze { +#include "forward.h" -class ASTNode; -typedef std::shared_ptr ASTNodePtr; +namespace blaze { class ASTNode { public: @@ -42,6 +42,7 @@ public: virtual bool isValue() const { return false; } virtual bool isSymbol() const { return false; } virtual bool isFunction() const { return false; } + virtual bool isLambda() const { return false; } protected: ASTNode() {} @@ -199,19 +200,38 @@ private: // ----------------------------------------- -using Lambda = std::function)>; +using FunctionType = std::function)>; class Function final : public ASTNode { public: - explicit Function(Lambda lambda); + explicit Function(FunctionType function); virtual ~Function() = default; virtual bool isFunction() const override { return true; } - Lambda lambda() const { return m_lambda; } + FunctionType function() const { return m_function; } private: - Lambda m_lambda; + FunctionType m_function; +}; + +// ----------------------------------------- + +class Lambda final : public ASTNode { +public: + Lambda(std::vector bindings, ASTNodePtr body, EnvironmentPtr env); + virtual ~Lambda() = default; + + virtual bool isLambda() const override { return true; } + + std::vector bindings() const { return m_bindings; } + ASTNodePtr body() const { return m_body; } + EnvironmentPtr env() const { return m_env; } + +private: + std::vector m_bindings; + ASTNodePtr m_body; + EnvironmentPtr m_env; }; // ----------------------------------------- @@ -254,6 +274,9 @@ inline bool ASTNode::fastIs() const { return isSymbol(); } template<> inline bool ASTNode::fastIs() const { return isFunction(); } + +template<> +inline bool ASTNode::fastIs() const { return isLambda(); } // clang-format on } // namespace blaze diff --git a/src/environment.cpp b/src/environment.cpp index 0b8990e..569f882 100644 --- a/src/environment.cpp +++ b/src/environment.cpp @@ -8,12 +8,41 @@ #include "ast.h" #include "environment.h" +#include "error.h" +#include "forward.h" namespace blaze { -Environment::Environment(EnvironmentPtr outer) - : m_outer(outer) +EnvironmentPtr Environment::create() { + return std::shared_ptr(new Environment); +} + +EnvironmentPtr Environment::create(EnvironmentPtr outer) +{ + auto env = create(); + + env->m_outer = outer; + + return env; +} + +EnvironmentPtr Environment::create(EnvironmentPtr outer, std::vector bindings, std::list arguments) +{ + auto env = create(outer); + + if (bindings.size() != arguments.size()) { + Error::the().addError(format("wrong number of arguments: fn*, {}", arguments.size())); + return nullptr; + } + + auto bindings_it = bindings.begin(); + auto arguments_it = arguments.begin(); + for (; bindings_it != bindings.end(); ++bindings_it, ++arguments_it) { + env->m_values.emplace(*bindings_it, *arguments_it); + } + + return env; } // ----------------------------------------- diff --git a/src/environment.h b/src/environment.h index daf87fd..26e9091 100644 --- a/src/environment.h +++ b/src/environment.h @@ -6,27 +6,32 @@ #pragma once +#include #include #include +#include -#include "ast.h" +#include "badge.h" +#include "forward.h" namespace blaze { -class Environment; -typedef std::shared_ptr EnvironmentPtr; - class Environment { public: - Environment() = default; - Environment(EnvironmentPtr outer); virtual ~Environment() = default; + // Factory functions instead of constructors because it can fail in the bindings/arguments case + static EnvironmentPtr create(); + static EnvironmentPtr create(EnvironmentPtr outer); + static EnvironmentPtr create(EnvironmentPtr outer, std::vector bindings, std::list arguments); + bool exists(const std::string& symbol); ASTNodePtr set(const std::string& symbol, ASTNodePtr value); ASTNodePtr get(const std::string& symbol); protected: + Environment() {} + std::string m_current_key; std::unordered_map m_values; EnvironmentPtr m_outer { nullptr }; diff --git a/src/eval.cpp b/src/eval.cpp index 89e501e..d07b1b2 100644 --- a/src/eval.cpp +++ b/src/eval.cpp @@ -14,7 +14,7 @@ #include "environment.h" #include "error.h" #include "eval.h" -#include "ruc/meta/assert.h" +#include "forward.h" #include "types.h" namespace blaze { @@ -63,6 +63,9 @@ ASTNodePtr Eval::evalImpl(ASTNodePtr ast, EnvironmentPtr env) if (symbol == "if") { return evalIf(nodes, env); } + if (symbol == "fn*") { + return evalFn(nodes, env); + } } // Function call @@ -136,7 +139,7 @@ ASTNodePtr Eval::evalDef(const std::list& nodes, EnvironmentPtr env) // First element needs to be a Symbol if (!is(first_argument.get())) { - Error::the().addError(format("wrong type argument: symbol, {}", first_argument)); + Error::the().addError(format("wrong argument type: symbol, {}", first_argument)); return nullptr; } @@ -187,7 +190,7 @@ ASTNodePtr Eval::evalLet(const std::list& nodes, EnvironmentPtr env) } // Create new environment - auto let_env = makePtr(env); + auto let_env = Environment::create(env); for (auto it = binding_nodes.begin(); it != binding_nodes.end(); std::advance(it, 2)) { // First element needs to be a Symbol @@ -243,6 +246,42 @@ ASTNodePtr Eval::evalIf(const std::list& nodes, EnvironmentPtr env) } } +#define ARG_COUNT_CHECK(name, size, comparison) \ + if (size comparison) { \ + Error::the().addError(format("wrong number of arguments: {}, {}", name, size)); \ + return nullptr; \ + } + +#define AST_CHECK(type, value) \ + if (!is(value.get())) { \ + Error::the().addError(format("wrong argument type: {}, {}", #type, value)); \ + return nullptr; \ + } + +#define AST_CAST(type, value, variable) \ + AST_CHECK(type, value) \ + auto variable = std::static_pointer_cast(value); + +ASTNodePtr Eval::evalFn(const std::list& nodes, EnvironmentPtr env) +{ + ARG_COUNT_CHECK("fn*", nodes.size(), != 2); + + auto first_argument = *nodes.begin(); + auto second_argument = *std::next(nodes.begin()); + + // First element needs to be a List + AST_CAST(List, first_argument, list); + + std::vector bindings; + for (auto node : list->nodes()) { + // All nodes need to be a Symbol + AST_CAST(Symbol, node, symbol); + bindings.push_back(symbol->symbol()); + } + + return makePtr(bindings, second_argument, env); +} + ASTNodePtr Eval::apply(std::shared_ptr evaluated_list) { if (evaluated_list == nullptr) { @@ -251,17 +290,32 @@ ASTNodePtr Eval::apply(std::shared_ptr evaluated_list) auto nodes = evaluated_list->nodes(); - if (!is(nodes.front().get())) { + if (!is(nodes.front().get()) && !is(nodes.front().get())) { Error::the().addError(format("invalid function: {}", nodes.front())); return nullptr; } + // Function + + if (is(nodes.front().get())) { + // car + auto function = std::static_pointer_cast(nodes.front())->function(); + // cdr + nodes.pop_front(); + + return function(nodes); + } + + // Lambda + // car - auto lambda = std::static_pointer_cast(nodes.front())->lambda(); + auto lambda = std::static_pointer_cast(nodes.front()); // cdr nodes.pop_front(); - return lambda(nodes); + auto lambda_env = Environment::create(lambda->env(), lambda->bindings(), nodes); + + return evalImpl(lambda->body(), lambda_env); } } // namespace blaze diff --git a/src/eval.h b/src/eval.h index 9fd85cf..1036546 100644 --- a/src/eval.h +++ b/src/eval.h @@ -8,8 +8,8 @@ #include -#include "ast.h" #include "environment.h" +#include "forward.h" namespace blaze { @@ -29,6 +29,7 @@ private: ASTNodePtr evalLet(const std::list& nodes, EnvironmentPtr env); ASTNodePtr evalDo(const std::list& nodes, EnvironmentPtr env); ASTNodePtr evalIf(const std::list& nodes, EnvironmentPtr env); + ASTNodePtr evalFn(const std::list& nodes, EnvironmentPtr env); ASTNodePtr apply(std::shared_ptr evaluated_list); ASTNodePtr m_ast; diff --git a/src/forward.h b/src/forward.h new file mode 100644 index 0000000..c6045b4 --- /dev/null +++ b/src/forward.h @@ -0,0 +1,19 @@ +/* + * Copyright (C) 2023 Riyyi + * + * SPDX-License-Identifier: MIT + */ + +#pragma once + +#include // std::shared_ptr + +namespace blaze { + +class ASTNode; +typedef std::shared_ptr ASTNodePtr; + +class Environment; +typedef std::shared_ptr EnvironmentPtr; + +} // namespace blaze diff --git a/src/functions.cpp b/src/functions.cpp index ec8fb55..52364c9 100644 --- a/src/functions.cpp +++ b/src/functions.cpp @@ -13,6 +13,7 @@ #include "ast.h" #include "environment.h" #include "error.h" +#include "forward.h" #include "printer.h" #include "types.h" #include "util.h" @@ -97,7 +98,7 @@ void GlobalEnvironment::div() for (auto node : nodes) { if (!is(node.get())) { - Error::the().addError(format("wrong type argument: number-or-marker-p, '{}'", node)); + Error::the().addError(format("wrong argument type: number, '{}'", node)); return nullptr; } } diff --git a/src/printer.cpp b/src/printer.cpp index 8e7c195..edba587 100644 --- a/src/printer.cpp +++ b/src/printer.cpp @@ -145,6 +145,14 @@ void Printer::printImpl(ASTNodePtr node, bool print_readably) printSpacing(); m_print += format("{}", std::static_pointer_cast(node)->symbol()); } + else if (is(node_raw_ptr)) { + printSpacing(); + m_print += format("#"); + } + else if (is(node_raw_ptr)) { + printSpacing(); + m_print += format("#"); + } } void Printer::printError() diff --git a/src/step4_if_fn_do.cpp b/src/step4_if_fn_do.cpp index b1319ee..d9c6419 100644 --- a/src/step4_if_fn_do.cpp +++ b/src/step4_if_fn_do.cpp @@ -10,6 +10,7 @@ #include "environment.h" #include "error.h" #include "eval.h" +#include "forward.h" #include "lexer.h" #include "printer.h" #include "reader.h"