Pass 'basic_union'

This commit is contained in:
David Gonzalez Martin 2025-06-16 15:30:48 -06:00
parent 06395cc20e
commit 6837d85273
2 changed files with 278 additions and 7 deletions

View File

@ -1407,6 +1407,22 @@ TypeAlias = struct
line: u32,
}
UnionField = struct
{
type: &Type,
name: []u8,
line: u32,
}
TypeUnion = struct
{
fields: []UnionField,
byte_size: u64,
biggest_field: u64,
byte_alignment: u32,
line: u32,
}
TypeContent = union
{
integer: TypeInteger,
@ -1418,6 +1434,7 @@ TypeContent = union
struct: TypeStruct,
bits: TypeBits,
alias: TypeAlias,
union: TypeUnion,
}
TypeLLVM = struct
@ -1604,6 +1621,11 @@ get_byte_size = fn (type: &Type) u64
>result = get_byte_size(type.content.bits.backing_type);
return result;
},
.union =>
{
>result = type.content.union.byte_size;
return result;
},
else =>
{
#trap();
@ -1643,6 +1665,10 @@ get_byte_alignment = fn (type: &Type) u32
>alignment = get_byte_alignment(backing_type);
return alignment;
},
.union =>
{
return type.content.union.byte_alignment;
},
else =>
{
#trap();
@ -2415,6 +2441,7 @@ LLVMICmpPredicate = enum u32
[extern] LLVMDIBuilderCreateMemberType = fn [cc(c)] (di_builder: &LLVMDIBuilder, scope: &LLVMMetadata, name_pointer: &u8, name_length: u64, file: &LLVMMetadata, line: u32, bit_size: u64, bit_alignment: u32, bit_offset: u64, flags: LLVMDIFlags, type: &LLVMMetadata) &LLVMMetadata;
[extern] LLVMDIBuilderCreateBitFieldMemberType = fn [cc(c)] (di_builder: &LLVMDIBuilder, scope: &LLVMMetadata, name_pointer: &u8, name_length: u64, file: &LLVMMetadata, line: u32, bit_size: u64, bit_offset: u64, storage_bit_offset: u64, flags: LLVMDIFlags, type: &LLVMMetadata) &LLVMMetadata;
[extern] LLVMDIBuilderCreateStructType = fn [cc(c)] (di_builder: &LLVMDIBuilder, scope: &LLVMMetadata, name_pointer: &u8, name_length: u64, file: &LLVMMetadata, line: u32, bit_size: u64, bit_alignment: u32, flags: LLVMDIFlags, derived_from: &LLVMMetadata, element_pointer: &&LLVMMetadata, element_count: u32, runtime_language: u32, vtable_holder: &LLVMMetadata, unique_identifier_pointer: &u8, unique_identifier_length: u64) &LLVMMetadata;
[extern] LLVMDIBuilderCreateUnionType = fn [cc(c)] (di_builder: &LLVMDIBuilder, scope: &LLVMMetadata, name_pointer: &u8, name_length: u64, file: &LLVMMetadata, line: u32, bit_size: u64, bit_alignment: u32, flags: LLVMDIFlags, element_pointer: &&LLVMMetadata, element_count: u32, runtime_language: u32, unique_identifier_pointer: &u8, unique_identifier_length: u64) &LLVMMetadata;
[extern] LLVMDIBuilderCreateFunction = fn [cc(c)] (di_builder: &LLVMDIBuilder, scope: &LLVMMetadata, name_pointer: &u8, name_length: u64, linkage_name_pointer: &u8, linkage_name_length: u64, file: &LLVMMetadata, line: u32, type: &LLVMMetadata, is_local_to_unit: s32, is_definition: s32, scope_line: u32, flags: LLVMDIFlags, is_optimized: s32) &LLVMMetadata;
[extern] LLVMDIBuilderFinalizeSubprogram = fn [cc(c)] (di_builder: &LLVMDIBuilder, subprogram: &LLVMMetadata) void;
@ -7115,7 +7142,88 @@ parse = fn (module: &Module) void
},
.union =>
{
#trap();
skip_space(module);
expect_character(module, left_brace);
>union_type: &Type = undefined;
if (type_forward_declaration)
{
union_type = type_forward_declaration;
}
else
{
union_type = new_type(module, {
.id = .forward_declaration,
.name = global_name,
.scope = &module.scope,
zero,
});
}
>field_count: u64 = 0;
>biggest_field: u64 = 0;
>alignment: u32 = 1;
>byte_size: u64 = 0;
>field_buffer: [64]UnionField = undefined;
while (1)
{
skip_space(module);
if (consume_character_if_match(module, right_brace))
{
break;
}
>field_index = field_count;
field_count += 1;
>field_line = get_line(module);
>field_name = parse_identifier(module);
skip_space(module);
expect_character(module, ':');
skip_space(module);
>field_type = parse_type(module, scope);
>field_alignment = get_byte_alignment(field_type);
>field_size = get_byte_size(field_type);
field_buffer[field_index] = {
.type = field_type,
.name = field_name,
.line = field_line,
};
biggest_field = #select(field_size > byte_size, field_index, biggest_field);
alignment = #max(alignment, field_alignment);
byte_size = #max(byte_size, field_size);
skip_space(module);
consume_character_if_match(module, ',');
}
skip_space(module);
consume_character_if_match(module, ';');
>fields = arena_allocate_slice[UnionField](module.arena, field_count);
memcpy(#pointer_cast(fields.pointer), #pointer_cast(&field_buffer), field_count * #byte_size(UnionField));
>biggest_size = get_byte_size(fields[biggest_field].type);
assert(biggest_size == byte_size);
union_type.content.union = {
.fields = fields,
.byte_size = byte_size,
.byte_alignment = alignment,
.line = global_line,
.biggest_field = biggest_field,
};
union_type.id = .union;
},
}
}
@ -7223,6 +7331,16 @@ resolve_type_in_place_abi = fn (module: &Module, type: &Type) void
>size = get_byte_size(type);
assert(llvm_size == size);
},
.union =>
{
>biggest_type = type.content.union.fields[type.content.union.biggest_field].type;
resolve_type_in_place_memory(module, biggest_type);
result = LLVMStructTypeInContext(module.llvm.context, &biggest_type.llvm.memory, 1, 0);
>llvm_size = LLVMStoreSizeOfType(module.llvm.target_data_layout, result);
>size = get_byte_size(type);
assert(llvm_size == size);
},
else =>
{
#trap();
@ -7273,6 +7391,16 @@ resolve_type_in_place_memory = fn (module: &Module, type: &Type) void
resolve_type_in_place_memory(module, backing_type);
result = backing_type.llvm.memory;
},
.union =>
{
>biggest_type = type.content.union.fields[type.content.union.biggest_field].type;
resolve_type_in_place_memory(module, biggest_type);
result = LLVMStructTypeInContext(module.llvm.context, &biggest_type.llvm.memory, 1, 0);
>llvm_size = LLVMStoreSizeOfType(module.llvm.target_data_layout, result);
>size = get_byte_size(type);
assert(llvm_size == size);
},
else =>
{
#trap();
@ -7424,6 +7552,40 @@ resolve_type_in_place_debug = fn (module: &Module, type: &Type) void
>struct_type = LLVMDIBuilderCreateStructType(module.llvm.di_builder, module.scope.llvm, type.name.pointer, type.name.length, module.llvm.file, type.content.bits.line, size, alignment, flags, zero, &llvm_type_buffer[0], #truncate(fields.length), runtime_language, vtable_holder, type.name.pointer, type.name.length);
result = struct_type;
},
.union =>
{
>flags: LLVMDIFlags = zero;
>runtime_language: u32 = 0;
>byte_size = get_byte_size(type);
>alignment = get_byte_alignment(type);
>forward_declaration = LLVMDIBuilderCreateReplaceableCompositeType(module.llvm.di_builder, module.llvm.debug_tag, type.name.pointer, type.name.length, module.scope.llvm, module.llvm.file, type.content.union.line, runtime_language, byte_size * 8, alignment * 8, flags, type.name.pointer, type.name.length);
module.llvm.debug_tag += 1;
type.llvm.debug = forward_declaration;
>llvm_type_buffer: [64]&LLVMMetadata = undefined;
>fields = type.content.union.fields;
for (i: 0..fields.length)
{
>field = &fields[i];
>field_type = field.type;
resolve_type_in_place_debug(module, field_type);
>offset: u64 = 0;
>member_type = LLVMDIBuilderCreateMemberType(module.llvm.di_builder, module.scope.llvm, field.name.pointer, field.name.length, module.llvm.file, field.line, get_byte_size(field_type) * 8, get_byte_alignment(field_type) * 8, offset, flags, field_type.llvm.debug);
llvm_type_buffer[i] = member_type;
}
>runtime_language: u32 = 0;
>union_type = LLVMDIBuilderCreateUnionType(module.llvm.di_builder, module.scope.llvm, type.name.pointer, type.name.length, module.llvm.file, type.content.union.line, byte_size * 8, alignment * 8, flags, &llvm_type_buffer[0], #truncate(fields.length), runtime_language, type.name.pointer, type.name.length);
LLVMMetadataReplaceAllUsesWith(forward_declaration, union_type);
result = union_type;
},
else =>
{
#trap();
@ -9681,7 +9843,30 @@ analyze_type = fn (module: &Module, value: &Value, expected_type: &Type, analysi
},
.union =>
{
#trap();
>fields = resolved_aggregate_type.content.union.fields;
>union_field: &UnionField = zero;
for (&field: fields)
{
if (string_equal(field_name, field.name))
{
union_field = field;
break;
}
}
if (!union_field)
{
report_error();
}
>field_type = union_field.type;
value_type = field_type;
if (value.kind == .left)
{
value_type = get_pointer_type(module, value_type);
}
},
.bits =>
{
@ -10039,7 +10224,32 @@ analyze_type = fn (module: &Module, value: &Value, expected_type: &Type, analysi
},
.union =>
{
#trap();
if (elements.length != 1)
{
report_error();
}
>initialization_value = elements[0].value;
>initialization_name = elements[0].name;
>result_field: &UnionField = zero;
>fields = aggregate_type.content.union.fields;
for (&field: fields)
{
if (string_equal(initialization_name, field.name))
{
result_field = field;
break;
}
}
if (!result_field)
{
report_error();
}
analyze_type(module, initialization_value, result_field.type, { .must_be_constant = analysis.must_be_constant, zero });
},
.enum_array =>
{
@ -11663,7 +11873,29 @@ emit_field_access = fn (module: &Module, value: &Value, left_llvm: &LLVMValue, l
},
.union =>
{
#trap();
>fields = resolved_aggregate_type.content.union.fields;
>union_field: &UnionField = zero;
for (&field: fields)
{
if (string_equal(field_name, field.name))
{
union_field = field;
break;
}
}
assert(union_field != zero);
>field_type = union_field.type;
resolve_type_in_place(module, field_type);
>struct_type = LLVMStructTypeInContext(module.llvm.context, &field_type.llvm.memory, 1, 0);
field_access = {
.type = field_type,
.field_index = 0,
.struct_type = struct_type,
};
},
else => { unreachable; },
}
@ -13348,7 +13580,45 @@ emit_assignment = fn (module: &Module, left_llvm: &LLVMValue, left_type: &Type,
},
.union =>
{
#trap();
assert(elements.length == 1);
>fields = resolved_value_type.content.union.fields;
>biggest_field_index = resolved_value_type.content.union.biggest_field;
>biggest_field = &fields[biggest_field_index];
>biggest_field_type = biggest_field.type;
>value = elements[0].value;
>field_value_type = value.type;
>field_type_size = get_byte_size(field_value_type);
>union_size = get_byte_size(resolved_value_type);
if (field_type_size < union_size)
{
>u8_type = uint8(module);
resolve_type_in_place(module, u8_type);
LLVMBuildMemSet(module.llvm.builder, left_llvm, LLVMConstNull(u8_type.llvm.memory), LLVMConstInt(u64_type.llvm.memory, union_size, 0), alignment);
}
else if (field_type_size > union_size)
{
unreachable;
}
>struct_type: &LLVMType = zero;
if (type_is_abi_equal(module, field_value_type, biggest_field_type))
{
struct_type = resolved_value_type.llvm.memory;
}
else
{
struct_type = LLVMStructTypeInContext(module.llvm.context, &field_value_type.llvm.memory, 1, 0);
}
assert(struct_type != zero);
>destination_pointer = LLVMBuildStructGEP2(module.llvm.builder, struct_type, left_llvm, 0, "");
>field_pointer_type = get_pointer_type(module, field_value_type);
emit_assignment(module, destination_pointer, field_pointer_type, value);
},
else =>
{
@ -15903,6 +16173,7 @@ names: [_][]u8 =
"integer_formats",
"for_each_int",
"bool_array",
"basic_union",
];
[export] main = fn [cc(c)] (argument_count: u32, argv: &&u8, envp: &&u8) s32

View File

@ -6591,7 +6591,7 @@ fn void emit_assignment(Module* module, LLVMValueRef left_llvm, Type* left_type,
auto fields = resolved_value_type->union_type.fields;
auto biggest_field_index = resolved_value_type->union_type.biggest_field;
auto& biggest_field = fields[biggest_field_index];
auto biggest_field_type = fields[biggest_field_index].type;
auto biggest_field_type = biggest_field.type;
auto value = elements[0].value;
auto field_value_type = value->type;
auto field_type_size = get_byte_size(field_value_type);
@ -6621,7 +6621,7 @@ fn void emit_assignment(Module* module, LLVMValueRef left_llvm, Type* left_type,
auto destination_pointer = LLVMBuildStructGEP2(module->llvm.builder, struct_type, left_llvm, 0, "");
auto field_pointer_type = get_pointer_type(module, field_value_type);
unused(biggest_field);
emit_assignment(module, destination_pointer, field_pointer_type, value);
} break;
default: unreachable();