Browse Source

Eval+Env: Add support for lambdas

master
Riyyi 1 year ago
parent
commit
a60859acc4
  1. 16
      src/ast.cpp
  2. 37
      src/ast.h
  3. 33
      src/environment.cpp
  4. 17
      src/environment.h
  5. 66
      src/eval.cpp
  6. 3
      src/eval.h
  7. 19
      src/forward.h
  8. 3
      src/functions.cpp
  9. 8
      src/printer.cpp
  10. 1
      src/step4_if_fn_do.cpp

16
src/ast.cpp

@ -9,6 +9,8 @@
#include <string> #include <string>
#include "ast.h" #include "ast.h"
#include "environment.h"
#include "forward.h"
#include "printer.h" #include "printer.h"
#include "types.h" #include "types.h"
@ -63,8 +65,17 @@ Value::Value(State state)
// ----------------------------------------- // -----------------------------------------
Function::Function(Lambda lambda) Function::Function(FunctionType function)
: m_lambda(lambda) : m_function(function)
{
}
// -----------------------------------------
Lambda::Lambda(std::vector<std::string> bindings, ASTNodePtr body, EnvironmentPtr env)
: m_bindings(bindings)
, m_body(body)
, m_env(env)
{ {
} }
@ -74,7 +85,6 @@ Function::Function(Lambda lambda)
void Formatter<blaze::ASTNodePtr>::format(Builder& builder, blaze::ASTNodePtr value) const void Formatter<blaze::ASTNodePtr>::format(Builder& builder, blaze::ASTNodePtr value) const
{ {
// printf("ASDJASJKDASJKDNAJK\n");
blaze::Printer printer; blaze::Printer printer;
return Formatter<std::string>::format(builder, printer.printNoErrorCheck(value)); return Formatter<std::string>::format(builder, printer.printNoErrorCheck(value));
} }

37
src/ast.h

