Implement while

This commit is contained in:
David Gonzalez Martin 2024-07-09 15:54:03 +02:00
parent 2b776bec44
commit 7a86ca2f79
2 changed files with 505 additions and 33 deletions

View File

@ -1209,10 +1209,12 @@ struct PinnedArray
{
if (index >= 0 & index < length)
{
auto original_len = length;
T item = pointer[index];
T last = pointer[length - 1];
pointer[index] = last;
pop();
assert(length == original_len - 1);
return item;
}
@ -1533,6 +1535,7 @@ struct NodeType
return (payload.constant.is_constant == other.payload.constant.is_constant) & (payload.constant.constant == other.payload.constant.constant);
case NodeType::Id::LIVE_CONTROL:
case NodeType::Id::DEAD_CONTROL:
case NodeType::Id::BOTTOM:
return 1;
default:
trap();
@ -1967,6 +1970,7 @@ struct Node
SYMBOL_FUNCTION,
CALL,
REGION,
REGION_LOOP,
PHI,
INTEGER_ADD,
@ -2193,9 +2197,28 @@ struct Node
}
case Id::ROOT:
case Id::IF:
case Id::INTEGER_ADD:
return 0;
case Id::INTEGER_ADD:
{
auto* left = inputs[1];
auto* right = inputs[2];
assert(!(left->type.is_constant() && right->type.is_constant()));
if (right->type.id == NodeType::Id::INTEGER && right->type.payload.constant.constant == 0)
{
return left;
}
if (left == right)
{
trap();
}
return 0;
}
case Id::REGION_LOOP:
case Id::REGION:
if (!region_in_progress())
{
// Find dead input
for (u32 i = 1; i < inputs.length; i += 1)
@ -2205,9 +2228,9 @@ struct Node
trap();
}
}
return 0;
}
return 0;
case Id::SCOPE:
trap();
// TODO:
@ -2230,21 +2253,51 @@ struct Node
}
case Id::PHI:
{
if (phi_same_inputs())
auto* region = phi_get_region();
auto is_r = region->is_region();
if (!is_r || region->region_in_progress())
{
return inputs[1];
return 0;
}
else
{
// Single unique input search
Node* live = 0;
Node* region = phi_get_region();
for (u32 i = 1; i < inputs.length; i += 1)
{
if (region->inputs[i]->type.id != NodeType::Id::DEAD_CONTROL && inputs[i] != this)
{
if (!live || live == inputs[i])
{
live = inputs[i];
}
else
{
live = 0;
break;
}
}
}
if (live)
{
return live;
}
Node* operand = inputs[1];
if (operand->inputs.length == 3 && !operand->inputs[0] && !operand->is_cfg() && phi_same_operand())
{
auto lefts = thread->arena->allocate_slice<Node*>(inputs.length);
auto rights = thread->arena->allocate_slice<Node*>(inputs.length);
lefts[0] = rights[0] = inputs[0];
u32 input_count = inputs.length;
auto lefts = thread->arena->allocate_slice<Node*>(input_count);
auto rights = thread->arena->allocate_slice<Node*>(input_count);
for (u32 i = 1; i < inputs.length; i += 1)
lefts[0] = inputs[0];
rights[0] = inputs[0];
for (u32 i = 1; i < input_count; i += 1)
{
lefts[i] = inputs[i]->inputs[1];
rights[i] = inputs[i]->inputs[2];
@ -2265,12 +2318,12 @@ struct Node
});
right_phi->payload.phi.label = payload.phi.label;
right_phi = right_phi->peephole(thread, function);
return operand->copy(thread, left_phi, right_phi);
}
else
{
return 0;
auto* result = operand->copy(thread, left_phi, right_phi);
return result;
}
return 0;
}
}
case Id::RETURN:
@ -2320,10 +2373,11 @@ struct Node
case Id::STOP:
case Id::RETURN:
case Id::REGION:
case Id::REGION_LOOP:
case Id::IF:
return 1;
case Id::PROJECTION:
trap();
return (payload.projection.index == 0) || (get_control()->id == Node::Id::IF);
default:
return 0;
}
@ -2411,7 +2465,7 @@ struct Node
auto* constant_int = Node::add(thread, {
.type = type,
.inputs = { .pointer = &function->root_node, .length = 1 },
.id = Node::Id::CONSTANT_INT,
.id = Id::CONSTANT_INT,
});
auto* result = constant_int->peephole(thread, function);
return dead_code_elimination(thread->arena, result);
@ -2446,6 +2500,7 @@ struct Node
default:
return 0;
case Id::CONSTANT_INT:
case Id::CONSTANT_CONTROL:
return 1;
}
}
@ -2593,7 +2648,13 @@ struct Node
},
};
}
case Node::Id::REGION_LOOP:
case Node::Id::REGION:
if (region_in_progress())
{
return { .id = Type::Id::LIVE_CONTROL };
}
else
{
Type ty = { .id = Type::Id::DEAD_CONTROL };
for (u32 i = 1; i < inputs.length; i += 1)
@ -2604,12 +2665,41 @@ struct Node
return ty;
}
case Node::Id::PHI:
return { .id = Type::Id::BOTTOM };
{
auto* region = phi_get_region();
auto is_r = region->is_region();
if (!is_r || region->region_in_progress())
{
return { .id = Type::Id::BOTTOM };
}
else
{
Node::Type ty = { .id = Type::Id::TOP };
for (u32 i = 1; i < inputs.length; i += 1)
{
ty = ty.meet(inputs[i]->type);
}
return ty;
}
}
default:
trap();
}
}
u8 is_region()
{
switch (id)
{
case Id::REGION: case Id::REGION_LOOP:
return 1;
default:
return 0;
}
}
method u8 is_associative()
{
switch (id)
@ -2622,6 +2712,7 @@ struct Node
}
method Node* associative_phi_constant(u8 should_rotate)
{
unused(should_rotate);
assert(is_associative());
trap();
}
@ -2658,13 +2749,15 @@ struct Node
case NodeType::Id::INVALID:
trap();
case NodeType::Id::BOTTOM:
trap();
// TODO:
return unit->get_integer_type(32, 1);
case NodeType::Id::TOP:
trap();
case NodeType::Id::LIVE_CONTROL:
case NodeType::Id::DEAD_CONTROL:
trap();
case NodeType::Id::INTEGER:
// TODO:
return unit->get_integer_type(32, 1);
case NodeType::Id::MULTIVALUE:
trap();
@ -2703,7 +2796,17 @@ struct Node
method u8 all_constants()
{
for (u32 i = 0; i < inputs.length; i += 1)
if (id == Id::PHI)
{
auto* region = phi_get_region();
auto is_r = region->is_region();
if (!is_r || region->region_in_progress())
{
return 0;
}
}
for (u32 i = 1; i < inputs.length; i += 1)
{
if (!inputs[i]->type.is_constant())
{
@ -2720,6 +2823,8 @@ struct Node
{
case Id::ROOT:
return 0;
case Id::REGION_LOOP:
return loop_entry();
case Id::REGION:
if (payload.region.immediate_dominator)
{
@ -2786,6 +2891,82 @@ struct Node
}
}
}
method u8 region_in_progress()
{
assert(is_region());
return !(inputs[inputs.length - 1]);
}
method Node* loop_entry()
{
assert(id == Id::REGION_LOOP);
return inputs[1];
}
method Node* loop_backedge()
{
assert(id == Id::REGION_LOOP);
return inputs[2];
}
method Node* set_control(Arena* arena, Node* node)
{
assert(id == Id::SCOPE);
return set_input(arena, 0, node);
}
method Node* phi_get_region()
{
assert(id == Id::PHI);
return inputs[0];
}
method void scope_end_loop(Thread* thread, Function* function, Node* back, Node* exit)
{
assert(id == Id::SCOPE);
assert(back->id == Id::SCOPE);
assert(exit->id == Id::SCOPE);
Node* control_node = get_control();
assert(control_node->id == Id::REGION_LOOP);
assert(control_node->region_in_progress());
control_node->set_input(thread->arena, 2, back->get_control());
for (u32 i = 1; i < inputs.length; i += 1)
{
auto* phi_node = inputs[i];
assert(phi_node->id == Id::PHI);
assert(phi_node->phi_get_region() == control_node);
assert(!(phi_node->inputs[2]));
phi_node->set_input(thread->arena, 2, back->inputs[2]);
Node* input = phi_node->peephole(thread, function);
if (input != phi_node)
{
phi_node->subsume(thread->arena, input);
}
}
back->kill(thread->arena);
}
method void subsume(Arena* arena, Node* node)
{
assert(node != this);
while (outputs.length > 0)
{
Node* n = outputs.pop();
s32 index = n->inputs.slice().find_index(this);
assert(index != -1);
n->inputs[index] = node;
node->add_output(n);
}
kill(arena);
}
};
static_assert(sizeof(Node) == 128);
@ -3340,7 +3521,7 @@ struct Analyzer
method Node* set_control(Arena* arena, Node* node)
{
return scope->set_input(arena, 0, node);
return scope->set_control(arena, node);
}
method void kill_control(Arena* arena)
@ -3376,13 +3557,12 @@ struct Analyzer
return scope->get_control();
}
method Node* duplicate_scope(Thread* thread)
method Node* duplicate_scope(Thread* thread, u8 loop)
{
auto original_input_count = scope->inputs.length;
auto* duplicate = create_scope(thread);
duplicate->payload.scope.stack.ensure_capacity(scope->payload.scope.stack.capacity);
// TODO: make this more efficient
// // TODO: make this more efficient
for (auto& hashmap: scope->payload.scope.stack.slice())
{
Hashmap<String, u16> duplicate_hashmap = {};
@ -3397,14 +3577,36 @@ struct Analyzer
duplicate->payload.scope.stack.append_one(duplicate_hashmap);
}
duplicate->add_input(get_control());
for (u32 i = 1; i < scope->inputs.length; i += 1)
{
duplicate->add_input(scope->inputs[i]);
if (loop)
{
auto names = scope_reverse_names(thread->arena, scope);
Node* inputs[] = {
get_control(),
scope->inputs[i],
0,
};
String label = names[i];
auto* phi_node = Node::add(thread, {
.type = {},
.inputs = array_to_slice(inputs),
.id = Node::Id::PHI,
});
phi_node->payload.phi.label = label;
phi_node = phi_node->peephole(thread, function);
duplicate->add_input(phi_node);
scope->set_input(thread->arena, i, duplicate->inputs[i]);
}
else
{
duplicate->add_input(scope->inputs[i]);
}
}
assert(duplicate->inputs.length == original_input_count);
return duplicate;
}
@ -3423,7 +3625,7 @@ struct Analyzer
.type = {},
.inputs = array_to_slice(inputs),
.id = Node::Id::REGION,
})->peephole(thread, function));
})->keep());
auto names = scope_reverse_names(thread->arena, scope_a);
// Skip input[0] ($ctrl)
@ -3453,7 +3655,7 @@ struct Analyzer
}
scope_b->kill(thread->arena);
return region_node;
return region_node->unkeep()->peephole(thread, function);
}
};
@ -3873,6 +4075,10 @@ fn Node* scope_lookup(Analyzer* analyzer, Arena* arena, String name)
INTEGER_COMPARE_LESS_EQUAL,
INTEGER_COMPARE_GREATER,
INTEGER_COMPARE_GREATER_EQUAL,
INTEGER_SHIFT_LEFT,
INTEGER_SHIFT_LEFT_ASSIGN,
INTEGER_SHIFT_RIGHT,
INTEGER_SHIFT_RIGHT_ASSIGN,
};
u64 iterations = 0;
@ -3931,6 +4137,10 @@ fn Node* scope_lookup(Analyzer* analyzer, Arena* arena, String name)
case CurrentOperation::INTEGER_COMPARE_LESS_EQUAL:
case CurrentOperation::INTEGER_COMPARE_GREATER:
case CurrentOperation::INTEGER_COMPARE_GREATER_EQUAL:
case CurrentOperation::INTEGER_SHIFT_LEFT:
case CurrentOperation::INTEGER_SHIFT_LEFT_ASSIGN:
case CurrentOperation::INTEGER_SHIFT_RIGHT:
case CurrentOperation::INTEGER_SHIFT_RIGHT_ASSIGN:
trap();
}
@ -3995,10 +4205,14 @@ fn Node* scope_lookup(Analyzer* analyzer, Arena* arena, String name)
previous_node = binary;
} break;
case CurrentOperation::ASSIGN:
case CurrentOperation::INTEGER_SHIFT_LEFT:
case CurrentOperation::INTEGER_SHIFT_LEFT_ASSIGN:
case CurrentOperation::INTEGER_SHIFT_RIGHT:
case CurrentOperation::INTEGER_SHIFT_RIGHT_ASSIGN:
case CurrentOperation::INTEGER_ADD_ASSIGN:
case CurrentOperation::INTEGER_SUB_ASSIGN:
trap();
}
}
previous_node = previous_node->peephole(thread, analyzer->function);
@ -4054,6 +4268,64 @@ fn Node* scope_lookup(Analyzer* analyzer, Arena* arena, String name)
break;
}
break;
case '<':
current_operation = CurrentOperation::INTEGER_COMPARE_LESS;
parser->i += 1;
switch (src[parser->i])
{
case '=':
current_operation = CurrentOperation::INTEGER_COMPARE_LESS_EQUAL;
parser->i += 1;
break;
case '<': // Shift left
current_operation = CurrentOperation::INTEGER_SHIFT_LEFT;
parser->i += 1;
switch (src[parser->i])
{
case '=':
current_operation = CurrentOperation::INTEGER_SHIFT_LEFT_ASSIGN;
parser->i += 1;
break;
default:
break;
}
break;
default:
break;
}
break;
case '>':
current_operation = CurrentOperation::INTEGER_COMPARE_GREATER;
parser->i += 1;
switch (src[parser->i])
{
case '=':
current_operation = CurrentOperation::INTEGER_COMPARE_GREATER_EQUAL;
parser->i += 1;
break;
case '>': // Shift right
current_operation = CurrentOperation::INTEGER_SHIFT_RIGHT;
parser->i += 1;
switch (src[parser->i])
{
case '=':
current_operation = CurrentOperation::INTEGER_SHIFT_RIGHT_ASSIGN;
parser->i += 1;
break;
default:
break;
}
break;
default:
break;
}
break;
case function_argument_start:
{
assert(previous_node->id == Node::Id::SYMBOL_FUNCTION);
@ -4151,7 +4423,7 @@ fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Threa
Node* if_false = if_node->project(thread, if_node, 1, strlit("False"))->peephole(thread, function);
u32 original_input_count = analyzer->scope->inputs.length;
auto* false_scope = analyzer->duplicate_scope(thread);
auto* false_scope = analyzer->duplicate_scope(thread, 0);
analyzer->set_control(thread->arena, if_true);
assert(analyzer->scope->get_control());
@ -4175,6 +4447,8 @@ fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Threa
parser->skip_space(src);
analyze_statement(analyzer, parser, unit, thread, src);
false_scope = analyzer->scope;
}
else
{
@ -4193,6 +4467,67 @@ fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Threa
statement_node = analyzer->set_control(thread->arena, merged_scope);
assert(statement_node);
}
else if (identifier.equal(strlit("while")))
{
parser->skip_space(src);
parser->expect_character(src, parenthesis_open);
Node* loop_inputs[] = {
0,
analyzer->get_control(),
0,
};
auto* loop_node = Node::add(thread, {
.type = {},
.inputs = array_to_slice(loop_inputs),
.id = Node::Id::REGION_LOOP,
})->peephole(thread, function);
analyzer->set_control(thread->arena, loop_node);
Node* head = analyzer->scope->keep();
auto is_loop = 1;
analyzer->scope = analyzer->duplicate_scope(thread, is_loop);
parser->skip_space(src);
auto* predicate_node = analyze_expression(analyzer, parser, unit, thread, src, 0, Side::right);
parser->skip_space(src);
parser->expect_character(src, parenthesis_close);
parser->skip_space(src);
Node* if_inputs[] = {
analyzer->get_control(),
predicate_node,
};
auto* if_node = Node::add(thread, {
.type = {},
.inputs = array_to_slice(if_inputs),
.id = Node::Id::IF,
})->keep()->peephole(thread, function);
Node* if_true = if_node->project(thread, if_node, 0, strlit("True"))->peephole(thread, function);
if_node->unkeep();
Node* if_false = if_node->project(thread, if_node, 1, strlit("False"))->peephole(thread, function);
auto* exit_scope = analyzer->duplicate_scope(thread, 0);
exit_scope->set_control(thread->arena, if_false);
analyzer->set_control(thread->arena, if_true);
analyze_statement(analyzer, parser, unit, thread, src);
head->scope_end_loop(thread, function, analyzer->scope, exit_scope);
head->unkeep()->kill(thread->arena);
analyzer->scope = exit_scope;
statement_node = exit_scope;
assert(statement_node);
}
if (statement_node)
{
@ -5068,6 +5403,7 @@ global String test_file_paths[] = {
strlit("tests/function_call_args/main.nat"),
strlit("tests/comparison/main.nat"),
strlit("tests/if/main.nat"),
strlit("tests/while/main.nat"),
};
#ifdef __linux__

136
tests/while/main.nat Normal file
View File

@ -0,0 +1,136 @@
fn while0(arg: s32) s32
{
>a: s32 = arg;
while (a < 10)
{
a = a + 1;
}
return a;
}
fn while1(arg: s32) s32
{
>a: s32 = 1;
if (arg)
{
}
else
{
while (a < 10)
{
a = a + 1;
}
}
return a;
}
fn while2(arg: s32) s32
{
>sum: s32 = 0;
>i: s32 = 0;
while (i < arg)
{
i = i + 1;
>j: s32 = 0;
while (j < arg)
{
sum = sum + j;
j = j + 1;
}
}
return sum;
}
fn while3(arg: s32) s32
{
>a: s32 = 1;
>b: s32 = 2;
while (a < 10)
{
if (a == 2)
{
a = 3;
}
else
{
b = 4;
}
}
return b;
}
fn while4(arg: s32) s32
{
>a: s32 = 1;
>b: s32 = 2;
while (a < 10)
{
if (a == 2)
{
a = 3;
}
else
{
b = 4;
}
b = b + 1;
a = a + 1;
}
return b;
}
fn while5(arg: s32) s32
{
>a: s32 = 1;
while (a < 10)
{
a = a + 1;
a = a + 2;
}
return a;
}
fn while6(arg: s32) s32
{
>a: s32 = 1;
while (arg)
{
a = 2;
}
return a;
}
fn while7(arg: s32) s32
{
>a: s32 = 1;
while (a < 10)
{
>b: s32 = a + 1;
a = b + 2;
}
return a;
}
fn[cc(.c)] main[export]() s32
{
return while0(0) +
while1(1) +
while2(2) +
while3(3) +
while4(4) +
while5(5) +
while6(6) +
while7(7);
}