From 0ee4e907b60053cb3d463f6a227f5ed2fa55fcaf Mon Sep 17 00:00:00 2001 From: David Gonzalez Martin Date: Fri, 7 Jun 2024 13:08:45 -0600 Subject: [PATCH] Implement function pointer --- bootstrap/compiler.zig | 330 +++++++++++++------- retest/standalone/function_pointer/main.nat | 10 + 2 files changed, 235 insertions(+), 105 deletions(-) create mode 100644 retest/standalone/function_pointer/main.nat diff --git a/bootstrap/compiler.zig b/bootstrap/compiler.zig index dc68947..006a752 100644 --- a/bootstrap/compiler.zig +++ b/bootstrap/compiler.zig @@ -88,18 +88,13 @@ const Side = enum { const GlobalSymbol = struct{ attributes: Attributes = .{}, global_declaration: GlobalDeclaration, + type: *Type, alignment: u32, value: Value, id: GlobalSymbol.Id, pub fn get_type(global_symbol: *GlobalSymbol) *Type { - return switch (global_symbol.id) { - .global_variable => block: { - const global_variable = global_symbol.get_payload(.global_variable); - break :block global_variable.type; - }, - else => |t| @panic(@tagName(t)), - }; + return global_symbol.type; } const Id = enum{ @@ -139,7 +134,6 @@ const GlobalSymbol = struct{ const GlobalVariable = struct { global_symbol: GlobalSymbol, - type: *Type, initial_value: *Value, }; @@ -170,6 +164,7 @@ const LocalSymbol = struct { appointee_type: ?*Type = null, instruction: Instruction, alignment: u32, + initial_value: *Value, const Attributes = struct{ mutability: Mutability = .@"const", @@ -716,71 +711,117 @@ const Parser = struct{ parser.i += 1; parser.skip_space(src); - switch (lookup_result.declaration.*.id) { - .local => unreachable, - .global => { - const global_declaration = lookup_result.declaration.*.get_payload(.global); - switch (global_declaration.id) { - .global_symbol => { - const global_symbol = global_declaration.to_symbol(); - switch (global_symbol.id) { - .function_definition => { - const function_definition = global_symbol.get_payload(.function_definition); - const declaration_argument_count = function_definition.declaration.argument_types.len; - var argument_values = PinnedArray(*Value){}; - while (true) { - parser.skip_space(src); - - if (src[parser.i] == ')') { - break; - } - - const argument_index = argument_values.length; - if (argument_index >= declaration_argument_count) { - exit(1); - } - const expected_argument_type = function_definition.declaration.argument_types[argument_index]; - const passed_argument_value = parser.parse_expression(analyzer, thread, file, expected_argument_type, .right); - _ = argument_values.append(passed_argument_value); - - parser.skip_space(src); - - switch (src[parser.i]) { - ',' => parser.i += 1, - ')' => {}, - else => unreachable, - } - } - - parser.i += 1; - - const call = thread.calls.append(.{ + const FunctionCallData = struct{ + type: *Type.Function, + value: *Value, + }; + const function_call_data: FunctionCallData = switch (lookup_result.declaration.*.id) { + .local => local: { + const local_declaration = lookup_result.declaration.*.get_payload(.local); + const local_symbol = local_declaration.to_symbol(); + break :local switch (local_symbol.type.sema.id) { + .pointer => p: { + const appointee_type = local_symbol.appointee_type.?; + break :p switch (appointee_type.sema.id) { + .function => f: { + const function_type = appointee_type.get_payload(.function); + const load = thread.loads.append(.{ .instruction = .{ + .id = .load, .value = .{ .sema = .{ .thread = thread.get_index(), .resolved = true, .id = .instruction, }, - }, - .id = .call, + }, }, - .callable = &function_definition.declaration.global_symbol.value, - .arguments = argument_values.const_slice(), + .value = &local_symbol.instruction.value, + .type = local_symbol.type, + .alignment = 8, + .is_volatile = false, }); - _ = analyzer.current_basic_block.instructions.append(&call.instruction); - return &call.instruction.value; + _ = analyzer.current_basic_block.instructions.append(&load.instruction); + break :f .{ + .type = function_type, + .value = &load.instruction.value, + }; }, else => |t| @panic(@tagName(t)), - } - unreachable; + }; }, else => |t| @panic(@tagName(t)), - } + }; + }, + .global => g: { + const global_declaration = lookup_result.declaration.*.get_payload(.global); + break :g switch (global_declaration.id) { + .global_symbol => gs: { + const global_symbol = global_declaration.to_symbol(); + break :gs switch (global_symbol.id) { + .function_definition => f: { + const function_definition = global_symbol.get_payload(.function_definition); + const function_type = function_definition.declaration.get_type(); + break :f .{ + .type = function_type, + .value = &function_definition.declaration.global_symbol.value, + }; + }, + else => |t| @panic(@tagName(t)), + }; + }, + else => |t| @panic(@tagName(t)), + }; }, .argument => unreachable, + }; + + const function_type = function_call_data.type; + const function_value = function_call_data.value; + const declaration_argument_count = function_type.argument_types.len; + var argument_values = PinnedArray(*Value){}; + while (true) { + parser.skip_space(src); + + if (src[parser.i] == ')') { + break; + } + + const argument_index = argument_values.length; + if (argument_index >= declaration_argument_count) { + exit(1); + } + const expected_argument_type = function_type.argument_types[argument_index]; + const passed_argument_value = parser.parse_expression(analyzer, thread, file, expected_argument_type, .right); + _ = argument_values.append(passed_argument_value); + + parser.skip_space(src); + + switch (src[parser.i]) { + ',' => parser.i += 1, + ')' => {}, + else => unreachable, + } } + parser.i += 1; + + const call = thread.calls.append(.{ + .instruction = .{ + .value = .{ + .sema = .{ + .thread = thread.get_index(), + .resolved = true, + .id = .instruction, + }, + }, + .id = .call, + }, + .callable = function_value, + .arguments = argument_values.const_slice(), + }); + _ = analyzer.current_basic_block.instructions.append(&call.instruction); + return &call.instruction.value; }, '.' => { switch (lookup_result.declaration.*.id) { @@ -920,9 +961,19 @@ const Parser = struct{ '&' => { parser.i += 1; - const local_declaration = lookup_result.declaration.*.get_payload(.local); - const local_symbol = local_declaration.to_symbol(); - return &local_symbol.instruction.value; + switch (lookup_result.declaration.*.id) { + .local => { + const local_declaration = lookup_result.declaration.*.get_payload(.local); + const local_symbol = local_declaration.to_symbol(); + return &local_symbol.instruction.value; + }, + .global => { + const global_declaration = lookup_result.declaration.*.get_payload(.global); + const global_symbol = global_declaration.to_symbol(); + return &global_symbol.value; + }, + else => |t| @panic(@tagName(t)), + } }, '@' => { parser.i += 1; @@ -1791,19 +1842,8 @@ const Value = struct { }, .call => { const call = instruction.get_payload(.call); - switch (call.callable.sema.id) { - .global_symbol => { - const global_symbol = call.callable.get_payload(.global_symbol); - switch (global_symbol.id) { - .function_definition => { - const function_declaration = global_symbol.get_payload(.function_declaration); - return function_declaration.return_type; - }, - else => |t| @panic(@tagName(t)), - } - }, - else => |t| @panic(@tagName(t)), - } + const function_type = call.get_function_type(); + return function_type.return_type; }, .integer_compare => { return &instance.threads[value.sema.thread].integers[0].type; @@ -1827,9 +1867,7 @@ const Value = struct { return constant_int.type; }, .global_symbol => { - const global_symbol = value.get_payload(.global_symbol); - const global_type = global_symbol.get_type(); - return global_type; + return &instance.threads[value.sema.thread].pointer; }, else => |t| @panic(@tagName(t)), }; @@ -1854,6 +1892,7 @@ const Type = struct { integer, array, pointer, + function, }; const Integer = struct { @@ -1877,12 +1916,19 @@ const Type = struct { }; }; + const Function = struct{ + type: Type, + argument_types: []const *Type, + return_type: *Type, + }; + const id_to_type_map = std.EnumArray(Id, type).init(.{ .unresolved = void, .void = void, .integer = Integer, .array = Array, .pointer = void, + .function = Type.Function, }); fn get_payload(ty: *Type, comptime id: Id) *id_to_type_map.get(id) { @@ -2091,9 +2137,13 @@ const Function = struct{ const Declaration = struct { attributes: Attributes = .{}, global_symbol: GlobalSymbol, - return_type: *Type, - argument_types: []const *Type = &.{}, file: u32, + + fn get_type(declaration: *Function.Declaration) *Type.Function { + const ty = declaration.global_symbol.type; + const function_type = ty.get_payload(.function); + return function_type; + } }; const Scope = struct { @@ -2234,6 +2284,53 @@ const Call = struct{ instruction: Instruction, callable: *Value, arguments: []const *Value, + + fn get_function_type(call: *Call) *Type.Function{ + switch (call.callable.sema.id) { + .global_symbol => { + const global_symbol = call.callable.get_payload(.global_symbol); + switch (global_symbol.id) { + .function_definition => { + const function_declaration = global_symbol.get_payload(.function_declaration); + const function_type = function_declaration.get_type(); + return function_type; + }, + else => |t| @panic(@tagName(t)), + } + }, + .instruction => { + const callable_instruction = call.callable.get_payload(.instruction); + switch (callable_instruction.id) { + .load => { + const load = callable_instruction.get_payload(.load); + switch (load.value.sema.id) { + .instruction => { + const load_instruction = load.value.get_payload(.instruction); + switch (load_instruction.id) { + .local_symbol => { + const local_symbol = load_instruction.get_payload(.local_symbol); + assert(local_symbol.type.sema.id == .pointer); + const app = local_symbol.appointee_type.?; + switch (app.sema.id) { + .function => { + const function_type = app.get_payload(.function); + return function_type; + }, + else => |t| @panic(@tagName(t)), + } + }, + else => |t| @panic(@tagName(t)), + } + }, + else => |t| @panic(@tagName(t)), + } + }, + else => |t| @panic(@tagName(t)), + } + }, + else => |t| @panic(@tagName(t)), + } + } }; const Load = struct { @@ -2332,6 +2429,7 @@ const Thread = struct{ unreachables: PinnedArray(Unreachable) = .{}, leading_zeroes: PinnedArray(LeadingZeroes) = .{}, trailing_zeroes: PinnedArray(TrailingZeroes) = .{}, + function_types: PinnedArray(Type.Function) = .{}, array_type_map: PinnedHashMap(Type.Array.Descriptor, *Type) = .{}, array_types: PinnedArray(Type.Array) = .{}, analyzed_file_count: u32 = 0, @@ -3355,14 +3453,15 @@ fn worker_thread(thread_index: u32, cpu_count: *u32) void { const function_definition = global_symbol.get_payload(.function_definition); assert(function_definition.declaration.global_symbol.value.sema.resolved); assert(function_definition.declaration.global_symbol.value.sema.resolved); - assert(function_definition.declaration.return_type.sema.thread == thread.get_index()); - // TODO: here we are duplicating the function declaration, but not the types. It could be interesting to duplicate the types so in the LLVM IR no special case has to take place to deduplicate work done in different threads + const function_type = function_definition.declaration.get_type(); + assert(function_type.return_type.sema.thread == thread.get_index()); + // // TODO: here we are duplicating the function declaration, but not the types. It could be interesting to duplicate the types so in the LLVM IR no special case has to take place to deduplicate work done in different threads const external_fn = thread.external_functions.append(function_definition.declaration); external_fn.global_symbol.attributes.@"export" = false; external_fn.global_symbol.attributes.@"extern" = true; external_fn.global_symbol.value.sema.thread = thread.get_index(); external_fn.global_symbol.value.llvm = null; - + call.callable = &external_fn.global_symbol.value; value.sema.resolved = true; }, @@ -3513,7 +3612,7 @@ fn worker_thread(thread_index: u32, cpu_count: *u32) void { } for (thread.global_variables.slice()) |*nat_global| { - const global_type = llvm_get_type(thread, nat_global.type); + const global_type = llvm_get_type(thread, nat_global.global_symbol.type); const linkage: LLVM.Linkage = switch (nat_global.global_symbol.attributes.@"export") { true => .@"extern", false => .internal, @@ -3619,9 +3718,8 @@ fn worker_thread(thread_index: u32, cpu_count: *u32) void { }, .call => block: { const call = instruction.get_payload(.call); + const function_type = llvm_get_type(thread, &call.get_function_type().type); const callee = llvm_get_value(thread, call.callable); - const callee_function = callee.toFunction() orelse unreachable; - const function_type = callee_function.getType(); var arguments = std.BoundedArray(*LLVM.Value, 512){}; for (call.arguments) |argument| { @@ -3631,7 +3729,7 @@ fn worker_thread(thread_index: u32, cpu_count: *u32) void { const args = arguments.constSlice(); - const call_i = builder.createCall(function_type, callee, args.ptr, args.len, "", "".len, null); + const call_i = builder.createCall(function_type.toFunction() orelse unreachable, callee, args.ptr, args.len, "", "".len, null); break :block call_i.toValue(); }, .integer_compare => block: { @@ -3948,6 +4046,19 @@ fn llvm_get_type(thread: *Thread, ty: *Type) *LLVM.Type { const pointer_type = thread.llvm.context.getPointerType(0); break :b pointer_type.toType(); }, + .function => b: { + const nat_function_type = ty.get_payload(.function); + const return_type = llvm_get_type(thread, nat_function_type.return_type); + var argument_types = PinnedArray(*LLVM.Type){}; + _ = &argument_types; + for (nat_function_type.argument_types) |argument_type| { + const llvm_arg_type = llvm_get_type(thread, argument_type); + _ = argument_types.append(llvm_arg_type); + } + const is_var_args = false; + const function_type = LLVM.getFunctionType(return_type, argument_types.pointer, argument_types.length, is_var_args); + break :b function_type.toType(); + }, else => |t| @panic(@tagName(t)), }; @@ -3991,22 +4102,15 @@ fn llvm_get_function(thread: *Thread, nat_function: *Function.Declaration, overr if (nat_function.global_symbol.value.llvm) |llvm| return llvm.toFunction() orelse unreachable else { _ = override_extern; // autofix const function_name = thread.identifiers.get(nat_function.global_symbol.global_declaration.declaration.name) orelse unreachable; - const return_type = llvm_get_type(thread, nat_function.return_type); - var argument_types = PinnedArray(*LLVM.Type){}; - _ = &argument_types; - for (nat_function.argument_types) |argument_type| { - const llvm_arg_type = llvm_get_type(thread, argument_type); - _ = argument_types.append(llvm_arg_type); - } - const is_var_args = false; - const function_type = LLVM.getFunctionType(return_type, argument_types.pointer, argument_types.length, is_var_args); + const nat_function_type = nat_function.get_type(); + const function_type = llvm_get_type(thread, &nat_function_type.type); const is_extern_function = nat_function.global_symbol.attributes.@"extern"; const export_or_extern = nat_function.global_symbol.attributes.@"export" or is_extern_function; const linkage: LLVM.Linkage = switch (export_or_extern) { true => .@"extern", false => .internal, }; - const function = thread.llvm.module.createFunction(function_type, linkage, address_space, function_name.ptr, function_name.len); + const function = thread.llvm.module.createFunction(function_type.toFunction() orelse unreachable, linkage, address_space, function_name.ptr, function_name.len); const debug_info = false; if (debug_info) { @@ -4155,7 +4259,6 @@ pub fn analyze_local_block(thread: *Thread, analyzer: *Analyzer, parser: *Parser exit_with_error("Existing declaration with the same name"); } - const has_local_attributes = src[parser.i] == '['; parser.i += @intFromBool(has_local_attributes); @@ -4230,6 +4333,7 @@ pub fn analyze_local_block(thread: *Thread, analyzer: *Analyzer, parser: *Parser .id = .local_symbol, }, .alignment = result.type.alignment, + .initial_value = result.initial_value, }); if (local_symbol.type.sema.id == .pointer) { @@ -4244,6 +4348,10 @@ pub fn analyze_local_block(thread: *Thread, analyzer: *Analyzer, parser: *Parser else => |t| @panic(@tagName(t)), } }, + .global_symbol => { + const global_symbol = result.initial_value.get_payload(.global_symbol); + local_symbol.appointee_type = global_symbol.type; + }, else => |t| @panic(@tagName(t)), } } @@ -4365,8 +4473,9 @@ pub fn analyze_local_block(thread: *Thread, analyzer: *Analyzer, parser: *Parser if (byte_equal(identifier, "return")) { parser.skip_space(src); - if (function.declaration.return_type.sema.id != .unresolved) { - const return_type = function.declaration.return_type; + const function_type = function.declaration.get_type(); + if (function_type.return_type.sema.id != .unresolved) { + const return_type = function_type.return_type; const return_value = parser.parse_expression(analyzer, thread, file, return_type, .right); parser.expect_character(src, ';'); @@ -4691,8 +4800,8 @@ pub fn analyze_file(thread: *Thread, file_index: u32) void { }, .alignment = global_type.alignment, .id = .global_variable, + .type = global_type, }, - .type = global_type, .initial_value = global_initial_value, }); @@ -4707,7 +4816,6 @@ pub fn analyze_file(thread: *Thread, file_index: u32) void { const entry_block = create_basic_block(thread); function.* = .{ .declaration = .{ - .return_type = undefined, .global_symbol = .{ .global_declaration = .{ .declaration = .{ @@ -4725,6 +4833,7 @@ pub fn analyze_file(thread: *Thread, file_index: u32) void { }, }, .id = .function_definition, + .type = undefined, }, .file = file_index, }, @@ -4912,13 +5021,27 @@ pub fn analyze_file(thread: *Thread, file_index: u32) void { } } - function.declaration.argument_types = argument_types.const_slice(); - parser.expect_character(src, ')'); parser.skip_space(src); - function.declaration.return_type = parser.parse_type_expression(thread, file); + const return_type = parser.parse_type_expression(thread, file); + + const function_type = thread.function_types.append(.{ + .type = .{ + .sema = .{ + .id = .function, + .resolved = true, + .thread = thread.get_index(), + }, + .size = 0, + .alignment = 0, + }, + .argument_types = argument_types.const_slice(), + .return_type = return_type, + }); + + function.declaration.global_symbol.type = &function_type.type; parser.skip_space(src); @@ -4992,9 +5115,6 @@ pub fn analyze_file(thread: *Thread, file_index: u32) void { analyzer.current_basic_block = current_basic_block; } - const return_type = function.declaration.return_type; - _ = return_type; - if (!current_basic_block.is_terminated and (current_basic_block.instructions.length > 0 or current_basic_block.predecessors.length > 0)) { unreachable; } diff --git a/retest/standalone/function_pointer/main.nat b/retest/standalone/function_pointer/main.nat new file mode 100644 index 0000000..1101912 --- /dev/null +++ b/retest/standalone/function_pointer/main.nat @@ -0,0 +1,10 @@ +>n: s32 = 5; +fn foo() s32 { + return n; +} + +fn[cc(.c)] main[export]() s32 { + >fn_pointer = foo&; + >a = fn_pointer(); + return a - n; +}