From 11f0553b5ac0864322dcea9fd9a46cd8143309de Mon Sep 17 00:00:00 2001 From: Riyyi Date: Sun, 26 Nov 2023 10:45:45 +0100 Subject: [PATCH] AST+Env+Printer+Reader: Implement floating point numbers --- src/ast.cpp | 9 +++- src/ast.h | 37 ++++++++++++- src/env/functions/compare.cpp | 81 +++++++++++++++++++---------- src/env/functions/convert.cpp | 29 +++++++---- src/env/functions/operators.cpp | 92 ++++++++++++++++++++++++++------- src/printer.cpp | 4 ++ src/reader.cpp | 43 ++++++++++----- vendor/ruc | 2 +- 8 files changed, 226 insertions(+), 71 deletions(-) diff --git a/src/ast.cpp b/src/ast.cpp index 94e46d3..28db938 100644 --- a/src/ast.cpp +++ b/src/ast.cpp @@ -193,7 +193,14 @@ Keyword::Keyword(int64_t number) // ----------------------------------------- Number::Number(int64_t number) - : m_number(number) + : Numeric() + , m_number(number) +{ +} + +Decimal::Decimal(double decimal) + : Numeric() + , m_decimal(decimal) { } diff --git a/src/ast.h b/src/ast.h index 4a93c96..f861896 100644 --- a/src/ast.h +++ b/src/ast.h @@ -52,7 +52,9 @@ public: virtual bool isHashMap() const { return false; } virtual bool isString() const { return false; } virtual bool isKeyword() const { return false; } + virtual bool isNumeric() const { return false; } virtual bool isNumber() const { return false; } + virtual bool isDecimal() const { return false; } virtual bool isConstant() const { return false; } virtual bool isSymbol() const { return false; } virtual bool isCallable() const { return false; } @@ -252,8 +254,19 @@ private: }; // ----------------------------------------- + +class Numeric : public Value { +public: + virtual ~Numeric() = default; + +protected: + Numeric() = default; + + virtual bool isNumeric() const override { return true; } +}; + // 123 -class Number final : public Value { +class Number final : public Numeric { public: Number(int64_t number); virtual ~Number() = default; @@ -268,6 +281,22 @@ private: const int64_t m_number { 0 }; }; +// 123.456 +class Decimal final : public Numeric { +public: + Decimal(double decimal); + virtual ~Decimal() = default; + + double decimal() const { return m_decimal; } + + WITH_NO_META(); + +private: + virtual bool isDecimal() const override { return true; } + + const double m_decimal { 0 }; +}; + // ----------------------------------------- // true, false, nil @@ -428,9 +457,15 @@ inline bool Value::fastIs() const { return isString(); } template<> inline bool Value::fastIs() const { return isKeyword(); } +template<> +inline bool Value::fastIs() const { return isNumeric(); } + template<> inline bool Value::fastIs() const { return isNumber(); } +template<> +inline bool Value::fastIs() const { return isDecimal(); } + template<> inline bool Value::fastIs() const { return isConstant(); } diff --git a/src/env/functions/compare.cpp b/src/env/functions/compare.cpp index de3e0b5..5a87563 100644 --- a/src/env/functions/compare.cpp +++ b/src/env/functions/compare.cpp @@ -16,28 +16,51 @@ namespace blaze { void Environment::loadCompare() { -#define NUMBER_COMPARE(operator) \ - { \ - bool result = true; \ - \ - CHECK_ARG_COUNT_AT_LEAST(#operator, SIZE(), 2); \ - \ - /* Start with the first number */ \ - VALUE_CAST(number_node, Number, (*begin)); \ - int64_t number = number_node->number(); \ - \ - /* Skip the first node */ \ - for (auto it = begin + 1; it != end; ++it) { \ - VALUE_CAST(current_number_node, Number, (*it)); \ - int64_t current_number = current_number_node->number(); \ - if (!(number operator current_number)) { \ - result = false; \ - break; \ - } \ - number = current_number; \ - } \ - \ - return makePtr((result) ? Constant::True : Constant::False); \ +#define NUMBER_COMPARE(operator) \ + { \ + CHECK_ARG_COUNT_AT_LEAST(#operator, SIZE(), 2); \ + \ + bool result = true; \ + \ + int64_t number = 0; \ + double decimal = 0; \ + bool current_numeric_is_number = false; \ + \ + /* Start with the first number */ \ + IS_VALUE(Numeric, (*begin)); \ + if (is(begin->get())) { \ + number = std::static_pointer_cast(*begin)->number(); \ + current_numeric_is_number = true; \ + } \ + else { \ + decimal = std::static_pointer_cast(*begin)->decimal(); \ + current_numeric_is_number = false; \ + } \ + \ + /* Skip the first node */ \ + for (auto it = begin + 1; it != end; ++it) { \ + IS_VALUE(Numeric, (*it)); \ + if (is(*it->get())) { \ + int64_t it_number = std::static_pointer_cast(*it)->number(); \ + if (!((current_numeric_is_number ? number : decimal) operator it_number)) { \ + result = false; \ + break; \ + } \ + number = it_number; \ + current_numeric_is_number = true; \ + } \ + else { \ + double it_decimal = std::static_pointer_cast(*it)->decimal(); \ + if (!((current_numeric_is_number ? number : decimal) operator it_decimal)) { \ + result = false; \ + break; \ + } \ + decimal = it_decimal; \ + current_numeric_is_number = false; \ + } \ + } \ + \ + return makePtr((result) ? Constant::True : Constant::False); \ } ADD_FUNCTION("<", "", "", NUMBER_COMPARE(<)); @@ -58,8 +81,7 @@ void Environment::loadCompare() std::function equal = [&equal](ValuePtr lhs, ValuePtr rhs) -> bool { - if ((is(lhs.get()) || is(lhs.get())) - && (is(rhs.get()) || is(rhs.get()))) { + if (is(lhs.get()) && is(rhs.get())) { auto lhs_collection = std::static_pointer_cast(lhs); auto rhs_collection = std::static_pointer_cast(rhs); @@ -104,10 +126,17 @@ void Environment::loadCompare() && std::static_pointer_cast(lhs)->keyword() == std::static_pointer_cast(rhs)->keyword()) { return true; } - if (is(lhs.get()) && is(rhs.get()) - && std::static_pointer_cast(lhs)->number() == std::static_pointer_cast(rhs)->number()) { + // clang-format off + if (is(lhs.get()) && is(rhs.get()) + && (is(lhs.get()) + ? std::static_pointer_cast(lhs)->number() + : std::static_pointer_cast(lhs)->decimal()) + == (is(rhs.get()) + ? std::static_pointer_cast(rhs)->number() + : std::static_pointer_cast(rhs)->decimal())) { return true; } + // clang-format on if (is(lhs.get()) && is(rhs.get()) && std::static_pointer_cast(lhs)->state() == std::static_pointer_cast(rhs)->state()) { return true; diff --git a/src/env/functions/convert.cpp b/src/env/functions/convert.cpp index 76f07f6..0d9939e 100644 --- a/src/env/functions/convert.cpp +++ b/src/env/functions/convert.cpp @@ -4,7 +4,8 @@ * SPDX-License-Identifier: MIT */ -#include // std::from_chars, std::to_chars +#include // std::from_chars, std::to_chars +#include #include // std::errc #include "ast.h" @@ -23,12 +24,14 @@ void Environment::loadConvert() { CHECK_ARG_COUNT_IS("number-to-string", SIZE(), 1); - VALUE_CAST(number, Number, (*begin)); + IS_VALUE(Numeric, (*begin)); char result[32]; auto conversion = std::to_chars(result, result + sizeof(result), - number->number()); + is(begin->get()) + ? std::static_pointer_cast(*begin)->number() + : std::static_pointer_cast(*begin)->decimal()); if (conversion.ec != std::errc()) { return makePtr(Constant::Nil); } @@ -61,15 +64,23 @@ void Environment::loadConvert() VALUE_CAST(string_value, String, (*begin)); std::string data = string_value->data(); - int64_t result; - auto conversion = std::from_chars(data.c_str(), - data.c_str() + data.size(), - result); - if (conversion.ec != std::errc()) { + if (data.find('.') == std::string::npos) { + int64_t number; + auto conversion_number = std::from_chars(data.c_str(), data.c_str() + data.size(), number); + if (conversion_number.ec != std::errc()) { + return makePtr(Constant::Nil); + } + + return makePtr(number); + } + + double decimal; + auto conversion_decimal = std::from_chars(data.c_str(), data.c_str() + data.size(), decimal); + if (conversion_decimal.ec != std::errc()) { return makePtr(Constant::Nil); } - return makePtr(result); + return makePtr(decimal); }); #define STRING_TO_COLLECTION(name, type) \ diff --git a/src/env/functions/operators.cpp b/src/env/functions/operators.cpp index 7cb3164..86f5e8e 100644 --- a/src/env/functions/operators.cpp +++ b/src/env/functions/operators.cpp @@ -5,6 +5,7 @@ */ #include // int64_t +#include // std::static_pointer_cast #include "ast.h" #include "env/macro.h" @@ -14,19 +15,46 @@ namespace blaze { void Environment::loadOperators() { +#define APPLY_NUMBER_OR_DECIMAL(it, apply) \ + IS_VALUE(Numeric, (*it)); \ + if (is(it->get())) { \ + auto it_numeric = std::static_pointer_cast(*it)->number(); \ + do { \ + apply \ + } while (0); \ + } \ + else { \ + return_decimal = true; \ + auto it_numeric = std::static_pointer_cast(*it)->decimal(); \ + do { \ + apply \ + } while (0); \ + } + +#define RETURN_NUMBER_OR_DECIMAL() \ + if (!return_decimal) { \ + return makePtr(number); \ + } \ + return makePtr(decimal); + ADD_FUNCTION( "+", "number...", "Return the sum of any amount of arguments, where NUMBER is of type number.", { - int64_t result = 0; + bool return_decimal = false; + + int64_t number = 0; + double decimal = 0; for (auto it = begin; it != end; ++it) { - VALUE_CAST(number, Number, (*it)); - result += number->number(); + APPLY_NUMBER_OR_DECIMAL(it, { + number += it_numeric; + decimal += it_numeric; + }); } - return makePtr(result); + RETURN_NUMBER_OR_DECIMAL(); }); ADD_FUNCTION( @@ -42,21 +70,33 @@ subtracts all but the first from the first.)", return makePtr(0); } + bool return_decimal = false; + + int64_t number = 0; + double decimal = 0; + // Start with the first number - VALUE_CAST(number, Number, (*begin)); - int64_t result = number->number(); + APPLY_NUMBER_OR_DECIMAL(begin, { + number = it_numeric; + decimal = it_numeric; + }); + // Return negative on single argument if (length == 1) { - return makePtr(-result); + number = -number; + decimal = -decimal; + RETURN_NUMBER_OR_DECIMAL(); } // Skip the first node for (auto it = begin + 1; it != end; ++it) { - VALUE_CAST(number, Number, (*it)); - result -= number->number(); + APPLY_NUMBER_OR_DECIMAL(it, { + number -= it_numeric; + decimal -= it_numeric; + }); } - return makePtr(result); + RETURN_NUMBER_OR_DECIMAL(); }); ADD_FUNCTION( @@ -64,14 +104,19 @@ subtracts all but the first from the first.)", "", "", { - int64_t result = 1; + bool return_decimal = false; + + int64_t number = 1; + double decimal = 1; for (auto it = begin; it != end; ++it) { - VALUE_CAST(number, Number, (*it)); - result *= number->number(); + APPLY_NUMBER_OR_DECIMAL(it, { + number *= it_numeric; + decimal *= it_numeric; + }); } - return makePtr(result); + RETURN_NUMBER_OR_DECIMAL(); }); ADD_FUNCTION( @@ -81,17 +126,26 @@ subtracts all but the first from the first.)", { CHECK_ARG_COUNT_AT_LEAST("/", SIZE(), 1); + bool return_decimal = false; + + int64_t number = 0; + double decimal = 0; + // Start with the first number - VALUE_CAST(number, Number, (*begin)); - double result = number->number(); + APPLY_NUMBER_OR_DECIMAL(begin, { + number = it_numeric; + decimal = it_numeric; + }); // Skip the first node for (auto it = begin + 1; it != end; ++it) { - VALUE_CAST(number, Number, (*it)); - result /= number->number(); + APPLY_NUMBER_OR_DECIMAL(it, { + number /= it_numeric; + decimal /= it_numeric; + }); } - return makePtr((int64_t)result); + RETURN_NUMBER_OR_DECIMAL(); }); // (% 5 2) -> 1 diff --git a/src/printer.cpp b/src/printer.cpp index 130a395..e37c597 100644 --- a/src/printer.cpp +++ b/src/printer.cpp @@ -128,6 +128,10 @@ void Printer::printImpl(ValuePtr value, bool print_readably) printSpacing(); m_print += ::format("{}", std::static_pointer_cast(value)->number()); } + else if (is(value_raw_ptr)) { + printSpacing(); + m_print += ::format("{:.15}", std::static_pointer_cast(value)->decimal()); + } else if (is(value_raw_ptr)) { printSpacing(); std::string constant; diff --git a/src/reader.cpp b/src/reader.cpp index fdc2bec..caf465f 100644 --- a/src/reader.cpp +++ b/src/reader.cpp @@ -4,11 +4,13 @@ * SPDX-License-Identifier: MIT */ -#include // size_t -#include // uint64_t -#include // std::strtoll -#include // std::static_pointer_cast -#include // std::move +#include // std::from_chars +#include // size_t +#include // uint64_t +#include // std::strtoll +#include // std::static_pointer_cast +#include // std::errc +#include // std::move #include "error.h" #include "ruc/format/color.h" @@ -306,24 +308,33 @@ ValuePtr Reader::readKeyword() ValuePtr Reader::readValue() { - Token token = consume(); - char* end_ptr = nullptr; - int64_t result = std::strtoll(token.symbol.c_str(), &end_ptr, 10); - if (end_ptr == token.symbol.c_str() + token.symbol.size()) { - return makePtr(result); + auto symbol = consume().symbol; + + int64_t number; + auto [_, error] = std::from_chars(symbol.data(), symbol.data() + symbol.size(), number); + if (error == std::errc() && symbol.find('.') == std::string::npos) { + return makePtr(number); + } + + double decimal; + { + auto [_, error] = std::from_chars(symbol.data(), symbol.data() + symbol.size(), decimal); + if (error == std::errc()) { + return makePtr(decimal); + } } - if (token.symbol == "nil") { + if (symbol == "nil") { return makePtr(Constant::Nil); } - else if (token.symbol == "true") { + else if (symbol == "true") { return makePtr(Constant::True); } - else if (token.symbol == "false") { + else if (symbol == "false") { return makePtr(Constant::False); } - return makePtr(token.symbol); + return makePtr(symbol); } // ----------------------------------------- @@ -431,6 +442,10 @@ void Reader::dumpImpl(ValuePtr node) pretty_print ? print(yellow, "NumberNode") : print("NumberNode"); print(" <{}>", node); } + else if (is(node_raw_ptr)) { + pretty_print ? print(yellow, "DecimalNode") : print("DecimalNode"); + print(" <{}>", node); + } else if (is(node_raw_ptr)) { pretty_print ? print(yellow, "ValueNode") : print("ValueNode"); print(" <{}>", node); diff --git a/vendor/ruc b/vendor/ruc index 07c9f99..c8e4ae8 160000 --- a/vendor/ruc +++ b/vendor/ruc @@ -1 +1 @@ -Subproject commit 07c9f9959d3ce46da8bc7b0777d803524f9d1ec0 +Subproject commit c8e4ae884eacc963dc0fe6a5254679aad050cb74