Implement function pointer

This commit is contained in:
David Gonzalez Martin 2024-06-07 13:08:45 -06:00
parent 6cd7c28efb
commit 0ee4e907b6
2 changed files with 235 additions and 105 deletions

View File

@ -88,18 +88,13 @@ const Side = enum {
const GlobalSymbol = struct{
attributes: Attributes = .{},
global_declaration: GlobalDeclaration,
type: *Type,
alignment: u32,
value: Value,
id: GlobalSymbol.Id,
pub fn get_type(global_symbol: *GlobalSymbol) *Type {
return switch (global_symbol.id) {
.global_variable => block: {
const global_variable = global_symbol.get_payload(.global_variable);
break :block global_variable.type;
},
else => |t| @panic(@tagName(t)),
};
return global_symbol.type;
}
const Id = enum{
@ -139,7 +134,6 @@ const GlobalSymbol = struct{
const GlobalVariable = struct {
global_symbol: GlobalSymbol,
type: *Type,
initial_value: *Value,
};
@ -170,6 +164,7 @@ const LocalSymbol = struct {
appointee_type: ?*Type = null,
instruction: Instruction,
alignment: u32,
initial_value: *Value,
const Attributes = struct{
mutability: Mutability = .@"const",
@ -716,17 +711,74 @@ const Parser = struct{
parser.i += 1;
parser.skip_space(src);
switch (lookup_result.declaration.*.id) {
.local => unreachable,
.global => {
const FunctionCallData = struct{
type: *Type.Function,
value: *Value,
};
const function_call_data: FunctionCallData = switch (lookup_result.declaration.*.id) {
.local => local: {
const local_declaration = lookup_result.declaration.*.get_payload(.local);
const local_symbol = local_declaration.to_symbol();
break :local switch (local_symbol.type.sema.id) {
.pointer => p: {
const appointee_type = local_symbol.appointee_type.?;
break :p switch (appointee_type.sema.id) {
.function => f: {
const function_type = appointee_type.get_payload(.function);
const load = thread.loads.append(.{
.instruction = .{
.id = .load,
.value = .{
.sema = .{
.thread = thread.get_index(),
.resolved = true,
.id = .instruction,
},
},
},
.value = &local_symbol.instruction.value,
.type = local_symbol.type,
.alignment = 8,
.is_volatile = false,
});
_ = analyzer.current_basic_block.instructions.append(&load.instruction);
break :f .{
.type = function_type,
.value = &load.instruction.value,
};
},
else => |t| @panic(@tagName(t)),
};
},
else => |t| @panic(@tagName(t)),
};
},
.global => g: {
const global_declaration = lookup_result.declaration.*.get_payload(.global);
switch (global_declaration.id) {
.global_symbol => {
break :g switch (global_declaration.id) {
.global_symbol => gs: {
const global_symbol = global_declaration.to_symbol();
switch (global_symbol.id) {
.function_definition => {
break :gs switch (global_symbol.id) {
.function_definition => f: {
const function_definition = global_symbol.get_payload(.function_definition);
const declaration_argument_count = function_definition.declaration.argument_types.len;
const function_type = function_definition.declaration.get_type();
break :f .{
.type = function_type,
.value = &function_definition.declaration.global_symbol.value,
};
},
else => |t| @panic(@tagName(t)),
};
},
else => |t| @panic(@tagName(t)),
};
},
.argument => unreachable,
};
const function_type = function_call_data.type;
const function_value = function_call_data.value;
const declaration_argument_count = function_type.argument_types.len;
var argument_values = PinnedArray(*Value){};
while (true) {
parser.skip_space(src);
@ -739,7 +791,7 @@ const Parser = struct{
if (argument_index >= declaration_argument_count) {
exit(1);
}
const expected_argument_type = function_definition.declaration.argument_types[argument_index];
const expected_argument_type = function_type.argument_types[argument_index];
const passed_argument_value = parser.parse_expression(analyzer, thread, file, expected_argument_type, .right);
_ = argument_values.append(passed_argument_value);
@ -765,23 +817,12 @@ const Parser = struct{
},
.id = .call,
},
.callable = &function_definition.declaration.global_symbol.value,
.callable = function_value,
.arguments = argument_values.const_slice(),
});
_ = analyzer.current_basic_block.instructions.append(&call.instruction);
return &call.instruction.value;
},
else => |t| @panic(@tagName(t)),
}
unreachable;
},
else => |t| @panic(@tagName(t)),
}
},
.argument => unreachable,
}
},
'.' => {
switch (lookup_result.declaration.*.id) {
.global => {
@ -920,10 +961,20 @@ const Parser = struct{
'&' => {
parser.i += 1;
switch (lookup_result.declaration.*.id) {
.local => {
const local_declaration = lookup_result.declaration.*.get_payload(.local);
const local_symbol = local_declaration.to_symbol();
return &local_symbol.instruction.value;
},
.global => {
const global_declaration = lookup_result.declaration.*.get_payload(.global);
const global_symbol = global_declaration.to_symbol();
return &global_symbol.value;
},
else => |t| @panic(@tagName(t)),
}
},
'@' => {
parser.i += 1;
@ -1791,19 +1842,8 @@ const Value = struct {
},
.call => {
const call = instruction.get_payload(.call);
switch (call.callable.sema.id) {
.global_symbol => {
const global_symbol = call.callable.get_payload(.global_symbol);
switch (global_symbol.id) {
.function_definition => {
const function_declaration = global_symbol.get_payload(.function_declaration);
return function_declaration.return_type;
},
else => |t| @panic(@tagName(t)),
}
},
else => |t| @panic(@tagName(t)),
}
const function_type = call.get_function_type();
return function_type.return_type;
},
.integer_compare => {
return &instance.threads[value.sema.thread].integers[0].type;
@ -1827,9 +1867,7 @@ const Value = struct {
return constant_int.type;
},
.global_symbol => {
const global_symbol = value.get_payload(.global_symbol);
const global_type = global_symbol.get_type();
return global_type;
return &instance.threads[value.sema.thread].pointer;
},
else => |t| @panic(@tagName(t)),
};
@ -1854,6 +1892,7 @@ const Type = struct {
integer,
array,
pointer,
function,
};
const Integer = struct {
@ -1877,12 +1916,19 @@ const Type = struct {
};
};
const Function = struct{
type: Type,
argument_types: []const *Type,
return_type: *Type,
};
const id_to_type_map = std.EnumArray(Id, type).init(.{
.unresolved = void,
.void = void,
.integer = Integer,
.array = Array,
.pointer = void,
.function = Type.Function,
});
fn get_payload(ty: *Type, comptime id: Id) *id_to_type_map.get(id) {
@ -2091,9 +2137,13 @@ const Function = struct{
const Declaration = struct {
attributes: Attributes = .{},
global_symbol: GlobalSymbol,
return_type: *Type,
argument_types: []const *Type = &.{},
file: u32,
fn get_type(declaration: *Function.Declaration) *Type.Function {
const ty = declaration.global_symbol.type;
const function_type = ty.get_payload(.function);
return function_type;
}
};
const Scope = struct {
@ -2234,6 +2284,53 @@ const Call = struct{
instruction: Instruction,
callable: *Value,
arguments: []const *Value,
fn get_function_type(call: *Call) *Type.Function{
switch (call.callable.sema.id) {
.global_symbol => {
const global_symbol = call.callable.get_payload(.global_symbol);
switch (global_symbol.id) {
.function_definition => {
const function_declaration = global_symbol.get_payload(.function_declaration);
const function_type = function_declaration.get_type();
return function_type;
},
else => |t| @panic(@tagName(t)),
}
},
.instruction => {
const callable_instruction = call.callable.get_payload(.instruction);
switch (callable_instruction.id) {
.load => {
const load = callable_instruction.get_payload(.load);
switch (load.value.sema.id) {
.instruction => {
const load_instruction = load.value.get_payload(.instruction);
switch (load_instruction.id) {
.local_symbol => {
const local_symbol = load_instruction.get_payload(.local_symbol);
assert(local_symbol.type.sema.id == .pointer);
const app = local_symbol.appointee_type.?;
switch (app.sema.id) {
.function => {
const function_type = app.get_payload(.function);
return function_type;
},
else => |t| @panic(@tagName(t)),
}
},
else => |t| @panic(@tagName(t)),
}
},
else => |t| @panic(@tagName(t)),
}
},
else => |t| @panic(@tagName(t)),
}
},
else => |t| @panic(@tagName(t)),
}
}
};
const Load = struct {
@ -2332,6 +2429,7 @@ const Thread = struct{
unreachables: PinnedArray(Unreachable) = .{},
leading_zeroes: PinnedArray(LeadingZeroes) = .{},
trailing_zeroes: PinnedArray(TrailingZeroes) = .{},
function_types: PinnedArray(Type.Function) = .{},
array_type_map: PinnedHashMap(Type.Array.Descriptor, *Type) = .{},
array_types: PinnedArray(Type.Array) = .{},
analyzed_file_count: u32 = 0,
@ -3355,8 +3453,9 @@ fn worker_thread(thread_index: u32, cpu_count: *u32) void {
const function_definition = global_symbol.get_payload(.function_definition);
assert(function_definition.declaration.global_symbol.value.sema.resolved);
assert(function_definition.declaration.global_symbol.value.sema.resolved);
assert(function_definition.declaration.return_type.sema.thread == thread.get_index());
// TODO: here we are duplicating the function declaration, but not the types. It could be interesting to duplicate the types so in the LLVM IR no special case has to take place to deduplicate work done in different threads
const function_type = function_definition.declaration.get_type();
assert(function_type.return_type.sema.thread == thread.get_index());
// // TODO: here we are duplicating the function declaration, but not the types. It could be interesting to duplicate the types so in the LLVM IR no special case has to take place to deduplicate work done in different threads
const external_fn = thread.external_functions.append(function_definition.declaration);
external_fn.global_symbol.attributes.@"export" = false;
external_fn.global_symbol.attributes.@"extern" = true;
@ -3513,7 +3612,7 @@ fn worker_thread(thread_index: u32, cpu_count: *u32) void {
}
for (thread.global_variables.slice()) |*nat_global| {
const global_type = llvm_get_type(thread, nat_global.type);
const global_type = llvm_get_type(thread, nat_global.global_symbol.type);
const linkage: LLVM.Linkage = switch (nat_global.global_symbol.attributes.@"export") {
true => .@"extern",
false => .internal,
@ -3619,9 +3718,8 @@ fn worker_thread(thread_index: u32, cpu_count: *u32) void {
},
.call => block: {
const call = instruction.get_payload(.call);
const function_type = llvm_get_type(thread, &call.get_function_type().type);
const callee = llvm_get_value(thread, call.callable);
const callee_function = callee.toFunction() orelse unreachable;
const function_type = callee_function.getType();
var arguments = std.BoundedArray(*LLVM.Value, 512){};
for (call.arguments) |argument| {
@ -3631,7 +3729,7 @@ fn worker_thread(thread_index: u32, cpu_count: *u32) void {
const args = arguments.constSlice();
const call_i = builder.createCall(function_type, callee, args.ptr, args.len, "", "".len, null);
const call_i = builder.createCall(function_type.toFunction() orelse unreachable, callee, args.ptr, args.len, "", "".len, null);
break :block call_i.toValue();
},
.integer_compare => block: {
@ -3948,6 +4046,19 @@ fn llvm_get_type(thread: *Thread, ty: *Type) *LLVM.Type {
const pointer_type = thread.llvm.context.getPointerType(0);
break :b pointer_type.toType();
},
.function => b: {
const nat_function_type = ty.get_payload(.function);
const return_type = llvm_get_type(thread, nat_function_type.return_type);
var argument_types = PinnedArray(*LLVM.Type){};
_ = &argument_types;
for (nat_function_type.argument_types) |argument_type| {
const llvm_arg_type = llvm_get_type(thread, argument_type);
_ = argument_types.append(llvm_arg_type);
}
const is_var_args = false;
const function_type = LLVM.getFunctionType(return_type, argument_types.pointer, argument_types.length, is_var_args);
break :b function_type.toType();
},
else => |t| @panic(@tagName(t)),
};
@ -3991,22 +4102,15 @@ fn llvm_get_function(thread: *Thread, nat_function: *Function.Declaration, overr
if (nat_function.global_symbol.value.llvm) |llvm| return llvm.toFunction() orelse unreachable else {
_ = override_extern; // autofix
const function_name = thread.identifiers.get(nat_function.global_symbol.global_declaration.declaration.name) orelse unreachable;
const return_type = llvm_get_type(thread, nat_function.return_type);
var argument_types = PinnedArray(*LLVM.Type){};
_ = &argument_types;
for (nat_function.argument_types) |argument_type| {
const llvm_arg_type = llvm_get_type(thread, argument_type);
_ = argument_types.append(llvm_arg_type);
}
const is_var_args = false;
const function_type = LLVM.getFunctionType(return_type, argument_types.pointer, argument_types.length, is_var_args);
const nat_function_type = nat_function.get_type();
const function_type = llvm_get_type(thread, &nat_function_type.type);
const is_extern_function = nat_function.global_symbol.attributes.@"extern";
const export_or_extern = nat_function.global_symbol.attributes.@"export" or is_extern_function;
const linkage: LLVM.Linkage = switch (export_or_extern) {
true => .@"extern",
false => .internal,
};
const function = thread.llvm.module.createFunction(function_type, linkage, address_space, function_name.ptr, function_name.len);
const function = thread.llvm.module.createFunction(function_type.toFunction() orelse unreachable, linkage, address_space, function_name.ptr, function_name.len);
const debug_info = false;
if (debug_info) {
@ -4155,7 +4259,6 @@ pub fn analyze_local_block(thread: *Thread, analyzer: *Analyzer, parser: *Parser
exit_with_error("Existing declaration with the same name");
}
const has_local_attributes = src[parser.i] == '[';
parser.i += @intFromBool(has_local_attributes);
@ -4230,6 +4333,7 @@ pub fn analyze_local_block(thread: *Thread, analyzer: *Analyzer, parser: *Parser
.id = .local_symbol,
},
.alignment = result.type.alignment,
.initial_value = result.initial_value,
});
if (local_symbol.type.sema.id == .pointer) {
@ -4244,6 +4348,10 @@ pub fn analyze_local_block(thread: *Thread, analyzer: *Analyzer, parser: *Parser
else => |t| @panic(@tagName(t)),
}
},
.global_symbol => {
const global_symbol = result.initial_value.get_payload(.global_symbol);
local_symbol.appointee_type = global_symbol.type;
},
else => |t| @panic(@tagName(t)),
}
}
@ -4365,8 +4473,9 @@ pub fn analyze_local_block(thread: *Thread, analyzer: *Analyzer, parser: *Parser
if (byte_equal(identifier, "return")) {
parser.skip_space(src);
if (function.declaration.return_type.sema.id != .unresolved) {
const return_type = function.declaration.return_type;
const function_type = function.declaration.get_type();
if (function_type.return_type.sema.id != .unresolved) {
const return_type = function_type.return_type;
const return_value = parser.parse_expression(analyzer, thread, file, return_type, .right);
parser.expect_character(src, ';');
@ -4691,8 +4800,8 @@ pub fn analyze_file(thread: *Thread, file_index: u32) void {
},
.alignment = global_type.alignment,
.id = .global_variable,
},
.type = global_type,
},
.initial_value = global_initial_value,
});
@ -4707,7 +4816,6 @@ pub fn analyze_file(thread: *Thread, file_index: u32) void {
const entry_block = create_basic_block(thread);
function.* = .{
.declaration = .{
.return_type = undefined,
.global_symbol = .{
.global_declaration = .{
.declaration = .{
@ -4725,6 +4833,7 @@ pub fn analyze_file(thread: *Thread, file_index: u32) void {
},
},
.id = .function_definition,
.type = undefined,
},
.file = file_index,
},
@ -4912,13 +5021,27 @@ pub fn analyze_file(thread: *Thread, file_index: u32) void {
}
}
function.declaration.argument_types = argument_types.const_slice();
parser.expect_character(src, ')');
parser.skip_space(src);
function.declaration.return_type = parser.parse_type_expression(thread, file);
const return_type = parser.parse_type_expression(thread, file);
const function_type = thread.function_types.append(.{
.type = .{
.sema = .{
.id = .function,
.resolved = true,
.thread = thread.get_index(),
},
.size = 0,
.alignment = 0,
},
.argument_types = argument_types.const_slice(),
.return_type = return_type,
});
function.declaration.global_symbol.type = &function_type.type;
parser.skip_space(src);
@ -4992,9 +5115,6 @@ pub fn analyze_file(thread: *Thread, file_index: u32) void {
analyzer.current_basic_block = current_basic_block;
}
const return_type = function.declaration.return_type;
_ = return_type;
if (!current_basic_block.is_terminated and (current_basic_block.instructions.length > 0 or current_basic_block.predecessors.length > 0)) {
unreachable;
}

View File

@ -0,0 +1,10 @@
>n: s32 = 5;
fn foo() s32 {
return n;
}
fn[cc(.c)] main[export]() s32 {
>fn_pointer = foo&;
>a = fn_pointer();
return a - n;
}