Merge pull request #11 from birth-software/peephole-if

If peephole
This commit is contained in:
David 2024-07-08 19:54:50 +02:00 committed by GitHub
commit 2b776bec44
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -1478,9 +1478,9 @@ struct NodeType
INVALID,
BOTTOM,
TOP,
CONTROL,
LIVE_CONTROL,
DEAD_CONTROL,
INTEGER,
VOID,
MULTIVALUE,
MEMORY,
POINTER,
@ -1497,7 +1497,7 @@ struct NodeType
{
u64 constant;
u8 is_constant;
} integer;
} constant;
struct
{
Slice<NodeType> types;
@ -1512,7 +1512,8 @@ struct NodeType
trap();
case Id::BOTTOM:
case Id::TOP:
case Id::CONTROL:
case Id::LIVE_CONTROL:
case Id::DEAD_CONTROL:
return 1;
default:
return 0;
@ -1529,7 +1530,10 @@ struct NodeType
switch (id)
{
case NodeType::Id::INTEGER:
return (payload.integer.is_constant == other.payload.integer.is_constant) & (payload.integer.constant == other.payload.integer.constant);
return (payload.constant.is_constant == other.payload.constant.is_constant) & (payload.constant.constant == other.payload.constant.constant);
case NodeType::Id::LIVE_CONTROL:
case NodeType::Id::DEAD_CONTROL:
return 1;
default:
trap();
}
@ -1539,11 +1543,12 @@ struct NodeType
{
switch (id)
{
case Id::VOID:
trap();
case Id::DEAD_CONTROL:
case Id::TOP:
return 1;
case Id::INTEGER:
return payload.integer.is_constant;
case Id::CONTROL:
return payload.constant.is_constant;
case Id::LIVE_CONTROL:
case Id::MULTIVALUE:
case Id::BOTTOM:
return 0;
@ -1557,22 +1562,67 @@ struct NodeType
method NodeType meet(NodeType other)
{
unused(other);
if (equal(other))
{
return *this;
}
if (id == other.id)
{
return x_meet(other);
}
if (is_simple())
{
return x_meet(other);
}
if (other.is_simple())
{
return other.x_meet(*this);
}
return { .id = NodeType::Id::BOTTOM };
}
method NodeType x_meet(NodeType other)
{
switch (id)
{
case NodeType::Id::MULTIVALUE:
fail();
case NodeType::Id::INTEGER:
case Id::BOTTOM:
case Id::TOP:
case Id::LIVE_CONTROL:
case Id::DEAD_CONTROL:
{
assert(is_simple());
if ((id == Id::BOTTOM) | (other.id == Id::TOP))
{
return *this;
}
if ((id == Id::TOP) | (other.id == Id::BOTTOM))
{
return other;
}
if (!other.is_simple())
{
return NodeType{ .id = NodeType::Id::BOTTOM };
}
auto new_id = ((id == Id::LIVE_CONTROL) | (other.id == Id::LIVE_CONTROL)) ? Id::LIVE_CONTROL : Id::DEAD_CONTROL;
return { .id = new_id };
}
case Id::INTEGER:
{
if (equal(other))
{
return *this;
}
if (other.id != NodeType::Id::INTEGER)
if (other.id != Id::INTEGER)
{
return NodeType{ .id = NodeType::Id::BOTTOM };
return meet(other);
}
if (is_bot())
@ -1596,37 +1646,40 @@ struct NodeType
}
assert(is_constant() & other.is_constant());
if (payload.integer.constant == other.payload.integer.constant)
if (payload.constant.constant == other.payload.constant.constant)
{
trap();
return *this;
}
else
{
trap();
return { .id = Id::BOTTOM };
}
} break;
}
case Id::MULTIVALUE:
fail();
default:
return NodeType{ .id = NodeType::Id::BOTTOM };
trap();
}
}
method u8 is_bot()
{
assert(id == Id::INTEGER);
return !payload.integer.is_constant & (payload.integer.constant == 1);
return !payload.constant.is_constant & (payload.constant.constant == 1);
}
method u8 is_top()
{
assert(id == Id::INTEGER);
return !payload.integer.is_constant & (payload.integer.constant == 0);
return !payload.constant.is_constant & (payload.constant.constant == 0);
}
};
may_be_unused global auto constexpr integer_top = NodeType{
.id = NodeType::Id::INTEGER,
.payload = {
.integer = {
.constant = {
.constant = 0,
.is_constant = 0,
},
@ -1636,7 +1689,7 @@ may_be_unused global auto constexpr integer_top = NodeType{
may_be_unused global auto constexpr integer_bot = NodeType{
.id = NodeType::Id::INTEGER,
.payload = {
.integer = {
.constant = {
.constant = 1,
.is_constant = 0,
},
@ -1646,23 +1699,65 @@ may_be_unused global auto constexpr integer_bot = NodeType{
may_be_unused global auto constexpr integer_zero = NodeType{
.id = NodeType::Id::INTEGER,
.payload = {
.integer = {
.constant = {
.constant = 0,
.is_constant = 1,
},
},
};
global NodeType type_if_types[2] = {
{ .id = NodeType::Id::CONTROL },
{ .id = NodeType::Id::CONTROL },
global NodeType if_both_types[2] = {
{ .id = NodeType::Id::LIVE_CONTROL },
{ .id = NodeType::Id::LIVE_CONTROL },
};
global auto constexpr type_if = NodeType{
global NodeType if_neither_types[2] = {
{ .id = NodeType::Id::DEAD_CONTROL },
{ .id = NodeType::Id::DEAD_CONTROL },
};
global NodeType if_true_types[2] = {
{ .id = NodeType::Id::LIVE_CONTROL },
{ .id = NodeType::Id::DEAD_CONTROL },
};
global NodeType if_false_types[2] = {
{ .id = NodeType::Id::DEAD_CONTROL },
{ .id = NodeType::Id::LIVE_CONTROL },
};
global auto constexpr type_if_both = NodeType{
.id = NodeType::Id::MULTIVALUE,
.payload = {
.multi = {
.types = array_to_slice(type_if_types),
.types = array_to_slice(if_both_types),
},
},
};
global auto constexpr type_if_neither = NodeType{
.id = NodeType::Id::MULTIVALUE,
.payload = {
.multi = {
.types = array_to_slice(if_neither_types),
},
},
};
global auto constexpr type_if_true = NodeType{
.id = NodeType::Id::MULTIVALUE,
.payload = {
.multi = {
.types = array_to_slice(if_true_types),
},
},
};
global auto constexpr type_if_false = NodeType{
.id = NodeType::Id::MULTIVALUE,
.payload = {
.multi = {
.types = array_to_slice(if_false_types),
},
},
};
@ -1867,6 +1962,7 @@ struct Node
RETURN,
IF,
CONSTANT_INT,
CONSTANT_CONTROL,
SCOPE,
SYMBOL_FUNCTION,
CALL,
@ -1891,6 +1987,7 @@ struct Node
Array<Node*> outputs;
u32 gvn;
Id id;
s32 immediate_depth = 0;
union
{
@ -1912,9 +2009,13 @@ struct Node
{
String label;
} phi;
struct
{
Node* immediate_dominator = 0;
} region;
} payload;
u8 padding[40] = {};
u8 padding[32] = {};
method forceinline Slice<Node*> get_inputs()
{
@ -1967,13 +2068,6 @@ struct Node
return node;
}
method u8 remove_output(Node* output)
{
s32 index = outputs.slice().find_index(output);
assert(index != -1);
outputs.remove_swap(index);
return outputs.length == 0;
}
method Node* add_output(Node* output)
{
@ -1991,17 +2085,17 @@ struct Node
return input;
}
method Node* set_input(Arena* arena, s32 index, Node* input)
method Node* set_input(Arena* arena, s32 index, Node* new_input)
{
Node* old_input = inputs[index];
if (old_input == input)
if (old_input == new_input)
{
return this;
}
if (input)
if (new_input)
{
input->add_output(this);
new_input->add_output(this);
}
if (old_input && old_input->remove_output(this))
@ -2009,9 +2103,30 @@ struct Node
old_input->kill(arena);
}
inputs[index] = input;
inputs[index] = new_input;
return input;
return new_input;
}
method u8 remove_output(Node* output)
{
s32 index = outputs.slice().find_index(output);
assert(index != -1);
outputs.remove_swap(index);
return outputs.length == 0;
}
method void remove_input(Arena* arena, u32 index)
{
if (Node* old_input = inputs[index])
{
if (old_input->remove_output(this))
{
old_input->kill(arena);
}
}
inputs.remove_swap(index);
}
method Node* idealize(Thread* thread, Function* function)
@ -2027,15 +2142,72 @@ struct Node
{
return 0;
}
case Id::ROOT:
case Id::STOP:
case Id::PROJECTION:
case Id::IF:
case Id::RETURN:
case Id::CONSTANT_INT:
case Id::INTEGER_ADD:
case Id::REGION:
case Id::CONSTANT_CONTROL:
return 0;
case Id::STOP:
{
auto input_count = inputs.length;
for (u32 i = 0; i < inputs.length; i += 1)
{
if (inputs[i]->type.id == NodeType::Id::DEAD_CONTROL)
{
remove_input(thread->arena, i);
i -= 1;
}
}
if (input_count != inputs.length)
{
return this;
}
else
{
return 0;
}
}
case Id::PROJECTION:
{
auto* control = get_control();
if (control->type.id == NodeType::Id::MULTIVALUE)
{
auto control_types = control->type.payload.multi.types;
auto projection_index = payload.projection.index;
if (control_types[projection_index].id == NodeType::Id::DEAD_CONTROL)
{
trap();
}
// TODO: fix
// auto index = 1 - projection_index;
// assert(index >= 0);
// assert((u32)index < control_types.length);
// if (control_types[index].id == NodeType::Id::DEAD_CONTROL)
// {
// trap();
// }
}
return 0;
}
case Id::ROOT:
case Id::IF:
case Id::INTEGER_ADD:
return 0;
case Id::REGION:
{
// Find dead input
for (u32 i = 1; i < inputs.length; i += 1)
{
if (inputs[i]->type.id == NodeType::Id::DEAD_CONTROL)
{
trap();
}
}
return 0;
}
case Id::SCOPE:
trap();
// TODO:
@ -2101,6 +2273,17 @@ struct Node
}
}
}
case Id::RETURN:
{
if (get_control()->type.id == Node::Type::Id::DEAD_CONTROL)
{
trap();
}
else
{
return 0;
}
}
}
}
@ -2267,6 +2450,12 @@ struct Node
}
}
method Node* predicate()
{
assert(id == Id::IF);
return inputs[1];
}
method Node::Type compute()
{
switch (id)
@ -2276,7 +2465,30 @@ struct Node
case Node::Id::STOP:
return { .id = Type::Id::BOTTOM };
case Node::Id::IF:
return type_if;
{
if (get_control()->type.id != NodeType::Id::LIVE_CONTROL)
{
trap();
}
auto* this_predicate = predicate();
if ((this_predicate->type.id == Node::Type::Id::INTEGER) & this_predicate->type.is_constant())
{
trap();
}
for (Node* dom = get_immediate_dominator(), *prior = this; dom; prior = dom, dom = dom->get_immediate_dominator())
{
if ((dom->id == Id::IF) && dom->predicate() == this_predicate)
{
unused(prior);
trap();
}
}
return type_if_both;
}
// return type_if;
case Node::Id::INTEGER_ADD:
case Node::Id::INTEGER_SUB:
case Node::Id::INTEGER_COMPARE_EQUAL:
@ -2298,35 +2510,35 @@ struct Node
default:
trap();
case Id::INTEGER_ADD:
result = left_type.payload.integer.constant + right_type.payload.integer.constant;
result = left_type.payload.constant.constant + right_type.payload.constant.constant;
break;
case Id::INTEGER_SUB:
result = left_type.payload.integer.constant - right_type.payload.integer.constant;
result = left_type.payload.constant.constant - right_type.payload.constant.constant;
break;
case Id::INTEGER_COMPARE_EQUAL:
result = left_type.payload.integer.constant == right_type.payload.integer.constant;
result = left_type.payload.constant.constant == right_type.payload.constant.constant;
break;
case Id::INTEGER_COMPARE_NOT_EQUAL:
result = left_type.payload.integer.constant != right_type.payload.integer.constant;
result = left_type.payload.constant.constant != right_type.payload.constant.constant;
break;
case Id::INTEGER_COMPARE_LESS:
result = left_type.payload.integer.constant < right_type.payload.integer.constant;
result = left_type.payload.constant.constant < right_type.payload.constant.constant;
break;
case Id::INTEGER_COMPARE_LESS_EQUAL:
result = left_type.payload.integer.constant <= right_type.payload.integer.constant;
result = left_type.payload.constant.constant <= right_type.payload.constant.constant;
break;
case Id::INTEGER_COMPARE_GREATER:
result = left_type.payload.integer.constant > right_type.payload.integer.constant;
result = left_type.payload.constant.constant > right_type.payload.constant.constant;
break;
case Id::INTEGER_COMPARE_GREATER_EQUAL:
result = left_type.payload.integer.constant >= right_type.payload.integer.constant;
result = left_type.payload.constant.constant >= right_type.payload.constant.constant;
break;
}
return Node::Type{
.id = Node::Type::Id::INTEGER,
.payload = {
.integer = {
.constant = {
.constant = result,
// .bit_count = left_type.payload.integer.bit_count,
.is_constant = 1,
@ -2345,6 +2557,7 @@ struct Node
}
}
case Node::Id::CONSTANT_INT:
case Node::Id::CONSTANT_CONTROL:
return type;
case Node::Id::PROJECTION:
{
@ -2381,7 +2594,15 @@ struct Node
};
}
case Node::Id::REGION:
return { .id = Type::Id::CONTROL };
{
Type ty = { .id = Type::Id::DEAD_CONTROL };
for (u32 i = 1; i < inputs.length; i += 1)
{
ty = ty.meet(inputs[i]->type);
}
return ty;
}
case Node::Id::PHI:
return { .id = Type::Id::BOTTOM };
default:
@ -2389,6 +2610,22 @@ struct Node
}
}
method u8 is_associative()
{
switch (id)
{
case Id::INTEGER_ADD:
return 1;
default:
return 0;
}
}
method Node* associative_phi_constant(u8 should_rotate)
{
assert(is_associative());
trap();
}
method Node* project(Thread* thread, Node* control, s32 index, String label)
{
assert(type.id == Node::Type::Id::MULTIVALUE);
@ -2424,12 +2661,11 @@ struct Node
trap();
case NodeType::Id::TOP:
trap();
case NodeType::Id::CONTROL:
case NodeType::Id::LIVE_CONTROL:
case NodeType::Id::DEAD_CONTROL:
trap();
case NodeType::Id::INTEGER:
return unit->get_integer_type(32, 1);
case NodeType::Id::VOID:
trap();
case NodeType::Id::MULTIVALUE:
trap();
case NodeType::Id::MEMORY:
@ -2448,11 +2684,108 @@ struct Node
switch (id)
{
case Node::Id::SCOPE:
case Node::Id::RETURN:
case Node::Id::IF:
case Node::Id::PROJECTION:
return inputs[0];
default:
trap();
}
}
method Node* swap_inputs_1_2()
{
Node* temporal = inputs[1];
inputs[1] = inputs[2];
inputs[2] = temporal;
return this;
}
method u8 all_constants()
{
for (u32 i = 0; i < inputs.length; i += 1)
{
if (!inputs[i]->type.is_constant())
{
return 0;
}
}
return 1;
}
method Node* get_immediate_dominator()
{
switch (id)
{
case Id::ROOT:
return 0;
case Id::REGION:
if (payload.region.immediate_dominator)
{
return payload.region.immediate_dominator;
}
else
{
if (inputs.length == 3)
{
Node* left = inputs[1]->get_immediate_dominator();
Node* right = inputs[2]->get_immediate_dominator();
while (left != right)
{
if (!left || !right)
{
return 0;
}
else
{
auto comp = left->immediate_depth - right->immediate_depth;
if (comp >= 0)
{
left = left->get_immediate_dominator();
}
if (comp <= 0)
{
right = right->get_immediate_dominator();
}
}
}
if (left)
{
immediate_depth = left->immediate_depth + 1;
payload.region.immediate_dominator = left;
return left;
}
else
{
return 0;
}
}
else
{
trap();
}
}
default:
{
Node* result = inputs[0];
if (result->immediate_depth == 0)
{
result->get_immediate_dominator();
}
if (immediate_depth == 0)
{
immediate_depth = result->immediate_depth + 1;
}
return result;
}
}
}
};
static_assert(sizeof(Node) == 128);
@ -2465,7 +2798,7 @@ static_assert(page_size % sizeof(Node) == 0);
{
.id = Node::Type::Id::INTEGER,
.payload = {
.integer = {
.constant = {
.constant = data.value,
// .bit_count = data.bit_count,
.is_constant = 1,
@ -3027,7 +3360,13 @@ struct Analyzer
auto* node = function->stop_node->add_input(return_node);
kill_control(thread->arena);
// Kill control
auto* dead_control = Node::add(thread, {
.type = { .id = Node::Type::Id::DEAD_CONTROL },
.inputs = { .pointer = &function->root_node, .length = 1 },
.id = Node::Id::CONSTANT_CONTROL,
})->peephole(thread, function);
set_control(thread->arena, dead_control);
return node;
}
@ -3808,6 +4147,7 @@ fn Node* analyze_statement(Analyzer* analyzer, Parser* parser, Unit* unit, Threa
})->keep()->peephole(thread, function);
Node* if_true = if_node->project(thread, if_node, 0, strlit("True"))->peephole(thread, function);
if_node->unkeep();
Node* if_false = if_node->project(thread, if_node, 1, strlit("False"))->peephole(thread, function);
u32 original_input_count = analyzer->scope->inputs.length;
@ -4549,7 +4889,7 @@ fn Node* analyze_function(Parser* parser, Thread* thread, Unit* unit, File* file
Array<Node::Type> abi_argument_types = {};
Array<Node::Type> root_arg_types = {};
root_arg_types.append_one({ .id = Node::Type::Id::CONTROL });
root_arg_types.append_one({ .id = Node::Type::Id::LIVE_CONTROL });
for (u32 i = 0; i < argument_type_abis.length; i += 1)
{
@ -4642,7 +4982,6 @@ fn Node* analyze_function(Parser* parser, Thread* thread, Unit* unit, File* file
case ABI_INFO_DIRECT:
{
auto* argument_node = function->root_node->project(thread, function->root_node, next_index, argument_name)->peephole(thread, function);
assert(argument_node->type.id != Node::Type::Id::CONTROL);
define_variable(&analyzer, argument_name, argument_node);
next_index += 1;
} break;