Merge pull request #13 from birth-software/break-continue

Implement break and continue
This commit is contained in:
David 2024-07-10 13:42:36 +02:00 committed by GitHub
commit 04964bec81
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 525 additions and 192 deletions

View File

@ -1953,6 +1953,7 @@ struct ConstantIntData
[[nodiscard]] fn Node* add_constant_integer(Thread* thread, ConstantIntData data);
struct File;
// This is a node in the "sea of nodes" sense:
// https://en.wikipedia.org/wiki/Sea_of_nodes
struct Node
@ -2225,7 +2226,35 @@ struct Node
{
if (inputs[i]->type.id == NodeType::Id::DEAD_CONTROL)
{
trap();
for (u32 output_index = 0; output_index < outputs.length; output_index += 1)
{
Node* output = outputs[output_index];
if (output->id == Id::PHI)
{
output->remove_input(thread->arena, i);
}
}
remove_input(thread->arena, i);
if (inputs.length == 2)
{
for (u32 output_index = 0; output_index < outputs.length; output_index += 1)
{
Node* output = outputs[output_index];
if (output->id == Id::PHI)
{
// TODO:
trap();
}
}
return inputs[1];
}
else
{
trap();
}
}
}
}
@ -2521,15 +2550,24 @@ struct Node
return { .id = Type::Id::BOTTOM };
case Node::Id::IF:
{
if (get_control()->type.id != NodeType::Id::LIVE_CONTROL)
auto* control_node = get_control();
if (control_node->type.id != NodeType::Id::LIVE_CONTROL && control_node->type.id != Node::Type::Id::BOTTOM)
{
trap();
return type_if_neither;
}
auto* this_predicate = predicate();
if ((this_predicate->type.id == Node::Type::Id::INTEGER) & this_predicate->type.is_constant())
{
trap();
auto value = this_predicate->type.payload.constant.constant;
if (value)
{
return type_if_true;
}
else
{
return type_if_false;
}
}
for (Node* dom = get_immediate_dominator(), *prior = this; dom; prior = dom, dom = dom->get_immediate_dominator())
@ -2649,6 +2687,15 @@ struct Node
};
}
case Node::Id::REGION_LOOP:
if (region_in_progress())
{
return { .id = Type::Id::LIVE_CONTROL };
}
else
{
auto* entry = loop_entry();
return entry->type;
}
case Node::Id::REGION:
if (region_in_progress())
{
@ -2923,35 +2970,6 @@ struct Node
return inputs[0];
}
method void scope_end_loop(Thread* thread, Function* function, Node* back, Node* exit)
{
assert(id == Id::SCOPE);
assert(back->id == Id::SCOPE);
assert(exit->id == Id::SCOPE);
Node* control_node = get_control();
assert(control_node->id == Id::REGION_LOOP);
assert(control_node->region_in_progress());
control_node->set_input(thread->arena, 2, back->get_control());
for (u32 i = 1; i < inputs.length; i += 1)
{
auto* phi_node = inputs[i];
assert(phi_node->id == Id::PHI);
assert(phi_node->phi_get_region() == control_node);
assert(!(phi_node->inputs[2]));
phi_node->set_input(thread->arena, 2, back->inputs[2]);
Node* input = phi_node->peephole(thread, function);
if (input != phi_node)
{
phi_node->subsume(thread->arena, input);
}
}
back->kill(thread->arena);
}
method void subsume(Arena* arena, Node* node)
{
assert(node != this);
@ -2967,8 +2985,204 @@ struct Node
kill(arena);
}
method Slice<String> scope_reverse_names(Arena* arena)
{
assert(id == Node::Id::SCOPE);
Slice<String> names = arena->allocate_slice<String>(inputs.length);
for (auto& hashmap : payload.scope.stack.slice())
{
for (String name : hashmap.key_slice())
{
auto index = *hashmap.get(name);
names[index] = name;
}
}
return names;
}
method Node* scope_update_extended(Thread* thread, Function* function, String name, Node* node, s32 nesting_level)
{
assert(id == Id::SCOPE);
if (nesting_level < 0)
{
return 0;
}
// TODO: avoid recursion
auto& map = payload.scope.stack[nesting_level];
if (auto* index_ptr = map.get(name))
{
auto index = *index_ptr;
auto* old = get_inputs()[index];
if (old->id == Node::Id::SCOPE)
{
auto* loop = old;
if (loop->inputs[index]->id == Id::PHI && loop->get_control() == loop->inputs[index]->phi_get_region())
{
old = loop->inputs[index];
}
else
{
Node* phi_inputs[] = {
loop->get_control(),
loop->scope_update_extended(thread, function, name, 0, nesting_level),
0,
};
auto* phi_node = Node::add(thread, {
.type = {},
.inputs = array_to_slice(phi_inputs),
.id = Node::Id::PHI,
});
phi_node->payload.phi.label = name;
phi_node = phi_node->peephole(thread, function);
old = loop->set_input(thread->arena, index, phi_node);
}
set_input(thread->arena, index, old);
}
if (node)
{
return set_input(thread->arena, index, node);
}
else
{
return old;
}
}
else
{
return scope_update_extended(thread, function, name, node, nesting_level - 1);
}
}
method Node* scope_update(Thread* thread, Function* function, String name, Node* node)
{
assert(id == Id::SCOPE);
return scope_update_extended(thread, function, name, node, payload.scope.stack.length - 1);
}
method void scope_end_loop(Thread* thread, Function* function, Node* back, Node* exit)
{
assert(id == Id::SCOPE);
assert(back->id == Id::SCOPE);
assert(exit->id == Id::SCOPE);
Node* control_node = get_control();
assert(control_node->id == Id::REGION_LOOP);
assert(control_node->region_in_progress());
control_node->set_input(thread->arena, 2, back->get_control());
for (u32 i = 1; i < inputs.length; i += 1)
{
if (back->inputs[i] != this)
{
auto* phi = inputs[i];
assert(phi->id == Id::PHI);
assert(phi->phi_get_region() == get_control());
assert(!phi->inputs[2]);
phi->set_input(thread->arena, 2, back->inputs[i]);
}
if (exit->inputs[i] == this)
{
exit->set_input(thread->arena, i, inputs[i]);
}
}
back->kill(thread->arena);
for (u32 i = 1; i < inputs.length; i += 1)
{
auto* node = inputs[i];
if (node->id == Id::PHI)
{
Node* input = node->peephole(thread, function);
if (input != node)
{
node->subsume(thread->arena, input);
set_input(thread->arena, i, input);
}
}
}
}
method Node* scope_lookup(Thread* thread, Function* function, File* file, String name);
method Node* merge_scopes(Thread* thread, File* file, Function* function, Node* other);
};
struct File
{
String path;
String source_code;
FileStatus status;
Hashmap<String, Node> symbols = {};
};
method Node* Node::scope_lookup(Thread* thread, Function* function, File* file, String name)
{
auto* result = scope_update_extended(thread, function, name, nullptr, payload.scope.stack.length - 1);
if (file && !result)
{
result = file->symbols.get(name);
}
return result;
}
method Node* Node::merge_scopes(Thread* thread, File* file, Function* function, Node* other_scope)
{
assert(id == Node::Id::SCOPE);
assert(other_scope->id == Node::Id::SCOPE);
Node* region_inputs[] = {
0,
get_control(),
other_scope->get_control(),
};
auto* region_node = set_control(thread->arena, Node::add(thread, {
.type = {},
.inputs = array_to_slice(region_inputs),
.id = Node::Id::REGION,
})->keep());
auto names = scope_reverse_names(thread->arena);
// Skip input[0] ($ctrl)
for (u32 i = 1; i < inputs.length; i += 1)
{
if (inputs[i] != other_scope->inputs[i])
{
String label = names[i];
Node* input_a = scope_lookup(thread, function, file, label);
Node* input_b = other_scope->scope_lookup(thread, function, file, label);
Node* inputs[] = {
region_node,
input_a,
input_b,
};
auto* phi_node = Node::add(thread, {
.type = {},
.inputs = array_to_slice(inputs),
.id = Node::Id::PHI,
});
phi_node->payload.phi.label = label;
phi_node = phi_node->peephole(thread, function);
set_input(thread->arena, i, phi_node);
}
}
other_scope->kill(thread->arena);
return region_node->unkeep()->peephole(thread, function);
}
static_assert(sizeof(Node) == 128);
static_assert(page_size % sizeof(Node) == 0);
@ -3384,14 +3598,6 @@ struct Parser
// {
// return parser->i - parser->column + 1;
// }
struct File
{
String path;
String source_code;
FileStatus status;
Hashmap<String, Node> symbols = {};
};
fn File* add_file(Arena* arena, String file_path)
{
auto* file = arena->allocate_one<File>();
@ -3495,27 +3701,12 @@ Node* create_scope(Thread* thread)
return scope;
}
Slice<String> scope_reverse_names(Arena* arena, Node* node)
{
assert(node->id == Node::Id::SCOPE);
Slice<String> names = arena->allocate_slice<String>(node->inputs.length);
for (auto& hashmap : node->payload.scope.stack.slice())
{
for (String name : hashmap.key_slice())
{
auto index = *hashmap.get(name);
names[index] = name;
}
}
return names;
}
struct Analyzer
{
Function* function;
Node* scope;
Node* break_scope = 0;
Node* continue_scope = 0;
File* file;
@ -3559,11 +3750,12 @@ struct Analyzer
method Node* duplicate_scope(Thread* thread, u8 loop)
{
auto original_input_count = scope->inputs.length;
auto* duplicate = create_scope(thread);
auto* original_scope = scope;
auto original_input_count = original_scope->inputs.length;
auto* duplicate_scope = create_scope(thread);
// // TODO: make this more efficient
for (auto& hashmap: scope->payload.scope.stack.slice())
for (auto& hashmap: original_scope->payload.scope.stack.slice())
{
Hashmap<String, u16> duplicate_hashmap = {};
duplicate_hashmap.ensure_capacity(hashmap.length);
@ -3575,88 +3767,19 @@ struct Analyzer
duplicate_hashmap.put_assume_not_existing(keys[i], values[i]);
}
duplicate->payload.scope.stack.append_one(duplicate_hashmap);
duplicate_scope->payload.scope.stack.append_one(duplicate_hashmap);
}
duplicate->add_input(get_control());
duplicate_scope->add_input(get_control());
for (u32 i = 1; i < scope->inputs.length; i += 1)
for (u32 i = 1; i < original_scope->inputs.length; i += 1)
{
if (loop)
{
auto names = scope_reverse_names(thread->arena, scope);
Node* inputs[] = {
get_control(),
scope->inputs[i],
0,
};
String label = names[i];
auto* phi_node = Node::add(thread, {
.type = {},
.inputs = array_to_slice(inputs),
.id = Node::Id::PHI,
});
phi_node->payload.phi.label = label;
phi_node = phi_node->peephole(thread, function);
duplicate->add_input(phi_node);
scope->set_input(thread->arena, i, duplicate->inputs[i]);
}
else
{
duplicate->add_input(scope->inputs[i]);
}
duplicate_scope->add_input(loop ? original_scope : original_scope->inputs[i]);
}
assert(duplicate->inputs.length == original_input_count);
return duplicate;
assert(duplicate_scope->inputs.length == original_input_count);
return duplicate_scope;
}
method Node* merge_scopes(Thread* thread, Node* scope_a, Node* scope_b)
{
assert(scope_a->id == Node::Id::SCOPE);
assert(scope_b->id == Node::Id::SCOPE);
Node* inputs[] = {
0,
scope_a->get_control(),
scope_b->get_control(),
};
auto* region_node = set_control(thread->arena, Node::add(thread, {
.type = {},
.inputs = array_to_slice(inputs),
.id = Node::Id::REGION,
})->keep());
auto names = scope_reverse_names(thread->arena, scope_a);
// Skip input[0] ($ctrl)
for (u32 i = 1; i < scope_a->inputs.length; i += 1)
{
Node* input_a = scope_a->inputs[i];
Node* input_b = scope_b->inputs[i];
if (input_a != input_b)
{
Node* inputs[] = {
region_node,
input_a,
input_b,
};
String label = names[i];
auto* phi_node = Node::add(thread, {
.type = {},
.inputs = array_to_slice(inputs),
.id = Node::Id::PHI,
});
phi_node->payload.phi.label = label;
phi_node = phi_node->peephole(thread, function);
scope->set_input(thread->arena, i, phi_node);
}
}
scope_b->kill(thread->arena);
return region_node->unkeep()->peephole(thread, function);
}
};
fn SemaType* analyze_type(Parser* parser, Unit* unit, String src)
@ -3874,47 +3997,6 @@ fn u64 parse_hex(String string)
return result;
}
fn Node* scope_update_extended(Node* scope, Arena* arena, String name, Node* node, s32 nesting_level)
{
if (nesting_level < 0)
{
return 0;
}
// TODO: avoid recursion
auto& map = scope->payload.scope.stack[nesting_level];
if (auto index = map.get(name))
{
auto* old = scope->get_inputs()[*index];
if (node)
{
return scope->set_input(arena, *index, node);
}
else
{
return old;
}
}
else
{
return scope_update_extended(scope, arena, name, node, nesting_level - 1);
}
}
fn Node* scope_update(Analyzer* analyzer, Arena* arena, String name, Node* node)
{
return scope_update_extended(analyzer->scope, arena, name, node, analyzer->scope->payload.scope.stack.length - 1);
}
fn Node* scope_lookup(Analyzer* analyzer, Arena* arena, String name)
{
if (auto* node = scope_update_extended(analyzer->scope, arena, name, nullptr, analyzer->scope->payload.scope.stack.length - 1))
{
return node;
}
return analyzer->file->symbols.get(name);
}
[[nodiscard]] fn Node* analyze_single_expression(Analyzer* analyzer, Parser* parser, Unit* unit, Thread* thread, String src, SemaType* type, Side side)
{
@ -3993,7 +4075,7 @@ fn Node* scope_lookup(Analyzer* analyzer, Arena* arena, String name)
else if (is_identifier)
{
String identifier = parser->parse_and_check_identifier(src);
auto* node = scope_lookup(analyzer, thread->arena, identifier);
auto* node = analyzer->scope->scope_lookup(thread, function, analyzer->file, identifier);
if (!node)
{
fail();
@ -4370,6 +4452,37 @@ fn Node* define_variable(Analyzer* analyzer, String name, Node* node)
fn Node* analyze_local_block(Analyzer* analyzer, Parser* parser, Unit* unit, Thread* thread, String src);
fn Node* jump_to(Analyzer* analyzer, Thread* thread, Node* target_scope)
{
auto* current_scope = analyzer->duplicate_scope(thread, 0);
// Kill current scope
auto* dead_control = Node::add(thread, {
.type = { .id = Node::Type::Id::DEAD_CONTROL, },
.inputs = { .pointer = &analyzer->function->root_node, .length = 1 },
.id = Node::Id::CONSTANT_CONTROL,
})->peephole(thread, analyzer->function);
analyzer->set_control(thread->arena, dead_control);
while (current_scope->payload.scope.stack.length > analyzer->break_scope->payload.scope.stack.length)
{
current_scope->payload.scope.stack.pop();
}
if (target_scope)
{
assert(target_scope->payload.scope.stack.length <= analyzer->break_scope->payload.scope.stack.length);
auto* result = target_scope->merge_scopes(thread, analyzer->file, analyzer->function, current_scope);
unused(result);
// TODO: is this right?
// assert(result == target_scope);
return target_scope;
}
else
{
return current_scope;
}
}
fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Thread* thread, String src)
{
auto statement_start_index = parser->i;
@ -4463,7 +4576,7 @@ fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Threa
analyzer->scope = true_scope;
auto* merged_scope = analyzer->merge_scopes(thread, true_scope, false_scope);
auto* merged_scope = true_scope->merge_scopes(thread, analyzer->file, analyzer->function, false_scope);
statement_node = analyzer->set_control(thread->arena, merged_scope);
assert(statement_node);
}
@ -4473,6 +4586,9 @@ fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Threa
parser->expect_character(src, parenthesis_open);
auto* old_break_scope = analyzer->break_scope;
auto* old_continue_scope = analyzer->continue_scope;
Node* loop_inputs[] = {
0,
analyzer->get_control(),
@ -4487,8 +4603,7 @@ fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Threa
analyzer->set_control(thread->arena, loop_node);
Node* head = analyzer->scope->keep();
auto is_loop = 1;
analyzer->scope = analyzer->duplicate_scope(thread, is_loop);
analyzer->scope = analyzer->duplicate_scope(thread, 1);
parser->skip_space(src);
@ -4505,28 +4620,65 @@ fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Threa
};
auto* if_node = Node::add(thread, {
.type = {},
.inputs = array_to_slice(if_inputs),
.id = Node::Id::IF,
})->keep()->peephole(thread, function);
.type = {},
.inputs = array_to_slice(if_inputs),
.id = Node::Id::IF,
})->keep()->peephole(thread, function);
Node* if_true = if_node->project(thread, if_node, 0, strlit("True"))->peephole(thread, function);
if_node->unkeep();
Node* if_false = if_node->project(thread, if_node, 1, strlit("False"))->peephole(thread, function);
auto* exit_scope = analyzer->duplicate_scope(thread, 0);
exit_scope->set_control(thread->arena, if_false);
analyzer->set_control(thread->arena, if_false);
analyzer->break_scope = analyzer->duplicate_scope(thread, 0);
analyzer->continue_scope = 0;
analyzer->set_control(thread->arena, if_true);
analyze_statement(analyzer, parser, unit, thread, src);
if (analyzer->continue_scope)
{
analyzer->continue_scope = jump_to(analyzer, thread, analyzer->continue_scope);
analyzer->scope->kill(thread->arena);
analyzer->scope = analyzer->continue_scope;
}
auto* exit_scope = analyzer->break_scope;
head->scope_end_loop(thread, function, analyzer->scope, exit_scope);
head->unkeep()->kill(thread->arena);
analyzer->break_scope = old_break_scope;
analyzer->continue_scope = old_continue_scope;
analyzer->scope = exit_scope;
statement_node = exit_scope;
assert(statement_node);
}
else if (identifier.equal(strlit("break")))
{
parser->skip_space(src);
parser->expect_character(src, end_of_statement);
if (!analyzer->break_scope)
{
fail();
}
analyzer->break_scope = jump_to(analyzer, thread, analyzer->break_scope);
statement_node = analyzer->break_scope;
}
else if (identifier.equal(strlit("continue")))
{
parser->skip_space(src);
parser->expect_character(src, end_of_statement);
if (!analyzer->break_scope)
{
fail();
}
analyzer->continue_scope = jump_to(analyzer, thread, analyzer->continue_scope);
statement_node = analyzer->continue_scope;
}
if (statement_node)
@ -4535,7 +4687,7 @@ fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Threa
}
else
{
if (auto* left_node = scope_lookup(analyzer, thread->arena, identifier))
if (auto* left_node = analyzer->scope->scope_lookup(thread, analyzer->function, analyzer->file, identifier))
{
parser->skip_space(src);
@ -4564,7 +4716,7 @@ fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Threa
switch (operation)
{
case StatementOperation::ASSIGN:
if (!scope_update(analyzer, thread->arena, identifier, right_expression))
if (!analyzer->scope->scope_update(thread, function, identifier, right_expression))
{
fail();
}
@ -4634,7 +4786,21 @@ fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Threa
.type = type,
};
} break;
case '=': trap();
case '=':
{
parser->i += 1;
parser->skip_space(src);
auto* initial_node = analyze_expression(analyzer, parser, unit, thread, src, 0, Side::right);
if (!define_variable(analyzer, name, initial_node))
{
fail();
}
local_result = {
.node = initial_node,
.type = initial_node->get_debug_type(unit),
};
} break;
default: fail();
}
@ -5404,6 +5570,7 @@ global String test_file_paths[] = {
strlit("tests/comparison/main.nat"),
strlit("tests/if/main.nat"),
strlit("tests/while/main.nat"),
strlit("tests/break_continue/main.nat"),
};
#ifdef __linux__