@ -15,13 +15,13 @@
#include <string_view> #include <string_view>
#include <typeinfo> // typeid #include <typeinfo> // typeid
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "ruc/format/formatter.h" #include "ruc/format/formatter.h"
namespace blaze { #include "forward.h"
class ASTNode; namespace blaze {
typedef std::shared_ptr<ASTNode> ASTNodePtr;
class ASTNode { class ASTNode {
public: public:
@ -42,6 +42,7 @@ public:
virtual bool isValue() const { return false; } virtual bool isValue() const { return false; }
virtual bool isSymbol() const { return false; } virtual bool isSymbol() const { return false; }
virtual bool isFunction() const { return false; } virtual bool isFunction() const { return false; }
virtual bool isLambda() const { return false; }
protected: protected:
ASTNode() {} ASTNode() {}
@ -199,19 +200,38 @@ private:
// ----------------------------------------- // -----------------------------------------
using Lambda = std::function<ASTNodePtr(std::list<ASTNodePtr>)>; using FunctionType = std::function<ASTNodePtr(std::list<ASTNodePtr>)>;
class Function final : public ASTNode { class Function final : public ASTNode {
public: public:
explicit Function(Lambda lambda); explicit Function(FunctionType function);
virtual ~Function() = default; virtual ~Function() = default;
virtual bool isFunction() const override { return true; } virtual bool isFunction() const override { return true; }
Lambda lambda() const { return m_lambda; } FunctionType function() const { return m_function; }
private: private:
Lambda m_lambda; FunctionType m_function;
};
// -----------------------------------------
class Lambda final : public ASTNode {
public:
Lambda(std::vector<std::string> bindings, ASTNodePtr body, EnvironmentPtr env);
virtual ~Lambda() = default;
virtual bool isLambda() const override { return true; }
std::vector<std::string> bindings() const { return m_bindings; }
ASTNodePtr body() const { return m_body; }
EnvironmentPtr env() const { return m_env; }
private:
std::vector<std::string> m_bindings;
ASTNodePtr m_body;
EnvironmentPtr m_env;
}; };
// ----------------------------------------- // -----------------------------------------
@ -254,6 +274,9 @@ inline bool ASTNode::fastIs<Symbol>() const { return isSymbol(); }
template<> template<>
inline bool ASTNode::fastIs<Function>() const { return isFunction(); } inline bool ASTNode::fastIs<Function>() const { return isFunction(); }
template<>
inline bool ASTNode::fastIs<Lambda>() const { return isLambda(); }
// clang-format on // clang-format on
} // namespace blaze } // namespace blaze

33
src/environment.cpp

@ -8,12 +8,41 @@
#include "ast.h" #include "ast.h"
#include "environment.h" #include "environment.h"
#include "error.h"
#include "forward.h"
namespace blaze { namespace blaze {
Environment::Environment(EnvironmentPtr outer) EnvironmentPtr Environment::create()
: m_outer(outer)
{ {
return std::shared_ptr<Environment>(new Environment);
}
EnvironmentPtr Environment::create(EnvironmentPtr outer)
{
auto env = create();
env->m_outer = outer;
return env;
}
EnvironmentPtr Environment::create(EnvironmentPtr outer, std::vector<std::string> bindings, std::list<ASTNodePtr> arguments)
{
auto env = create(outer);
if (bindings.size() != arguments.size()) {
Error::the().addError(format("wrong number of arguments: fn*, {}", arguments.size()));
return nullptr;
}
auto bindings_it = bindings.begin();
auto arguments_it = arguments.begin();
for (; bindings_it != bindings.end(); ++bindings_it, ++arguments_it) {
env->m_values.emplace(*bindings_it, *arguments_it);
}
return env;
} }
// ----------------------------------------- // -----------------------------------------

17
src/environment.h

@ -6,27 +6,32 @@
#pragma once #pragma once
#include <list>
#include <string> #include <string>
#include <unordered_map> #include <unordered_map>
#include <vector>
#include "ast.h" #include "badge.h"
#include "forward.h"
namespace blaze { namespace blaze {
class Environment;
typedef std::shared_ptr<Environment> EnvironmentPtr;
class Environment { class Environment {
public: public:
Environment() = default;
Environment(EnvironmentPtr outer);
virtual ~Environment() = default; virtual ~Environment() = default;
// Factory functions instead of constructors because it can fail in the bindings/arguments case
static EnvironmentPtr create();
static EnvironmentPtr create(EnvironmentPtr outer);
static EnvironmentPtr create(EnvironmentPtr outer, std::vector<std::string> bindings, std::list<ASTNodePtr> arguments);
bool exists(const std::string& symbol); bool exists(const std::string& symbol);
ASTNodePtr set(const std::string& symbol, ASTNodePtr value); ASTNodePtr set(const std::string& symbol, ASTNodePtr value);
ASTNodePtr get(const std::string& symbol); ASTNodePtr get(const std::string& symbol);
protected: protected:
Environment() {}
std::string m_current_key; std::string m_current_key;
std::unordered_map<std::string, ASTNodePtr> m_values; std::unordered_map<std::string, ASTNodePtr> m_values;
EnvironmentPtr m_outer { nullptr }; EnvironmentPtr m_outer { nullptr };

66
src/eval.cpp

@ -14,7 +14,7 @@
#include "environment.h" #include "environment.h"
#include "error.h" #include "error.h"
#include "eval.h" #include "eval.h"
#include "ruc/meta/assert.h" #include "forward.h"
#include "types.h" #include "types.h"
namespace blaze { namespace blaze {
@ -63,6 +63,9 @@ ASTNodePtr Eval::evalImpl(ASTNodePtr ast, EnvironmentPtr env)
if (symbol == "if") { if (symbol == "if") {
return evalIf(nodes, env); return evalIf(nodes, env);
} }
if (symbol == "fn*") {
return evalFn(nodes, env);
}
} }
// Function call // Function call
@ -136,7 +139,7 @@ ASTNodePtr Eval::evalDef(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env)
// First element needs to be a Symbol // First element needs to be a Symbol
if (!is<Symbol>(first_argument.get())) { if (!is<Symbol>(first_argument.get())) {
Error::the().addError(format("wrong type argument: symbol, {}", first_argument)); Error::the().addError(format("wrong argument type: symbol, {}", first_argument));
return nullptr; return nullptr;
} }
@ -187,7 +190,7 @@ ASTNodePtr Eval::evalLet(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env)
} }
// Create new environment // Create new environment
auto let_env = makePtr<Environment>(env); auto let_env = Environment::create(env);
for (auto it = binding_nodes.begin(); it != binding_nodes.end(); std::advance(it, 2)) { for (auto it = binding_nodes.begin(); it != binding_nodes.end(); std::advance(it, 2)) {
// First element needs to be a Symbol // First element needs to be a Symbol
@ -243,6 +246,42 @@ ASTNodePtr Eval::evalIf(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env)
} }
} }
#define ARG_COUNT_CHECK(name, size, comparison) \
if (size comparison) { \
Error::the().addError(format("wrong number of arguments: {}, {}", name, size)); \
return nullptr; \
}
#define AST_CHECK(type, value) \
if (!is<type>(value.get())) { \
Error::the().addError(format("wrong argument type: {}, {}", #type, value)); \
return nullptr; \
}
#define AST_CAST(type, value, variable) \
AST_CHECK(type, value) \
auto variable = std::static_pointer_cast<type>(value);
ASTNodePtr Eval::evalFn(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env)
{
ARG_COUNT_CHECK("fn*", nodes.size(), != 2);
auto first_argument = *nodes.begin();
auto second_argument = *std::next(nodes.begin());
// First element needs to be a List
AST_CAST(List, first_argument, list);
std::vector<std::string> bindings;
for (auto node : list->nodes()) {
// All nodes need to be a Symbol
AST_CAST(Symbol, node, symbol);
bindings.push_back(symbol->symbol());
}
return makePtr<Lambda>(bindings, second_argument, env);
}
ASTNodePtr Eval::apply(std::shared_ptr<List> evaluated_list) ASTNodePtr Eval::apply(std::shared_ptr<List> evaluated_list)
{ {
if (evaluated_list == nullptr) { if (evaluated_list == nullptr) {
@ -251,17 +290,32 @@ ASTNodePtr Eval::apply(std::shared_ptr<List> evaluated_list)
auto nodes = evaluated_list->nodes(); auto nodes = evaluated_list->nodes();
if (!is<Function>(nodes.front().get())) { if (!is<Function>(nodes.front().get()) && !is<Lambda>(nodes.front().get())) {
Error::the().addError(format("invalid function: {}", nodes.front())); Error::the().addError(format("invalid function: {}", nodes.front()));
return nullptr; return nullptr;
} }
// Function
if (is<Function>(nodes.front().get())) {
// car
auto function = std::static_pointer_cast<Function>(nodes.front())->function();
// cdr
nodes.pop_front();
return function(nodes);
}
// Lambda
// car // car
auto lambda = std::static_pointer_cast<Function>(nodes.front())->lambda(); auto lambda = std::static_pointer_cast<Lambda>(nodes.front());
// cdr // cdr
nodes.pop_front(); nodes.pop_front();
return lambda(nodes); auto lambda_env = Environment::create(lambda->env(), lambda->bindings(), nodes);
return evalImpl(lambda->body(), lambda_env);
} }
} // namespace blaze } // namespace blaze

3
src/eval.h

@ -8,8 +8,8 @@
#include <list> #include <list>
#include "ast.h"
#include "environment.h" #include "environment.h"
#include "forward.h"
namespace blaze { namespace blaze {
@ -29,6 +29,7 @@ private:
ASTNodePtr evalLet(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env); ASTNodePtr evalLet(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env);
ASTNodePtr evalDo(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env); ASTNodePtr evalDo(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env);
ASTNodePtr evalIf(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env); ASTNodePtr evalIf(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env);
ASTNodePtr evalFn(const std::list<ASTNodePtr>& nodes, EnvironmentPtr env);
ASTNodePtr apply(std::shared_ptr<List> evaluated_list); ASTNodePtr apply(std::shared_ptr<List> evaluated_list);
ASTNodePtr m_ast; ASTNodePtr m_ast;

19
src/forward.h

@ -0,0 +1,19 @@
/*
* Copyright (C) 2023 Riyyi
*
* SPDX-License-Identifier: MIT
*/
#pragma once
#include <memory> // std::shared_ptr
namespace blaze {
class ASTNode;
typedef std::shared_ptr<ASTNode> ASTNodePtr;
class Environment;
typedef std::shared_ptr<Environment> EnvironmentPtr;
} // namespace blaze

3
src/functions.cpp

@ -13,6 +13,7 @@
#include "ast.h" #include "ast.h"
#include "environment.h" #include "environment.h"
#include "error.h" #include "error.h"
#include "forward.h"
#include "printer.h" #include "printer.h"
#include "types.h" #include "types.h"
#include "util.h" #include "util.h"
@ -97,7 +98,7 @@ void GlobalEnvironment::div()
for (auto node : nodes) { for (auto node : nodes) {
if (!is<Number>(node.get())) { if (!is<Number>(node.get())) {
Error::the().addError(format("wrong type argument: number-or-marker-p, '{}'", node)); Error::the().addError(format("wrong argument type: number, '{}'", node));
return nullptr; return nullptr;
} }
} }

8
src/printer.cpp

@ -145,6 +145,14 @@ void Printer::printImpl(ASTNodePtr node, bool print_readably)
printSpacing(); printSpacing();
m_print += format("{}", std::static_pointer_cast<Symbol>(node)->symbol()); m_print += format("{}", std::static_pointer_cast<Symbol>(node)->symbol());
} }
else if (is<Function>(node_raw_ptr)) {
printSpacing();
m_print += format("#<builtin-function>");
}
else if (is<Lambda>(node_raw_ptr)) {
printSpacing();
m_print += format("#<user-function>");
}
} }
void Printer::printError() void Printer::printError()

1
src/step4_if_fn_do.cpp

@ -10,6 +10,7 @@
#include "environment.h" #include "environment.h"
#include "error.h" #include "error.h"
#include "eval.h" #include "eval.h"
#include "forward.h"
#include "lexer.h" #include "lexer.h"
#include "printer.h" #include "printer.h"
#include "reader.h" #include "reader.h"

Loading…
Cancel
Save