Merge pull request #3 from birth-software/constant-prop

Implement constant propagation
This commit is contained in:
David 2024-06-30 10:04:49 +02:00 committed by GitHub
commit 2d06e4a632
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 399 additions and 117 deletions

View File

@ -1146,68 +1146,6 @@ enum class Side : u8
right,
};
struct NodeDataType
{
enum class Id : u8
{
VOID,
INTEGER,
TUPLE,
CONTROL,
MEMORY,
POINTER,
};
Id id;
u8 bit_count:5;
};
union AbiInfoPayload
{
NodeDataType direct;
NodeDataType direct_pair[2];
NodeDataType direct_coerce;
struct
{
NodeDataType type;
u32 alignment;
} indirect;
};
typedef union AbiInfoPayload AbiInfoPayload;
struct AbiInfoAttributes
{
u8 by_reg: 1;
u8 zero_extend: 1;
u8 sign_extend: 1;
u8 realign: 1;
u8 by_value: 1;
};
typedef struct AbiInfoAttributes AbiInfoAttributes;
struct AbiInfo
{
AbiInfoPayload payload;
u16 indices[2];
AbiInfoAttributes attributes;
AbiInfoKind kind;
};
struct FunctionPrototype
{
AbiInfo* argument_type_abis; // The count for this array is "original_argument_count", not "abi_argument_count"
SemaType** original_argument_types;
// TODO: are these needed?
// Node::DataType* abi_argument_types;
// u32 abi_argument_count;
SemaType* original_return_type;
AbiInfo return_type_abi;
u32 original_argument_count;
// TODO: is this needed?
// Node::DataType abi_return_type;
u8 varags:1;
};
struct Function;
global auto constexpr void_type_index = 0;
global auto constexpr noreturn_type_index = 1;
@ -1218,6 +1156,7 @@ global auto constexpr integer_type_offset = 5;
global auto constexpr integer_type_count = 64 * 2;
global auto constexpr builtin_type_count = integer_type_count + integer_type_offset + 1;
struct Function;
struct Thread
{
Arena* arena;
@ -1242,30 +1181,123 @@ struct Unit
}
};
struct Node;
struct FunctionPrototype;
struct NodeType
{
enum class Id: u8
{
INVALID,
BOTTOM,
TOP,
VOID,
INTEGER,
TUPLE,
CONTROL,
MEMORY,
POINTER,
};
Id id;
union
{
struct
{
u64 constant;
u8 bit_count;
u8 is_constant;
} integer;
};
u8 equal(NodeType other)
{
if (id != other.id)
{
return 0;
}
switch (id)
{
default: trap();
}
}
u8 is_constant()
{
switch (id)
{
case Id::VOID:
case Id::INTEGER:
return 1;
default:
trap();
}
}
};
union AbiInfoPayload
{
NodeType direct;
NodeType direct_pair[2];
NodeType direct_coerce;
struct
{
NodeType type;
u32 alignment;
} indirect;
};
typedef union AbiInfoPayload AbiInfoPayload;
struct AbiInfoAttributes
{
u8 by_reg: 1;
u8 zero_extend: 1;
u8 sign_extend: 1;
u8 realign: 1;
u8 by_value: 1;
};
typedef struct AbiInfoAttributes AbiInfoAttributes;
struct AbiInfo
{
AbiInfoPayload payload;
u16 indices[2];
AbiInfoAttributes attributes;
AbiInfoKind kind;
};
struct Function
{
struct Prototype
{
AbiInfo* argument_type_abis; // The count for this array is "original_argument_count", not "abi_argument_count"
SemaType** original_argument_types;
// TODO: are these needed?
// Node::DataType* abi_argument_types;
// u32 abi_argument_count;
SemaType* original_return_type;
AbiInfo return_type_abi;
u32 original_argument_count;
// TODO: is this needed?
// Node::DataType abi_return_type;
u8 varags:1;
};
Symbol symbol;
Node* root_node;
Node** parameters;
FunctionPrototype prototype;
Function::Prototype prototype;
u32 node_count;
u16 parameter_count;
};
struct ProjectionData
struct ConstantIntData
{
NodeDataType type;
u16 index;
u64 value;
Node* input;
u32 gvn;
u8 bit_count;
};
struct Output
{
Node* node;
u16 slot;
};
[[nodiscard]] fn Node* add_constant_integer(Arena* arena, ConstantIntData data);
// This is a node in the "sea of nodes" sense:
// https://en.wikipedia.org/wiki/Sea_of_nodes
@ -1277,18 +1309,27 @@ struct Node
PROJECTION,
RETURN,
CONSTANT_INT,
INT_ADD,
INT_SUB,
};
static_assert(sizeof(NodeDataType) <= 2);
using Type = NodeType;
struct Output
{
Node* node;
u16 slot;
};
Node** inputs;
Output* outputs;
u32 gvn;
Type type;
u16 input_count;
u16 input_capacity;
u16 output_count;
u16 output_capacity;
NodeDataType data_type;
Id id;
union
@ -1297,7 +1338,6 @@ struct Node
{
u32 index;
} projection;
u64 constant_int;
};
forceinline Slice<Node*> get_inputs()
@ -1318,7 +1358,7 @@ struct Node
struct NodeData
{
NodeDataType type;
Type type;
u16 input_count;
Id id;
};
@ -1339,11 +1379,11 @@ struct Node
.inputs = arena->allocate_many<Node*>(data.input_capacity),
.outputs = arena->allocate_many<Output>(output_capacity),
.gvn = data.gvn,
.type = data.s.type,
.input_count = data.s.input_count,
.input_capacity = data.input_capacity,
.output_count = output_count,
.output_capacity = output_capacity,
.data_type = data.s.type,
.id = data.s.id,
};
@ -1370,41 +1410,53 @@ struct Node
{
return add_from_function_dynamic(arena, function, data, data.input_count);
}
struct ProjectionData
{
Node::Type type;
u16 index;
};
[[nodiscard]] Node* project(Arena* arena, Function* function, ProjectionData data)
{
assert(data_type.id == NodeDataType::Id::TUPLE);
assert(type.id == Type::Id::TUPLE);
Node* projection = Node::add_from_function(arena, function, {
.input_count = 1,
});
assert(projection != this);
projection->id = Node::Id::PROJECTION;
projection->data_type = data.type;
projection->type = data.type;
// projection->reallocate_edges(unit, 4);
projection->input_count = 1;
projection->set_input(this, 0);
projection->set_input(arena, this, 0);
projection->projection.index = data.index;
return projection;
}
void set_input(Node* input, u16 slot)
void set_input(Arena* arena, Node* input, u16 slot)
{
assert(slot < input_count);
remove_output(slot);
inputs[slot] = input;
if (input)
{
add_output(input, slot);
add_output(arena, input, slot);
}
}
void add_output(Node* input, u16 slot)
void add_output(Arena* arena, Node* input, u16 slot)
{
if (input->output_count >= input->output_capacity)
{
trap();
auto new_capacity = max<u32>(input->output_count, (u32)input->output_capacity * 2);
assert(new_capacity <= 0xffff);
auto* new_array = arena->allocate_many<Output>(new_capacity);
memcpy(new_array, input->outputs, sizeof(Output) * input->output_count);
memset(new_array + input->output_count, 0, sizeof(Output) * (new_capacity - input->output_count));
input->outputs = new_array;
input->output_capacity = new_capacity;
}
auto index = input->output_count;
@ -1439,6 +1491,9 @@ struct Node
case Id::PROJECTION:
case Id::CONSTANT_INT:
break;
case Id::INT_ADD:
case Id::INT_SUB:
trap();
}
return is_good_id | is_projection() | cfg_is_control_projection();
@ -1457,16 +1512,16 @@ struct Node
u8 cfg_is_control_projection()
{
return is_projection() & (data_type.id == NodeDataType::Id::CONTROL);
return is_projection() & (type.id == Node::Type::Id::CONTROL);
}
u8 is_cfg_control()
{
switch (data_type.id)
switch (type.id)
{
case NodeDataType::Id::CONTROL:
case Node::Type::Id::CONTROL:
return 1;
case NodeDataType::Id::TUPLE:
case Node::Type::Id::TUPLE:
for (Output& output : get_outputs())
{
if (output.node->cfg_is_control_projection())
@ -1478,8 +1533,164 @@ struct Node
return 0;
}
}
Node* idealize()
{
switch (id)
{
case Id::ROOT:
case Id::PROJECTION:
case Id::RETURN:
case Id::CONSTANT_INT:
case Id::INT_ADD:
case Id::INT_SUB:
return 0;
}
}
u8 is_unused()
{
return output_count == 0;
}
u8 is_dead()
{
return is_unused() & (input_count == 0) & (type.id == Node::Type::Id::INVALID);
}
void kill(Arena* arena)
{
assert(is_unused());
for (u16 i = 0; i < input_count; i += 1)
{
set_input(arena, 0, i);
}
input_count = 0;
type = {};
assert(is_dead());
}
static auto constexpr enable_peephole = 1;
Node* peephole(Arena* arena, Function* function)
{
Node::Type type = this->type = compute();
if (!enable_peephole)
{
return this;
}
if ((!is_constant()) & type.is_constant())
{
this->kill(arena);
auto gvn = function->node_count;
function->node_count += 1;
auto* constant_int = Node::add(arena, {
.s =
{
.type = type,
.input_count = 1,
.id = Node::Id::CONSTANT_INT,
},
.gvn = gvn,
.input_capacity = 1,
});
constant_int->set_input(arena, function->root_node, 0);
auto* result = constant_int->peephole(arena, function);
return result;
}
Node* n = idealize();
return n ? n : this;
}
u8 is_constant()
{
switch (id)
{
default:
return 0;
case Id::CONSTANT_INT:
return 1;
}
}
Node::Type compute()
{
switch (id)
{
case Node::Id::INT_ADD:
case Node::Id::INT_SUB:
{
auto left_type = inputs[1]->type;
auto right_type = inputs[2]->type;
if ((left_type.id == Node::Type::Id::INTEGER) & (right_type.id == Node::Type::Id::INTEGER))
{
if (left_type.is_constant() & right_type.is_constant())
{
u64 result;
switch (id)
{
case Id::ROOT:
case Id::PROJECTION:
case Id::RETURN:
case Id::CONSTANT_INT:
trap();
case Id::INT_ADD:
result = left_type.integer.constant + right_type.integer.constant;
break;
case Id::INT_SUB:
result = left_type.integer.constant - right_type.integer.constant;
break;
}
return Node::Type{
.id = Node::Type::Id::INTEGER,
.integer = {
.constant = result,
.bit_count = left_type.integer.bit_count,
.is_constant = 1,
},
};
}
}
trap();
}
case Node::Id::CONSTANT_INT:
return type;
default:
trap();
}
}
};
[[nodiscard]] fn Node* add_constant_integer(Arena* arena, ConstantIntData data)
{
auto* constant_int = Node::add(arena, {
.s = {
.type =
{
.id = Node::Type::Id::INTEGER,
.integer =
{
.constant = data.value,
.bit_count = data.bit_count,
.is_constant = 1,
},
},
.input_count = 1,
.id = Node::Id::CONSTANT_INT,
},
.gvn = data.gvn,
.input_capacity = 1,
});
constant_int->set_input(arena, data.input, 0);
return constant_int;
}
struct WorkList
{
using BitsetBackingType = u32;
@ -2254,34 +2465,25 @@ fn u64 parse_hex(String string)
return value;
}
struct ConstantIntData
fn u64 parse_decimal(String string)
{
u64 value;
Node* input;
u32 gvn;
u8 bit_count;
};
u64 value = 0;
for (u8 ch : string)
{
assert(((ch >= '0') & (ch <= '9')));
value = (value * 10) + (ch - '0');
}
[[nodiscard]] fn Node* add_constant_integer(Arena* arena, ConstantIntData data)
{
auto* constant_int = Node::add(arena, {
.s = {
.type = { .id = NodeDataType::Id::INTEGER, .bit_count = data.bit_count, },
.input_count = 1,
.id = Node::Id::CONSTANT_INT,
},
.gvn = data.gvn,
.input_capacity = 1,
});
constant_int->constant_int = data.value;
constant_int->set_input(data.input, 0);
return constant_int;
return value;
}
[[nodiscard]] fn Node* parse_constant_integer(Parser* parser, Arena* arena, String src, SemaType* type, u32 gvn, Node* input)
{
u64 value = 0;
auto starting_ch = src[parser->i];
auto starting_index = parser->i;
auto starting_ch = src[starting_index];
if (starting_ch == '0')
{
@ -2336,7 +2538,13 @@ struct ConstantIntData
}
else
{
trap();
while (is_decimal_digit(src[parser->i]))
{
parser->i += 1;
}
auto slice = src.slice(starting_index, parser->i);
value = parse_decimal(slice);
}
Node* result = add_constant_integer(arena, {
@ -2345,6 +2553,7 @@ struct ConstantIntData
.gvn = gvn,
.bit_count = type->get_bit_count(),
});
return result;
}
@ -2439,6 +2648,10 @@ struct ConstantIntData
enum class CurrentOperation
{
NONE,
ADD,
ADD_ASSIGN,
SUB,
SUB_ASSIGN,
};
u64 iterations = 0;
@ -2472,8 +2685,42 @@ struct ConstantIntData
case CurrentOperation::NONE:
previous_node = current_node;
break;
case CurrentOperation::ADD:
case CurrentOperation::SUB:
{
Node::Id id;
switch (current_operation)
{
case CurrentOperation::NONE:
trap();
case CurrentOperation::ADD:
id = Node::Id::INT_ADD;
break;
case CurrentOperation::SUB:
id = Node::Id::INT_SUB;
break;
case CurrentOperation::ADD_ASSIGN:
case CurrentOperation::SUB_ASSIGN:
trap();
}
auto* binary = Node::add_from_function(arena, analyzer->function, {
.type = current_node->type,
.input_count = 3,
.id = id,
});
binary->set_input(arena, 0, 0);
binary->set_input(arena, previous_node, 1);
binary->set_input(arena, current_node, 2);
previous_node = binary;
} break;
default:
trap();
}
previous_node = previous_node->peephole(arena, analyzer->function);
auto original_index = parser->i;
u8 original = src[original_index];
@ -2484,10 +2731,40 @@ struct ConstantIntData
case parenthesis_close:
case bracket_close:
return previous_node;
case '+':
current_operation = CurrentOperation::ADD;
parser->i += 1;
switch (src[parser->i])
{
case '=':
current_operation = CurrentOperation::ADD_ASSIGN;
parser->i += 1;
break;
default:
break;
}
break;
case '-':
current_operation = CurrentOperation::SUB;
parser->i += 1;
switch (src[parser->i])
{
case '=':
current_operation = CurrentOperation::SUB_ASSIGN;
parser->i += 1;
break;
default:
break;
}
break;
default:
trap();
}
skip_space(parser, src);
iterations += 1;
}
}
@ -2520,12 +2797,12 @@ fn void analyze_local_block(Analyzer* analyzer, Parser* parser, Unit* unit, Aren
Function* function = analyzer->function;
Node* ret_node = Node::add_from_function(arena, function, {
.type = { .id = NodeDataType::Id::CONTROL },
.type = { .id = Node::Type::Id::CONTROL },
.input_count = 2,
.id = Node::Id::RETURN,
});
ret_node->set_input(function->root_node, 0);
ret_node->set_input(return_value, 1);
ret_node->set_input(arena, function->root_node, 0);
ret_node->set_input(arena, return_value, 1);
}
else
{
@ -3022,7 +3299,7 @@ fn void analyze_function(Parser* parser, Thread* thread, Unit* unit, String src)
};
function->root_node = Node::add_from_function_dynamic(thread->arena, function, {
.type = { .id = NodeDataType::Id::TUPLE },
.type = { .id = Node::Type::Id::TUPLE },
.input_count = 2,
.id = Node::Id::ROOT,
}, 4);
@ -3030,7 +3307,7 @@ fn void analyze_function(Parser* parser, Thread* thread, Unit* unit, String src)
// TODO: revisit
// auto* control_node = root_node->project(unit, function, {
// .type = { .id = NodeDataType::Id::CONTROL },
// .type = { .id = Node::Type::Id::CONTROL },
// });
// auto* memory_node = root_node->project(unit, function, {});
// auto* pointer_node = root_node->project(unit, function, {});
@ -3048,7 +3325,7 @@ fn void analyze_function(Parser* parser, Thread* thread, Unit* unit, String src)
// TODO: revisit
// Node* ret_node = Node::add_from_function(unit, function);
// ret_node->id = Node::Id::RETURN;
// ret_node->data_type = { .id = NodeDataType::Id::CONTROL };
// ret_node->data_type = { .id = Node::Type::Id::CONTROL };
// ret_node->reallocate_edges(unit, 4);
// ret_node->input_count = 2;
// ret_node->set_input(unit, function, root_node, 0);
@ -3186,7 +3463,7 @@ global Instance instance;
// continue;
// }
//
// if (node->data_type.id == NodeDataType::Id::MEMORY)
// if (node->data_type.id == Node::Type::Id::MEMORY)
// {
// trap();
// }
@ -3219,6 +3496,7 @@ global Instance instance;
String test_file_paths[] = {
strlit("tests/first/main.nat"),
strlit("tests/constant_prop/main.nat"),
};
extern "C" void entry_point()

View File

@ -0,0 +1,4 @@
fn[cc(.c)] main [export] () s32
{
return 2 + 4 - 1 - 5;
}