View File

@ -0,0 +1,166 @@
fn fn0(arg: s32) s32
{
>a = arg;
while (a < 10)
{
a = a + 1;
if (a == 5)
{
break;
}
if (a == 6)
{
break;
}
}
return a;
}
fn fn1(arg: s32) s32
{
>a: s32 = 1;
>i = arg;
while (i < 10)
{
i = i + 1;
if (i == 5)
{
continue;
}
if (i == 7)
{
continue;
}
a = a + 1;
}
return a;
}
fn fn2(arg: s32) s32
{
>i = arg;
while (i < 10)
{
i = i + 1;
if (i == 5)
{
continue;
}
if (i == 6)
{
break;
}
}
return i;
}
fn fn3(arg: s32) s32
{
>i = arg;
while (i < 10)
{
i = i + 1;
if (i == 6)
{
break;
}
}
return i;
}
fn fn4(arg: s32) s32
{
>i = arg;
while (i < 10)
{
i = i + 1;
if (i == 5)
{
continue;
}
if (i == 6)
{
continue;
}
}
return i;
}
fn fn5(arg: s32) s32
{
>i = arg;
while (i < 10)
{
i = i + 1;
if (i == 5)
{
continue;
}
}
return i;
}
fn fn6(arg: s32) s32
{
>i = arg;
while (i < 10)
{
>a = i + 2;
if (a > 4)
{
break;
}
}
return i;
}
fn fn7(arg: s32) s32
{
>i = arg;
while (i < 10)
{
break;
}
return i;
}
fn fn8(arg: s32) s32
{
>a: s32 = 1;
while (1)
{
a = a + 1;
if (a < 10)
{
continue;
}
break;
}
return a;
}
fn[cc(.c)] main[export]() s32
{
return fn0(0) +
fn1(1) +
fn2(2) +
fn3(3) +
fn4(4) +
fn5(5) +
fn6(6) +
fn7(7) +
fn8(8);
}