Browse Source

Env: Add equal function

master
Riyyi 2 years ago
parent
commit
9c1c5114a9
  1. 2
      src/ast.h
  2. 11
      src/environment.cpp
  3. 11
      src/environment.h
  4. 151
      src/functions.cpp

2
src/ast.h

@ -33,9 +33,9 @@ public:
bool fastIs() const = delete; bool fastIs() const = delete;
virtual bool isCollection() const { return false; } virtual bool isCollection() const { return false; }
virtual bool isList() const { return false; }
virtual bool isVector() const { return false; } virtual bool isVector() const { return false; }
virtual bool isHashMap() const { return false; } virtual bool isHashMap() const { return false; }
virtual bool isList() const { return false; }
virtual bool isString() const { return false; } virtual bool isString() const { return false; }
virtual bool isKeyword() const { return false; } virtual bool isKeyword() const { return false; }
virtual bool isNumber() const { return false; } virtual bool isNumber() const { return false; }

11
src/environment.cpp

@ -64,9 +64,16 @@ GlobalEnvironment::GlobalEnvironment()
gte(); gte();
list(); list();
is_list(); isList();
is_empty(); isEmpty();
count(); count();
str();
prStr();
prn();
println();
equal();
} }
} // namespace blaze } // namespace blaze

11
src/environment.h

@ -49,11 +49,16 @@ private:
void gte(); // >= void gte(); // >=
void list(); // list void list(); // list
void is_list(); // list? void isList(); // list?
void is_empty(); // empty? void isEmpty(); // empty?
void count(); // count void count(); // count
// void equal(); // = void str(); // str
void prStr(); // pr-str
void prn(); // prn
void println(); // println
void equal(); // =
}; };
} // namespace blaze } // namespace blaze

151
src/functions.cpp

