diff --git a/bootstrap/main.cpp b/bootstrap/main.cpp index 8bd53d8..24c52d7 100644 --- a/bootstrap/main.cpp +++ b/bootstrap/main.cpp @@ -1953,6 +1953,7 @@ struct ConstantIntData [[nodiscard]] fn Node* add_constant_integer(Thread* thread, ConstantIntData data); +struct File; // This is a node in the "sea of nodes" sense: // https://en.wikipedia.org/wiki/Sea_of_nodes struct Node @@ -2225,7 +2226,35 @@ struct Node { if (inputs[i]->type.id == NodeType::Id::DEAD_CONTROL) { - trap(); + for (u32 output_index = 0; output_index < outputs.length; output_index += 1) + { + Node* output = outputs[output_index]; + if (output->id == Id::PHI) + { + output->remove_input(thread->arena, i); + } + } + + remove_input(thread->arena, i); + + if (inputs.length == 2) + { + for (u32 output_index = 0; output_index < outputs.length; output_index += 1) + { + Node* output = outputs[output_index]; + if (output->id == Id::PHI) + { + // TODO: + trap(); + } + } + + return inputs[1]; + } + else + { + trap(); + } } } } @@ -2521,15 +2550,24 @@ struct Node return { .id = Type::Id::BOTTOM }; case Node::Id::IF: { - if (get_control()->type.id != NodeType::Id::LIVE_CONTROL) + auto* control_node = get_control(); + if (control_node->type.id != NodeType::Id::LIVE_CONTROL && control_node->type.id != Node::Type::Id::BOTTOM) { - trap(); + return type_if_neither; } auto* this_predicate = predicate(); if ((this_predicate->type.id == Node::Type::Id::INTEGER) & this_predicate->type.is_constant()) { - trap(); + auto value = this_predicate->type.payload.constant.constant; + if (value) + { + return type_if_true; + } + else + { + return type_if_false; + } } for (Node* dom = get_immediate_dominator(), *prior = this; dom; prior = dom, dom = dom->get_immediate_dominator()) @@ -2649,6 +2687,15 @@ struct Node }; } case Node::Id::REGION_LOOP: + if (region_in_progress()) + { + return { .id = Type::Id::LIVE_CONTROL }; + } + else + { + auto* entry = loop_entry(); + return entry->type; + } case Node::Id::REGION: if (region_in_progress()) { @@ -2923,35 +2970,6 @@ struct Node return inputs[0]; } - method void scope_end_loop(Thread* thread, Function* function, Node* back, Node* exit) - { - assert(id == Id::SCOPE); - assert(back->id == Id::SCOPE); - assert(exit->id == Id::SCOPE); - - Node* control_node = get_control(); - assert(control_node->id == Id::REGION_LOOP); - assert(control_node->region_in_progress()); - control_node->set_input(thread->arena, 2, back->get_control()); - - for (u32 i = 1; i < inputs.length; i += 1) - { - auto* phi_node = inputs[i]; - assert(phi_node->id == Id::PHI); - assert(phi_node->phi_get_region() == control_node); - assert(!(phi_node->inputs[2])); - phi_node->set_input(thread->arena, 2, back->inputs[2]); - Node* input = phi_node->peephole(thread, function); - - if (input != phi_node) - { - phi_node->subsume(thread->arena, input); - } - } - - back->kill(thread->arena); - } - method void subsume(Arena* arena, Node* node) { assert(node != this); @@ -2967,8 +2985,204 @@ struct Node kill(arena); } + + method Slice scope_reverse_names(Arena* arena) + { + assert(id == Node::Id::SCOPE); + Slice names = arena->allocate_slice(inputs.length); + + for (auto& hashmap : payload.scope.stack.slice()) + { + for (String name : hashmap.key_slice()) + { + auto index = *hashmap.get(name); + names[index] = name; + } + } + + return names; + } + + method Node* scope_update_extended(Thread* thread, Function* function, String name, Node* node, s32 nesting_level) + { + assert(id == Id::SCOPE); + if (nesting_level < 0) + { + return 0; + } + + // TODO: avoid recursion + auto& map = payload.scope.stack[nesting_level]; + if (auto* index_ptr = map.get(name)) + { + auto index = *index_ptr; + auto* old = get_inputs()[index]; + + if (old->id == Node::Id::SCOPE) + { + auto* loop = old; + if (loop->inputs[index]->id == Id::PHI && loop->get_control() == loop->inputs[index]->phi_get_region()) + { + old = loop->inputs[index]; + } + else + { + + Node* phi_inputs[] = { + loop->get_control(), + loop->scope_update_extended(thread, function, name, 0, nesting_level), + 0, + }; + auto* phi_node = Node::add(thread, { + .type = {}, + .inputs = array_to_slice(phi_inputs), + .id = Node::Id::PHI, + }); + phi_node->payload.phi.label = name; + phi_node = phi_node->peephole(thread, function); + old = loop->set_input(thread->arena, index, phi_node); + } + + set_input(thread->arena, index, old); + } + + if (node) + { + return set_input(thread->arena, index, node); + } + else + { + return old; + } + } + else + { + return scope_update_extended(thread, function, name, node, nesting_level - 1); + } + } + + method Node* scope_update(Thread* thread, Function* function, String name, Node* node) + { + assert(id == Id::SCOPE); + return scope_update_extended(thread, function, name, node, payload.scope.stack.length - 1); + } + + method void scope_end_loop(Thread* thread, Function* function, Node* back, Node* exit) + { + assert(id == Id::SCOPE); + assert(back->id == Id::SCOPE); + assert(exit->id == Id::SCOPE); + + Node* control_node = get_control(); + assert(control_node->id == Id::REGION_LOOP); + assert(control_node->region_in_progress()); + control_node->set_input(thread->arena, 2, back->get_control()); + for (u32 i = 1; i < inputs.length; i += 1) + { + if (back->inputs[i] != this) + { + auto* phi = inputs[i]; + assert(phi->id == Id::PHI); + assert(phi->phi_get_region() == get_control()); + assert(!phi->inputs[2]); + phi->set_input(thread->arena, 2, back->inputs[i]); + } + + if (exit->inputs[i] == this) + { + exit->set_input(thread->arena, i, inputs[i]); + } + } + + back->kill(thread->arena); + + for (u32 i = 1; i < inputs.length; i += 1) + { + auto* node = inputs[i]; + if (node->id == Id::PHI) + { + Node* input = node->peephole(thread, function); + if (input != node) + { + node->subsume(thread->arena, input); + set_input(thread->arena, i, input); + } + } + } + } + + method Node* scope_lookup(Thread* thread, Function* function, File* file, String name); + method Node* merge_scopes(Thread* thread, File* file, Function* function, Node* other); }; +struct File +{ + String path; + String source_code; + FileStatus status; + Hashmap symbols = {}; +}; + +method Node* Node::scope_lookup(Thread* thread, Function* function, File* file, String name) +{ + auto* result = scope_update_extended(thread, function, name, nullptr, payload.scope.stack.length - 1); + if (file && !result) + { + result = file->symbols.get(name); + } + + return result; +} + +method Node* Node::merge_scopes(Thread* thread, File* file, Function* function, Node* other_scope) +{ + assert(id == Node::Id::SCOPE); + assert(other_scope->id == Node::Id::SCOPE); + + Node* region_inputs[] = { + 0, + get_control(), + other_scope->get_control(), + }; + + auto* region_node = set_control(thread->arena, Node::add(thread, { + .type = {}, + .inputs = array_to_slice(region_inputs), + .id = Node::Id::REGION, + })->keep()); + auto names = scope_reverse_names(thread->arena); + + // Skip input[0] ($ctrl) + for (u32 i = 1; i < inputs.length; i += 1) + { + if (inputs[i] != other_scope->inputs[i]) + { + String label = names[i]; + Node* input_a = scope_lookup(thread, function, file, label); + Node* input_b = other_scope->scope_lookup(thread, function, file, label); + + Node* inputs[] = { + region_node, + input_a, + input_b, + }; + + auto* phi_node = Node::add(thread, { + .type = {}, + .inputs = array_to_slice(inputs), + .id = Node::Id::PHI, + }); + phi_node->payload.phi.label = label; + phi_node = phi_node->peephole(thread, function); + + set_input(thread->arena, i, phi_node); + } + } + + other_scope->kill(thread->arena); + return region_node->unkeep()->peephole(thread, function); +} + static_assert(sizeof(Node) == 128); static_assert(page_size % sizeof(Node) == 0); @@ -3384,14 +3598,6 @@ struct Parser // { // return parser->i - parser->column + 1; // } -struct File -{ - String path; - String source_code; - FileStatus status; - Hashmap symbols = {}; -}; - fn File* add_file(Arena* arena, String file_path) { auto* file = arena->allocate_one(); @@ -3495,27 +3701,12 @@ Node* create_scope(Thread* thread) return scope; } -Slice scope_reverse_names(Arena* arena, Node* node) -{ - assert(node->id == Node::Id::SCOPE); - Slice names = arena->allocate_slice(node->inputs.length); - - for (auto& hashmap : node->payload.scope.stack.slice()) - { - for (String name : hashmap.key_slice()) - { - auto index = *hashmap.get(name); - names[index] = name; - } - } - - return names; -} - struct Analyzer { Function* function; Node* scope; + Node* break_scope = 0; + Node* continue_scope = 0; File* file; @@ -3559,11 +3750,12 @@ struct Analyzer method Node* duplicate_scope(Thread* thread, u8 loop) { - auto original_input_count = scope->inputs.length; - auto* duplicate = create_scope(thread); + auto* original_scope = scope; + auto original_input_count = original_scope->inputs.length; + auto* duplicate_scope = create_scope(thread); // // TODO: make this more efficient - for (auto& hashmap: scope->payload.scope.stack.slice()) + for (auto& hashmap: original_scope->payload.scope.stack.slice()) { Hashmap duplicate_hashmap = {}; duplicate_hashmap.ensure_capacity(hashmap.length); @@ -3575,88 +3767,19 @@ struct Analyzer duplicate_hashmap.put_assume_not_existing(keys[i], values[i]); } - duplicate->payload.scope.stack.append_one(duplicate_hashmap); + duplicate_scope->payload.scope.stack.append_one(duplicate_hashmap); } - duplicate->add_input(get_control()); + duplicate_scope->add_input(get_control()); - for (u32 i = 1; i < scope->inputs.length; i += 1) + for (u32 i = 1; i < original_scope->inputs.length; i += 1) { - if (loop) - { - auto names = scope_reverse_names(thread->arena, scope); - Node* inputs[] = { - get_control(), - scope->inputs[i], - 0, - }; - String label = names[i]; - auto* phi_node = Node::add(thread, { - .type = {}, - .inputs = array_to_slice(inputs), - .id = Node::Id::PHI, - }); - phi_node->payload.phi.label = label; - phi_node = phi_node->peephole(thread, function); - duplicate->add_input(phi_node); - scope->set_input(thread->arena, i, duplicate->inputs[i]); - } - else - { - duplicate->add_input(scope->inputs[i]); - } + duplicate_scope->add_input(loop ? original_scope : original_scope->inputs[i]); } - assert(duplicate->inputs.length == original_input_count); - return duplicate; + assert(duplicate_scope->inputs.length == original_input_count); + return duplicate_scope; } - method Node* merge_scopes(Thread* thread, Node* scope_a, Node* scope_b) - { - assert(scope_a->id == Node::Id::SCOPE); - assert(scope_b->id == Node::Id::SCOPE); - - Node* inputs[] = { - 0, - scope_a->get_control(), - scope_b->get_control(), - }; - - auto* region_node = set_control(thread->arena, Node::add(thread, { - .type = {}, - .inputs = array_to_slice(inputs), - .id = Node::Id::REGION, - })->keep()); - auto names = scope_reverse_names(thread->arena, scope_a); - - // Skip input[0] ($ctrl) - for (u32 i = 1; i < scope_a->inputs.length; i += 1) - { - Node* input_a = scope_a->inputs[i]; - Node* input_b = scope_b->inputs[i]; - if (input_a != input_b) - { - Node* inputs[] = { - region_node, - input_a, - input_b, - }; - String label = names[i]; - - auto* phi_node = Node::add(thread, { - .type = {}, - .inputs = array_to_slice(inputs), - .id = Node::Id::PHI, - }); - phi_node->payload.phi.label = label; - phi_node = phi_node->peephole(thread, function); - - scope->set_input(thread->arena, i, phi_node); - } - } - - scope_b->kill(thread->arena); - return region_node->unkeep()->peephole(thread, function); - } }; fn SemaType* analyze_type(Parser* parser, Unit* unit, String src) @@ -3874,47 +3997,6 @@ fn u64 parse_hex(String string) return result; } -fn Node* scope_update_extended(Node* scope, Arena* arena, String name, Node* node, s32 nesting_level) -{ - if (nesting_level < 0) - { - return 0; - } - - // TODO: avoid recursion - auto& map = scope->payload.scope.stack[nesting_level]; - if (auto index = map.get(name)) - { - auto* old = scope->get_inputs()[*index]; - if (node) - { - return scope->set_input(arena, *index, node); - } - else - { - return old; - } - } - else - { - return scope_update_extended(scope, arena, name, node, nesting_level - 1); - } -} - -fn Node* scope_update(Analyzer* analyzer, Arena* arena, String name, Node* node) -{ - return scope_update_extended(analyzer->scope, arena, name, node, analyzer->scope->payload.scope.stack.length - 1); -} - -fn Node* scope_lookup(Analyzer* analyzer, Arena* arena, String name) -{ - if (auto* node = scope_update_extended(analyzer->scope, arena, name, nullptr, analyzer->scope->payload.scope.stack.length - 1)) - { - return node; - } - - return analyzer->file->symbols.get(name); -} [[nodiscard]] fn Node* analyze_single_expression(Analyzer* analyzer, Parser* parser, Unit* unit, Thread* thread, String src, SemaType* type, Side side) { @@ -3993,7 +4075,7 @@ fn Node* scope_lookup(Analyzer* analyzer, Arena* arena, String name) else if (is_identifier) { String identifier = parser->parse_and_check_identifier(src); - auto* node = scope_lookup(analyzer, thread->arena, identifier); + auto* node = analyzer->scope->scope_lookup(thread, function, analyzer->file, identifier); if (!node) { fail(); @@ -4370,6 +4452,37 @@ fn Node* define_variable(Analyzer* analyzer, String name, Node* node) fn Node* analyze_local_block(Analyzer* analyzer, Parser* parser, Unit* unit, Thread* thread, String src); +fn Node* jump_to(Analyzer* analyzer, Thread* thread, Node* target_scope) +{ + auto* current_scope = analyzer->duplicate_scope(thread, 0); + // Kill current scope + auto* dead_control = Node::add(thread, { + .type = { .id = Node::Type::Id::DEAD_CONTROL, }, + .inputs = { .pointer = &analyzer->function->root_node, .length = 1 }, + .id = Node::Id::CONSTANT_CONTROL, + })->peephole(thread, analyzer->function); + analyzer->set_control(thread->arena, dead_control); + + while (current_scope->payload.scope.stack.length > analyzer->break_scope->payload.scope.stack.length) + { + current_scope->payload.scope.stack.pop(); + } + + if (target_scope) + { + assert(target_scope->payload.scope.stack.length <= analyzer->break_scope->payload.scope.stack.length); + auto* result = target_scope->merge_scopes(thread, analyzer->file, analyzer->function, current_scope); + unused(result); + // TODO: is this right? + // assert(result == target_scope); + return target_scope; + } + else + { + return current_scope; + } +} + fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Thread* thread, String src) { auto statement_start_index = parser->i; @@ -4463,7 +4576,7 @@ fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Threa analyzer->scope = true_scope; - auto* merged_scope = analyzer->merge_scopes(thread, true_scope, false_scope); + auto* merged_scope = true_scope->merge_scopes(thread, analyzer->file, analyzer->function, false_scope); statement_node = analyzer->set_control(thread->arena, merged_scope); assert(statement_node); } @@ -4473,6 +4586,9 @@ fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Threa parser->expect_character(src, parenthesis_open); + auto* old_break_scope = analyzer->break_scope; + auto* old_continue_scope = analyzer->continue_scope; + Node* loop_inputs[] = { 0, analyzer->get_control(), @@ -4487,8 +4603,7 @@ fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Threa analyzer->set_control(thread->arena, loop_node); Node* head = analyzer->scope->keep(); - auto is_loop = 1; - analyzer->scope = analyzer->duplicate_scope(thread, is_loop); + analyzer->scope = analyzer->duplicate_scope(thread, 1); parser->skip_space(src); @@ -4505,28 +4620,65 @@ fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Threa }; auto* if_node = Node::add(thread, { - .type = {}, - .inputs = array_to_slice(if_inputs), - .id = Node::Id::IF, - })->keep()->peephole(thread, function); + .type = {}, + .inputs = array_to_slice(if_inputs), + .id = Node::Id::IF, + })->keep()->peephole(thread, function); Node* if_true = if_node->project(thread, if_node, 0, strlit("True"))->peephole(thread, function); if_node->unkeep(); Node* if_false = if_node->project(thread, if_node, 1, strlit("False"))->peephole(thread, function); - auto* exit_scope = analyzer->duplicate_scope(thread, 0); - exit_scope->set_control(thread->arena, if_false); + analyzer->set_control(thread->arena, if_false); + analyzer->break_scope = analyzer->duplicate_scope(thread, 0); + analyzer->continue_scope = 0; analyzer->set_control(thread->arena, if_true); - analyze_statement(analyzer, parser, unit, thread, src); + if (analyzer->continue_scope) + { + analyzer->continue_scope = jump_to(analyzer, thread, analyzer->continue_scope); + analyzer->scope->kill(thread->arena); + analyzer->scope = analyzer->continue_scope; + } + + auto* exit_scope = analyzer->break_scope; head->scope_end_loop(thread, function, analyzer->scope, exit_scope); head->unkeep()->kill(thread->arena); + analyzer->break_scope = old_break_scope; + analyzer->continue_scope = old_continue_scope; + analyzer->scope = exit_scope; + statement_node = exit_scope; - assert(statement_node); + } + else if (identifier.equal(strlit("break"))) + { + parser->skip_space(src); + parser->expect_character(src, end_of_statement); + + if (!analyzer->break_scope) + { + fail(); + } + + analyzer->break_scope = jump_to(analyzer, thread, analyzer->break_scope); + statement_node = analyzer->break_scope; + } + else if (identifier.equal(strlit("continue"))) + { + parser->skip_space(src); + parser->expect_character(src, end_of_statement); + + if (!analyzer->break_scope) + { + fail(); + } + + analyzer->continue_scope = jump_to(analyzer, thread, analyzer->continue_scope); + statement_node = analyzer->continue_scope; } if (statement_node) @@ -4535,7 +4687,7 @@ fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Threa } else { - if (auto* left_node = scope_lookup(analyzer, thread->arena, identifier)) + if (auto* left_node = analyzer->scope->scope_lookup(thread, analyzer->function, analyzer->file, identifier)) { parser->skip_space(src); @@ -4564,7 +4716,7 @@ fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Threa switch (operation) { case StatementOperation::ASSIGN: - if (!scope_update(analyzer, thread->arena, identifier, right_expression)) + if (!analyzer->scope->scope_update(thread, function, identifier, right_expression)) { fail(); } @@ -4634,7 +4786,21 @@ fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Threa .type = type, }; } break; - case '=': trap(); + case '=': + { + parser->i += 1; + parser->skip_space(src); + + auto* initial_node = analyze_expression(analyzer, parser, unit, thread, src, 0, Side::right); + if (!define_variable(analyzer, name, initial_node)) + { + fail(); + } + local_result = { + .node = initial_node, + .type = initial_node->get_debug_type(unit), + }; + } break; default: fail(); } @@ -5404,6 +5570,7 @@ global String test_file_paths[] = { strlit("tests/comparison/main.nat"), strlit("tests/if/main.nat"), strlit("tests/while/main.nat"), + strlit("tests/break_continue/main.nat"), }; #ifdef __linux__ diff --git a/tests/break_continue/main.nat b/tests/break_continue/main.nat new file mode 100644 index 0000000..5ce6822 --- /dev/null +++ b/tests/break_continue/main.nat @@ -0,0 +1,166 @@ +fn fn0(arg: s32) s32 +{ + >a = arg; + while (a < 10) + { + a = a + 1; + if (a == 5) + { + break; + } + + if (a == 6) + { + break; + } + } + + return a; +} + +fn fn1(arg: s32) s32 +{ + >a: s32 = 1; + >i = arg; + while (i < 10) + { + i = i + 1; + if (i == 5) + { + continue; + } + + if (i == 7) + { + continue; + } + + a = a + 1; + } + + return a; +} + +fn fn2(arg: s32) s32 +{ + >i = arg; + while (i < 10) + { + i = i + 1; + if (i == 5) + { + continue; + } + + if (i == 6) + { + break; + } + } + + return i; +} + +fn fn3(arg: s32) s32 +{ + >i = arg; + while (i < 10) + { + i = i + 1; + if (i == 6) + { + break; + } + } + + return i; +} + +fn fn4(arg: s32) s32 +{ + >i = arg; + while (i < 10) + { + i = i + 1; + if (i == 5) + { + continue; + } + if (i == 6) + { + continue; + } + } + + return i; +} + +fn fn5(arg: s32) s32 +{ + >i = arg; + while (i < 10) + { + i = i + 1; + if (i == 5) + { + continue; + } + } + + return i; +} + +fn fn6(arg: s32) s32 +{ + >i = arg; + while (i < 10) + { + >a = i + 2; + if (a > 4) + { + break; + } + } + + return i; +} + +fn fn7(arg: s32) s32 +{ + >i = arg; + while (i < 10) + { + break; + } + + return i; +} + +fn fn8(arg: s32) s32 +{ + >a: s32 = 1; + while (1) + { + a = a + 1; + if (a < 10) + { + continue; + } + break; + } + + return a; +} + +fn[cc(.c)] main[export]() s32 +{ + return fn0(0) + + fn1(1) + + fn2(2) + + fn3(3) + + fn4(4) + + fn5(5) + + fn6(6) + + fn7(7) + + fn8(8); +}