Implement comparisons

This commit is contained in:
David Gonzalez Martin 2024-07-06 17:41:42 +02:00
parent f6ddf827f2
commit ff258e3df3
2 changed files with 248 additions and 83 deletions

View File

@ -997,7 +997,7 @@ struct Arena
u64 commited;
u64 commit_position;
u64 granularity;
u8 reserved[4 * 8];
u8 reserved[4 * 8] = {};
global auto constexpr minimum_granularity = KB(4);
global auto constexpr middle_granularity = MB(2);
@ -1472,15 +1472,13 @@ struct NodeType
struct
{
u64 constant;
u8 bit_count;
u8 is_constant;
} integer;
struct
{
Slice<NodeType> types;
} multi;
};
} payload = {};
u8 is_simple()
{
@ -1506,6 +1504,8 @@ struct NodeType
switch (id)
{
case NodeType::Id::INTEGER:
return (payload.integer.is_constant == other.payload.integer.is_constant) & (payload.integer.constant == other.payload.integer.constant);
default:
trap();
}
@ -1518,7 +1518,7 @@ struct NodeType
case Id::VOID:
trap();
case Id::INTEGER:
return integer.is_constant;
return payload.integer.is_constant;
case Id::CONTROL:
case Id::MULTIVALUE:
case Id::BOTTOM:
@ -1572,7 +1572,7 @@ struct NodeType
}
assert(is_constant() & other.is_constant());
if (integer.constant == other.integer.constant)
if (payload.integer.constant == other.payload.integer.constant)
{
trap();
}
@ -1589,37 +1589,43 @@ struct NodeType
u8 is_bot()
{
assert(id == Id::INTEGER);
return !integer.is_constant & (integer.constant == 1);
return !payload.integer.is_constant & (payload.integer.constant == 1);
}
u8 is_top()
{
assert(id == Id::INTEGER);
return !integer.is_constant & (integer.constant == 0);
return !payload.integer.is_constant & (payload.integer.constant == 0);
}
};
may_be_unused global auto constexpr integer_top = NodeType{
.id = NodeType::Id::TOP,
.integer = {
.constant = 0,
.is_constant = 0,
.payload = {
.integer = {
.constant = 0,
.is_constant = 0,
},
},
};
may_be_unused global auto constexpr integer_bot = NodeType{
.id = NodeType::Id::TOP,
.integer = {
.constant = 1,
.is_constant = 0,
.payload = {
.integer = {
.constant = 1,
.is_constant = 0,
},
},
};
may_be_unused global auto constexpr integer_zero = NodeType{
.id = NodeType::Id::TOP,
.integer = {
.constant = 0,
.is_constant = 1,
.payload = {
.integer = {
.constant = 0,
.is_constant = 1,
},
},
};
@ -1630,7 +1636,7 @@ struct SemaType
SemaTypeId id : type_id_bit_count;
u32 resolved: 1;
u32 flags: type_flags_bit_count;
u32 reserved;
u32 reserved = 0;
String name;
u8 get_bit_count()
@ -1656,9 +1662,12 @@ struct SemaType
case SemaTypeId::INTEGER:
return NodeType{
.id = NodeType::Id::INTEGER,
.integer = {
.bit_count = get_bit_count(),
.is_constant = 0,
.payload = {
.integer = {
.constant = 0,
// .bit_count = get_bit_count(),
.is_constant = 0,
},
},
};
case SemaTypeId::ARRAY:
@ -1726,8 +1735,8 @@ struct Function;
struct Thread
{
Arena* arena;
PinnedArray<Function> functions;
u32 node_count;
PinnedArray<Function> functions = {};
u32 node_count = 0;
};
struct Unit
@ -1776,8 +1785,8 @@ typedef struct AbiInfoAttributes AbiInfoAttributes;
struct AbiInfo
{
AbiInfoPayload payload;
u16 indices[2];
AbiInfoAttributes attributes;
u16 indices[2] = {};
AbiInfoAttributes attributes = {};
AbiInfoKind kind;
};
@ -1811,7 +1820,6 @@ struct ConstantIntData
{
u64 value;
Node* input;
u32 gvn;
u8 bit_count;
};
@ -1827,8 +1835,15 @@ struct Node
PROJECTION,
RETURN,
CONSTANT_INT,
INT_ADD,
INT_SUB,
INTEGER_ADD,
INTEGER_SUB,
INTEGER_COMPARE_EQUAL,
INTEGER_COMPARE_NOT_EQUAL,
INTEGER_COMPARE_LESS,
INTEGER_COMPARE_LESS_EQUAL,
INTEGER_COMPARE_GREATER,
INTEGER_COMPARE_GREATER_EQUAL,
SCOPE,
SYMBOL_FUNCTION,
CALL,
@ -1858,9 +1873,9 @@ struct Node
Type args;
} root;
Symbol* symbol;
};
} payload;
u8 padding[40];
u8 padding[40] = {};
forceinline Slice<Node*> get_inputs()
{
@ -1896,6 +1911,7 @@ struct Node
.outputs = {},
.gvn = gvn,
.id = data.id,
.payload = {},
};
node->inputs.append(data.inputs);
@ -1970,14 +1986,21 @@ struct Node
case Id::PROJECTION:
case Id::CONSTANT_INT:
break;
case Id::INT_ADD:
case Id::INT_SUB:
case Id::INTEGER_ADD:
case Id::INTEGER_SUB:
trap();
case Id::SCOPE:
trap();
case Id::SYMBOL_FUNCTION:
case Id::CALL:
trap();
case Id::INTEGER_COMPARE_EQUAL:
case Id::INTEGER_COMPARE_NOT_EQUAL:
case Id::INTEGER_COMPARE_LESS:
case Id::INTEGER_COMPARE_LESS_EQUAL:
case Id::INTEGER_COMPARE_GREATER:
case Id::INTEGER_COMPARE_GREATER_EQUAL:
trap();
}
return is_good_id | is_projection() | cfg_is_control_projection();
@ -2022,7 +2045,7 @@ struct Node
{
switch (id)
{
case Id::INT_SUB:
case Id::INTEGER_SUB:
if (inputs[1] == inputs[2])
{
trap();
@ -2035,7 +2058,7 @@ struct Node
case Id::PROJECTION:
case Id::RETURN:
case Id::CONSTANT_INT:
case Id::INT_ADD:
case Id::INTEGER_ADD:
return 0;
case Id::SCOPE:
trap();
@ -2043,6 +2066,13 @@ struct Node
case Id::SYMBOL_FUNCTION:
case Id::CALL:
return 0;
case Id::INTEGER_COMPARE_EQUAL:
case Id::INTEGER_COMPARE_NOT_EQUAL:
case Id::INTEGER_COMPARE_LESS:
case Id::INTEGER_COMPARE_LESS_EQUAL:
case Id::INTEGER_COMPARE_GREATER:
case Id::INTEGER_COMPARE_GREATER_EQUAL:
trap();
}
}
@ -2140,9 +2170,15 @@ struct Node
switch (id)
{
case Node::Id::ROOT:
return root.args;
case Node::Id::INT_ADD:
case Node::Id::INT_SUB:
return payload.root.args;
case Node::Id::INTEGER_ADD:
case Node::Id::INTEGER_SUB:
case Node::Id::INTEGER_COMPARE_EQUAL:
case Node::Id::INTEGER_COMPARE_NOT_EQUAL:
case Node::Id::INTEGER_COMPARE_LESS:
case Node::Id::INTEGER_COMPARE_LESS_EQUAL:
case Node::Id::INTEGER_COMPARE_GREATER:
case Node::Id::INTEGER_COMPARE_GREATER_EQUAL:
{
auto left_type = inputs[1]->type;
auto right_type = inputs[2]->type;
@ -2161,20 +2197,40 @@ struct Node
case Id::SYMBOL_FUNCTION:
case Id::CALL:
trap();
case Id::INT_ADD:
result = left_type.integer.constant + right_type.integer.constant;
case Id::INTEGER_ADD:
result = left_type.payload.integer.constant + right_type.payload.integer.constant;
break;
case Id::INT_SUB:
result = left_type.integer.constant - right_type.integer.constant;
case Id::INTEGER_SUB:
result = left_type.payload.integer.constant - right_type.payload.integer.constant;
break;
case Id::INTEGER_COMPARE_EQUAL:
result = left_type.payload.integer.constant == right_type.payload.integer.constant;
break;
case Id::INTEGER_COMPARE_NOT_EQUAL:
result = left_type.payload.integer.constant != right_type.payload.integer.constant;
break;
case Id::INTEGER_COMPARE_LESS:
result = left_type.payload.integer.constant < right_type.payload.integer.constant;
break;
case Id::INTEGER_COMPARE_LESS_EQUAL:
result = left_type.payload.integer.constant <= right_type.payload.integer.constant;
break;
case Id::INTEGER_COMPARE_GREATER:
result = left_type.payload.integer.constant > right_type.payload.integer.constant;
break;
case Id::INTEGER_COMPARE_GREATER_EQUAL:
result = left_type.payload.integer.constant >= right_type.payload.integer.constant;
break;
}
return Node::Type{
.id = Node::Type::Id::INTEGER,
.integer = {
.constant = result,
.bit_count = left_type.integer.bit_count,
.is_constant = 1,
.payload = {
.integer = {
.constant = result,
// .bit_count = left_type.payload.integer.bit_count,
.is_constant = 1,
},
},
};
}
@ -2195,7 +2251,7 @@ struct Node
auto* control_node = inputs[0];
if (control_node->type.id == NodeType::Id::MULTIVALUE)
{
auto type = control_node->type.multi.types[this->projection.index];
auto type = control_node->type.payload.multi.types[this->payload.projection.index];
return type;
}
else
@ -2217,9 +2273,11 @@ struct Node
types.append_one(inputs[1]->type);
return Type{
.id = Node::Type::Id::MULTIVALUE,
.payload = {
.multi = {
.types = types.slice(),
},
},
};
}
default:
@ -2235,8 +2293,8 @@ struct Node
.inputs = { .pointer = &function->root_node, .length = 1 },
.id = Node::Id::PROJECTION,
});
projection->projection.index = index;
projection->projection.name = label;
projection->payload.projection.index = index;
projection->payload.projection.name = label;
return projection;
}
@ -2267,11 +2325,12 @@ static_assert(page_size % sizeof(Node) == 0);
.type =
{
.id = Node::Type::Id::INTEGER,
.integer =
{
.constant = data.value,
.bit_count = data.bit_count,
.is_constant = 1,
.payload = {
.integer = {
.constant = data.value,
// .bit_count = data.bit_count,
.is_constant = 1,
},
},
},
.inputs = { .pointer = &data.input, .length = 1 },
@ -2491,6 +2550,7 @@ fn void unit_initialize(Unit* unit)
// .node_arena = Arena::init(Arena::default_size, Arena::minimum_granularity, KB(64)),
// .type_arena = type_arena,
.builtin_types = builtin_types,
.generate_debug_information = 1,
};
builtin_types[void_type_index] = {
@ -2498,6 +2558,7 @@ fn void unit_initialize(Unit* unit)
.alignment = 1,
.id = SemaTypeId::VOID,
.resolved = 1,
.flags = 0,
.name = strlit("void"),
};
builtin_types[noreturn_type_index] = {
@ -2505,6 +2566,7 @@ fn void unit_initialize(Unit* unit)
.alignment = 1,
.id = SemaTypeId::NORETURN,
.resolved = 1,
.flags = 0,
.name = strlit("noreturn"),
};
builtin_types[opaque_pointer_type_index] = {
@ -2512,6 +2574,7 @@ fn void unit_initialize(Unit* unit)
.alignment = 8,
.id = SemaTypeId::POINTER,
.resolved = 1,
.flags = 0,
.name = strlit("*any"),
};
// TODO: float types
@ -2747,7 +2810,7 @@ struct File
String path;
String source_code;
FileStatus status;
Hashmap<String, Node> symbols;
Hashmap<String, Node> symbols = {};
};
fn File* add_file(Arena* arena, String file_path)
@ -2755,6 +2818,8 @@ fn File* add_file(Arena* arena, String file_path)
auto* file = arena->allocate_one<File>();
*file = {
.path = file_path,
.source_code = {},
.status = FILE_STATUS_ADDED,
};
return file;
}
@ -3077,7 +3142,7 @@ fn Node* scope_update_extended(Node* scope, String name, Node* node, s32 nesting
}
// TODO: avoid recursion
auto& map = scope->scope.stack[nesting_level];
auto& map = scope->payload.scope.stack[nesting_level];
if (auto index = map.get(name))
{
auto* old = scope->get_inputs()[*index];
@ -3103,7 +3168,7 @@ fn Node* scope_update_extended(Node* scope, String name, Node* node, s32 nesting
fn Node* scope_lookup(Analyzer* analyzer, String name)
{
if (auto* node = scope_update_extended(analyzer->scope, name, nullptr, analyzer->scope->scope.stack.length - 1))
if (auto* node = scope_update_extended(analyzer->scope, name, nullptr, analyzer->scope->payload.scope.stack.length - 1))
{
return node;
}
@ -3238,6 +3303,7 @@ fn Node* scope_lookup(Analyzer* analyzer, String name)
argument_nodes.append_one(node);
Node* call_node = Node::add(thread, {
.type = {},
.inputs = argument_nodes.slice(),
.id = Node::Id::CALL,
})->peephole(thread, function);
@ -3258,10 +3324,17 @@ fn Node* scope_lookup(Analyzer* analyzer, String name)
enum class CurrentOperation
{
NONE,
ADD,
ADD_ASSIGN,
SUB,
SUB_ASSIGN,
ASSIGN,
INTEGER_ADD,
INTEGER_ADD_ASSIGN,
INTEGER_SUB,
INTEGER_SUB_ASSIGN,
INTEGER_COMPARE_EQUAL,
INTEGER_COMPARE_NOT_EQUAL,
INTEGER_COMPARE_LESS,
INTEGER_COMPARE_LESS_EQUAL,
INTEGER_COMPARE_GREATER,
INTEGER_COMPARE_GREATER_EQUAL,
};
u64 iterations = 0;
@ -3295,22 +3368,31 @@ fn Node* scope_lookup(Analyzer* analyzer, String name)
case CurrentOperation::NONE:
previous_node = current_node;
break;
case CurrentOperation::ADD:
case CurrentOperation::SUB:
case CurrentOperation::INTEGER_ADD:
case CurrentOperation::INTEGER_SUB:
{
Node::Id id;
switch (current_operation)
{
case CurrentOperation::NONE:
trap();
case CurrentOperation::ADD:
id = Node::Id::INT_ADD;
case CurrentOperation::INTEGER_ADD:
id = Node::Id::INTEGER_ADD;
break;
case CurrentOperation::SUB:
id = Node::Id::INT_SUB;
case CurrentOperation::INTEGER_SUB:
id = Node::Id::INTEGER_SUB;
break;
case CurrentOperation::ADD_ASSIGN:
case CurrentOperation::SUB_ASSIGN:
case CurrentOperation::INTEGER_ADD_ASSIGN:
case CurrentOperation::INTEGER_SUB_ASSIGN:
trap();
case CurrentOperation::INTEGER_COMPARE_EQUAL:
trap();
case CurrentOperation::ASSIGN:
case CurrentOperation::INTEGER_COMPARE_NOT_EQUAL:
case CurrentOperation::INTEGER_COMPARE_LESS:
case CurrentOperation::INTEGER_COMPARE_LESS_EQUAL:
case CurrentOperation::INTEGER_COMPARE_GREATER:
case CurrentOperation::INTEGER_COMPARE_GREATER_EQUAL:
trap();
}
@ -3328,9 +3410,57 @@ fn Node* scope_lookup(Analyzer* analyzer, String name)
previous_node = binary;
} break;
default:
trap();
}
case CurrentOperation::INTEGER_COMPARE_EQUAL:
case CurrentOperation::INTEGER_COMPARE_NOT_EQUAL:
case CurrentOperation::INTEGER_COMPARE_LESS:
case CurrentOperation::INTEGER_COMPARE_LESS_EQUAL:
case CurrentOperation::INTEGER_COMPARE_GREATER:
case CurrentOperation::INTEGER_COMPARE_GREATER_EQUAL:
{
Node::Id id;
switch (current_operation)
{
case CurrentOperation::INTEGER_COMPARE_EQUAL:
id = Node::Id::INTEGER_COMPARE_EQUAL;
break;
case CurrentOperation::INTEGER_COMPARE_NOT_EQUAL:
id = Node::Id::INTEGER_COMPARE_NOT_EQUAL;
break;
case CurrentOperation::INTEGER_COMPARE_LESS:
id = Node::Id::INTEGER_COMPARE_LESS;
break;
case CurrentOperation::INTEGER_COMPARE_LESS_EQUAL:
id = Node::Id::INTEGER_COMPARE_LESS_EQUAL;
break;
case CurrentOperation::INTEGER_COMPARE_GREATER:
id = Node::Id::INTEGER_COMPARE_GREATER;
break;
case CurrentOperation::INTEGER_COMPARE_GREATER_EQUAL:
id = Node::Id::INTEGER_COMPARE_GREATER_EQUAL;
break;
default:
trap();
}
Node* inputs[] = {
0,
previous_node,
current_node,
};
auto* binary = Node::add(thread, {
.type = current_node->type,
.inputs = { .pointer = inputs, .length = array_length(inputs), },
.id = id,
});
previous_node = binary;
} break;
case CurrentOperation::ASSIGN:
case CurrentOperation::INTEGER_ADD_ASSIGN:
case CurrentOperation::INTEGER_SUB_ASSIGN:
trap();
}
previous_node = previous_node->peephole(thread, analyzer->function);
@ -3345,13 +3475,13 @@ fn Node* scope_lookup(Analyzer* analyzer, String name)
case bracket_close:
return previous_node;
case '+':
current_operation = CurrentOperation::ADD;
current_operation = CurrentOperation::INTEGER_ADD;
parser->i += 1;
switch (src[parser->i])
{
case '=':
current_operation = CurrentOperation::ADD_ASSIGN;
current_operation = CurrentOperation::INTEGER_ADD_ASSIGN;
parser->i += 1;
break;
default:
@ -3359,13 +3489,27 @@ fn Node* scope_lookup(Analyzer* analyzer, String name)
}
break;
case '-':
current_operation = CurrentOperation::SUB;
current_operation = CurrentOperation::INTEGER_SUB;
parser->i += 1;
switch (src[parser->i])
{
case '=':
current_operation = CurrentOperation::SUB_ASSIGN;
current_operation = CurrentOperation::INTEGER_SUB_ASSIGN;
parser->i += 1;
break;
default:
break;
}
break;
case '=':
current_operation = CurrentOperation::ASSIGN;
parser->i += 1;
switch (src[parser->i])
{
case '=':
current_operation = CurrentOperation::INTEGER_COMPARE_EQUAL;
parser->i += 1;
break;
default:
@ -3389,17 +3533,17 @@ fn Node* scope_lookup(Analyzer* analyzer, String name)
fn void push_scope(Analyzer* analyzer)
{
analyzer->scope->scope.stack.append_one({});
analyzer->scope->payload.scope.stack.append_one({});
}
fn void pop_scope(Analyzer* analyzer)
{
analyzer->scope->scope.stack.pop();
analyzer->scope->payload.scope.stack.pop();
}
fn Node* define_variable(Analyzer* analyzer, String name, Node* node)
{
auto* stack = &analyzer->scope->scope.stack;
auto* stack = &analyzer->scope->payload.scope.stack;
assert(stack->length);
auto* last = &stack->pointer[stack->length - 1];
@ -3461,7 +3605,7 @@ fn Node* analyze_local_block(Analyzer* analyzer, Parser* parser, Unit* unit, Thr
if (!statement_node)
{
auto& list = analyzer->scope->scope.stack;
auto& list = analyzer->scope->payload.scope.stack;
u32 i = list.length;
u8 found = 0;
while (i > 0)
@ -3812,7 +3956,9 @@ fn void analyze_function(Parser* parser, Thread* thread, Unit* unit, File* file)
.outputs = {},
.gvn = function_gvn,
.id = Node::Id::SYMBOL_FUNCTION,
.symbol = &function->symbol,
.payload = {
.symbol = &function->symbol,
},
});
parser->skip_space(src);
@ -4150,12 +4296,20 @@ fn void analyze_function(Parser* parser, Thread* thread, Unit* unit, File* file)
root_arg_types.append(abi_argument_types.slice());
Node::Type root_type = { .id = Node::Type::Id::MULTIVALUE, .multi = { .types = root_arg_types.slice(), }, };
Node::Type root_type = {
.id = Node::Type::Id::MULTIVALUE,
.payload = {
.multi = {
.types = root_arg_types.slice(),
},
},
};
function->root_node = Node::add(thread, {
.type = root_type,
.inputs = {},
.id = Node::Id::ROOT,
});
function->root_node->root.args = root_type;
function->root_node->payload.root.args = root_type;
function->root_node->peephole(thread, function);
auto* scope_node = Node::add(thread, {
@ -4163,7 +4317,7 @@ fn void analyze_function(Parser* parser, Thread* thread, Unit* unit, File* file)
.inputs = { .pointer = &function->root_node, .length = 1 },
.id = Node::Id::SCOPE,
});
scope_node->scope.stack = {};
scope_node->payload.scope.stack = {};
Analyzer analyzer = {
.function = function,
.scope = scope_node,
@ -4363,6 +4517,7 @@ String test_file_paths[] = {
strlit("tests/constant_prop/main.nat"),
strlit("tests/simple_variable_declaration/main.nat"),
strlit("tests/function_call_args/main.nat"),
strlit("tests/comparison/main.nat"),
};
#ifdef __linux__

10
tests/comparison/main.nat Normal file
View File

@ -0,0 +1,10 @@
fn foo(arg: s32) s32
{
return arg == 0;
}
fn[cc(.c)] main [export] () s32
{
>arg: s32 = 0;
return foo(arg);
}