@ -5,13 +5,17 @@
*/ */
#include <memory> // std::static_pointer_cast #include <memory> // std::static_pointer_cast
#include <string>
#include "ruc/format/color.h"
#include "ruc/format/format.h" #include "ruc/format/format.h"
#include "ast.h" #include "ast.h"
#include "environment.h" #include "environment.h"
#include "error.h" #include "error.h"
#include "printer.h"
#include "types.h" #include "types.h"
#include "util.h"
namespace blaze { namespace blaze {
@ -185,7 +189,7 @@ void GlobalEnvironment::list()
m_values.emplace("list", makePtr<Function>(list)); m_values.emplace("list", makePtr<Function>(list));
} }
void GlobalEnvironment::is_list() void GlobalEnvironment::isList()
{ {
auto is_list = [](std::span<ASTNodePtr> nodes) -> ASTNodePtr { auto is_list = [](std::span<ASTNodePtr> nodes) -> ASTNodePtr {
bool result = true; bool result = true;
@ -203,7 +207,7 @@ void GlobalEnvironment::is_list()
m_values.emplace("list?", makePtr<Function>(is_list)); m_values.emplace("list?", makePtr<Function>(is_list));
} }
void GlobalEnvironment::is_empty() void GlobalEnvironment::isEmpty()
{ {
auto is_empty = [](std::span<ASTNodePtr> nodes) -> ASTNodePtr { auto is_empty = [](std::span<ASTNodePtr> nodes) -> ASTNodePtr {
bool result = true; bool result = true;
@ -252,4 +256,147 @@ void GlobalEnvironment::count()
m_values.emplace("count", makePtr<Function>(count)); m_values.emplace("count", makePtr<Function>(count));
} }
// -----------------------------------------
#define PRINTER_STRING(symbol, concatenation, print_readably) \
auto lambda = [](std::span<ASTNodePtr> nodes) -> ASTNodePtr { \
std::string result; \
\
Printer printer; \
for (auto it = nodes.begin(); it != nodes.end(); ++it) { \
result += format("{}", printer.printNoErrorCheck(*it, print_readably)); \
\
if (!isLast(it, nodes)) { \
result += concatenation; \
} \
} \
\
return makePtr<String>(result); \
}; \
\
m_values.emplace(symbol, makePtr<Function>(lambda));
void GlobalEnvironment::str()
{
PRINTER_STRING("str", "", false);
}
void GlobalEnvironment::prStr()
{
PRINTER_STRING("pr-str", " ", true);
}
#define PRINTER_PRINT(symbol, print_readably) \
auto lambda = [](std::span<ASTNodePtr> nodes) -> ASTNodePtr { \
Printer printer; \
for (auto it = nodes.begin(); it != nodes.end(); ++it) { \
print("{}", printer.printNoErrorCheck(*it, print_readably)); \
\
if (!isLast(it, nodes)) { \
print(" "); \
} \
} \
print("\n"); \
\
return makePtr<Value>(Value::Nil); \
}; \
\
m_values.emplace(symbol, makePtr<Function>(lambda));
void GlobalEnvironment::prn()
{
PRINTER_PRINT("prn", true);
}
void GlobalEnvironment::println()
{
PRINTER_PRINT("println", false);
}
// -----------------------------------------
void GlobalEnvironment::equal()
{
auto lambda = [this](std::span<ASTNodePtr> nodes) -> ASTNodePtr {
if (nodes.size() < 2) {
Error::the().addError(format("wrong number of arguments: {}, {}", m_current_key, nodes.size() - 1));
return nullptr;
}
std::function<bool(ASTNodePtr, ASTNodePtr)> equal =
[&equal](ASTNodePtr lhs, ASTNodePtr rhs) -> bool {
if ((is<List>(lhs.get()) && is<List>(rhs.get()))
|| (is<Vector>(lhs.get()) && is<Vector>(rhs.get()))) {
auto lhs_nodes = std::static_pointer_cast<List>(lhs)->nodes();
auto rhs_nodes = std::static_pointer_cast<List>(rhs)->nodes();
if (lhs_nodes.size() != rhs_nodes.size()) {
return false;
}
for (size_t i = 0; i < lhs_nodes.size(); ++i) {
if (!equal(lhs_nodes[i], rhs_nodes[i])) {
return false;
}
}
return true;
}
if (is<HashMap>(lhs.get()) && is<HashMap>(rhs.get())) {
auto lhs_nodes = std::static_pointer_cast<HashMap>(lhs)->elements();
auto rhs_nodes = std::static_pointer_cast<HashMap>(rhs)->elements();
if (lhs_nodes.size() != rhs_nodes.size()) {
return false;
}
for (const auto& [key, value] : lhs_nodes) {
auto it = rhs_nodes.find(key);
if (it == rhs_nodes.end() || !equal(value, it->second)) {
return false;
}
}
return true;
}
if (is<String>(lhs.get()) && is<String>(rhs.get())
&& std::static_pointer_cast<String>(lhs)->data() == std::static_pointer_cast<String>(rhs)->data()) {
return true;
}
if (is<Keyword>(lhs.get()) && is<Keyword>(rhs.get())
&& std::static_pointer_cast<Keyword>(lhs)->keyword() == std::static_pointer_cast<Keyword>(rhs)->keyword()) {
return true;
}
if (is<Number>(lhs.get()) && is<Number>(rhs.get())
&& std::static_pointer_cast<Number>(lhs)->number() == std::static_pointer_cast<Number>(rhs)->number()) {
return true;
}
if (is<Value>(lhs.get()) && is<Value>(rhs.get())
&& std::static_pointer_cast<Value>(lhs)->state() == std::static_pointer_cast<Value>(rhs)->state()) {
return true;
}
if (is<Symbol>(lhs.get()) && is<Symbol>(rhs.get())
&& std::static_pointer_cast<Symbol>(lhs)->symbol() == std::static_pointer_cast<Symbol>(rhs)->symbol()) {
return true;
}
return false;
};
bool result = true;
for (size_t i = 0; i < nodes.size() - 1; ++i) {
if (!equal(nodes[i], nodes[i + 1])) {
result = false;
break;
}
}
return makePtr<Value>((result) ? Value::True : Value::False);
};
m_values.emplace("=", makePtr<Function>(lambda));
}
} // namespace blaze } // namespace blaze

Loading…
Cancel
Save