Implement constant propagation

This commit is contained in:
David Gonzalez Martin 2024-06-30 10:04:01 +02:00
parent beb3af95fc
commit fe92747a41
2 changed files with 399 additions and 117 deletions

View File

@ -1146,68 +1146,6 @@ enum class Side : u8
right, right,
}; };
struct NodeDataType
{
enum class Id : u8
{
VOID,
INTEGER,
TUPLE,
CONTROL,
MEMORY,
POINTER,
};
Id id;
u8 bit_count:5;
};
union AbiInfoPayload
{
NodeDataType direct;
NodeDataType direct_pair[2];
NodeDataType direct_coerce;
struct
{
NodeDataType type;
u32 alignment;
} indirect;
};
typedef union AbiInfoPayload AbiInfoPayload;
struct AbiInfoAttributes
{
u8 by_reg: 1;
u8 zero_extend: 1;
u8 sign_extend: 1;
u8 realign: 1;
u8 by_value: 1;
};
typedef struct AbiInfoAttributes AbiInfoAttributes;
struct AbiInfo
{
AbiInfoPayload payload;
u16 indices[2];
AbiInfoAttributes attributes;
AbiInfoKind kind;
};
struct FunctionPrototype
{
AbiInfo* argument_type_abis; // The count for this array is "original_argument_count", not "abi_argument_count"
SemaType** original_argument_types;
// TODO: are these needed?
// Node::DataType* abi_argument_types;
// u32 abi_argument_count;
SemaType* original_return_type;
AbiInfo return_type_abi;
u32 original_argument_count;
// TODO: is this needed?
// Node::DataType abi_return_type;
u8 varags:1;
};
struct Function;
global auto constexpr void_type_index = 0; global auto constexpr void_type_index = 0;
global auto constexpr noreturn_type_index = 1; global auto constexpr noreturn_type_index = 1;
@ -1218,6 +1156,7 @@ global auto constexpr integer_type_offset = 5;
global auto constexpr integer_type_count = 64 * 2; global auto constexpr integer_type_count = 64 * 2;
global auto constexpr builtin_type_count = integer_type_count + integer_type_offset + 1; global auto constexpr builtin_type_count = integer_type_count + integer_type_offset + 1;
struct Function;
struct Thread struct Thread
{ {
Arena* arena; Arena* arena;
@ -1242,30 +1181,123 @@ struct Unit
} }
}; };
struct Node; struct Node;
struct FunctionPrototype;
struct NodeType
{
enum class Id: u8
{
INVALID,
BOTTOM,
TOP,
VOID,
INTEGER,
TUPLE,
CONTROL,
MEMORY,
POINTER,
};
Id id;
union
{
struct
{
u64 constant;
u8 bit_count;
u8 is_constant;
} integer;
};
u8 equal(NodeType other)
{
if (id != other.id)
{
return 0;
}
switch (id)
{
default: trap();
}
}
u8 is_constant()
{
switch (id)
{
case Id::VOID:
case Id::INTEGER:
return 1;
default:
trap();
}
}
};
union AbiInfoPayload
{
NodeType direct;
NodeType direct_pair[2];
NodeType direct_coerce;
struct
{
NodeType type;
u32 alignment;
} indirect;
};
typedef union AbiInfoPayload AbiInfoPayload;
struct AbiInfoAttributes
{
u8 by_reg: 1;
u8 zero_extend: 1;
u8 sign_extend: 1;
u8 realign: 1;
u8 by_value: 1;
};
typedef struct AbiInfoAttributes AbiInfoAttributes;
struct AbiInfo
{
AbiInfoPayload payload;
u16 indices[2];
AbiInfoAttributes attributes;
AbiInfoKind kind;
};
struct Function struct Function
{ {
struct Prototype
{
AbiInfo* argument_type_abis; // The count for this array is "original_argument_count", not "abi_argument_count"
SemaType** original_argument_types;
// TODO: are these needed?
// Node::DataType* abi_argument_types;
// u32 abi_argument_count;
SemaType* original_return_type;
AbiInfo return_type_abi;
u32 original_argument_count;
// TODO: is this needed?
// Node::DataType abi_return_type;
u8 varags:1;
};
Symbol symbol; Symbol symbol;
Node* root_node; Node* root_node;
Node** parameters; Node** parameters;
FunctionPrototype prototype; Function::Prototype prototype;
u32 node_count; u32 node_count;
u16 parameter_count; u16 parameter_count;
}; };
struct ProjectionData struct ConstantIntData
{ {
NodeDataType type; u64 value;
u16 index; Node* input;
u32 gvn;
u8 bit_count;
}; };
struct Output [[nodiscard]] fn Node* add_constant_integer(Arena* arena, ConstantIntData data);
{
Node* node;
u16 slot;
};
// This is a node in the "sea of nodes" sense: // This is a node in the "sea of nodes" sense:
// https://en.wikipedia.org/wiki/Sea_of_nodes // https://en.wikipedia.org/wiki/Sea_of_nodes
@ -1277,18 +1309,27 @@ struct Node
PROJECTION, PROJECTION,
RETURN, RETURN,
CONSTANT_INT, CONSTANT_INT,
INT_ADD,
INT_SUB,
}; };
static_assert(sizeof(NodeDataType) <= 2); using Type = NodeType;
struct Output
{
Node* node;
u16 slot;
};
Node** inputs; Node** inputs;
Output* outputs; Output* outputs;
u32 gvn; u32 gvn;
Type type;
u16 input_count; u16 input_count;
u16 input_capacity; u16 input_capacity;
u16 output_count; u16 output_count;
u16 output_capacity; u16 output_capacity;
NodeDataType data_type;
Id id; Id id;
union union
@ -1297,7 +1338,6 @@ struct Node
{ {
u32 index; u32 index;
} projection; } projection;
u64 constant_int;
}; };
forceinline Slice<Node*> get_inputs() forceinline Slice<Node*> get_inputs()
@ -1318,7 +1358,7 @@ struct Node
struct NodeData struct NodeData
{ {
NodeDataType type; Type type;
u16 input_count; u16 input_count;
Id id; Id id;
}; };
@ -1339,11 +1379,11 @@ struct Node
.inputs = arena->allocate_many<Node*>(data.input_capacity), .inputs = arena->allocate_many<Node*>(data.input_capacity),
.outputs = arena->allocate_many<Output>(output_capacity), .outputs = arena->allocate_many<Output>(output_capacity),
.gvn = data.gvn, .gvn = data.gvn,
.type = data.s.type,
.input_count = data.s.input_count, .input_count = data.s.input_count,
.input_capacity = data.input_capacity, .input_capacity = data.input_capacity,
.output_count = output_count, .output_count = output_count,
.output_capacity = output_capacity, .output_capacity = output_capacity,
.data_type = data.s.type,
.id = data.s.id, .id = data.s.id,
}; };
@ -1370,41 +1410,53 @@ struct Node
{ {
return add_from_function_dynamic(arena, function, data, data.input_count); return add_from_function_dynamic(arena, function, data, data.input_count);
} }
struct ProjectionData
{
Node::Type type;
u16 index;
};
[[nodiscard]] Node* project(Arena* arena, Function* function, ProjectionData data) [[nodiscard]] Node* project(Arena* arena, Function* function, ProjectionData data)
{ {
assert(data_type.id == NodeDataType::Id::TUPLE); assert(type.id == Type::Id::TUPLE);
Node* projection = Node::add_from_function(arena, function, { Node* projection = Node::add_from_function(arena, function, {
.input_count = 1, .input_count = 1,
}); });
assert(projection != this); assert(projection != this);
projection->id = Node::Id::PROJECTION; projection->id = Node::Id::PROJECTION;
projection->data_type = data.type; projection->type = data.type;
// projection->reallocate_edges(unit, 4); // projection->reallocate_edges(unit, 4);
projection->input_count = 1; projection->input_count = 1;
projection->set_input(this, 0); projection->set_input(arena, this, 0);
projection->projection.index = data.index; projection->projection.index = data.index;
return projection; return projection;
} }
void set_input(Node* input, u16 slot) void set_input(Arena* arena, Node* input, u16 slot)
{ {
assert(slot < input_count); assert(slot < input_count);
remove_output(slot); remove_output(slot);
inputs[slot] = input; inputs[slot] = input;
if (input) if (input)
{ {
add_output(input, slot); add_output(arena, input, slot);
} }
} }
void add_output(Node* input, u16 slot) void add_output(Arena* arena, Node* input, u16 slot)
{ {
if (input->output_count >= input->output_capacity) if (input->output_count >= input->output_capacity)
{ {
trap(); auto new_capacity = max<u32>(input->output_count, (u32)input->output_capacity * 2);
assert(new_capacity <= 0xffff);
auto* new_array = arena->allocate_many<Output>(new_capacity);
memcpy(new_array, input->outputs, sizeof(Output) * input->output_count);
memset(new_array + input->output_count, 0, sizeof(Output) * (new_capacity - input->output_count));
input->outputs = new_array;
input->output_capacity = new_capacity;
} }
auto index = input->output_count; auto index = input->output_count;
@ -1439,6 +1491,9 @@ struct Node
case Id::PROJECTION: case Id::PROJECTION:
case Id::CONSTANT_INT: case Id::CONSTANT_INT:
break; break;
case Id::INT_ADD:
case Id::INT_SUB:
trap();
} }
return is_good_id | is_projection() | cfg_is_control_projection(); return is_good_id | is_projection() | cfg_is_control_projection();
@ -1457,16 +1512,16 @@ struct Node
u8 cfg_is_control_projection() u8 cfg_is_control_projection()
{ {
return is_projection() & (data_type.id == NodeDataType::Id::CONTROL); return is_projection() & (type.id == Node::Type::Id::CONTROL);
} }
u8 is_cfg_control() u8 is_cfg_control()
{ {
switch (data_type.id) switch (type.id)
{ {
case NodeDataType::Id::CONTROL: case Node::Type::Id::CONTROL:
return 1; return 1;
case NodeDataType::Id::TUPLE: case Node::Type::Id::TUPLE:
for (Output& output : get_outputs()) for (Output& output : get_outputs())
{ {
if (output.node->cfg_is_control_projection()) if (output.node->cfg_is_control_projection())
@ -1478,8 +1533,164 @@ struct Node
return 0; return 0;
} }
} }
Node* idealize()
{
switch (id)
{
case Id::ROOT:
case Id::PROJECTION:
case Id::RETURN:
case Id::CONSTANT_INT:
case Id::INT_ADD:
case Id::INT_SUB:
return 0;
}
}
u8 is_unused()
{
return output_count == 0;
}
u8 is_dead()
{
return is_unused() & (input_count == 0) & (type.id == Node::Type::Id::INVALID);
}
void kill(Arena* arena)
{
assert(is_unused());
for (u16 i = 0; i < input_count; i += 1)
{
set_input(arena, 0, i);
}
input_count = 0;
type = {};
assert(is_dead());
}
static auto constexpr enable_peephole = 1;
Node* peephole(Arena* arena, Function* function)
{
Node::Type type = this->type = compute();
if (!enable_peephole)
{
return this;
}
if ((!is_constant()) & type.is_constant())
{
this->kill(arena);
auto gvn = function->node_count;
function->node_count += 1;
auto* constant_int = Node::add(arena, {
.s =
{
.type = type,
.input_count = 1,
.id = Node::Id::CONSTANT_INT,
},
.gvn = gvn,
.input_capacity = 1,
});
constant_int->set_input(arena, function->root_node, 0);
auto* result = constant_int->peephole(arena, function);
return result;
}
Node* n = idealize();
return n ? n : this;
}
u8 is_constant()
{
switch (id)
{
default:
return 0;
case Id::CONSTANT_INT:
return 1;
}
}
Node::Type compute()
{
switch (id)
{
case Node::Id::INT_ADD:
case Node::Id::INT_SUB:
{
auto left_type = inputs[1]->type;
auto right_type = inputs[2]->type;
if ((left_type.id == Node::Type::Id::INTEGER) & (right_type.id == Node::Type::Id::INTEGER))
{
if (left_type.is_constant() & right_type.is_constant())
{
u64 result;
switch (id)
{
case Id::ROOT:
case Id::PROJECTION:
case Id::RETURN:
case Id::CONSTANT_INT:
trap();
case Id::INT_ADD:
result = left_type.integer.constant + right_type.integer.constant;
break;
case Id::INT_SUB:
result = left_type.integer.constant - right_type.integer.constant;
break;
}
return Node::Type{
.id = Node::Type::Id::INTEGER,
.integer = {
.constant = result,
.bit_count = left_type.integer.bit_count,
.is_constant = 1,
},
};
}
}
trap();
}
case Node::Id::CONSTANT_INT:
return type;
default:
trap();
}
}
}; };
[[nodiscard]] fn Node* add_constant_integer(Arena* arena, ConstantIntData data)
{
auto* constant_int = Node::add(arena, {
.s = {
.type =
{
.id = Node::Type::Id::INTEGER,
.integer =
{
.constant = data.value,
.bit_count = data.bit_count,
.is_constant = 1,
},
},
.input_count = 1,
.id = Node::Id::CONSTANT_INT,
},
.gvn = data.gvn,
.input_capacity = 1,
});
constant_int->set_input(arena, data.input, 0);
return constant_int;
}
struct WorkList struct WorkList
{ {
using BitsetBackingType = u32; using BitsetBackingType = u32;
@ -2254,34 +2465,25 @@ fn u64 parse_hex(String string)
return value; return value;
} }
struct ConstantIntData fn u64 parse_decimal(String string)
{ {
u64 value; u64 value = 0;
Node* input; for (u8 ch : string)
u32 gvn; {
u8 bit_count; assert(((ch >= '0') & (ch <= '9')));
}; value = (value * 10) + (ch - '0');
}
[[nodiscard]] fn Node* add_constant_integer(Arena* arena, ConstantIntData data) return value;
{
auto* constant_int = Node::add(arena, {
.s = {
.type = { .id = NodeDataType::Id::INTEGER, .bit_count = data.bit_count, },
.input_count = 1,
.id = Node::Id::CONSTANT_INT,
},
.gvn = data.gvn,
.input_capacity = 1,
});
constant_int->constant_int = data.value;
constant_int->set_input(data.input, 0);
return constant_int;
} }
[[nodiscard]] fn Node* parse_constant_integer(Parser* parser, Arena* arena, String src, SemaType* type, u32 gvn, Node* input) [[nodiscard]] fn Node* parse_constant_integer(Parser* parser, Arena* arena, String src, SemaType* type, u32 gvn, Node* input)
{ {
u64 value = 0; u64 value = 0;
auto starting_ch = src[parser->i]; auto starting_index = parser->i;
auto starting_ch = src[starting_index];
if (starting_ch == '0') if (starting_ch == '0')
{ {
@ -2336,7 +2538,13 @@ struct ConstantIntData
} }
else else
{ {
trap(); while (is_decimal_digit(src[parser->i]))
{
parser->i += 1;
}
auto slice = src.slice(starting_index, parser->i);
value = parse_decimal(slice);
} }
Node* result = add_constant_integer(arena, { Node* result = add_constant_integer(arena, {
@ -2345,6 +2553,7 @@ struct ConstantIntData
.gvn = gvn, .gvn = gvn,
.bit_count = type->get_bit_count(), .bit_count = type->get_bit_count(),
}); });
return result; return result;
} }
@ -2439,6 +2648,10 @@ struct ConstantIntData
enum class CurrentOperation enum class CurrentOperation
{ {
NONE, NONE,
ADD,
ADD_ASSIGN,
SUB,
SUB_ASSIGN,
}; };
u64 iterations = 0; u64 iterations = 0;
@ -2472,8 +2685,42 @@ struct ConstantIntData
case CurrentOperation::NONE: case CurrentOperation::NONE:
previous_node = current_node; previous_node = current_node;
break; break;
case CurrentOperation::ADD:
case CurrentOperation::SUB:
{
Node::Id id;
switch (current_operation)
{
case CurrentOperation::NONE:
trap();
case CurrentOperation::ADD:
id = Node::Id::INT_ADD;
break;
case CurrentOperation::SUB:
id = Node::Id::INT_SUB;
break;
case CurrentOperation::ADD_ASSIGN:
case CurrentOperation::SUB_ASSIGN:
trap();
}
auto* binary = Node::add_from_function(arena, analyzer->function, {
.type = current_node->type,
.input_count = 3,
.id = id,
});
binary->set_input(arena, 0, 0);
binary->set_input(arena, previous_node, 1);
binary->set_input(arena, current_node, 2);
previous_node = binary;
} break;
default:
trap();
} }
previous_node = previous_node->peephole(arena, analyzer->function);
auto original_index = parser->i; auto original_index = parser->i;
u8 original = src[original_index]; u8 original = src[original_index];
@ -2484,10 +2731,40 @@ struct ConstantIntData
case parenthesis_close: case parenthesis_close:
case bracket_close: case bracket_close:
return previous_node; return previous_node;
case '+':
current_operation = CurrentOperation::ADD;
parser->i += 1;
switch (src[parser->i])
{
case '=':
current_operation = CurrentOperation::ADD_ASSIGN;
parser->i += 1;
break;
default:
break;
}
break;
case '-':
current_operation = CurrentOperation::SUB;
parser->i += 1;
switch (src[parser->i])
{
case '=':
current_operation = CurrentOperation::SUB_ASSIGN;
parser->i += 1;
break;
default:
break;
}
break;
default: default:
trap(); trap();
} }
skip_space(parser, src);
iterations += 1; iterations += 1;
} }
} }
@ -2520,12 +2797,12 @@ fn void analyze_local_block(Analyzer* analyzer, Parser* parser, Unit* unit, Aren
Function* function = analyzer->function; Function* function = analyzer->function;
Node* ret_node = Node::add_from_function(arena, function, { Node* ret_node = Node::add_from_function(arena, function, {
.type = { .id = NodeDataType::Id::CONTROL }, .type = { .id = Node::Type::Id::CONTROL },
.input_count = 2, .input_count = 2,
.id = Node::Id::RETURN, .id = Node::Id::RETURN,
}); });
ret_node->set_input(function->root_node, 0); ret_node->set_input(arena, function->root_node, 0);
ret_node->set_input(return_value, 1); ret_node->set_input(arena, return_value, 1);
} }
else else
{ {
@ -3022,7 +3299,7 @@ fn void analyze_function(Parser* parser, Thread* thread, Unit* unit, String src)
}; };
function->root_node = Node::add_from_function_dynamic(thread->arena, function, { function->root_node = Node::add_from_function_dynamic(thread->arena, function, {
.type = { .id = NodeDataType::Id::TUPLE }, .type = { .id = Node::Type::Id::TUPLE },
.input_count = 2, .input_count = 2,
.id = Node::Id::ROOT, .id = Node::Id::ROOT,
}, 4); }, 4);
@ -3030,7 +3307,7 @@ fn void analyze_function(Parser* parser, Thread* thread, Unit* unit, String src)
// TODO: revisit // TODO: revisit
// auto* control_node = root_node->project(unit, function, { // auto* control_node = root_node->project(unit, function, {
// .type = { .id = NodeDataType::Id::CONTROL }, // .type = { .id = Node::Type::Id::CONTROL },
// }); // });
// auto* memory_node = root_node->project(unit, function, {}); // auto* memory_node = root_node->project(unit, function, {});
// auto* pointer_node = root_node->project(unit, function, {}); // auto* pointer_node = root_node->project(unit, function, {});
@ -3048,7 +3325,7 @@ fn void analyze_function(Parser* parser, Thread* thread, Unit* unit, String src)
// TODO: revisit // TODO: revisit
// Node* ret_node = Node::add_from_function(unit, function); // Node* ret_node = Node::add_from_function(unit, function);
// ret_node->id = Node::Id::RETURN; // ret_node->id = Node::Id::RETURN;
// ret_node->data_type = { .id = NodeDataType::Id::CONTROL }; // ret_node->data_type = { .id = Node::Type::Id::CONTROL };
// ret_node->reallocate_edges(unit, 4); // ret_node->reallocate_edges(unit, 4);
// ret_node->input_count = 2; // ret_node->input_count = 2;
// ret_node->set_input(unit, function, root_node, 0); // ret_node->set_input(unit, function, root_node, 0);
@ -3186,7 +3463,7 @@ global Instance instance;
// continue; // continue;
// } // }
// //
// if (node->data_type.id == NodeDataType::Id::MEMORY) // if (node->data_type.id == Node::Type::Id::MEMORY)
// { // {
// trap(); // trap();
// } // }
@ -3219,6 +3496,7 @@ global Instance instance;
String test_file_paths[] = { String test_file_paths[] = {
strlit("tests/first/main.nat"), strlit("tests/first/main.nat"),
strlit("tests/constant_prop/main.nat"),
}; };
extern "C" void entry_point() extern "C" void entry_point()

View File

@ -0,0 +1,4 @@
fn[cc(.c)] main [export] () s32
{
return 2 + 4 - 1 - 5;
}