diff --git a/bootstrap/main.c b/bootstrap/main.c index a68efbe..a37c26c 100644 --- a/bootstrap/main.c +++ b/bootstrap/main.c @@ -62,6 +62,7 @@ typedef u64 Hash; #define auto __auto_type #define bad_ex(file, line) do { print("Bad exit at {cstr}:{u32}\n", file, line); __builtin_trap(); } while(0) +#define todo() do { print("TODO at {cstr}:{u32}\n", __FILE__, __LINE__); __builtin_trap(); } while(0) may_be_unused fn void print(const char* format, ...); @@ -1783,9 +1784,11 @@ declare_slice(TypeIndex); struct TypeInteger { u64 constant; - u64 bit_count:7; - u64 is_constant:1; - u64 is_signed:1; + u8 bit_count; + u8 is_constant; + u8 is_signed; + u8 padding1; + u32 padding; }; typedef struct TypeInteger TypeInteger; static_assert(sizeof(TypeInteger) == 16); @@ -2035,13 +2038,12 @@ fn void bitset_ensure_length(Bitset* bitset, u64 max) } } -fn void bitset_set_assert_unset(Bitset* bitset, u64 index) +fn void bitset_set_value(Bitset* bitset, u64 index, u8 value) { bitset_ensure_length(bitset, index + 1); auto element_index = index / element_bitsize; auto bit_index = index % element_bitsize; - assert((bitset->arr.pointer[element_index] & (1 << bit_index)) == 0); - bitset->arr.pointer[element_index] |= 1 << bit_index; + bitset->arr.pointer[element_index] |= (!!value) << bit_index; } fn void bitset_clear(Bitset* bitset) @@ -2096,9 +2098,41 @@ struct Thread u64 nop; } iteration; WorkList worklist; + s64 main_function; }; typedef struct Thread Thread; +fn NodeIndex thread_worklist_push(Thread* thread, NodeIndex node_index) +{ + if (validi(node_index)) + { + if (!bitset_get(&thread->worklist.bitset, geti(node_index))) + { + bitset_set_value(&thread->worklist.bitset, geti(node_index), 1); + *vb_add(&thread->worklist.nodes, 1) = node_index; + } + } + + return node_index; +} + +fn NodeIndex thread_worklist_pop(Thread* thread) +{ + auto result = invalidi(Node); + + auto len = thread->worklist.nodes.length; + if (len) + { + auto index = len - 1; + auto node_index = thread->worklist.nodes.pointer[index]; + thread->worklist.nodes.length = index; + bitset_set_value(&thread->worklist.bitset, index, 0); + result = node_index; + } + + return result; +} + fn void thread_worklist_clear(Thread* thread) { bitset_clear(&thread->worklist.visited); @@ -2360,6 +2394,40 @@ struct NodeCreate }; typedef struct NodeCreate NodeCreate; +fn String node_id_to_string(Node* node) +{ + switch (node->id) + { + case_to_name(NODE_, START); + case_to_name(NODE_, STOP); + case_to_name(NODE_, CONTROL_PROJECTION); + case_to_name(NODE_, DEAD_CONTROL); + case_to_name(NODE_, SCOPE); + case_to_name(NODE_, PROJECTION); + case_to_name(NODE_, RETURN); + case_to_name(NODE_, REGION); + case_to_name(NODE_, REGION_LOOP); + case_to_name(NODE_, IF); + case_to_name(NODE_, PHI); + case_to_name(NODE_, INTEGER_ADD); + case_to_name(NODE_, INTEGER_SUBSTRACT); + case_to_name(NODE_, INTEGER_MULTIPLY); + case_to_name(NODE_, INTEGER_UNSIGNED_DIVIDE); + case_to_name(NODE_, INTEGER_SIGNED_DIVIDE); + case_to_name(NODE_, INTEGER_UNSIGNED_REMAINDER); + case_to_name(NODE_, INTEGER_SIGNED_REMAINDER); + case_to_name(NODE_, INTEGER_UNSIGNED_SHIFT_LEFT); + case_to_name(NODE_, INTEGER_SIGNED_SHIFT_LEFT); + case_to_name(NODE_, INTEGER_UNSIGNED_SHIFT_RIGHT); + case_to_name(NODE_, INTEGER_SIGNED_SHIFT_RIGHT); + case_to_name(NODE_, INTEGER_AND); + case_to_name(NODE_, INTEGER_OR); + case_to_name(NODE_, INTEGER_XOR); + case_to_name(NODE_, CONSTANT); + case_to_name(NODE_, COUNT); + } +} + fn NodeIndex thread_node_add(Thread* thread, NodeCreate data) { auto input_result = thread_get_node_reference_array(thread, data.inputs.length); @@ -2373,15 +2441,20 @@ fn NodeIndex thread_node_add(Thread* thread, NodeCreate data) node->input_count = data.inputs.length; node->type = invalidi(Type); + print("[NODE CREATION] #{u32} {s} | INPUTS: { ", node_index.index, node_id_to_string(node)); + for (u32 i = 0; i < data.inputs.length; i += 1) { NodeIndex input = data.inputs.pointer[i]; + print("{u32} ", input.index); if (validi(input)) { node_add_output(thread, input, node_index); } } + print("}\n"); + return node_index; } @@ -2463,11 +2536,6 @@ fn NodeIndex scope_define(Thread* thread, FunctionBuilder* builder, String name, return result; } -fn NodeIndex node_keep(Thread* thread, NodeIndex node_index) -{ - return node_add_output(thread, node_index, invalidi(Node)); -} - fn u8 type_equal(Thread* thread, Type* a, Type* b) { u8 result = 0; @@ -2515,7 +2583,7 @@ fn u8 type_equal(Thread* thread, Type* a, Type* b) } fn Hash hash_type(Thread* thread, Type* type); -fn Hash node_get_hash_default(Thread* thread, Node* node) +fn Hash node_get_hash_default(Thread* thread, Node* node, NodeIndex node_index) { return fnv_offset; } @@ -2525,24 +2593,81 @@ fn Hash node_get_hash_projection(Thread* thread, Node* node) trap(); } -fn Hash node_get_hash_control_projection(Thread* thread, Node* node) +fn Hash node_get_hash_control_projection(Thread* thread, Node* node, NodeIndex node_index) { auto projection_index = node->control_projection.projection.index; auto proj_index_bytes = struct_to_bytes(projection_index); return hash_bytes(proj_index_bytes); } -fn Hash node_get_hash_constant(Thread* thread, Node* node) +fn Hash node_get_hash_constant(Thread* thread, Node* node, NodeIndex node_index) { + auto type_index = node->type; auto* type = thread_type_get(thread, node->type); auto type_hash = hash_type(thread, type); + // print("Hashing node #{u32} (constant) (type: #{u32}) (hash: {u64:x})\n", node_index.index, type_index.index, type_hash); return type_hash; } +struct TypeGetOrPut +{ + TypeIndex index; + u8 existing; +}; + +typedef struct TypeGetOrPut TypeGetOrPut; + +fn TypeGetOrPut intern_pool_get_or_put_new_type(Thread* thread, Type* type); + typedef NodeIndex NodeIdealize(Thread* thread, NodeIndex node_index); typedef TypeIndex NodeComputeType(Thread* thread, NodeIndex node_index); typedef Hash TypeGetHash(Thread* thread, Type* type); -typedef Hash NodeGetHash(Thread* thread, Node* node); +typedef Hash NodeGetHash(Thread* thread, Node* node, NodeIndex node_index); + +fn TypeIndex thread_get_integer_type(Thread* thread, TypeInteger type_integer) +{ + Type type; + memset(&type, 0, sizeof(Type)); + type.integer = type_integer; + type.id = TYPE_INTEGER; + + auto result = intern_pool_get_or_put_new_type(thread, &type); + return result.index; +} + +fn NodeIndex peephole(Thread* thread, Function* function, NodeIndex node_index); +fn NodeIndex constant_int_create_with_type(Thread* thread, Function* function, TypeIndex type_index) +{ + auto node_index = thread_node_add(thread, (NodeCreate){ + .id = NODE_CONSTANT, + .inputs = array_to_slice(((NodeIndex []) { + function->start, + })) + }); + auto* node = thread_node_get(thread, node_index); + + node->constant = (NodeConstant) { + .type = type_index, + }; + + print("Creating constant integer node #{u32} with value: {u64:x}\n", node_index.index, thread_type_get(thread, type_index)->integer.constant); + + auto result = peephole(thread, function, node_index); + return result; +} + +fn NodeIndex constant_int_create(Thread* thread, Function* function, u64 value) +{ + auto type_index = thread_get_integer_type(thread, (TypeInteger){ + .constant = value, + .bit_count = 0, + .is_constant = 1, + .is_signed = 0, + }); + + auto constant_int = constant_int_create_with_type(thread, function, type_index); + return constant_int; +} struct NodeVirtualTable { @@ -2573,11 +2698,19 @@ fn TypeIndex compute_type_constant(Thread* thread, NodeIndex node_index) fn Hash type_get_hash_default(Thread* thread, Type* type) { + assert(!type->hash); Hash hash = fnv_offset; + + u32 i = 0; for (auto* it = (u8*)type; it < (u8*)(type + 1); it += 1) { hash = hash_byte(hash, *it); + if (type->id == TYPE_INTEGER) + { + // print("Byte [{u32}] = 0x{u32:x}\n", i, (u32)*it); + i += 1; + } } return hash; @@ -2723,6 +2856,14 @@ fn TypeIndex intern_pool_put_new_type_at_assume_not_existent_assume_capacity(Thr auto* result = vb_add(&thread->buffer.types, 1); auto buffer_index = result - thread->buffer.types.pointer; auto type_index = Index(Type, buffer_index); + if (type->id == TYPE_INTEGER) + { + auto diff = (u8*)(&type->hash) - (u8*)(&type->integer.is_signed + 1); + for (u32 i = 0; i < diff; i += 1) + { + assert(*(&type->integer.is_signed + 1 + i) == 0); + } + } *result = *type; thread->interned.types.pointer[index] = *(u32*)&type_index; @@ -2745,13 +2886,6 @@ fn TypeIndex intern_pool_put_new_type_assume_not_existent(Thread* thread, Hash h return intern_pool_put_new_type_assume_not_existent_assume_capacity(thread, hash, type); } -struct TypeGetOrPut -{ - TypeIndex index; - u8 existing; -}; -typedef struct TypeGetOrPut TypeGetOrPut; - fn s32 intern_pool_find_type_slot(Thread* thread, u32 original_index, Type* type) { auto it_index = original_index; @@ -2929,6 +3063,71 @@ fn TypeIndex compute_type_start(Thread* thread, NodeIndex node_index) return node->start.arguments; } +fn u8 type_is_constant(Type* type) +{ + switch (type->id) + { + case TYPE_INTEGER: + return type->integer.is_constant; + default: + return 0; + } +} + +fn u8 node_is_unused(Node* node) +{ + return node->output_count == 0; +} + +fn u8 node_is_dead(Node* node) +{ + return node_is_unused(node) & ((node->input_count == 0) & (!validi(node->type))); +} + +fn TypeIndex compute_type_integer_binary(Thread* thread, NodeIndex node_index) +{ + auto* node = thread_node_get(thread, node_index); + auto inputs = node_get_inputs(thread, node); + auto* left = thread_node_get(thread, inputs.pointer[1]); + auto* right = thread_node_get(thread, inputs.pointer[2]); + assert(!node_is_dead(left)); + assert(!node_is_dead(right)); + auto* left_type = thread_type_get(thread, left->type); + auto* right_type = thread_type_get(thread, right->type); + + if (((left_type->id == TYPE_INTEGER) & (right_type->id == TYPE_INTEGER)) & (type_is_constant(left_type) & type_is_constant(right_type))) + { + auto left_value = left_type->integer.constant; + auto right_value = right_type->integer.constant; + assert(left_type->integer.bit_count == 0); + assert(right_type->integer.bit_count == 0); + assert(!left_type->integer.is_signed); + assert(!right_type->integer.is_signed); + + u64 result; + TypeInteger type_integer = left_type->integer; + + switch (node->id) + { + case NODE_INTEGER_ADD: + result = left_value + right_value; + break; + case NODE_INTEGER_SUBSTRACT: + result = left_value - right_value; + break; + default: + trap(); + } + type_integer.constant = result; + auto new_type = thread_get_integer_type(thread, type_integer); + return new_type; + } + else + { + trap(); + } +} + global const TypeVirtualTable type_functions[TYPE_COUNT] = { [TYPE_BOTTOM] = { .get_hash = &type_get_hash_default }, [TYPE_TOP] = { .get_hash = &type_get_hash_default }, @@ -2964,6 +3163,16 @@ global const NodeVirtualTable node_functions[NODE_COUNT] = { .idealize = &idealize_return, .get_hash = &node_get_hash_default, }, + + // Integer operations + [NODE_INTEGER_ADD] = { + .compute_type = &compute_type_integer_binary, + }, + [NODE_INTEGER_SUBSTRACT] = { + .compute_type = &compute_type_integer_binary, + }, + + // Constant [NODE_CONSTANT] = { .compute_type = &compute_type_constant, .idealize = &idealize_null, @@ -2971,6 +3180,20 @@ global const NodeVirtualTable node_functions[NODE_COUNT] = { }, }; +fn String type_id_to_string(Type* type) +{ + switch (type->id) + { + case_to_name(TYPE_, BOTTOM); + case_to_name(TYPE_, TOP); + case_to_name(TYPE_, LIVE_CONTROL); + case_to_name(TYPE_, DEAD_CONTROL); + case_to_name(TYPE_, INTEGER); + case_to_name(TYPE_, TUPLE); + case_to_name(TYPE_, COUNT); + } +} + fn Hash hash_type(Thread* thread, Type* type) { Hash hash = type->hash; @@ -2978,6 +3201,7 @@ fn Hash hash_type(Thread* thread, Type* type) if (!hash) { hash = type_functions[type->id].get_hash(thread, type); + // print("Hashing type id {s}: {u64:x}\n", type_id_to_string(type), hash); } assert(hash != 0); @@ -3016,6 +3240,64 @@ struct NodeGetOrPut }; typedef struct NodeGetOrPut NodeGetOrPut; +// This assumes the indices are not equal +fn u8 node_equal(Thread* thread, Node* a, Node* b) +{ + u8 result = 0; + assert(a != b); + assert(a->hash); + assert(b->hash); + + if (((a->id == b->id) & (a->hash == b->hash)) & (a->input_count == b->input_count)) + { + auto inputs_a = node_get_inputs(thread, a); + auto inputs_b = node_get_inputs(thread, b); + result = 1; + + for (u16 i = 0; i < a->input_count; i += 1) + { + if (!index_equal(inputs_a.pointer[i], inputs_b.pointer[i])) + { + result = 0; + break; + } + } + + if (result) + { + switch (a->id) + { + case NODE_CONSTANT: + result = index_equal(a->constant.type, b->constant.type); + break; + default: + trap(); + } + } + } + + return result; +} + + +fn u8 node_index_equal(Thread* thread, NodeIndex a, NodeIndex b) +{ + u8 result = 0; + if (index_equal(a, b)) + { + result = 1; + } + else + { + auto* node_a = thread_node_get(thread, a); + auto* node_b = thread_node_get(thread, b); + assert(node_a != node_b); + result = node_equal(thread, node_a, node_b); + } + + return result; +} + fn s32 intern_pool_find_node_slot(Thread* thread, u32 original_index, NodeIndex node_index) { assert(validi(node_index)); @@ -3037,19 +3319,12 @@ fn s32 intern_pool_find_node_slot(Thread* thread, u32 original_index, NodeIndex else { NodeIndex existing_node_index = *(NodeIndex*)&key; - if (index_equal(existing_node_index, node_index)) + // Exhaustive comparation, shortcircuit when possible + if (node_index_equal(thread, existing_node_index, node_index)) { result = index; break; } - else - { - auto* existing_node = &thread->buffer.nodes.pointer[geti(existing_node_index)]; - if (existing_node->id == node->id) - { - trap(); - } - } } it_index += 1; @@ -3058,15 +3333,17 @@ fn s32 intern_pool_find_node_slot(Thread* thread, u32 original_index, NodeIndex return result; } -fn Hash hash_node(Thread* thread, Node* node) +fn Hash hash_node(Thread* thread, Node* node, NodeIndex node_index) { auto hash = node->hash; if (!hash) { - hash = node_functions[node->id].get_hash(thread, node); + hash = node_functions[node->id].get_hash(thread, node, node_index); + // print("[HASH #{u32}] Received hash from callback: {u64:x}\n", node_index.index, hash); hash = hash_byte(hash, node->id); + auto inputs = node_get_inputs(thread, node); for (u32 i = 0; i < inputs.length; i += 1) { @@ -3080,6 +3357,8 @@ fn Hash hash_node(Thread* thread, Node* node) } } + // print("[HASH] Node #{u32}, {s}: {u64:x}\n", node_index.index, node_id_to_string(node), hash); + node->hash = hash; } @@ -3092,17 +3371,25 @@ fn NodeGetOrPut intern_pool_get_or_put_node(Thread* thread, NodeIndex node_index { auto existing_capacity = thread->interned.nodes.capacity; auto* node = &thread->buffer.nodes.pointer[geti(node_index)]; - auto hash = hash_node(thread, node); + auto hash = hash_node(thread, node, node_index); auto original_index = hash & (existing_capacity - 1); auto slot = intern_pool_find_node_slot(thread, original_index, node_index); if (slot != -1) { u32 index = slot; - u8 existing = thread->interned.nodes.pointer[index] != 0; - auto result = intern_pool_put_node_at_assume_not_existent_assume_capacity(thread, node_index, index); + auto* existing_ptr = &thread->interned.nodes.pointer[index]; + NodeIndex existing_value = *(NodeIndex*)existing_ptr; + u8 existing = validi(existing_value); + NodeIndex new_value = existing_value; + if (!existing) + { + new_value = intern_pool_put_node_at_assume_not_existent_assume_capacity(thread, node_index, index); + assert(!index_equal(new_value, existing_value)); + assert(index_equal(new_value, node_index)); + } return (NodeGetOrPut) { - .index = result, + .index = new_value, .existing = existing, }; } @@ -3131,7 +3418,7 @@ fn NodeIndex intern_pool_remove_node(Thread* thread, NodeIndex node_index) { auto existing_capacity = thread->interned.nodes.capacity; auto* node = thread_node_get(thread, node_index); - auto hash = hash_node(thread, node); + auto hash = hash_node(thread, node, node_index); auto original_index = hash & (existing_capacity - 1); auto slot = intern_pool_find_node_slot(thread, original_index, node_index); @@ -3154,7 +3441,7 @@ fn NodeIndex intern_pool_remove_node(Thread* thread, NodeIndex node_index) } auto existing_node_index = *(NodeIndex*)&existing; auto* existing_node = thread_node_get(thread, existing_node_index); - auto existing_node_hash = hash_node(thread, existing_node); + auto existing_node_hash = hash_node(thread, existing_node, existing_node_index); auto existing_index = existing_node_hash & (existing_capacity - 1); if (slot_index <= existing_index) @@ -3451,31 +3738,6 @@ fn u8 type_is_a(Thread* thread, TypeIndex a, TypeIndex b) auto m = type_meet(thread, a, b); return index_equal(m, b); } - - -fn void set_type(Thread* thread, Node* node, TypeIndex new_type) -{ - auto old_type = node->type; - assert(!validi(old_type) || type_is_a(thread, new_type, old_type)); - if (!index_equal(old_type, new_type)) - { - node->type = new_type; - auto outputs = node_get_outputs(thread, node); - thread_add_jobs(thread, outputs); - move_dependencies_to_worklist(thread, node); - } -} - -fn u8 node_is_unused(Node* node) -{ - return node->output_count == 0; -} - -fn u8 node_is_dead(Node* node) -{ - return node_is_unused(node) & ((node->input_count == 0) & (!validi(node->type))); -} - union NodePair { struct @@ -3487,6 +3749,49 @@ union NodePair }; typedef union NodePair NodePair; +fn NodeIndex node_keep(Thread* thread, NodeIndex node_index) +{ + return node_add_output(thread, node_index, invalidi(Node)); +} + +fn NodeIndex node_unkeep(Thread* thread, NodeIndex node_index) +{ + node_remove_output(thread, node_index, invalidi(Node)); + return node_index; +} + +fn void node_kill(Thread* thread, NodeIndex node_index) +{ + node_unlock(thread, node_index); + auto* node = thread_node_get(thread, node_index); + print("[NODE KILLING] (#{u32}, {s}) START\n", node_index.index, node_id_to_string(node)); + assert(node_is_unused(node)); + node->type = invalidi(Type); + + auto inputs = node_get_inputs(thread, node); + while (node->input_count > 0) + { + auto input_index = node->input_count - 1; + node->input_count = input_index; + auto old_input_index = inputs.pointer[input_index]; + + print("[NODE KILLING] (#{u32}, {s}) Removing input #{u32} at slot {u32}\n", node_index.index, node_id_to_string(node), old_input_index.index, input_index); + if (validi(old_input_index)) + { + thread_worklist_push(thread, old_input_index); + u8 no_more_outputs = node_remove_output(thread, old_input_index, node_index); + if (no_more_outputs) + { + print("[NODE KILLING] (#{u32}, {s}) (NO MORE OUTPUTS - KILLING) Input #{u32}\n", node_index.index, node_id_to_string(node), old_input_index.index); + node_kill(thread, old_input_index); + } + } + } + + assert(node_is_dead(node)); + // print("[NODE KILLING] (#{u32}, {s}) END\n", node_index.index, node_id_to_string(node)); +} + fn NodeIndex dead_code_elimination(Thread* thread, NodePair nodes) { NodeIndex old = nodes.old; @@ -3494,10 +3799,13 @@ fn NodeIndex dead_code_elimination(Thread* thread, NodePair nodes) if (!index_equal(old, new)) { + print("[DCE] old: #{u32} != new: #{u32}. Proceeding to eliminate\n", old.index, new.index); auto* old_node = thread_node_get(thread, old); if (node_is_unused(old_node) & !node_is_dead(old_node)) { - trap(); + node_keep(thread, new); + node_kill(thread, old); + node_unkeep(thread, new); } } @@ -3523,34 +3831,83 @@ fn u8 type_is_high_or_const(Thread* thread, TypeIndex type_index) return result; } +fn TypeIndex type_join(Thread* thread, TypeIndex a, TypeIndex b) +{ + TypeIndex result; + if (index_equal(a, b)) + { + result = a; + } + else + { + trap(); + } + + return result; +} + +fn void node_set_type(Thread* thread, Node* node, TypeIndex new_type) +{ + auto old_type = node->type; + assert(!validi(old_type) || type_is_a(thread, new_type, old_type)); + if (!index_equal(old_type, new_type)) + { + node->type = new_type; + auto outputs = node_get_outputs(thread, node); + thread_add_jobs(thread, outputs); + move_dependencies_to_worklist(thread, node); + } +} + global auto enable_peephole = 1; -fn NodeIndex peephole_optimize(Thread* thread, NodeIndex node_index) +fn NodeIndex peephole_optimize(Thread* thread, Function* function, NodeIndex node_index) { assert(enable_peephole); auto result = node_index; auto* node = thread_node_get(thread, node_index); + print("Peepholing node #{u32} ({s})\n", node_index.index, node_id_to_string(node)); auto old_type = node->type; auto new_type = node_functions[node->id].compute_type(thread, node_index); if (enable_peephole) { thread->iteration.total += 1; - set_type(thread, node, new_type); + node_set_type(thread, node, new_type); if (node->id != NODE_CONSTANT && node->id != NODE_DEAD_CONTROL && type_is_high_or_const(thread, node->type)) { - trap(); + if (index_equal(node->type, thread->types.dead_control)) + { + trap(); + } + else + { + auto constant_node = constant_int_create_with_type(thread, function, node->type); + return constant_node; + } } auto idealize = 1; if (!node->hash) { auto gop = intern_pool_get_or_put_node(thread, node_index); - if (gop.existing) + idealize = !gop.existing; + + if (gop.existing) { - idealize = 0; - trap(); + auto interned_node_index = gop.index; + auto* interned_node = thread_node_get(thread, interned_node_index); + auto new_type = type_join(thread, interned_node->type, node->type); + node_set_type(thread, interned_node, new_type); + node->hash = 0; + print("[peephole_optimize] Eliminating #{u32} because an existing node was found: #{u32}\n", node_index.index, interned_node_index.index); + auto dce_node = dead_code_elimination(thread, (NodePair) { + .old = node_index, + .new = interned_node_index, + }); + + result = dce_node; } } @@ -3578,15 +3935,16 @@ fn NodeIndex peephole_optimize(Thread* thread, NodeIndex node_index) return result; } -fn NodeIndex peephole(Thread* thread, NodeIndex node_index) +fn NodeIndex peephole(Thread* thread, Function* function, NodeIndex node_index) { NodeIndex result; if (enable_peephole) { - NodeIndex new_node = peephole_optimize(thread, node_index); + NodeIndex new_node = peephole_optimize(thread, function, node_index); if (validi(new_node)) { - NodeIndex peephole_new_node = peephole(thread, new_node); + NodeIndex peephole_new_node = peephole(thread, function, new_node); + print("[peephole] Eliminating #{u32} because a better node was found: #{u32}\n", node_index.index, new_node.index); auto dce_node = dead_code_elimination(thread, (NodePair) { .old = node_index, @@ -3611,16 +3969,6 @@ fn NodeIndex peephole(Thread* thread, NodeIndex node_index) return result; } -fn TypeIndex thread_get_integer_type(Thread* thread, TypeInteger type_integer) -{ - Type type; - memset(&type, 0, sizeof(Type)); - type.integer = type_integer; - type.id = TYPE_INTEGER; - - auto result = intern_pool_get_or_put_new_type(thread, &type); - return result.index; -} fn TypeIndex analyze_type(Thread* thread, Parser* parser, String src) { @@ -3732,29 +4080,8 @@ fn TypeIndex analyze_type(Thread* thread, Parser* parser, String src) trap(); } -fn NodeIndex constant_int_create(Thread* thread, Function* function, u64 value) -{ - auto node_index = thread_node_add(thread, (NodeCreate){ - .id = NODE_CONSTANT, - .inputs = array_to_slice(((NodeIndex []) { - function->start, - })) - }); - auto* node = thread_node_get(thread, node_index); - - auto type_index = thread_get_integer_type(thread, (TypeInteger){ - .constant = value, - .bit_count = 0, - .is_constant = 1, - .is_signed = 0, - }); - node->constant = (NodeConstant) { - .type = type_index, - }; - - auto result = peephole(thread, node_index); - return result; -} +fn NodeIndex analyze_addition(Thread* thread, Parser* parser, Function* function, String src); +fn NodeIndex analyze_multiplication(Thread* thread, Parser* parser, Function* function, String src); fn NodeIndex analyze_primary_expression(Thread* thread, Parser* parser, Function* function, String src) { @@ -3835,9 +4162,93 @@ fn NodeIndex analyze_primary_expression(Thread* thread, Parser* parser, Function } } +fn NodeIndex analyze_unary(Thread* thread, Parser* parser, Function* function, String src) +{ + // TODO: postfix + switch (src.pointer[parser->i]) + { + case '-': + trap(); + case '!': + trap(); + default: + { + auto expression = analyze_primary_expression(thread, parser, function, src); + return expression; + } + } + trap(); +} + +fn NodeIndex analyze_addition(Thread* thread, Parser* parser, Function* function, String src) +{ + auto left = analyze_unary(thread, parser, function, src); + + while (1) + { + skip_space(parser, src); + + NodeId node_id; + auto skip_count = 1; + + switch (src.pointer[parser->i]) + { + case '+': + node_id = NODE_INTEGER_ADD; + break; + case '-': + node_id = NODE_INTEGER_SUBSTRACT; + break; + default: + node_id = NODE_COUNT; + break; + } + + if (node_id == NODE_COUNT) + { + break; + } + + parser->i += skip_count; + skip_space(parser, src); + + auto new_node_index = thread_node_add(thread, (NodeCreate) { + .id = node_id, + .inputs = array_to_slice(((NodeIndex[]) { + invalidi(Node), + left, + invalidi(Node), + })), + }); + + print("Before right: LEFT is #{u32}\n", left.index); + print("Left code:\n```\n{s}\n```\n", s_get_slice(u8, src, parser->i, src.length)); + auto right = analyze_multiplication(thread, parser, function, src); + print("Addition: left: #{u32}, right: #{u32}\n", left.index, right.index); + print("Left code:\n```\n{s}\n```\n", s_get_slice(u8, src, parser->i, src.length)); + + node_set_input(thread, new_node_index, 2, right); + + print("Addition new node #{u32}\n", new_node_index.index); + print("Left code:\n```\n{s}\n```\n", s_get_slice(u8, src, parser->i, src.length)); + + left = peephole(thread, function, new_node_index); + } + + print("Analyze addition returned node #{u32}\n", left.index); + + return left; +} + +fn NodeIndex analyze_multiplication(Thread* thread, Parser* parser, Function* function, String src) +{ + // TODO: + return analyze_addition(thread, parser, function, src); +} + fn NodeIndex analyze_expression(Thread* thread, Parser* parser, Function* function, String src, TypeIndex result_type) { - NodeIndex result = analyze_primary_expression(thread, parser, function, src); + NodeIndex result = analyze_addition(thread, parser, function, src); return result; } @@ -3883,6 +4294,10 @@ fn void analyze_file(Thread* thread, File* file) builder->function = function; function->name = parse_identifier(parser, src); + if (s_equal(function->name, strlit("main"))) + { + thread->main_function = thread->buffer.functions.length - 1; + } skip_space(parser, src); @@ -3942,7 +4357,7 @@ fn void analyze_file(Thread* thread, File* file) .id = NODE_DEAD_CONTROL, .inputs = { .pointer = &function->start, .length = 1 }, }); - dead_control = peephole(thread, dead_control); + dead_control = peephole(thread, function, dead_control); dead_control = node_keep(thread, dead_control); @@ -3971,7 +4386,7 @@ fn void analyze_file(Thread* thread, File* file) .label = control_name, .index = 0, }; - control_node_index = peephole(thread, control_node_index); + control_node_index = peephole(thread, function, control_node_index); scope_define(thread, builder, control_name, thread->types.live_control, control_node_index); } @@ -4015,7 +4430,7 @@ fn void analyze_file(Thread* thread, File* file) // trap(); } - return_node_index = peephole(thread, return_node_index); + return_node_index = peephole(thread, function, return_node_index); builder_add_return(thread, builder, return_node_index); @@ -4036,7 +4451,7 @@ fn void analyze_file(Thread* thread, File* file) parser->i += 1; scope_pop(thread, builder); - function->stop = peephole(thread, function->stop); + function->stop = peephole(thread, function, function->stop); } else { @@ -4050,9 +4465,9 @@ fn void analyze_file(Thread* thread, File* file) } } -typedef NodeIndex NodeCallback(Thread* thread, NodeIndex node_index); +typedef NodeIndex NodeCallback(Thread* thread, Function* function, NodeIndex node_index); -fn NodeIndex node_walk_internal(Thread* thread, NodeIndex node_index, NodeCallback* callback) +fn NodeIndex node_walk_internal(Thread* thread, Function* function, NodeIndex node_index, NodeCallback* callback) { if (bitset_get(&thread->worklist.visited, geti(node_index))) { @@ -4060,8 +4475,8 @@ fn NodeIndex node_walk_internal(Thread* thread, NodeIndex node_index, NodeCallba } else { - bitset_set_assert_unset(&thread->worklist.visited, geti(node_index)); - auto callback_result = callback(thread, node_index); + bitset_set_value(&thread->worklist.visited, geti(node_index), 1); + auto callback_result = callback(thread, function, node_index); if (validi(callback_result)) { return callback_result; @@ -4076,7 +4491,7 @@ fn NodeIndex node_walk_internal(Thread* thread, NodeIndex node_index, NodeCallba auto n = inputs.pointer[i]; if (validi(n)) { - auto n_result = node_walk_internal(thread, n, callback); + auto n_result = node_walk_internal(thread, function, n, callback); if (validi(n_result)) { return n_result; @@ -4089,7 +4504,7 @@ fn NodeIndex node_walk_internal(Thread* thread, NodeIndex node_index, NodeCallba auto n = outputs.pointer[i]; if (validi(n)) { - auto n_result = node_walk_internal(thread, n, callback); + auto n_result = node_walk_internal(thread, function, n, callback); if (validi(n_result)) { return n_result; @@ -4101,44 +4516,61 @@ fn NodeIndex node_walk_internal(Thread* thread, NodeIndex node_index, NodeCallba } } -fn NodeIndex node_walk(Thread* thread, NodeIndex node_index, NodeCallback* callback) +fn NodeIndex node_walk(Thread* thread, Function* function, NodeIndex node_index, NodeCallback* callback) { assert(thread->worklist.visited.length == 0); - NodeIndex result = node_walk_internal(thread, node_index, callback); + NodeIndex result = node_walk_internal(thread, function, node_index, callback); bitset_clear(&thread->worklist.visited); return result; } -fn NodeIndex progress_on_list_callback(Thread* thread, NodeIndex node_index) +fn NodeIndex progress_on_list_callback(Thread* thread, Function* function, NodeIndex node_index) { if (bitset_get(&thread->worklist.bitset, geti(node_index))) { - trap(); + return invalidi(Node); } else { - NodeIndex new_node = peephole_optimize(thread, node_index); + NodeIndex new_node = peephole_optimize(thread, function, node_index); return new_node; } } -fn u8 progress_on_list(Thread* thread, NodeIndex stop_node) +fn u8 progress_on_list(Thread* thread, Function* function, NodeIndex stop_node) { thread->worklist.mid_assert = 1; - NodeIndex changed = node_walk(thread, stop_node, &progress_on_list_callback); + NodeIndex changed = node_walk(thread, function, stop_node, &progress_on_list_callback); thread->worklist.mid_assert = 0; return !validi(changed); } -fn void iterate_peepholes(Thread* thread, NodeIndex stop_node_index) +fn void iterate_peepholes(Thread* thread, Function* function, NodeIndex stop_node_index) { - assert(progress_on_list(thread, stop_node_index)); + assert(progress_on_list(thread, function, stop_node_index)); if (thread->worklist.nodes.length > 0) { - trap(); + while (1) + { + auto node_index = thread_worklist_pop(thread); + if (!validi(node_index)) + { + break; + } + + auto* node = thread_node_get(thread, node_index); + if (!node_is_dead(node)) + { + auto new_node_index = peephole_optimize(thread, function, node_index); + if (validi(new_node_index)) + { + trap(); + } + } + } } thread_worklist_clear(thread); @@ -4166,7 +4598,7 @@ fn void rpo_cfg(Thread* thread, NodeIndex node_index) auto* node = thread_node_get(thread, node_index); if (node_is_cfg(node) && !bitset_get(&thread->worklist.visited, geti(node_index))) { - bitset_set_assert_unset(&thread->worklist.visited, geti(node_index)); + bitset_set_value(&thread->worklist.visited, geti(node_index), 1); auto outputs = node_get_outputs(thread, node); for (u64 i = 0; i < outputs.length; i += 1) { @@ -4262,7 +4694,7 @@ fn void schedule_early(Thread* thread, NodeIndex node_index, NodeIndex start_nod { if (validi(node_index) && !bitset_get(&thread->worklist.visited, geti(node_index))) { - bitset_set_assert_unset(&thread->worklist.visited, geti(node_index)); + bitset_set_value(&thread->worklist.visited, geti(node_index), 1); auto* node = thread_node_get(thread, node_index); auto inputs = node_get_inputs(thread, node); for (u64 i = 0; i < inputs.length; i += 1) @@ -4431,54 +4863,6 @@ fn void gcm_build_cfg(Thread* thread, NodeIndex start_node_index, NodeIndex stop } } -fn String node_id_to_string(Node* node) -{ - switch (node->id) - { - case_to_name(NODE_, START); - case_to_name(NODE_, STOP); - case_to_name(NODE_, CONTROL_PROJECTION); - case_to_name(NODE_, DEAD_CONTROL); - case_to_name(NODE_, SCOPE); - case_to_name(NODE_, PROJECTION); - case_to_name(NODE_, RETURN); - case_to_name(NODE_, REGION); - case_to_name(NODE_, REGION_LOOP); - case_to_name(NODE_, IF); - case_to_name(NODE_, PHI); - case_to_name(NODE_, INTEGER_ADD); - case_to_name(NODE_, INTEGER_SUBSTRACT); - case_to_name(NODE_, INTEGER_MULTIPLY); - case_to_name(NODE_, INTEGER_UNSIGNED_DIVIDE); - case_to_name(NODE_, INTEGER_SIGNED_DIVIDE); - case_to_name(NODE_, INTEGER_UNSIGNED_REMAINDER); - case_to_name(NODE_, INTEGER_SIGNED_REMAINDER); - case_to_name(NODE_, INTEGER_UNSIGNED_SHIFT_LEFT); - case_to_name(NODE_, INTEGER_SIGNED_SHIFT_LEFT); - case_to_name(NODE_, INTEGER_UNSIGNED_SHIFT_RIGHT); - case_to_name(NODE_, INTEGER_SIGNED_SHIFT_RIGHT); - case_to_name(NODE_, INTEGER_AND); - case_to_name(NODE_, INTEGER_OR); - case_to_name(NODE_, INTEGER_XOR); - case_to_name(NODE_, CONSTANT); - case_to_name(NODE_, COUNT); - } -} - -fn String type_id_to_string(Type* type) -{ - switch (type->id) - { - case_to_name(TYPE_, BOTTOM); - case_to_name(TYPE_, TOP); - case_to_name(TYPE_, LIVE_CONTROL); - case_to_name(TYPE_, DEAD_CONTROL); - case_to_name(TYPE_, INTEGER); - case_to_name(TYPE_, TUPLE); - case_to_name(TYPE_, COUNT); - } -} - fn void print_function(Thread* thread, Function* function) { print("fn {s}\n====\n", function->name); @@ -4695,6 +5079,7 @@ fn void thread_init(Thread* thread) { memset(thread, 0, sizeof(Thread)); thread->arena = arena_init_default(KB(64)); + thread->main_function = -1; Type top, bot, live_control, dead_control; memset(&top, 0, sizeof(Type)); @@ -4749,19 +5134,98 @@ fn void unit_tests() Slice(String) arguments; +typedef enum ExecutionEngine : u8 +{ + EXECUTION_ENGINE_C = 'c', + EXECUTION_ENGINE_INTERPRETER = 'i', +} ExecutionEngine; + +struct Interpreter +{ + Function* function; +}; +typedef struct Interpreter Interpreter; + +fn Interpreter* interpreter_create(Thread* thread) +{ + auto* interpreter = arena_allocate(thread->arena, Interpreter, 1); + *interpreter = (Interpreter){}; + return interpreter; +} + +fn s32 interpreter_run(Interpreter* interpreter, Thread* thread) +{ + Function* function = interpreter->function; + auto start_node_index = function->start; + auto* start_node = thread_node_get(thread, start_node_index); + assert(start_node->output_count > 0); + auto stop_node_index = function->stop; + + auto proj_node_index = node_output_get(thread, start_node, 1); + auto it_node_index = proj_node_index; + auto current_statement_margin = 1; + + s32 result = -1; + + while (!index_equal(it_node_index, stop_node_index)) + { + auto* it_node = thread_node_get(thread, it_node_index); + auto outputs = node_get_outputs(thread, it_node); + auto inputs = node_get_inputs(thread, it_node); + + switch (it_node->id) + { + case NODE_CONTROL_PROJECTION: + break; + case NODE_RETURN: + { + auto return_value = thread_node_get(thread, inputs.pointer[1]); + if (return_value->id == NODE_CONSTANT) + { + auto constant_type_index = return_value->constant.type; + auto* constant_type = thread_type_get(thread, constant_type_index); + switch (constant_type->id) + { + case TYPE_INTEGER: + { + assert(constant_type->integer.is_constant); + result = constant_type->integer.constant; + } break; + default: + trap(); + } + } + else + { + trap(); + } + } break; + case NODE_STOP: + break; + default: + trap(); + } + + assert(outputs.length == 1); + it_node_index = outputs.pointer[0]; + } + + return result; +} + #if LINK_LIBC int main(int argc, const char* argv[], char* envp[]) { #else void entry_point(int argc, const char* argv[]) { - char** envp = &argv[argc + 1]; + char** envp = (char**)&argv[argc + 1]; #endif #if DO_UNIT_TESTS unit_tests(); #endif - if (argc < 2) + if (argc < 3) { fail(); } @@ -4782,6 +5246,7 @@ void entry_point(int argc, const char* argv[]) } String source_file_path = arguments.pointer[1]; + ExecutionEngine execution_engine = arguments.pointer[2].pointer[0]; Thread* thread = arena_allocate(global_arena, Thread, 1); thread_init(thread); @@ -4801,7 +5266,7 @@ void entry_point(int argc, const char* argv[]) Function* function = &thread->buffer.functions.pointer[function_i]; NodeIndex start_node_index = function->start; NodeIndex stop_node_index = function->stop; - iterate_peepholes(thread, stop_node_index); + iterate_peepholes(thread, function, stop_node_index); // print_string(strlit("Before optimizations\n")); // print_function(thread, function); gcm_build_cfg(thread, start_node_index, stop_node_index); @@ -4809,6 +5274,11 @@ void entry_point(int argc, const char* argv[]) // print_function(thread, function); } + if (thread->main_function == -1) + { + fail(); + } + auto lowered_source = c_lower(thread); // print("Transpiled to C:\n```\n{s}\n```\n", lowered_source); @@ -4826,14 +5296,29 @@ void entry_point(int argc, const char* argv[]) exe_path[exe_path_view.length] = 0; auto command = (char*[]) { - "clang", "-g", + "/usr/bin/cc", "-g", "-o", exe_path, (char*)c_source_path.pointer, 0, }; - int res = syscall_execve("/usr/bin/clang", command, envp); - assert(0); + switch (execution_engine) + { + case EXECUTION_ENGINE_C: + { + int res = syscall_execve("/usr/bin/cc", command, envp); + assert(0); + } break; + case EXECUTION_ENGINE_INTERPRETER: + { + auto* main_function = &thread->buffer.functions.pointer[thread->main_function]; + auto* interpreter = interpreter_create(thread); + interpreter->function = main_function; + auto exit_code = interpreter_run(interpreter, thread); + print("Interpreter exited with exit code: {u32}\n", exit_code); + syscall_exit(exit_code); + } break; + } thread_clear(thread); #if LINK_LIBC == 0 @@ -4844,7 +5329,7 @@ void entry_point(int argc, const char* argv[]) #if LINK_LIBC == 0 [[gnu::naked]] [[noreturn]] void _start() { - asm( + __asm__ __volatile__( "\nxor %ebp, %ebp" "\npopq %rdi" "\nmov %rsp, %rsi" diff --git a/debug.sh b/debug.sh index fdc2730..c6c935c 100755 --- a/debug.sh +++ b/debug.sh @@ -15,7 +15,7 @@ exe_name="nest" exe_path=$build_dir/$exe_name debug_flags="-g" optimization_flags="-O0" -bootstrap_args="$path" +bootstrap_args="$path i" case "$OSTYPE" in darwin*) static=0;; linux*) static=1;; @@ -26,6 +26,6 @@ compile $build_dir $exe_name $debug_flags $optimization_flags $static case "$OSTYPE" in darwin*) lldb -- $exe_path $bootstrap_args;; - linux*) gf2 -ex b entry_point -ex r --args $exe_path $bootstrap_args;; + linux*) gf2 -ex "set auto-solib-add off" -ex "r" --args $exe_path $bootstrap_args;; *) echo "unknown: $OSTYPE" ;; esac diff --git a/run_tests.sh b/run_tests.sh index 1ab2bfa..f4c8363 100755 --- a/run_tests.sh +++ b/run_tests.sh @@ -8,7 +8,7 @@ build_dir="build" base_exe_name="nest" debug_flags="-g" no_optimization_flags="" -test_names="first" +test_names=("first" "add_sub") if [ "$all" == "1" ] then @@ -18,6 +18,7 @@ then linux*) linking_modes=("0" "1") ;; *) echo "unknown: $OSTYPE"; exit 1 ;; esac + execution_engines=("c", "i") else optimization_modes=("-O0") case "$OSTYPE" in @@ -25,13 +26,13 @@ else linux*) linking_modes=("1") ;; *) echo "unknown: $OSTYPE"; exit 1 ;; esac + execution_engines=("i") fi for linking_mode in "${linking_modes[@]}" do for optimization_mode in "${optimization_modes[@]}" do - printf "\n===========================\n" echo "TESTS (STATIC=$linking_mode, $optimization_mode)" printf "===========================\n\n" @@ -48,16 +49,25 @@ do printf "\n===========================\n" echo "$test_name..." printf "===========================\n\n" - cmd="build/$exe_name tests/$test_name.nat" - echo "Run command: $cmd" - eval "$cmd" - printf "\n===========================\n" - echo "$test_name [COMPILATION] [OK]" - printf "===========================\n\n" - nest/$test_name - printf "\n===========================\n" - echo "$test_name [RUN] [OK]" - printf "===========================\n\n" + + for execution_engine in "${execution_engines[@]}" + do + cmd="build/$exe_name tests/$test_name.nat $execution_engine" + echo "Run command: $cmd" + eval "$cmd" + printf "\n===========================\n" + echo "$test_name [COMPILATION] [EXECUTION ENGINE: $execution_engine] [OK]" + printf "===========================\n\n" + + if [ "$execution_engine" != "i" ] + then + nest/$test_name + fi + + printf "\n===========================\n" + echo "$test_name [RUN] [OK]" + printf "===========================\n\n" + done done done done diff --git a/tests/add_sub.nat b/tests/add_sub.nat new file mode 100644 index 0000000..2f9ef00 --- /dev/null +++ b/tests/add_sub.nat @@ -0,0 +1,4 @@ +fn main() s32 +{ + return 1 - 1 + 1 - 1; +}