/* * Copyright (C) 2023 Riyyi * * SPDX-License-Identifier: MIT */ #include // std::copy, std::reverse_copy #include // size_t #include // std::static_pointer_cast #include "ast.h" #include "env/environment.h" #include "env/macro.h" #include "forward.h" #include "util.h" namespace blaze { void Environment::loadCollectionModify() { // (apply + 1 2 (list 3 4)) -> (+ 1 2 3 4) -> 10 ADD_FUNCTION( "apply", { CHECK_ARG_COUNT_AT_LEAST("apply", SIZE(), 2); auto callable = *begin; IS_VALUE(Callable, callable); VALUE_CAST(collection, Collection, (*std::prev(end))); auto arguments = ValueVector(begin + 1, end - 1); arguments.reserve(arguments.size() + collection->size()); // Append list nodes to the argument leftovers auto nodes = collection->nodesRead(); for (const auto& node : nodes) { arguments.push_back(node); } ValuePtr value = nullptr; if (is(callable.get())) { auto function = std::static_pointer_cast(callable)->function(); value = function(arguments.begin(), arguments.end()); } else { auto lambda = std::static_pointer_cast(callable); value = eval(lambda->body(), Environment::create(lambda, std::move(arguments))); } return value; }); // (cons 1 (list 2 3)) ADD_FUNCTION( "cons", { CHECK_ARG_COUNT_IS("cons", SIZE(), 2); ValuePtr first = *begin; begin++; VALUE_CAST(collection, Collection, (*begin)); const auto& collection_nodes = collection->nodesRead(); auto result_nodes = ValueVector(collection_nodes.size() + 1); result_nodes.at(0) = first; std::copy(collection_nodes.begin(), collection_nodes.end(), result_nodes.begin() + 1); return makePtr(result_nodes); }); // (concat (list 1) (list 2 3)) -> (1 2 3) ADD_FUNCTION( "concat", { size_t count = 0; for (auto it = begin; it != end; ++it) { VALUE_CAST(collection, Collection, (*it)); count += collection->size(); } auto result_nodes = ValueVector(count); size_t offset = 0; for (auto it = begin; it != end; ++it) { const auto& collection_nodes = std::static_pointer_cast(*it)->nodesRead(); std::copy(collection_nodes.begin(), collection_nodes.end(), result_nodes.begin() + offset); offset += collection_nodes.size(); } return makePtr(result_nodes); }); // (conj '(1 2 3) 4 5 6) -> (6 5 4 1 2 3) // (conj [1 2 3] 4 5 6) -> [1 2 3 4 5 6] ADD_FUNCTION( "conj", { CHECK_ARG_COUNT_AT_LEAST("conj", SIZE(), 1); VALUE_CAST(collection, Collection, (*begin)); begin++; const auto& collection_nodes = collection->nodesRead(); size_t collection_count = collection_nodes.size(); size_t argument_count = SIZE(); auto nodes = ValueVector(argument_count + collection_count); if (is(collection.get())) { std::reverse_copy(begin, end, nodes.begin()); std::copy(collection_nodes.begin(), collection_nodes.end(), nodes.begin() + argument_count); return makePtr(nodes); } std::copy(collection_nodes.begin(), collection_nodes.end(), nodes.begin()); std::copy(begin, end, nodes.begin() + collection_count); return makePtr(nodes); }); // (map (fn* (x) (* x 2)) (list 1 2 3)) -> (2 4 6) ADD_FUNCTION( "map", { CHECK_ARG_COUNT_IS("map", SIZE(), 2); VALUE_CAST(callable, Callable, (*begin)); VALUE_CAST(collection, Collection, (*(begin + 1))); size_t count = collection->size(); auto nodes = ValueVector(count); if (is(callable.get())) { auto function = std::static_pointer_cast(callable)->function(); for (size_t i = 0; i < count; ++i) { nodes.at(i) = function(collection->begin() + i, collection->begin() + i + 1); } } else { auto lambda = std::static_pointer_cast(callable); auto collection_nodes = collection->nodesRead(); for (size_t i = 0; i < count; ++i) { nodes.at(i) = (eval(lambda->body(), Environment::create(lambda, { collection_nodes[i] }))); } } return makePtr(nodes); }); // (set-nth (list 1 2 3) 1 "foo") -> (1 "foo" 3) ADD_FUNCTION( "set-nth", { CHECK_ARG_COUNT_IS("set-nth-element", SIZE(), 3); VALUE_CAST(collection, Collection, (*begin)); VALUE_CAST(number_node, Number, (*(begin + 1))); auto index = static_cast(number_node->number() < 0 ? 0 : number_node->number()); auto value = *(begin + 2); auto collection_nodes = collection->nodesCopy(); if (index >= collection->size()) { // Enlarge list if index out of bounds collection_nodes.resize(index + 1, makePtr()); } collection_nodes[index] = value; if (is(begin->get())) { return makePtr(collection_nodes); } return makePtr(collection_nodes); }); // (seq '(1 2 3)) -> (1 2 3) // (seq [1 2 3]) -> (1 2 3) // (seq "foo") -> ("f" "o" "o") ADD_FUNCTION( "seq", { CHECK_ARG_COUNT_IS("seq", SIZE(), 1); auto front = *begin; Value* front_raw_ptr = front.get(); if (is(front_raw_ptr) && std::static_pointer_cast(front)->state() == Constant::Nil) { return makePtr(); } if (is(front_raw_ptr)) { auto collection = std::static_pointer_cast(front); if (collection->empty()) { return makePtr(); } if (is(front_raw_ptr)) { return front; } return makePtr(collection->nodesCopy()); } if (is(front_raw_ptr)) { auto string = std::static_pointer_cast(front); if (string->empty()) { return makePtr(); } size_t count = string->size(); auto nodes = ValueVector(count); const auto& data = string->data(); for (size_t i = 0; i < count; ++i) { nodes.at(i) = makePtr(data[i]); } return makePtr(nodes); } Error::the().add(::format("wrong argument type: Collection or String, {}", front)); return nullptr; }); // ----------------------------------------- // (assoc {:a 1 :b 2} :a 3 :c 1) -> {:a 3 :b 2 :c 1} ADD_FUNCTION( "assoc", { CHECK_ARG_COUNT_AT_LEAST("assoc", SIZE(), 1); VALUE_CAST(hash_map, HashMap, (*begin)); begin++; CHECK_ARG_COUNT_EVEN("assoc", SIZE()); Elements elements(hash_map->elements()); for (auto it = begin; it != end; std::advance(it, 2)) { const ValuePtr& value = *(std::next(it)); // temporary instance to get around const elements.insert_or_assign(HashMap::getKeyString(*it), value); } return makePtr(elements); }); // (dissoc {:a 1 :b 2 :c 3} :a :c :d) -> {:b 2} ADD_FUNCTION( "dissoc", { CHECK_ARG_COUNT_AT_LEAST("dissoc", SIZE(), 1); VALUE_CAST(hash_map, HashMap, (*begin)); begin++; Elements elements(hash_map->elements()); for (auto it = begin; it != end; ++it) { elements.erase(HashMap::getKeyString(*it)); } return makePtr(elements); }); } } // namespace blaze