From ff258e3df32354ba76f5b57d9ae0e9c520e556b6 Mon Sep 17 00:00:00 2001 From: David Gonzalez Martin Date: Sat, 6 Jul 2024 17:41:42 +0200 Subject: [PATCH] Implement comparisons --- bootstrap/main.cpp | 321 ++++++++++++++++++++++++++++---------- tests/comparison/main.nat | 10 ++ 2 files changed, 248 insertions(+), 83 deletions(-) create mode 100644 tests/comparison/main.nat diff --git a/bootstrap/main.cpp b/bootstrap/main.cpp index a70fbf0..5ceb0eb 100644 --- a/bootstrap/main.cpp +++ b/bootstrap/main.cpp @@ -997,7 +997,7 @@ struct Arena u64 commited; u64 commit_position; u64 granularity; - u8 reserved[4 * 8]; + u8 reserved[4 * 8] = {}; global auto constexpr minimum_granularity = KB(4); global auto constexpr middle_granularity = MB(2); @@ -1472,15 +1472,13 @@ struct NodeType struct { u64 constant; - u8 bit_count; u8 is_constant; } integer; struct { Slice types; } multi; - }; - + } payload = {}; u8 is_simple() { @@ -1506,6 +1504,8 @@ struct NodeType switch (id) { + case NodeType::Id::INTEGER: + return (payload.integer.is_constant == other.payload.integer.is_constant) & (payload.integer.constant == other.payload.integer.constant); default: trap(); } @@ -1518,7 +1518,7 @@ struct NodeType case Id::VOID: trap(); case Id::INTEGER: - return integer.is_constant; + return payload.integer.is_constant; case Id::CONTROL: case Id::MULTIVALUE: case Id::BOTTOM: @@ -1572,7 +1572,7 @@ struct NodeType } assert(is_constant() & other.is_constant()); - if (integer.constant == other.integer.constant) + if (payload.integer.constant == other.payload.integer.constant) { trap(); } @@ -1589,37 +1589,43 @@ struct NodeType u8 is_bot() { assert(id == Id::INTEGER); - return !integer.is_constant & (integer.constant == 1); + return !payload.integer.is_constant & (payload.integer.constant == 1); } u8 is_top() { assert(id == Id::INTEGER); - return !integer.is_constant & (integer.constant == 0); + return !payload.integer.is_constant & (payload.integer.constant == 0); } }; may_be_unused global auto constexpr integer_top = NodeType{ .id = NodeType::Id::TOP, - .integer = { - .constant = 0, - .is_constant = 0, + .payload = { + .integer = { + .constant = 0, + .is_constant = 0, + }, }, }; may_be_unused global auto constexpr integer_bot = NodeType{ .id = NodeType::Id::TOP, - .integer = { - .constant = 1, - .is_constant = 0, + .payload = { + .integer = { + .constant = 1, + .is_constant = 0, + }, }, }; may_be_unused global auto constexpr integer_zero = NodeType{ .id = NodeType::Id::TOP, - .integer = { - .constant = 0, - .is_constant = 1, + .payload = { + .integer = { + .constant = 0, + .is_constant = 1, + }, }, }; @@ -1630,7 +1636,7 @@ struct SemaType SemaTypeId id : type_id_bit_count; u32 resolved: 1; u32 flags: type_flags_bit_count; - u32 reserved; + u32 reserved = 0; String name; u8 get_bit_count() @@ -1656,9 +1662,12 @@ struct SemaType case SemaTypeId::INTEGER: return NodeType{ .id = NodeType::Id::INTEGER, - .integer = { - .bit_count = get_bit_count(), - .is_constant = 0, + .payload = { + .integer = { + .constant = 0, + // .bit_count = get_bit_count(), + .is_constant = 0, + }, }, }; case SemaTypeId::ARRAY: @@ -1726,8 +1735,8 @@ struct Function; struct Thread { Arena* arena; - PinnedArray functions; - u32 node_count; + PinnedArray functions = {}; + u32 node_count = 0; }; struct Unit @@ -1776,8 +1785,8 @@ typedef struct AbiInfoAttributes AbiInfoAttributes; struct AbiInfo { AbiInfoPayload payload; - u16 indices[2]; - AbiInfoAttributes attributes; + u16 indices[2] = {}; + AbiInfoAttributes attributes = {}; AbiInfoKind kind; }; @@ -1811,7 +1820,6 @@ struct ConstantIntData { u64 value; Node* input; - u32 gvn; u8 bit_count; }; @@ -1827,8 +1835,15 @@ struct Node PROJECTION, RETURN, CONSTANT_INT, - INT_ADD, - INT_SUB, + INTEGER_ADD, + INTEGER_SUB, + + INTEGER_COMPARE_EQUAL, + INTEGER_COMPARE_NOT_EQUAL, + INTEGER_COMPARE_LESS, + INTEGER_COMPARE_LESS_EQUAL, + INTEGER_COMPARE_GREATER, + INTEGER_COMPARE_GREATER_EQUAL, SCOPE, SYMBOL_FUNCTION, CALL, @@ -1858,9 +1873,9 @@ struct Node Type args; } root; Symbol* symbol; - }; + } payload; - u8 padding[40]; + u8 padding[40] = {}; forceinline Slice get_inputs() { @@ -1896,6 +1911,7 @@ struct Node .outputs = {}, .gvn = gvn, .id = data.id, + .payload = {}, }; node->inputs.append(data.inputs); @@ -1970,14 +1986,21 @@ struct Node case Id::PROJECTION: case Id::CONSTANT_INT: break; - case Id::INT_ADD: - case Id::INT_SUB: + case Id::INTEGER_ADD: + case Id::INTEGER_SUB: trap(); case Id::SCOPE: trap(); case Id::SYMBOL_FUNCTION: case Id::CALL: trap(); + case Id::INTEGER_COMPARE_EQUAL: + case Id::INTEGER_COMPARE_NOT_EQUAL: + case Id::INTEGER_COMPARE_LESS: + case Id::INTEGER_COMPARE_LESS_EQUAL: + case Id::INTEGER_COMPARE_GREATER: + case Id::INTEGER_COMPARE_GREATER_EQUAL: + trap(); } return is_good_id | is_projection() | cfg_is_control_projection(); @@ -2022,7 +2045,7 @@ struct Node { switch (id) { - case Id::INT_SUB: + case Id::INTEGER_SUB: if (inputs[1] == inputs[2]) { trap(); @@ -2035,7 +2058,7 @@ struct Node case Id::PROJECTION: case Id::RETURN: case Id::CONSTANT_INT: - case Id::INT_ADD: + case Id::INTEGER_ADD: return 0; case Id::SCOPE: trap(); @@ -2043,6 +2066,13 @@ struct Node case Id::SYMBOL_FUNCTION: case Id::CALL: return 0; + case Id::INTEGER_COMPARE_EQUAL: + case Id::INTEGER_COMPARE_NOT_EQUAL: + case Id::INTEGER_COMPARE_LESS: + case Id::INTEGER_COMPARE_LESS_EQUAL: + case Id::INTEGER_COMPARE_GREATER: + case Id::INTEGER_COMPARE_GREATER_EQUAL: + trap(); } } @@ -2140,9 +2170,15 @@ struct Node switch (id) { case Node::Id::ROOT: - return root.args; - case Node::Id::INT_ADD: - case Node::Id::INT_SUB: + return payload.root.args; + case Node::Id::INTEGER_ADD: + case Node::Id::INTEGER_SUB: + case Node::Id::INTEGER_COMPARE_EQUAL: + case Node::Id::INTEGER_COMPARE_NOT_EQUAL: + case Node::Id::INTEGER_COMPARE_LESS: + case Node::Id::INTEGER_COMPARE_LESS_EQUAL: + case Node::Id::INTEGER_COMPARE_GREATER: + case Node::Id::INTEGER_COMPARE_GREATER_EQUAL: { auto left_type = inputs[1]->type; auto right_type = inputs[2]->type; @@ -2161,20 +2197,40 @@ struct Node case Id::SYMBOL_FUNCTION: case Id::CALL: trap(); - case Id::INT_ADD: - result = left_type.integer.constant + right_type.integer.constant; + case Id::INTEGER_ADD: + result = left_type.payload.integer.constant + right_type.payload.integer.constant; break; - case Id::INT_SUB: - result = left_type.integer.constant - right_type.integer.constant; + case Id::INTEGER_SUB: + result = left_type.payload.integer.constant - right_type.payload.integer.constant; + break; + case Id::INTEGER_COMPARE_EQUAL: + result = left_type.payload.integer.constant == right_type.payload.integer.constant; + break; + case Id::INTEGER_COMPARE_NOT_EQUAL: + result = left_type.payload.integer.constant != right_type.payload.integer.constant; + break; + case Id::INTEGER_COMPARE_LESS: + result = left_type.payload.integer.constant < right_type.payload.integer.constant; + break; + case Id::INTEGER_COMPARE_LESS_EQUAL: + result = left_type.payload.integer.constant <= right_type.payload.integer.constant; + break; + case Id::INTEGER_COMPARE_GREATER: + result = left_type.payload.integer.constant > right_type.payload.integer.constant; + break; + case Id::INTEGER_COMPARE_GREATER_EQUAL: + result = left_type.payload.integer.constant >= right_type.payload.integer.constant; break; } return Node::Type{ .id = Node::Type::Id::INTEGER, - .integer = { - .constant = result, - .bit_count = left_type.integer.bit_count, - .is_constant = 1, + .payload = { + .integer = { + .constant = result, + // .bit_count = left_type.payload.integer.bit_count, + .is_constant = 1, + }, }, }; } @@ -2195,7 +2251,7 @@ struct Node auto* control_node = inputs[0]; if (control_node->type.id == NodeType::Id::MULTIVALUE) { - auto type = control_node->type.multi.types[this->projection.index]; + auto type = control_node->type.payload.multi.types[this->payload.projection.index]; return type; } else @@ -2217,9 +2273,11 @@ struct Node types.append_one(inputs[1]->type); return Type{ .id = Node::Type::Id::MULTIVALUE, + .payload = { .multi = { .types = types.slice(), }, + }, }; } default: @@ -2235,8 +2293,8 @@ struct Node .inputs = { .pointer = &function->root_node, .length = 1 }, .id = Node::Id::PROJECTION, }); - projection->projection.index = index; - projection->projection.name = label; + projection->payload.projection.index = index; + projection->payload.projection.name = label; return projection; } @@ -2267,11 +2325,12 @@ static_assert(page_size % sizeof(Node) == 0); .type = { .id = Node::Type::Id::INTEGER, - .integer = - { - .constant = data.value, - .bit_count = data.bit_count, - .is_constant = 1, + .payload = { + .integer = { + .constant = data.value, + // .bit_count = data.bit_count, + .is_constant = 1, + }, }, }, .inputs = { .pointer = &data.input, .length = 1 }, @@ -2491,6 +2550,7 @@ fn void unit_initialize(Unit* unit) // .node_arena = Arena::init(Arena::default_size, Arena::minimum_granularity, KB(64)), // .type_arena = type_arena, .builtin_types = builtin_types, + .generate_debug_information = 1, }; builtin_types[void_type_index] = { @@ -2498,6 +2558,7 @@ fn void unit_initialize(Unit* unit) .alignment = 1, .id = SemaTypeId::VOID, .resolved = 1, + .flags = 0, .name = strlit("void"), }; builtin_types[noreturn_type_index] = { @@ -2505,6 +2566,7 @@ fn void unit_initialize(Unit* unit) .alignment = 1, .id = SemaTypeId::NORETURN, .resolved = 1, + .flags = 0, .name = strlit("noreturn"), }; builtin_types[opaque_pointer_type_index] = { @@ -2512,6 +2574,7 @@ fn void unit_initialize(Unit* unit) .alignment = 8, .id = SemaTypeId::POINTER, .resolved = 1, + .flags = 0, .name = strlit("*any"), }; // TODO: float types @@ -2747,7 +2810,7 @@ struct File String path; String source_code; FileStatus status; - Hashmap symbols; + Hashmap symbols = {}; }; fn File* add_file(Arena* arena, String file_path) @@ -2755,6 +2818,8 @@ fn File* add_file(Arena* arena, String file_path) auto* file = arena->allocate_one(); *file = { .path = file_path, + .source_code = {}, + .status = FILE_STATUS_ADDED, }; return file; } @@ -3077,7 +3142,7 @@ fn Node* scope_update_extended(Node* scope, String name, Node* node, s32 nesting } // TODO: avoid recursion - auto& map = scope->scope.stack[nesting_level]; + auto& map = scope->payload.scope.stack[nesting_level]; if (auto index = map.get(name)) { auto* old = scope->get_inputs()[*index]; @@ -3103,7 +3168,7 @@ fn Node* scope_update_extended(Node* scope, String name, Node* node, s32 nesting fn Node* scope_lookup(Analyzer* analyzer, String name) { - if (auto* node = scope_update_extended(analyzer->scope, name, nullptr, analyzer->scope->scope.stack.length - 1)) + if (auto* node = scope_update_extended(analyzer->scope, name, nullptr, analyzer->scope->payload.scope.stack.length - 1)) { return node; } @@ -3238,6 +3303,7 @@ fn Node* scope_lookup(Analyzer* analyzer, String name) argument_nodes.append_one(node); Node* call_node = Node::add(thread, { + .type = {}, .inputs = argument_nodes.slice(), .id = Node::Id::CALL, })->peephole(thread, function); @@ -3258,10 +3324,17 @@ fn Node* scope_lookup(Analyzer* analyzer, String name) enum class CurrentOperation { NONE, - ADD, - ADD_ASSIGN, - SUB, - SUB_ASSIGN, + ASSIGN, + INTEGER_ADD, + INTEGER_ADD_ASSIGN, + INTEGER_SUB, + INTEGER_SUB_ASSIGN, + INTEGER_COMPARE_EQUAL, + INTEGER_COMPARE_NOT_EQUAL, + INTEGER_COMPARE_LESS, + INTEGER_COMPARE_LESS_EQUAL, + INTEGER_COMPARE_GREATER, + INTEGER_COMPARE_GREATER_EQUAL, }; u64 iterations = 0; @@ -3295,22 +3368,31 @@ fn Node* scope_lookup(Analyzer* analyzer, String name) case CurrentOperation::NONE: previous_node = current_node; break; - case CurrentOperation::ADD: - case CurrentOperation::SUB: + case CurrentOperation::INTEGER_ADD: + case CurrentOperation::INTEGER_SUB: { Node::Id id; switch (current_operation) { case CurrentOperation::NONE: trap(); - case CurrentOperation::ADD: - id = Node::Id::INT_ADD; + case CurrentOperation::INTEGER_ADD: + id = Node::Id::INTEGER_ADD; break; - case CurrentOperation::SUB: - id = Node::Id::INT_SUB; + case CurrentOperation::INTEGER_SUB: + id = Node::Id::INTEGER_SUB; break; - case CurrentOperation::ADD_ASSIGN: - case CurrentOperation::SUB_ASSIGN: + case CurrentOperation::INTEGER_ADD_ASSIGN: + case CurrentOperation::INTEGER_SUB_ASSIGN: + trap(); + case CurrentOperation::INTEGER_COMPARE_EQUAL: + trap(); + case CurrentOperation::ASSIGN: + case CurrentOperation::INTEGER_COMPARE_NOT_EQUAL: + case CurrentOperation::INTEGER_COMPARE_LESS: + case CurrentOperation::INTEGER_COMPARE_LESS_EQUAL: + case CurrentOperation::INTEGER_COMPARE_GREATER: + case CurrentOperation::INTEGER_COMPARE_GREATER_EQUAL: trap(); } @@ -3328,9 +3410,57 @@ fn Node* scope_lookup(Analyzer* analyzer, String name) previous_node = binary; } break; - default: - trap(); - } + case CurrentOperation::INTEGER_COMPARE_EQUAL: + case CurrentOperation::INTEGER_COMPARE_NOT_EQUAL: + case CurrentOperation::INTEGER_COMPARE_LESS: + case CurrentOperation::INTEGER_COMPARE_LESS_EQUAL: + case CurrentOperation::INTEGER_COMPARE_GREATER: + case CurrentOperation::INTEGER_COMPARE_GREATER_EQUAL: + { + Node::Id id; + switch (current_operation) + { + case CurrentOperation::INTEGER_COMPARE_EQUAL: + id = Node::Id::INTEGER_COMPARE_EQUAL; + break; + case CurrentOperation::INTEGER_COMPARE_NOT_EQUAL: + id = Node::Id::INTEGER_COMPARE_NOT_EQUAL; + break; + case CurrentOperation::INTEGER_COMPARE_LESS: + id = Node::Id::INTEGER_COMPARE_LESS; + break; + case CurrentOperation::INTEGER_COMPARE_LESS_EQUAL: + id = Node::Id::INTEGER_COMPARE_LESS_EQUAL; + break; + case CurrentOperation::INTEGER_COMPARE_GREATER: + id = Node::Id::INTEGER_COMPARE_GREATER; + break; + case CurrentOperation::INTEGER_COMPARE_GREATER_EQUAL: + id = Node::Id::INTEGER_COMPARE_GREATER_EQUAL; + break; + default: + trap(); + } + + Node* inputs[] = { + 0, + previous_node, + current_node, + }; + + auto* binary = Node::add(thread, { + .type = current_node->type, + .inputs = { .pointer = inputs, .length = array_length(inputs), }, + .id = id, + }); + + previous_node = binary; + } break; + case CurrentOperation::ASSIGN: + case CurrentOperation::INTEGER_ADD_ASSIGN: + case CurrentOperation::INTEGER_SUB_ASSIGN: + trap(); + } previous_node = previous_node->peephole(thread, analyzer->function); @@ -3345,13 +3475,13 @@ fn Node* scope_lookup(Analyzer* analyzer, String name) case bracket_close: return previous_node; case '+': - current_operation = CurrentOperation::ADD; + current_operation = CurrentOperation::INTEGER_ADD; parser->i += 1; switch (src[parser->i]) { case '=': - current_operation = CurrentOperation::ADD_ASSIGN; + current_operation = CurrentOperation::INTEGER_ADD_ASSIGN; parser->i += 1; break; default: @@ -3359,13 +3489,27 @@ fn Node* scope_lookup(Analyzer* analyzer, String name) } break; case '-': - current_operation = CurrentOperation::SUB; + current_operation = CurrentOperation::INTEGER_SUB; parser->i += 1; switch (src[parser->i]) { case '=': - current_operation = CurrentOperation::SUB_ASSIGN; + current_operation = CurrentOperation::INTEGER_SUB_ASSIGN; + parser->i += 1; + break; + default: + break; + } + break; + case '=': + current_operation = CurrentOperation::ASSIGN; + parser->i += 1; + + switch (src[parser->i]) + { + case '=': + current_operation = CurrentOperation::INTEGER_COMPARE_EQUAL; parser->i += 1; break; default: @@ -3389,17 +3533,17 @@ fn Node* scope_lookup(Analyzer* analyzer, String name) fn void push_scope(Analyzer* analyzer) { - analyzer->scope->scope.stack.append_one({}); + analyzer->scope->payload.scope.stack.append_one({}); } fn void pop_scope(Analyzer* analyzer) { - analyzer->scope->scope.stack.pop(); + analyzer->scope->payload.scope.stack.pop(); } fn Node* define_variable(Analyzer* analyzer, String name, Node* node) { - auto* stack = &analyzer->scope->scope.stack; + auto* stack = &analyzer->scope->payload.scope.stack; assert(stack->length); auto* last = &stack->pointer[stack->length - 1]; @@ -3461,7 +3605,7 @@ fn Node* analyze_local_block(Analyzer* analyzer, Parser* parser, Unit* unit, Thr if (!statement_node) { - auto& list = analyzer->scope->scope.stack; + auto& list = analyzer->scope->payload.scope.stack; u32 i = list.length; u8 found = 0; while (i > 0) @@ -3812,7 +3956,9 @@ fn void analyze_function(Parser* parser, Thread* thread, Unit* unit, File* file) .outputs = {}, .gvn = function_gvn, .id = Node::Id::SYMBOL_FUNCTION, - .symbol = &function->symbol, + .payload = { + .symbol = &function->symbol, + }, }); parser->skip_space(src); @@ -4150,12 +4296,20 @@ fn void analyze_function(Parser* parser, Thread* thread, Unit* unit, File* file) root_arg_types.append(abi_argument_types.slice()); - Node::Type root_type = { .id = Node::Type::Id::MULTIVALUE, .multi = { .types = root_arg_types.slice(), }, }; + Node::Type root_type = { + .id = Node::Type::Id::MULTIVALUE, + .payload = { + .multi = { + .types = root_arg_types.slice(), + }, + }, + }; function->root_node = Node::add(thread, { .type = root_type, + .inputs = {}, .id = Node::Id::ROOT, }); - function->root_node->root.args = root_type; + function->root_node->payload.root.args = root_type; function->root_node->peephole(thread, function); auto* scope_node = Node::add(thread, { @@ -4163,7 +4317,7 @@ fn void analyze_function(Parser* parser, Thread* thread, Unit* unit, File* file) .inputs = { .pointer = &function->root_node, .length = 1 }, .id = Node::Id::SCOPE, }); - scope_node->scope.stack = {}; + scope_node->payload.scope.stack = {}; Analyzer analyzer = { .function = function, .scope = scope_node, @@ -4363,6 +4517,7 @@ String test_file_paths[] = { strlit("tests/constant_prop/main.nat"), strlit("tests/simple_variable_declaration/main.nat"), strlit("tests/function_call_args/main.nat"), + strlit("tests/comparison/main.nat"), }; #ifdef __linux__ diff --git a/tests/comparison/main.nat b/tests/comparison/main.nat new file mode 100644 index 0000000..f1b1793 --- /dev/null +++ b/tests/comparison/main.nat @@ -0,0 +1,10 @@ +fn foo(arg: s32) s32 +{ + return arg == 0; +} + +fn[cc(.c)] main [export] () s32 +{ + >arg: s32 = 0; + return foo(arg); +}