From 81d304e4faf84de1e45bee6e490ebaf2a5843181 Mon Sep 17 00:00:00 2001 From: David Gonzalez Martin Date: Sun, 7 Jul 2024 08:15:05 +0200 Subject: [PATCH] Implement ifs --- bootstrap/main.cpp | 1207 ++++++++++++++++++++++++++------------------ tests/if/main.nat | 130 +++++ 2 files changed, 839 insertions(+), 498 deletions(-) create mode 100644 tests/if/main.nat diff --git a/bootstrap/main.cpp b/bootstrap/main.cpp index 5ceb0eb..00589ed 100644 --- a/bootstrap/main.cpp +++ b/bootstrap/main.cpp @@ -21,10 +21,12 @@ typedef double f64; typedef u32 Hash; #define fn static +#define method __attribute__((visibility("internal"))) #define global static #define assert(x) if (__builtin_expect(!(x), 0)) { trap(); } #define forceinline __attribute__((always_inline)) #define expect(x, b) __builtin_expect(x, b) +#define breakpoint() __builtin_debugtrap() #define trap() __builtin_trap() #define array_length(arr) sizeof(arr) / sizeof((arr)[0]) #define page_size (0x1000) @@ -101,7 +103,7 @@ struct Slice T* pointer; u64 length; - T& operator[](u64 index) + method T& operator[](u64 index) { assert(index < length); return pointer[index]; @@ -116,7 +118,7 @@ struct Slice }; } - Slice slice(u64 start, u64 end) + method Slice slice(u64 start, u64 end) { return { .pointer = pointer + start, @@ -124,7 +126,7 @@ struct Slice }; } - forceinline u8 equal(Slice other) + method forceinline u8 equal(Slice other) { if (length == other.length) { @@ -136,23 +138,23 @@ struct Slice } } - forceinline T* begin() + method forceinline T* begin() { return pointer; } - forceinline T* end() + method forceinline T* end() { return pointer + length; } - forceinline void copy_in(Slice other) + method forceinline void copy_in(Slice other) { assert(length == other.length); memcpy(pointer, other.pointer, sizeof(T) * other.length); } - T* find(T item) + method T* find(T item) { T* result = 0; @@ -168,12 +170,12 @@ struct Slice return result; } - u32 index(T* item) + method u32 index(T* item) { return item - pointer; } - s32 find_index(T item) + method s32 find_index(T item) { if (auto* result = find(item)) { @@ -187,7 +189,7 @@ struct Slice } // Gotta implement this just because C++ - u8 operator==(Slice other) + method u8 operator==(Slice other) { u8 result = 0; if (other.length == length) @@ -231,6 +233,7 @@ forceinline fn T max(T a, T b) using String = Slice; #define strlit(s) String{ .pointer = (u8*)s, .length = sizeof(s) - 1, } #define ch_to_str(ch) String{ .pointer = &ch, .length = 1 } +#define array_to_slice(arr) { .pointer = arr, .length = array_length(arr) } fn u64 parse_decimal(String string) { @@ -956,12 +959,20 @@ may_be_unused fn void print(const char* format, ...) u8 reverse_buffer[64]; u8 reverse_index = 0; u64 value = original_value; - while (value) + if (value) { - u8 decimal_value = (value % 10); - u8 ascii_ch = decimal_value + '0'; - value /= 10; - reverse_buffer[reverse_index] = ascii_ch; + while (value) + { + u8 decimal_value = (value % 10); + u8 ascii_ch = decimal_value + '0'; + value /= 10; + reverse_buffer[reverse_index] = ascii_ch; + reverse_index += 1; + } + } + else + { + reverse_buffer[0] = '0'; reverse_index += 1; } @@ -1018,12 +1029,12 @@ struct Arena return arena; } - fn Arena* init_default(u64 initial_size) + method fn Arena* init_default(u64 initial_size) { return init(default_size, minimum_granularity, initial_size); } - void* allocate_bytes(u64 size, u64 alignment) + method void* allocate_bytes(u64 size, u64 alignment) { u64 aligned_offset = align_forward(commit_position, alignment); u64 aligned_size_after = aligned_offset + size; @@ -1040,19 +1051,19 @@ struct Arena } template - T* allocate_many(u64 count) + method T* allocate_many(u64 count) { return (T*)allocate_bytes(sizeof(T) * count, alignof(T)); } template - T* allocate_one() + method T* allocate_one() { return allocate_many(1); } template - T* allocate_slice(u64 count) + method Slice allocate_slice(u64 count) { return { .pointer = allocate_many(count), @@ -1125,24 +1136,24 @@ struct PinnedArray // static_assert(sizeof(T) % granularity == 0); - forceinline T& operator[](u32 index) + method forceinline T& operator[](u32 index) { assert(index < length); return pointer[index]; } - forceinline void ensure_capacity(u32 additional) + method forceinline void ensure_capacity(u32 additional) { auto generic_array = (PinnedArray*)(this); generic_pinned_array_ensure_capacity(generic_array, additional, sizeof(T)); } - forceinline void clear() + method forceinline void clear() { length = 0; } - forceinline Slice add_with_capacity(u32 additional) + method forceinline Slice add_with_capacity(u32 additional) { auto generic_array = (PinnedArray*)(this); auto pointer = generic_pinned_array_add_with_capacity(generic_array, additional, sizeof(T)); @@ -1152,14 +1163,14 @@ struct PinnedArray }; } - forceinline Slice add(u32 additional) + method forceinline Slice add(u32 additional) { ensure_capacity(additional); auto slice = add_with_capacity(additional); return slice; } - forceinline Slice append(Slice items) + method forceinline Slice append(Slice items) { assert(items.length <= 0xffffffff); auto slice = add(items.length); @@ -1167,26 +1178,26 @@ struct PinnedArray return slice; } - forceinline T* add_one() + method forceinline T* add_one() { return add(1).pointer; } - forceinline T* append_one(T item) + method forceinline T* append_one(T item) { T* new_item = add_one(); *new_item = item; return new_item; } - forceinline T pop() + method forceinline T pop() { assert(length); length -= 1; return pointer[length]; } - forceinline Slice slice() + method forceinline Slice slice() { return { .pointer = pointer, @@ -1194,7 +1205,7 @@ struct PinnedArray }; } - T remove_swap(u32 index) + method T remove_swap(u32 index) { if (index >= 0 & index < length) { @@ -1282,7 +1293,7 @@ struct PinnedHashmap static_assert(granularity % sizeof(K) == 0, ""); static_assert(granularity % sizeof(V) == 0, ""); - Slice key_slice() + method forceinline Slice key_slice() { return { .pointer = keys, @@ -1290,7 +1301,15 @@ struct PinnedHashmap }; } - V* get(K key) + method forceinline Slice value_slice() + { + return { + .pointer = values, + .length = length, + }; + } + + method V* get(K key) { V* result = 0; @@ -1307,19 +1326,24 @@ struct PinnedHashmap return result; } - forceinline PinnedHashmap* generic() + method forceinline PinnedHashmap* generic() { auto* generic_hashmap = (PinnedHashmap*)(this); return generic_hashmap; } - forceinline GetOrPut get_or_put(K key, V value) + method forceinline void ensure_capacity(u32 additional) + { + generic_pinned_hashmap_ensure_capacity(generic(), sizeof(K), sizeof(K), additional); + } + + method forceinline GetOrPut get_or_put(K key, V value) { auto generic_get_or_put = generic_pinned_hashmap_get_or_put(generic(), (u8*)&key, sizeof(K), (u8*)&value, sizeof(V)); return *(GetOrPut*)&generic_get_or_put; } - forceinline V* put_assume_not_existing(K key, V value) + method forceinline V* put_assume_not_existing(K key, V value) { auto result = generic_pinned_hashmap_put_assume_not_existing(generic(), (u8*)&key, sizeof(K), (u8*)&value, sizeof(V)); return (V*)(result.value); @@ -1480,7 +1504,7 @@ struct NodeType } multi; } payload = {}; - u8 is_simple() + method u8 is_simple() { switch (id) { @@ -1495,7 +1519,7 @@ struct NodeType } } - u8 equal(NodeType other) + method u8 equal(NodeType other) { if (id != other.id) { @@ -1511,7 +1535,7 @@ struct NodeType } } - u8 is_constant() + method u8 is_constant() { switch (id) { @@ -1531,7 +1555,7 @@ struct NodeType } } - NodeType meet(NodeType other) + method NodeType meet(NodeType other) { unused(other); @@ -1586,13 +1610,13 @@ struct NodeType } } - u8 is_bot() + method u8 is_bot() { assert(id == Id::INTEGER); return !payload.integer.is_constant & (payload.integer.constant == 1); } - u8 is_top() + method u8 is_top() { assert(id == Id::INTEGER); return !payload.integer.is_constant & (payload.integer.constant == 0); @@ -1600,7 +1624,7 @@ struct NodeType }; may_be_unused global auto constexpr integer_top = NodeType{ - .id = NodeType::Id::TOP, + .id = NodeType::Id::INTEGER, .payload = { .integer = { .constant = 0, @@ -1610,7 +1634,7 @@ may_be_unused global auto constexpr integer_top = NodeType{ }; may_be_unused global auto constexpr integer_bot = NodeType{ - .id = NodeType::Id::TOP, + .id = NodeType::Id::INTEGER, .payload = { .integer = { .constant = 1, @@ -1620,7 +1644,7 @@ may_be_unused global auto constexpr integer_bot = NodeType{ }; may_be_unused global auto constexpr integer_zero = NodeType{ - .id = NodeType::Id::TOP, + .id = NodeType::Id::INTEGER, .payload = { .integer = { .constant = 0, @@ -1629,6 +1653,20 @@ may_be_unused global auto constexpr integer_zero = NodeType{ }, }; +global NodeType type_if_types[2] = { + { .id = NodeType::Id::CONTROL }, + { .id = NodeType::Id::CONTROL }, +}; + +global auto constexpr type_if = NodeType{ + .id = NodeType::Id::MULTIVALUE, + .payload = { + .multi = { + .types = array_to_slice(type_if_types), + }, + }, +}; + struct SemaType { u64 size; @@ -1639,7 +1677,7 @@ struct SemaType u32 reserved = 0; String name; - u8 get_bit_count() + method u8 get_bit_count() { assert(id == SemaTypeId::INTEGER); u32 bit_count_mask = (1 << (type_flags_bit_count - 1)) - 1; @@ -1649,7 +1687,7 @@ struct SemaType return bit_count; } - NodeType lower() + method NodeType lower() { switch (id) { @@ -1660,16 +1698,7 @@ struct SemaType case SemaTypeId::POINTER: trap(); case SemaTypeId::INTEGER: - return NodeType{ - .id = NodeType::Id::INTEGER, - .payload = { - .integer = { - .constant = 0, - // .bit_count = get_bit_count(), - .is_constant = 0, - }, - }, - }; + return integer_bot; case SemaTypeId::ARRAY: trap(); case SemaTypeId::STRUCT: @@ -1750,7 +1779,7 @@ struct Unit SemaType* builtin_types; u64 generate_debug_information : 1; - SemaType* get_integer_type(u8 bit_count, u8 signedness) + method SemaType* get_integer_type(u8 bit_count, u8 signedness) { auto index = integer_type_offset + signedness * 64 + bit_count - 1; return &builtin_types[index]; @@ -1810,6 +1839,7 @@ struct Function Symbol symbol; Node* root_node; + Node* stop_node; Node** parameters; Function::Prototype prototype; // u32 node_count; @@ -1832,9 +1862,17 @@ struct Node enum class Id: u8 { ROOT, + STOP, PROJECTION, RETURN, + IF, CONSTANT_INT, + SCOPE, + SYMBOL_FUNCTION, + CALL, + REGION, + PHI, + INTEGER_ADD, INTEGER_SUB, @@ -1844,9 +1882,6 @@ struct Node INTEGER_COMPARE_LESS_EQUAL, INTEGER_COMPARE_GREATER, INTEGER_COMPARE_GREATER_EQUAL, - SCOPE, - SYMBOL_FUNCTION, - CALL, }; using Type = NodeType; @@ -1873,11 +1908,15 @@ struct Node Type args; } root; Symbol* symbol; + struct + { + String label; + } phi; } payload; u8 padding[40] = {}; - forceinline Slice get_inputs() + method forceinline Slice get_inputs() { return { .pointer = inputs.pointer, @@ -1885,7 +1924,7 @@ struct Node }; } - forceinline Slice get_outputs() + method forceinline Slice get_outputs() { return { .pointer = outputs.pointer, @@ -1902,9 +1941,10 @@ struct Node [[nodiscard]] fn Node* add(Thread* thread, NodeData data) { - auto* node = thread->arena->allocate_one(); auto gvn = thread->node_count; thread->node_count += 1; + + auto* node = thread->arena->allocate_one(); *node = { .type = data.type, .inputs = {}, @@ -1927,7 +1967,7 @@ struct Node return node; } - u8 remove_output(Node* output) + method u8 remove_output(Node* output) { s32 index = outputs.slice().find_index(output); assert(index != -1); @@ -1935,13 +1975,13 @@ struct Node return outputs.length == 0; } - Node* add_output(Node* output) + method Node* add_output(Node* output) { outputs.append_one(output); return this; } - Node* add_input(Node* input) + method Node* add_input(Node* input) { inputs.append_one(input); if (input) @@ -1951,7 +1991,7 @@ struct Node return input; } - Node* set_input(Arena* arena, s32 index, Node* input) + method Node* set_input(Arena* arena, s32 index, Node* input) { Node* old_input = inputs[index]; if (old_input == input) @@ -1974,74 +2014,7 @@ struct Node return input; } - u8 is_pinned() - { - u8 is_good_id = 0; - switch (id) - { - case Id::ROOT: - case Id::RETURN: - is_good_id = 1; - break; - case Id::PROJECTION: - case Id::CONSTANT_INT: - break; - case Id::INTEGER_ADD: - case Id::INTEGER_SUB: - trap(); - case Id::SCOPE: - trap(); - case Id::SYMBOL_FUNCTION: - case Id::CALL: - trap(); - case Id::INTEGER_COMPARE_EQUAL: - case Id::INTEGER_COMPARE_NOT_EQUAL: - case Id::INTEGER_COMPARE_LESS: - case Id::INTEGER_COMPARE_LESS_EQUAL: - case Id::INTEGER_COMPARE_GREATER: - case Id::INTEGER_COMPARE_GREATER_EQUAL: - trap(); - } - - return is_good_id | is_projection() | cfg_is_control_projection(); - } - - u8 is_projection() - { - switch (id) - { - case Id::PROJECTION: - return 1; - default: - return 0; - } - } - - u8 cfg_is_control_projection() - { - return is_projection() & (type.id == Node::Type::Id::CONTROL); - } - - u8 is_cfg_control() - { - switch (type.id) - { - case Node::Type::Id::CONTROL: - return 1; - case Node::Type::Id::MULTIVALUE: - for (Node* output : get_outputs()) - { - if (output->cfg_is_control_projection()) - { - return 1; - } - } - default: - return 0; - } - } - - Node* idealize() + method Node* idealize(Thread* thread, Function* function) { switch (id) { @@ -2055,14 +2028,17 @@ struct Node return 0; } case Id::ROOT: + case Id::STOP: case Id::PROJECTION: + case Id::IF: case Id::RETURN: case Id::CONSTANT_INT: case Id::INTEGER_ADD: + case Id::REGION: return 0; case Id::SCOPE: trap(); - // TODO: + // TODO: case Id::SYMBOL_FUNCTION: case Id::CALL: return 0; @@ -2072,21 +2048,146 @@ struct Node case Id::INTEGER_COMPARE_LESS_EQUAL: case Id::INTEGER_COMPARE_GREATER: case Id::INTEGER_COMPARE_GREATER_EQUAL: - trap(); + if (inputs[1] == inputs[2]) + { + trap(); + } + else + { + return 0; + } + case Id::PHI: + { + if (phi_same_inputs()) + { + return inputs[1]; + } + else + { + Node* operand = inputs[1]; + + if (operand->inputs.length == 3 && !operand->inputs[0] && !operand->is_cfg() && phi_same_operand()) + { + auto lefts = thread->arena->allocate_slice(inputs.length); + auto rights = thread->arena->allocate_slice(inputs.length); + lefts[0] = rights[0] = inputs[0]; + + for (u32 i = 1; i < inputs.length; i += 1) + { + lefts[i] = inputs[i]->inputs[1]; + rights[i] = inputs[i]->inputs[2]; + } + + auto* left_phi = Node::add(thread, { + .type = {}, + .inputs = lefts, + .id = Node::Id::PHI, + }); + left_phi->payload.phi.label = payload.phi.label; + left_phi = left_phi->peephole(thread, function); + + auto* right_phi = Node::add(thread, { + .type = {}, + .inputs = rights, + .id = Node::Id::PHI, + }); + right_phi->payload.phi.label = payload.phi.label; + right_phi = right_phi->peephole(thread, function); + return operand->copy(thread, left_phi, right_phi); + } + else + { + return 0; + } + } + } } } - u8 is_unused() + method Node* copy(Thread* thread, Node* left, Node* right) + { + switch (id) + { + case Id::INTEGER_COMPARE_EQUAL: + case Id::INTEGER_COMPARE_NOT_EQUAL: + case Id::INTEGER_COMPARE_LESS: + case Id::INTEGER_COMPARE_LESS_EQUAL: + case Id::INTEGER_COMPARE_GREATER: + case Id::INTEGER_COMPARE_GREATER_EQUAL: + case Id::INTEGER_ADD: + { + Node* inputs[] = { 0, left, right }; + auto* result = Node::add(thread, { + .type = {}, + .inputs = array_to_slice(inputs), + .id = id, + }); + return result; + } break; + default: + trap(); + } + } + + method u8 is_cfg() + { + switch (id) + { + case Id::ROOT: + case Id::STOP: + case Id::RETURN: + case Id::REGION: + case Id::IF: + return 1; + case Id::PROJECTION: + trap(); + default: + return 0; + } + } + + method u8 phi_same_operand() + { + assert(id == Id::PHI); + auto input_class = inputs[1]->id; + for (u32 i = 2; i < inputs.length; i += 1) + { + auto other_input_class = inputs[i]->id; + if (input_class != other_input_class) + { + return 0; + } + } + + return 1; + } + + method u8 phi_same_inputs() + { + assert(id == Id::PHI); + auto* input = inputs[1]; + for (u32 i = 2; i < inputs.length; i += 1) + { + if (input != inputs[i]) + { + return 0; + } + } + + return 1; + } + + method u8 is_unused() { return outputs.length == 0; } - u8 is_dead() + method u8 is_dead() { return is_unused() & (inputs.length == 0) & (type.id == Node::Type::Id::INVALID); } - void pop_inputs(Arena* arena, u32 count) + method void pop_inputs(Arena* arena, u32 count) { for (u32 i = 0; i < count; i += 1) { @@ -2101,7 +2202,7 @@ struct Node } } - void kill(Arena* arena) + method void kill(Arena* arena) { assert(is_unused()); @@ -2111,8 +2212,9 @@ struct Node assert(is_dead()); } - static auto constexpr enable_peephole = 1; - Node* peephole(Thread* thread, Function* function) + global auto constexpr enable_peephole = 1; + + method Node* peephole(Thread* thread, Function* function) { Node::Type type = this->type = compute(); @@ -2132,10 +2234,10 @@ struct Node return dead_code_elimination(thread->arena, result); } - Node* n = idealize(); - if (n) + Node* idealized = idealize(thread, function); + if (idealized) { - trap(); + return dead_code_elimination(thread->arena, idealized->peephole(thread, function)); } else { @@ -2143,18 +2245,18 @@ struct Node } } - Node* keep() + method Node* keep() { return add_output(0); } - Node* unkeep() + method Node* unkeep() { remove_output(0); return this; } - u8 is_constant() + method u8 is_constant() { switch (id) { @@ -2165,12 +2267,16 @@ struct Node } } - Node::Type compute() + method Node::Type compute() { switch (id) { case Node::Id::ROOT: return payload.root.args; + case Node::Id::STOP: + return { .id = Type::Id::BOTTOM }; + case Node::Id::IF: + return type_if; case Node::Id::INTEGER_ADD: case Node::Id::INTEGER_SUB: case Node::Id::INTEGER_COMPARE_EQUAL: @@ -2189,14 +2295,8 @@ struct Node u64 result; switch (id) { - case Id::ROOT: - case Id::PROJECTION: - case Id::RETURN: - case Id::CONSTANT_INT: - case Id::SCOPE: - case Id::SYMBOL_FUNCTION: - case Id::CALL: - trap(); + default: + trap(); case Id::INTEGER_ADD: result = left_type.payload.integer.constant + right_type.payload.integer.constant; break; @@ -2280,17 +2380,21 @@ struct Node }, }; } + case Node::Id::REGION: + return { .id = Type::Id::CONTROL }; + case Node::Id::PHI: + return { .id = Type::Id::BOTTOM }; default: trap(); } } - Node* project(Thread* thread, Function* function, s32 index, String label) + method Node* project(Thread* thread, Node* control, s32 index, String label) { assert(type.id == Node::Type::Id::MULTIVALUE); auto* projection = Node::add(thread, { .type = {}, - .inputs = { .pointer = &function->root_node, .length = 1 }, + .inputs = { .pointer = &control, .length = 1 }, .id = Node::Id::PROJECTION, }); projection->payload.projection.index = index; @@ -2298,7 +2402,7 @@ struct Node return projection; } - Node* dead_code_elimination(Arena* arena, Node* new_node) + method Node* dead_code_elimination(Arena* arena, Node* new_node) { if (new_node != this && is_unused()) { @@ -2310,9 +2414,44 @@ struct Node return new_node; } - Node* control(Arena* arena, Node* node) + method SemaType* get_debug_type(Unit* unit) { - return set_input(arena, 0, node); + switch (type.id) + { + case NodeType::Id::INVALID: + trap(); + case NodeType::Id::BOTTOM: + trap(); + case NodeType::Id::TOP: + trap(); + case NodeType::Id::CONTROL: + trap(); + case NodeType::Id::INTEGER: + return unit->get_integer_type(32, 1); + case NodeType::Id::VOID: + trap(); + case NodeType::Id::MULTIVALUE: + trap(); + case NodeType::Id::MEMORY: + trap(); + case NodeType::Id::POINTER: + trap(); + case NodeType::Id::FUNCTION: + trap(); + case NodeType::Id::CALL: + trap(); + } + } + + method Node* get_control() + { + switch (id) + { + case Node::Id::SCOPE: + return inputs[0]; + default: + trap(); + } } }; @@ -2339,55 +2478,6 @@ static_assert(page_size % sizeof(Node) == 0); return constant_int; } -struct WorkList -{ - using BitsetBackingType = u32; - PinnedArray nodes; - PinnedArray bitset; - - global constexpr auto bit_count = sizeof(BitsetBackingType) * 8; - - void push(Node* node) - { - if (!test_and_set(node)) - { - nodes.append_one(node); - } - } - - u8 test_and_set(Node* node) - { - BitsetBackingType gvn_word = node->gvn / bit_count; - if (gvn_word >= bitset.capacity) - { - trap(); - } - BitsetBackingType gvn_mask = 1 << (node->gvn % bit_count); - if (bitset[gvn_word] & gvn_mask) - { - return 1; - } - else - { - bitset[gvn_word] |= gvn_mask; - return 0; - } - } - - void ensure_capacity(u32 capacity) - { - u32 aligned_capacity = align_forward(capacity, bit_count); - nodes.ensure_capacity(aligned_capacity); - auto bitset_length = aligned_capacity / bit_count; - unused(bitset.add(bitset_length)); - } - - void clear() - { - nodes.clear(); - } -}; - fn u64 round_up_to_next_power_of_2(u64 n) { n -= 1; @@ -2401,13 +2491,6 @@ fn u64 round_up_to_next_power_of_2(u64 n) return n; } -// fn Hash intern_identifier(Unit* unit, String identifier) -// { -// Hash hash = hash_bytes(identifier); -// (void)unit->identifiers.get_or_put(hash, identifier); -// return hash; -// } - global String integer_names[] = { strlit("u1"), @@ -2626,7 +2709,6 @@ struct Instance { Arena* arena; }; -typedef struct Instance Instance; fn Unit* instance_add_unit(Instance* instance) { @@ -2652,7 +2734,7 @@ struct Parser u32 line; u32 column; - void skip_space(String src) + method void skip_space(String src) { u64 original_i = i; @@ -2699,7 +2781,8 @@ struct Parser } } } - void expect_character(String src, u8 expected_ch) + + method void expect_character(String src, u8 expected_ch) { u64 index = i; if (expect(index < src.length, 1)) @@ -2727,7 +2810,7 @@ struct Parser } } - String parse_raw_identifier(String src) + method String parse_raw_identifier(String src) { u64 identifier_start_index = i; u64 is_string_literal = src.pointer[identifier_start_index] == '"'; @@ -2765,27 +2848,9 @@ struct Parser } } - typedef enum Keyword : u32 - { - KEYWORD_COUNT, - KEYWORD_INVALID = ~0u, - } Keyword; - - // TODO: - // fn Keyword parse_keyword(String identifier) - // { - // Keyword result = KEYWORD_INVALID; - // return result; - // } - - String parse_and_check_identifier(String src) + method String parse_and_check_identifier(String src) { String identifier = parse_raw_identifier(src); - // Keyword keyword_index = parse_keyword(identifier); - // if (expect(keyword_index != KEYWORD_INVALID, 0)) - // { - // fail(); - // } if (expect(identifier.equal(strlit("_")), 0)) { @@ -2831,7 +2896,6 @@ fn void compiler_file_read(Arena* arena, File* file) file->status = FILE_STATUS_READ; } - global constexpr auto pointer_sign = '*'; global constexpr auto end_of_statement = ';'; global constexpr auto end_of_argument = ','; @@ -2902,20 +2966,155 @@ struct GlobalSymbolAttributes u8 exported: 1; u8 external: 1; }; -typedef struct GlobalSymbolAttributes GlobalSymbolAttributes; static_assert(array_length(global_symbol_attributes) == GLOBAL_SYMBOL_ATTRIBUTE_COUNT, ""); +Node* create_scope(Thread* thread) +{ + auto* scope = Node::add(thread, { + .type = { .id = Node::Type::Id::BOTTOM }, + .inputs = {}, + .id = Node::Id::SCOPE, + }); + scope->payload.scope.stack = {}; + + return scope; +} + +Slice scope_reverse_names(Arena* arena, Node* node) +{ + assert(node->id == Node::Id::SCOPE); + Slice names = arena->allocate_slice(node->inputs.length); + + for (auto& hashmap : node->payload.scope.stack.slice()) + { + for (String name : hashmap.key_slice()) + { + auto index = *hashmap.get(name); + names[index] = name; + } + } + + return names; +} + struct Analyzer { Function* function; Node* scope; File* file; - void kill_control(Arena* arena) + + method Node* set_control(Arena* arena, Node* node) { - scope->control(arena, 0); - // scope->scope + return scope->set_input(arena, 0, node); + } + + method void kill_control(Arena* arena) + { + set_control(arena, 0); + } + + method Node* add_return(Thread* thread, Node* return_value) + { + Node* inputs[] = { get_control(), return_value }; + + auto* return_node = Node::add(thread, { + .type = {}, + .inputs = array_to_slice(inputs), + .id = Node::Id::RETURN, + })->peephole(thread, function); + + auto* node = function->stop_node->add_input(return_node); + + kill_control(thread->arena); + + return node; + } + + method Node* get_control() + { + return scope->get_control(); + } + + method Node* duplicate_scope(Thread* thread) + { + auto* duplicate = create_scope(thread); + + duplicate->payload.scope.stack.ensure_capacity(scope->payload.scope.stack.capacity); + + // TODO: make this more efficient + for (auto& hashmap: scope->payload.scope.stack.slice()) + { + Hashmap duplicate_hashmap = {}; + duplicate_hashmap.ensure_capacity(hashmap.length); + auto keys = hashmap.key_slice(); + auto values = hashmap.value_slice(); + + for (u32 i = 0; i < hashmap.length; i += 1) + { + duplicate_hashmap.put_assume_not_existing(keys[i], values[i]); + } + + duplicate->payload.scope.stack.append_one(duplicate_hashmap); + } + + duplicate->add_input(get_control()); + + for (u32 i = 1; i < scope->inputs.length; i += 1) + { + duplicate->add_input(scope->inputs[i]); + } + + return duplicate; + } + + method Node* merge_scopes(Thread* thread, Node* scope_a, Node* scope_b) + { + assert(scope_a->id == Node::Id::SCOPE); + assert(scope_b->id == Node::Id::SCOPE); + + Node* inputs[] = { + 0, + scope_a->get_control(), + scope_b->get_control(), + }; + + auto* region_node = set_control(thread->arena, Node::add(thread, { + .type = {}, + .inputs = array_to_slice(inputs), + .id = Node::Id::REGION, + })->peephole(thread, function)); + auto names = scope_reverse_names(thread->arena, scope_a); + + // Skip input[0] ($ctrl) + for (u32 i = 1; i < scope_a->inputs.length; i += 1) + { + Node* input_a = scope_a->inputs[i]; + Node* input_b = scope_b->inputs[i]; + if (input_a != input_b) + { + Node* inputs[] = { + region_node, + input_a, + input_b, + }; + String label = names[i]; + + auto* phi_node = Node::add(thread, { + .type = {}, + .inputs = array_to_slice(inputs), + .id = Node::Id::PHI, + }); + phi_node->payload.phi.label = label; + phi_node = phi_node->peephole(thread, function); + + scope->set_input(thread->arena, i, phi_node); + } + } + + scope_b->kill(thread->arena); + return region_node; } }; @@ -3134,7 +3333,7 @@ fn u64 parse_hex(String string) return result; } -fn Node* scope_update_extended(Node* scope, String name, Node* node, s32 nesting_level) +fn Node* scope_update_extended(Node* scope, Arena* arena, String name, Node* node, s32 nesting_level) { if (nesting_level < 0) { @@ -3148,7 +3347,7 @@ fn Node* scope_update_extended(Node* scope, String name, Node* node, s32 nesting auto* old = scope->get_inputs()[*index]; if (node) { - trap(); + return scope->set_input(arena, *index, node); } else { @@ -3157,18 +3356,18 @@ fn Node* scope_update_extended(Node* scope, String name, Node* node, s32 nesting } else { - return scope_update_extended(scope, name, node, nesting_level - 1); + return scope_update_extended(scope, arena, name, node, nesting_level - 1); } } -// fn Node* scope_update(Node* scope, String name, Node* node) -// { -// trap(); -// } - -fn Node* scope_lookup(Analyzer* analyzer, String name) +fn Node* scope_update(Analyzer* analyzer, Arena* arena, String name, Node* node) { - if (auto* node = scope_update_extended(analyzer->scope, name, nullptr, analyzer->scope->payload.scope.stack.length - 1)) + return scope_update_extended(analyzer->scope, arena, name, node, analyzer->scope->payload.scope.stack.length - 1); +} + +fn Node* scope_lookup(Analyzer* analyzer, Arena* arena, String name) +{ + if (auto* node = scope_update_extended(analyzer->scope, arena, name, nullptr, analyzer->scope->payload.scope.stack.length - 1)) { return node; } @@ -3253,7 +3452,7 @@ fn Node* scope_lookup(Analyzer* analyzer, String name) else if (is_identifier) { String identifier = parser->parse_and_check_identifier(src); - auto* node = scope_lookup(analyzer, identifier); + auto* node = scope_lookup(analyzer, thread->arena, identifier); if (!node) { fail(); @@ -3344,9 +3543,9 @@ fn Node* scope_lookup(Analyzer* analyzer, String name) while (1) { - if ((iterations == 0) & !iteration_type) + if ((iterations == 1) & !iteration_type) { - trap(); + iteration_type = previous_node->get_debug_type(unit); } // u32 line = get_line(parser); @@ -3404,7 +3603,7 @@ fn Node* scope_lookup(Analyzer* analyzer, String name) auto* binary = Node::add(thread, { .type = current_node->type, - .inputs = { .pointer = inputs, .length = array_length(inputs), }, + .inputs = array_to_slice(inputs), .id = id, }); @@ -3460,7 +3659,7 @@ fn Node* scope_lookup(Analyzer* analyzer, String name) case CurrentOperation::INTEGER_ADD_ASSIGN: case CurrentOperation::INTEGER_SUB_ASSIGN: trap(); - } + } previous_node = previous_node->peephole(thread, analyzer->function); @@ -3547,7 +3746,9 @@ fn Node* define_variable(Analyzer* analyzer, String name, Node* node) assert(stack->length); auto* last = &stack->pointer[stack->length - 1]; - if (last->get_or_put(name, analyzer->scope->inputs.length).existing) + auto input_index = analyzer->scope->inputs.length; + + if (last->get_or_put(name, input_index).existing) { trap(); return 0; @@ -3556,14 +3757,232 @@ fn Node* define_variable(Analyzer* analyzer, String name, Node* node) return analyzer->scope->add_input(node); } +fn Node* analyze_local_block(Analyzer* analyzer, Parser* parser, Unit* unit, Thread* thread, String src); + +fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Thread* thread, String src) +{ + auto statement_start_index = parser->i; + u8 statement_start_ch = src[statement_start_index]; + Function* function = analyzer->function; + + if (is_identifier_start(statement_start_ch)) + { + Node* statement_node = 0; + String identifier = parser->parse_raw_identifier(src); + + if (identifier.equal(strlit("return"))) + { + parser->skip_space(src); + + auto* return_value = analyze_expression(analyzer, parser, unit, thread, src, analyzer->function->prototype.original_return_type, Side::right)->peephole(thread, function); + parser->expect_character(src, ';'); + + auto* return_node = analyzer->add_return(thread, return_value); + statement_node = return_node; + } + else if (identifier.equal(strlit("if"))) + { + parser->skip_space(src); + + parser->expect_character(src, parenthesis_open); + + parser->skip_space(src); + + auto* predicate_node = analyze_expression(analyzer, parser, unit, thread, src, 0, Side::right); + + parser->skip_space(src); + + parser->expect_character(src, parenthesis_close); + + parser->skip_space(src); + + Node* if_inputs[] = { + analyzer->get_control(), + predicate_node, + }; + + auto* if_node = Node::add(thread, { + .type = {}, + .inputs = array_to_slice(if_inputs), + .id = Node::Id::IF, + })->keep()->peephole(thread, function); + + Node* if_true = if_node->project(thread, if_node, 0, strlit("True"))->peephole(thread, function); + Node* if_false = if_node->project(thread, if_node, 1, strlit("False"))->peephole(thread, function); + + u32 original_input_count = analyzer->scope->inputs.length; + auto* false_scope = analyzer->duplicate_scope(thread); + + analyzer->set_control(thread->arena, if_true); + assert(analyzer->scope->get_control()); + + analyze_statement(analyzer, parser, unit, thread, src); + + auto* true_scope = analyzer->scope; + + analyzer->scope = false_scope; + analyzer->set_control(thread->arena, if_false); + assert(analyzer->scope->get_control()); + + parser->skip_space(src); + + if (is_identifier_start(src[parser->i])) + { + auto before_else = parser->i; + String identifier = parser->parse_raw_identifier(src); + if (identifier.equal(strlit("else"))) + { + parser->skip_space(src); + + analyze_statement(analyzer, parser, unit, thread, src); + } + else + { + parser->i = before_else; + } + } + + if ((true_scope->inputs.length != original_input_count) | (false_scope->inputs.length != original_input_count)) + { + fail(); + } + + analyzer->scope = true_scope; + + auto* merged_scope = analyzer->merge_scopes(thread, true_scope, false_scope); + statement_node = analyzer->set_control(thread->arena, merged_scope); + assert(statement_node); + } + + if (statement_node) + { + return statement_node; + } + else + { + if (auto* left_node = scope_lookup(analyzer, thread->arena, identifier)) + { + parser->skip_space(src); + + enum class StatementOperation : u8 + { + ASSIGN, + }; + StatementOperation operation; + switch (src[parser->i]) + { + case '=': + operation = StatementOperation::ASSIGN; + parser->i += 1; + break; + default: + trap(); + } + + parser->skip_space(src); + + Node* right_expression = analyze_expression(analyzer, parser, unit, thread, src, 0, Side::right); + + parser->skip_space(src); + parser->expect_character(src, ';'); + + switch (operation) + { + case StatementOperation::ASSIGN: + if (!scope_update(analyzer, thread->arena, identifier, right_expression)) + { + fail(); + } + break; + default: + trap(); + } + + return 0; + } + else + { + fail(); + } + } + } + else + { + switch (statement_start_ch) + { + case local_symbol_declaration_start: + { + parser->i += 1; + + parser->skip_space(src); + + String name = parser->parse_and_check_identifier(src); + + u8 has_local_attributes = src[parser->i] == symbol_attribute_start; + parser->i += has_local_attributes; + + if (has_local_attributes) + { + // TODO: local attributes + fail(); + } + + parser->skip_space(src); + + struct LocalResult + { + Node* node; + SemaType* type; + }; + + LocalResult local_result = {}; + switch (src[parser->i]) + { + case ':': + { + parser->i += 1; + parser->skip_space(src); + + SemaType* type = analyze_type(parser, unit, src); + + parser->skip_space(src); + parser->expect_character(src, '='); + parser->skip_space(src); + + auto* initial_node = analyze_expression(analyzer, parser, unit, thread, src, type, Side::right); + if (!define_variable(analyzer, name, initial_node)) + { + fail(); + } + local_result = { + .node = initial_node, + .type = type, + }; + } break; + case '=': trap(); + default: fail(); + } + + parser->skip_space(src); + parser->expect_character(src, ';'); + + return local_result.node; + } break; + case block_start: + { + return analyze_local_block(analyzer, parser, unit, thread, src); + } break; + default: + trap(); + } + } +} fn Node* analyze_local_block(Analyzer* analyzer, Parser* parser, Unit* unit, Thread* thread, String src) { push_scope(analyzer); parser->expect_character(src, block_start); - Function* function = analyzer->function; - Node* node = 0; while (1) { parser->skip_space(src); @@ -3573,138 +3992,14 @@ fn Node* analyze_local_block(Analyzer* analyzer, Parser* parser, Unit* unit, Thr break; } - auto statement_start_index = parser->i; - u8 statement_start_ch = src[statement_start_index]; - - Node* statement_node = 0; - - if (is_identifier_start(statement_start_ch)) - { - String identifier = parser->parse_raw_identifier(src); - if (identifier.equal(strlit("return"))) - { - parser->skip_space(src); - - auto* return_value = analyze_expression(analyzer, parser, unit, thread, src, analyzer->function->prototype.original_return_type, Side::right)->peephole(thread, function); - parser->expect_character(src, ';'); - - Node* inputs[] = - { - function->root_node, - return_value, - }; - - Node* ret_node = Node::add(thread, { - .type = { .id = Node::Type::Id::CONTROL }, - .inputs = { .pointer = inputs, .length = array_length(inputs) }, - .id = Node::Id::RETURN, - })->peephole(thread, function); - analyzer->kill_control(thread->arena); - statement_node = ret_node; - } - - if (!statement_node) - { - auto& list = analyzer->scope->payload.scope.stack; - u32 i = list.length; - u8 found = 0; - while (i > 0) - { - i -= 1; - - auto& map = list[i]; - if (auto* foo = map.get(identifier)) - { - found = 1; - break; - } - } - assert(found); - trap(); - } - } - else - { - switch (statement_start_ch) - { - case local_symbol_declaration_start: - { - parser->i += 1; - - parser->skip_space(src); - - String name = parser->parse_and_check_identifier(src); - - u8 has_local_attributes = src[parser->i] == symbol_attribute_start; - parser->i += has_local_attributes; - - if (has_local_attributes) - { - // TODO: local attributes - fail(); - } - - parser->skip_space(src); - - struct LocalResult - { - Node* node; - SemaType* type; - }; - - LocalResult local_result = {}; - switch (src[parser->i]) - { - case ':': - { - parser->i += 1; - parser->skip_space(src); - - SemaType* type = analyze_type(parser, unit, src); - - parser->skip_space(src); - parser->expect_character(src, '='); - parser->skip_space(src); - - auto* initial_node = analyze_expression(analyzer, parser, unit, thread, src, type, Side::right); - if (!define_variable(analyzer, name, initial_node)) - { - fail(); - } - local_result = { - .node = initial_node, - .type = type, - }; - } break; - case '=': trap(); - default: fail(); - } - - parser->skip_space(src); - parser->expect_character(src, ';'); - - statement_node = local_result.node; - } break; - case block_start: - { - statement_node = analyze_local_block(analyzer, parser, unit, thread, src); - } break; - default: - trap(); - } - } - - if (statement_node) - { - node = statement_node; - } + analyze_statement(analyzer, parser, unit, thread, src); } parser->expect_character(src, block_end); pop_scope(analyzer); - return node; + return 0; } typedef enum SystemVClass @@ -3720,14 +4015,12 @@ struct SystemVClassification { SystemVClass v[2]; }; -typedef struct SystemVClassification SystemVClassification; struct SystemVRegisterCount { u32 gp_registers; u32 sse_registers; }; -typedef struct SystemVRegisterCount SystemVRegisterCount; fn SystemVClassification systemv_classify(SemaType* type, u64 base_offset) { @@ -3837,7 +4130,7 @@ fn SemaType* systemv_get_int_type_at_offset(SemaType* type, u64 offset, SemaType } } -fn void analyze_function(Parser* parser, Thread* thread, Unit* unit, File* file) +fn Node* analyze_function(Parser* parser, Thread* thread, Unit* unit, File* file) { String src = file->source_code; parser->expect_character(src, 'f'); @@ -4241,6 +4534,7 @@ fn void analyze_function(Parser* parser, Thread* thread, Unit* unit, File* file) .linkage = symbol_attributes.external ? Symbol::Linkage::external : Symbol::Linkage::internal, }, .root_node = 0, + .stop_node = 0, .parameters = thread->arena->allocate_many(argument_type_abis.length), .prototype = { .argument_type_abis = argument_type_abis.pointer, @@ -4312,26 +4606,29 @@ fn void analyze_function(Parser* parser, Thread* thread, Unit* unit, File* file) function->root_node->payload.root.args = root_type; function->root_node->peephole(thread, function); - auto* scope_node = Node::add(thread, { - .type = { .id = Node::Type::Id::BOTTOM }, - .inputs = { .pointer = &function->root_node, .length = 1 }, - .id = Node::Id::SCOPE, + function->stop_node = Node::add(thread, { + .type = {}, + .inputs = {}, + .id = Node::Id::STOP, }); - scope_node->payload.scope.stack = {}; + Analyzer analyzer = { .function = function, - .scope = scope_node, + .scope = create_scope(thread), .file = file, }; push_scope(&analyzer); auto control_name = strlit("$control"); s32 next_index = 0; - Node* control_node = function->root_node->project(thread, function, next_index, control_name)->peephole(thread, function); + Node* control_node = function->root_node->project(thread, function->root_node, next_index, control_name)->peephole(thread, function); next_index += 1; define_variable(&analyzer, control_name, control_node); // assert(abi_argument_type_count == 0); // TODO: reserve memory for them + assert(argument_type_abis.length == argument_type_abis.length); + assert(argument_names.length == argument_type_abis.length); + for (u32 i = 0; i < argument_type_abis.length; i += 1) { auto* abi_info = &argument_type_abis[i]; @@ -4344,7 +4641,8 @@ fn void analyze_function(Parser* parser, Thread* thread, Unit* unit, File* file) trap(); case ABI_INFO_DIRECT: { - auto* argument_node = function->root_node->project(thread, function, next_index, argument_name)->peephole(thread, function); + auto* argument_node = function->root_node->project(thread, function->root_node, next_index, argument_name)->peephole(thread, function); + assert(argument_node->type.id != Node::Type::Id::CONTROL); define_variable(&analyzer, argument_name, argument_node); next_index += 1; } break; @@ -4368,9 +4666,15 @@ fn void analyze_function(Parser* parser, Thread* thread, Unit* unit, File* file) analyze_local_block(&analyzer, parser, unit, thread, src); pop_scope(&analyzer); + + function->stop_node->peephole(thread, function); + + return function->stop_node; } break; case 1: trap(); + default: + trap(); } } @@ -4418,106 +4722,13 @@ fn void unit_file_analyze(Thread* thread, Unit* unit, File* file) global Instance instance; -// fn Node* instruction_selection(Node* node) -// { -// switch (node->id) -// { -// case Node::Id::PROJECTION: -// return node; -// case Node::Id::ROOT: -// { -// return node; -// } -// case Node::Id::RETURN: -// trap(); -// case Node::Id::CONSTANT_INT: -// trap(); -// break; -// } -// trap(); -// } - -// fn void function_codegen(Function* function) -// { -// WorkList helper = {}; -// helper.ensure_capacity(function->node_count); -// -// helper.push(function->root_node); -// PinnedArray pins = {}; -// -// u64 i = 0; -// while (i < helper.nodes.length) -// { -// Node* node = helper.nodes[i]; -// i += 1; -// -// if (node->is_pinned() & !node->is_projection()) -// { -// pins.append_one(node); -// } -// -// for (Output& output : node->get_outputs()) -// { -// helper.push(output.node); -// } -// } -// -// helper.clear(); -// -// WorkList walker = {}; -// walker.ensure_capacity(function->node_count); -// -// for (Node* pin_node : pins.slice()) -// { -// walker.push(pin_node); -// -// while (walker.nodes.length > 0) -// { -// Node* node = walker.nodes.pop(); -// -// if (!node->is_projection() & (node->output_count == 0)) -// { -// helper.push(node); -// continue; -// } -// -// if (node->data_type.id == Node::Type::Id::MEMORY) -// { -// trap(); -// } -// -// Node* new_node = instruction_selection(node); -// if (new_node && new_node != node) -// { -// trap(); -// } -// -// u16 input_i = node->input_count; -// while (input_i > 0) -// { -// input_i -= 1; -// -// if (node->inputs[input_i]) -// { -// trap(); -// } -// } -// -// // TODO: region -// } -// } -// -// -// -// trap(); -// } - -String test_file_paths[] = { +global String test_file_paths[] = { strlit("tests/first/main.nat"), strlit("tests/constant_prop/main.nat"), strlit("tests/simple_variable_declaration/main.nat"), strlit("tests/function_call_args/main.nat"), strlit("tests/comparison/main.nat"), + strlit("tests/if/main.nat"), }; #ifdef __linux__ diff --git a/tests/if/main.nat b/tests/if/main.nat new file mode 100644 index 0000000..981df95 --- /dev/null +++ b/tests/if/main.nat @@ -0,0 +1,130 @@ +fn if0(arg: s32) s32 +{ + >a: s32 = 1; + if (arg == 1) + { + a = arg + 2; + } + else + { + a = arg - 3; + } + + return a; +} + +fn if1(arg: s32) s32 +{ + >c: s32 = 3; + >b: s32 = 2; + + if (arg == 1) + { + b = 3; + c = 4; + } + + return c; +} + +fn if2(arg: s32) s32 +{ + if (arg == 1) + { + return 3; + } + else + { + return 4; + } +} + +fn if3(arg: s32) s32 +{ + >a: s32 = arg + 1; + >b: s32 = 0; + if (arg == 1) + { + b = a; + } + else + { + b = a + 1; + } + + return a + b; +} + +fn if4(arg: s32) s32 +{ + >a: s32 = arg + 1; + >b: s32 = arg + 2; + if (arg == 1) + { + b = b + a; + } + else + { + a = b + 1; + } + + return a + b; +} + +fn if5(arg: s32) s32 +{ + >a: s32 = 1; + if (arg == 1) + { + if (arg == 2) + { + a = 2; + } + else + { + a = 3; + } + } + else if (arg == 3) + { + a = 4; + } + else + { + a = 5; + } + + return a; +} + +fn if6(arg: s32) s32 +{ + >a: s32 = 0; + >b: s32 = 0; + if (arg) + { + a = 1; + } + if (arg == 0) + { + b = 2; + } + + return arg + a + b; +} + +fn if7(arg: s32) s32 +{ + >a: s32 = arg == 2; + if (arg == 1) + { + a = arg == 3; + } + + return a; +} + +fn[cc(.c)] main[export] () s32 +{ + return if0(3) + if1(1) - 4 + if2(1) - 3 + if3(1) - 4 + if4(0) - 5 + if5(4) - 5 + if6(0) - 2 + if7(0); +}