/* * Copyright (C) 2023 Riyi * * SPDX-License-Identifier: MIT */ #include // std::static_pointer_cast #include #include "ruc/file.h" #include "ruc/format/format.h" #include "ast.h" #include "environment.h" #include "error.h" #include "forward.h" #include "printer.h" #include "types.h" #include "util.h" // At the top-level you cant invoke any function, but you can create variables. // Using a struct's constructor you can work around this limitation. // Also the line number in the file is used to make the struct names unique. #define FUNCTION_STRUCT_NAME(unique) __functionStruct##unique #define ADD_FUNCTION_IMPL(unique, symbol, lambda) \ struct FUNCTION_STRUCT_NAME(unique) { \ FUNCTION_STRUCT_NAME(unique) \ (std::string __symbol, FunctionType __lambda) \ { \ s_functions.emplace(__symbol, __lambda); \ } \ }; \ static struct FUNCTION_STRUCT_NAME(unique) \ FUNCTION_STRUCT_NAME(unique)( \ symbol, \ [](std::list nodes) -> ValuePtr lambda); #define ADD_FUNCTION(symbol, lambda) ADD_FUNCTION_IMPL(__LINE__, symbol, lambda); namespace blaze { static std::unordered_map s_functions; ADD_FUNCTION( "+", { int64_t result = 0; for (auto node : nodes) { if (!is(node.get())) { Error::the().add(format("wrong argument type: number, '{}'", node)); return nullptr; } result += std::static_pointer_cast(node)->number(); } return makePtr(result); }); ADD_FUNCTION( "-", { if (nodes.size() == 0) { return makePtr(0); } for (auto node : nodes) { if (!is(node.get())) { Error::the().add(format("wrong argument type: number, '{}'", node)); return nullptr; } } // Start with the first number int64_t result = std::static_pointer_cast(nodes.front())->number(); // Skip the first node for (auto it = std::next(nodes.begin()); it != nodes.end(); ++it) { result -= std::static_pointer_cast(*it)->number(); } return makePtr(result); }); ADD_FUNCTION( "*", { int64_t result = 1; for (auto node : nodes) { if (!is(node.get())) { Error::the().add(format("wrong argument type: number, '{}'", node)); return nullptr; } result *= std::static_pointer_cast(node)->number(); } return makePtr(result); }); ADD_FUNCTION( "/", { if (nodes.size() == 0) { Error::the().add(format("wrong number of arguments: /, 0")); return nullptr; } for (auto node : nodes) { if (!is(node.get())) { Error::the().add(format("wrong argument type: number, '{}'", node)); return nullptr; } } // Start with the first number double result = std::static_pointer_cast(nodes.front())->number(); // Skip the first node for (auto it = std::next(nodes.begin()); it != nodes.end(); ++it) { result /= std::static_pointer_cast(*it)->number(); } return makePtr((int64_t)result); }); // // ----------------------------------------- #define NUMBER_COMPARE(operator) \ { \ bool result = true; \ \ if (nodes.size() < 2) { \ Error::the().add(format("wrong number of arguments: {}, {}", #operator, nodes.size())); \ return nullptr; \ } \ \ for (auto node : nodes) { \ if (!is(node.get())) { \ Error::the().add(format("wrong argument type: number, '{}'", node)); \ return nullptr; \ } \ } \ \ /* Start with the first number */ \ int64_t number = std::static_pointer_cast(nodes.front())->number(); \ \ /* Skip the first node */ \ for (auto it = std::next(nodes.begin()); it != nodes.end(); ++it) { \ int64_t current_number = std::static_pointer_cast(*it)->number(); \ if (!(number operator current_number)) { \ result = false; \ break; \ } \ number = current_number; \ } \ \ return makePtr((result) ? Constant::True : Constant::False); \ } ADD_FUNCTION("<", NUMBER_COMPARE(<)); ADD_FUNCTION("<=", NUMBER_COMPARE(<=)); ADD_FUNCTION(">", NUMBER_COMPARE(>)); ADD_FUNCTION(">=", NUMBER_COMPARE(>=)); // ----------------------------------------- ADD_FUNCTION( "list", { return makePtr(nodes); }); ADD_FUNCTION( "list?", { bool result = true; if (nodes.size() == 0) { result = false; } for (auto node : nodes) { if (!is(node.get())) { result = false; break; } } return makePtr((result) ? Constant::True : Constant::False); }); ADD_FUNCTION( "empty?", { bool result = true; for (auto node : nodes) { if (!is(node.get())) { Error::the().add(format("wrong argument type: collection, '{}'", node)); return nullptr; } if (!std::static_pointer_cast(node)->empty()) { result = false; break; } } return makePtr((result) ? Constant::True : Constant::False); }); ADD_FUNCTION( "count", { if (nodes.size() != 1) { Error::the().add(format("wrong number of arguments: count, {}", nodes.size())); return nullptr; } auto first_argument = nodes.front(); size_t result = 0; if (is(first_argument.get()) && std::static_pointer_cast(nodes.front())->state() == Constant::Nil) { // result = 0 } else if (is(first_argument.get())) { result = std::static_pointer_cast(first_argument)->size(); } else { Error::the().add(format("wrong argument type: collection, '{}'", first_argument)); return nullptr; } // FIXME: Add numeric_limits check for implicit cast: size_t > int64_t return makePtr((int64_t)result); }); // ----------------------------------------- #define PRINTER_STRING(print_readably, concatenation) \ { \ std::string result; \ \ Printer printer; \ for (auto it = nodes.begin(); it != nodes.end(); ++it) { \ result += format("{}", printer.printNoErrorCheck(*it, print_readably)); \ \ if (!isLast(it, nodes)) { \ result += concatenation; \ } \ } \ \ return makePtr(result); \ } ADD_FUNCTION("str", PRINTER_STRING(false, "")); ADD_FUNCTION("pr-str", PRINTER_STRING(true, " ")); #define PRINTER_PRINT(print_readably) \ { \ Printer printer; \ for (auto it = nodes.begin(); it != nodes.end(); ++it) { \ print("{}", printer.printNoErrorCheck(*it, print_readably)); \ \ if (!isLast(it, nodes)) { \ print(" "); \ } \ } \ print("\n"); \ \ return makePtr(Constant::Nil); \ } ADD_FUNCTION("prn", PRINTER_PRINT(true)); ADD_FUNCTION("println", PRINTER_PRINT(false)); // ----------------------------------------- ADD_FUNCTION( "=", { if (nodes.size() < 2) { Error::the().add(format("wrong number of arguments: =, {}", nodes.size())); return nullptr; } std::function equal = [&equal](ValuePtr lhs, ValuePtr rhs) -> bool { if ((is(lhs.get()) || is(lhs.get())) && (is(rhs.get()) || is(rhs.get()))) { auto lhs_nodes = std::static_pointer_cast(lhs)->nodes(); auto rhs_nodes = std::static_pointer_cast(rhs)->nodes(); if (lhs_nodes.size() != rhs_nodes.size()) { return false; } auto lhs_it = lhs_nodes.begin(); auto rhs_it = rhs_nodes.begin(); for (; lhs_it != lhs_nodes.end(); ++lhs_it, ++rhs_it) { if (!equal(*lhs_it, *rhs_it)) { return false; } } return true; } if (is(lhs.get()) && is(rhs.get())) { auto lhs_nodes = std::static_pointer_cast(lhs)->elements(); auto rhs_nodes = std::static_pointer_cast(rhs)->elements(); if (lhs_nodes.size() != rhs_nodes.size()) { return false; } for (const auto& [key, value] : lhs_nodes) { auto it = rhs_nodes.find(key); if (it == rhs_nodes.end() || !equal(value, it->second)) { return false; } } return true; } if (is(lhs.get()) && is(rhs.get()) && std::static_pointer_cast(lhs)->data() == std::static_pointer_cast(rhs)->data()) { return true; } if (is(lhs.get()) && is(rhs.get()) && std::static_pointer_cast(lhs)->keyword() == std::static_pointer_cast(rhs)->keyword()) { return true; } if (is(lhs.get()) && is(rhs.get()) && std::static_pointer_cast(lhs)->number() == std::static_pointer_cast(rhs)->number()) { return true; } if (is(lhs.get()) && is(rhs.get()) && std::static_pointer_cast(lhs)->state() == std::static_pointer_cast(rhs)->state()) { return true; } if (is(lhs.get()) && is(rhs.get()) && std::static_pointer_cast(lhs)->symbol() == std::static_pointer_cast(rhs)->symbol()) { return true; } return false; }; bool result = true; auto it = nodes.begin(); auto it_next = std::next(nodes.begin()); for (; it_next != nodes.end(); ++it, ++it_next) { if (!equal(*it, *it_next)) { result = false; break; } } return makePtr((result) ? Constant::True : Constant::False); }); ADD_FUNCTION( "read-string", { if (nodes.size() != 1) { Error::the().add(format("wrong number of arguments: read-string, {}", nodes.size())); return nullptr; } if (!is(nodes.front().get())) { Error::the().add(format("wrong argument type: string, '{}'", nodes.front())); return nullptr; } std::string input = std::static_pointer_cast(nodes.front())->data(); return read(input); }); ADD_FUNCTION( "slurp", { if (nodes.size() != 1) { Error::the().add(format("wrong number of arguments: slurp, {}", nodes.size())); return nullptr; } if (!is(nodes.front().get())) { Error::the().add(format("wrong argument type: string, '{}'", nodes.front())); return nullptr; } std::string path = std::static_pointer_cast(nodes.front())->data(); auto file = ruc::File(path); return makePtr(file.data()); }); ADD_FUNCTION( "eval", { if (nodes.size() != 1) { Error::the().add(format("wrong number of arguments: eval, {}", nodes.size())); return nullptr; } return eval(nodes.front(), nullptr); }); // (atom 1) ADD_FUNCTION( "atom", { if (nodes.size() != 1) { Error::the().add(format("wrong number of arguments: atom, {}", nodes.size())); return nullptr; } return makePtr(nodes.front()); }); // (atom? myatom 2 "foo") ADD_FUNCTION( "atom?", { bool result = true; if (nodes.size() == 0) { result = false; } for (auto node : nodes) { if (!is(node.get())) { result = false; break; } } return makePtr((result) ? Constant::True : Constant::False); }); // (deref myatom) ADD_FUNCTION( "deref", { if (nodes.size() != 1) { Error::the().add(format("wrong number of arguments: deref, {}", nodes.size())); return nullptr; } if (!is(nodes.front().get())) { Error::the().add(format("wrong argument type: atom, '{}'", nodes.front())); return nullptr; } return std::static_pointer_cast(nodes.front())->deref(); }); // (reset! myatom 2) ADD_FUNCTION( "reset!", { if (nodes.size() != 2) { Error::the().add(format("wrong number of arguments: reset!, {}", nodes.size())); return nullptr; } if (!is(nodes.front().get())) { Error::the().add(format("wrong argument type: atom, '{}'", nodes.front())); return nullptr; } auto atom = std::static_pointer_cast(*nodes.begin()); auto value = *std::next(nodes.begin()); atom->reset(value); return value; }); // (swap! myatom (fn* [x] (+ 1 x))) ADD_FUNCTION( "swap!", { if (nodes.size() < 2) { Error::the().add(format("wrong number of arguments: swap!, {}", nodes.size())); return nullptr; } auto first_argument = *nodes.begin(); auto second_argument = *std::next(nodes.begin()); if (!is(first_argument.get())) { Error::the().add(format("wrong argument type: atom, '{}'", first_argument)); return nullptr; } if (!is(second_argument.get())) { Error::the().add(format("wrong argument type: function, '{}'", second_argument)); return nullptr; } auto atom = std::static_pointer_cast(first_argument); // Remove atom and function from the argument list, add atom value nodes.pop_front(); nodes.pop_front(); nodes.push_front(atom->deref()); ValuePtr value = nullptr; if (is(second_argument.get())) { auto function = std::static_pointer_cast(second_argument)->function(); value = function(nodes); } else { auto lambda = std::static_pointer_cast(second_argument); value = eval(lambda->body(), Environment::create(lambda, nodes)); } return atom->reset(value); }); // (cons 1 (list 2 3)) ADD_FUNCTION( "cons", { if (nodes.size() != 2) { Error::the().add(format("wrong number of arguments: cons, {}", nodes.size())); return nullptr; } auto first_argument = *nodes.begin(); auto second_argument = *std::next(nodes.begin()); if (!is(second_argument.get())) { Error::the().add(format("wrong argument type: list, '{}'", second_argument)); return nullptr; } auto result_nodes = std::static_pointer_cast(second_argument)->nodes(); result_nodes.push_front(first_argument); return makePtr(result_nodes); }); // (concat (list 1) (list 2 3)) ADD_FUNCTION( "concat", { std::list result_nodes; for (auto node : nodes) { if (!is(node.get())) { Error::the().add(format("wrong argument type: list, '{}'", node)); return nullptr; } auto argument_nodes = std::static_pointer_cast(node)->nodes(); result_nodes.splice(result_nodes.end(), argument_nodes); } return makePtr(result_nodes); }); // (vec (list 1 2 3)) ADD_FUNCTION( "vec", { if (nodes.size() != 1) { Error::the().add(format("wrong number of arguments: vec, {}", nodes.size())); return nullptr; } if (!is(nodes.front().get())) { Error::the().add(format("wrong argument type: list, '{}'", nodes.front())); return nullptr; } auto result_nodes = std::static_pointer_cast(nodes.front())->nodes(); return makePtr(result_nodes); }); // ----------------------------------------- void installFunctions(EnvironmentPtr env) { for (const auto& [name, lambda] : s_functions) { env->set(name, makePtr(name, lambda)); } } } // namespace blaze