diff --git a/src/LLVM.zig b/src/LLVM.zig index d346965..54c6448 100644 --- a/src/LLVM.zig +++ b/src/LLVM.zig @@ -982,6 +982,8 @@ pub const Builder = opaque { pub fn create_not(builder: *Builder, value: *Value) *Value { return api.LLVMBuildNot(builder, value, ""); } + + pub const create_switch = api.LLVMBuildSwitch; }; pub const GlobalValue = opaque { @@ -1195,6 +1197,12 @@ pub const Instruction = opaque { return api.LLVMAddIncoming(phi, values.ptr, basic_blocks.ptr, @intCast(values.len)); } }; + + pub const Switch = opaque { + pub fn add_case(switchi: *Switch, case_value: *Value, case_block: *BasicBlock) void { + return api.LLVMAddCase(switchi, case_value, case_block); + } + }; }; pub const DI = struct { diff --git a/src/bootstrap.zig b/src/bootstrap.zig index ff6bab8..a83c9e0 100644 --- a/src/bootstrap.zig +++ b/src/bootstrap.zig @@ -450,6 +450,7 @@ pub const Type = struct { return switch (ty.bb) { .integer => |integer| integer.signed, .bits => |bits| bits.backing_type.is_signed(), + .enumerator => |enumerator| enumerator.backing_type.is_signed(), else => @trap(), }; } @@ -470,6 +471,7 @@ pub const Type = struct { else => true, }, .bits => |bits| bits.backing_type.is_arbitrary_bit_integer(), + .enumerator => |enumerator| enumerator.backing_type.is_arbitrary_bit_integer(), else => false, }; } @@ -574,15 +576,9 @@ pub const Statement = struct { @"return": ?*Value, assignment: Assignment, expression: *Value, - @"if": struct { - condition: *Value, - if_block: *LexicalBlock, - else_block: ?*LexicalBlock, - }, - @"while": struct { - condition: *Value, - block: *LexicalBlock, - }, + @"if": If, + @"while": While, + @"switch": Switch, }, line: u32, column: u32, @@ -606,6 +602,28 @@ pub const Statement = struct { @"^=", }; }; + + const If = struct { + condition: *Value, + if_block: *LexicalBlock, + else_block: ?*LexicalBlock, + }; + + const While = struct { + condition: *Value, + block: *LexicalBlock, + }; + + const Switch = struct { + discriminant: *Value, + clauses: []Clause, + + const Clause = struct { + values: []const *Value, + block: *LexicalBlock, + basic_block: *llvm.BasicBlock = undefined, + }; + }; }; const Unary = struct { @@ -788,6 +806,7 @@ pub const Value = struct { }, .undefined => true, .call => false, + .enum_literal => true, else => @trap(), }; } @@ -1561,6 +1580,7 @@ pub const Module = struct { @"if", // TODO: make `unreachable` a statement start keyword? @"while", + @"switch", }; const rules = blk: { @@ -2245,77 +2265,146 @@ pub const Module = struct { 'A'...'Z', 'a'...'z' => blk: { const statement_start_identifier = module.parse_identifier(); - if (lib.string.to_enum(StatementStartKeyword, statement_start_identifier)) |statement_start_keyword| { - switch (statement_start_keyword) { - ._ => @trap(), - .@"return" => break :blk .{ - .@"return" = module.parse_value(function, .{}), - }, - .@"if" => { + if (lib.string.to_enum(StatementStartKeyword, statement_start_identifier)) |statement_start_keyword| switch (statement_start_keyword) { + ._ => @trap(), + .@"return" => break :blk .{ + .@"return" = module.parse_value(function, .{}), + }, + .@"if" => { + module.skip_space(); + + module.expect_character(left_parenthesis); + module.skip_space(); + + const condition = module.parse_value(function, .{}); + + module.skip_space(); + module.expect_character(right_parenthesis); + + module.skip_space(); + + const if_block = module.parse_block(function); + + module.skip_space(); + + var is_else = false; + if (is_identifier_start_ch(module.content[module.offset])) { + const identifier = module.parse_identifier(); + is_else = lib.string.equal(identifier, "else"); + if (!is_else) { + module.offset -= identifier.len; + } else { + module.skip_space(); + } + } + + const else_block = if (is_else) module.parse_block(function) else null; + + require_semicolon = false; + + break :blk .{ + .@"if" = .{ + .condition = condition, + .if_block = if_block, + .else_block = else_block, + }, + }; + }, + .@"while" => { + module.skip_space(); + + module.expect_character(left_parenthesis); + module.skip_space(); + + const condition = module.parse_value(function, .{}); + + module.skip_space(); + module.expect_character(right_parenthesis); + + module.skip_space(); + + const while_block = module.parse_block(function); + + require_semicolon = false; + + break :blk .{ + .@"while" = .{ + .condition = condition, + .block = while_block, + }, + }; + }, + .@"switch" => { + module.skip_space(); + module.expect_character(left_parenthesis); + module.skip_space(); + + const discriminant = module.parse_value(function, .{}); + + module.skip_space(); + module.expect_character(right_parenthesis); + + module.skip_space(); + module.expect_character(left_brace); + + var clause_buffer: [64]Statement.Switch.Clause = undefined; + var clause_count: u64 = 0; + + while (true) { module.skip_space(); - module.expect_character(left_parenthesis); - module.skip_space(); + var case_buffer: [64]*Value = undefined; + var case_count: u64 = 0; - const condition = module.parse_value(function, .{}); + while (true) { + const case_value = module.parse_value(function, .{}); + case_buffer[case_count] = case_value; + case_count += 1; - module.skip_space(); - module.expect_character(right_parenthesis); + _ = module.consume_character_if_match(','); - module.skip_space(); + module.skip_space(); - const if_block = module.parse_block(function); - - module.skip_space(); - - var is_else = false; - if (is_identifier_start_ch(module.content[module.offset])) { - const identifier = module.parse_identifier(); - is_else = lib.string.equal(identifier, "else"); - if (!is_else) { - module.offset -= identifier.len; - } else { - module.skip_space(); + if (module.consume_character_if_match('=')) { + module.expect_character('>'); + break; } } - const else_block = if (is_else) module.parse_block(function) else null; + module.skip_space(); - require_semicolon = false; + const clause_block = module.parse_block(function); - break :blk .{ - .@"if" = .{ - .condition = condition, - .if_block = if_block, - .else_block = else_block, - }, + const clause_values = module.arena.allocate(*Value, case_count); + @memcpy(clause_values, case_buffer[0..case_count]); + + clause_buffer[clause_count] = .{ + .values = clause_values, + .block = clause_block, }; - }, - .@"while" => { - module.skip_space(); + clause_count += 1; - module.expect_character(left_parenthesis); - module.skip_space(); - - const condition = module.parse_value(function, .{}); + _ = module.consume_character_if_match(','); module.skip_space(); - module.expect_character(right_parenthesis); + + if (module.consume_character_if_match(right_brace)) { + break; + } + } - module.skip_space(); + const clauses = module.arena.allocate(Statement.Switch.Clause, clause_count); + @memcpy(clauses, clause_buffer[0..clause_count]); - const while_block = module.parse_block(function); + require_semicolon = false; - require_semicolon = false; - - break :blk .{ - .@"while" = .{ - .condition = condition, - .block = while_block, - }, - }; - }, - } + break :blk .{ + .@"switch" = .{ + .discriminant = discriminant, + .clauses = clauses, + }, + }; + }, } else { module.offset -= statement_start_identifier.len; @@ -3356,14 +3445,12 @@ pub const Module = struct { const field_index = field_count; const field_name = module.parse_identifier(); module.skip_space(); - - const field_value = if (module.consume_character_if_match('=')) blk: { + const has_explicit_value = module.consume_character_if_match('='); + const field_value = if (has_explicit_value) blk: { module.skip_space(); const field_value = module.parse_integer_value(false); break :blk field_value; - } else { - @trap(); - }; + } else field_index; field_buffer[field_index] = .{ .name = field_name, @@ -6448,6 +6535,69 @@ pub const Module = struct { module.llvm.builder.position_at_end(loop_end_block); }, + .@"switch" => |switch_statement| { + const previous_exit_block = current_function.exit_block; + defer current_function.exit_block = previous_exit_block; + + const exit_block = module.llvm.context.create_basic_block("exit_block", null); + current_function.exit_block = exit_block; + + module.analyze(function, switch_statement.discriminant, .{}); + const switch_discriminant_type = switch_statement.discriminant.type.?; + + switch (switch_discriminant_type.bb) { + .enumerator => |enumerator| { + _ = enumerator; + var else_clause_index: ?usize = null; + var total_discriminant_cases: u32 = 0; + for (switch_statement.clauses, 0..) |*clause, clause_index| { + clause.basic_block = module.llvm.context.create_basic_block("case_block", llvm_function); + total_discriminant_cases += @intCast(clause.values.len); + if (clause.values.len == 0) { + if (else_clause_index != null) { + module.report_error(); + } + else_clause_index = clause_index; + } else { + for (clause.values) |v| { + module.analyze(function, v, .{ .type = switch_discriminant_type }); + if (!v.is_constant()) { + module.report_error(); + } + } + } + } + + const else_block = if (else_clause_index) |i| switch_statement.clauses[i].basic_block else module.llvm.context.create_basic_block("else_case_block", llvm_function); + const switch_instruction = module.llvm.builder.create_switch(switch_statement.discriminant.llvm.?, else_block, total_discriminant_cases); + for (switch_statement.clauses) |clause| { + for (clause.values) |v| { + switch_instruction.add_case(v.llvm.?, clause.basic_block); + } + + current_function.exit_block = exit_block; + module.llvm.builder.position_at_end(clause.basic_block); + module.analyze_block(function, clause.block); + if (module.llvm.builder.get_insert_block() != null) { + _ = module.llvm.builder.create_branch(exit_block); + module.llvm.builder.clear_insertion_position(); + } + } + + current_function.exit_block = exit_block; + + if (else_clause_index) |i| { + _ = i; + @trap(); + } else { + module.llvm.builder.position_at_end(else_block); + _ = module.llvm.builder.create_unreachable(); + module.llvm.builder.clear_insertion_position(); + } + }, + else => @trap(), + } + }, } } } @@ -6688,7 +6838,7 @@ pub const Module = struct { } pub fn align_integer_type(module: *Module, ty: *Type) *Type { - assert(ty.bb == .integer); + assert(ty.bb == .integer or ty.bb == .enumerator); const bit_count = ty.get_bit_size(); const abi_bit_count: u32 = @intCast(@max(8, lib.next_power_of_two(bit_count))); if (bit_count != abi_bit_count) { diff --git a/src/llvm_api.zig b/src/llvm_api.zig index 6ebb571..7b8d0a8 100644 --- a/src/llvm_api.zig +++ b/src/llvm_api.zig @@ -99,6 +99,8 @@ pub extern fn LLVMAddIncoming(phi: *llvm.Instruction.Phi, incoming_value_pointer pub extern fn LLVMBuildSelect(builder: *llvm.Builder, condition: *llvm.Value, true_value: *llvm.Value, false_value: *llvm.Value, name: [*:0]const u8) *llvm.Value; pub extern fn LLVMBuildVAArg(builder: *llvm.Builder, va_list: *llvm.Value, arg_type: *llvm.Type, name: [*:0]const u8) *llvm.Value; +pub extern fn LLVMBuildSwitch(builder: *llvm.Builder, discriminant: *llvm.Value, else_basic_block: *llvm.BasicBlock, case_count: c_uint) *llvm.Instruction.Switch; +pub extern fn LLVMAddCase(switchi: *llvm.Instruction.Switch, case_value: *llvm.Value, case_block: *llvm.BasicBlock) void; // Casts pub extern fn LLVMBuildZExt(builder: *llvm.Builder, value: *llvm.Value, destination_type: *llvm.Type, name: [*:0]const u8) *llvm.Value; diff --git a/tests/basic_switch.bbb b/tests/basic_switch.bbb new file mode 100644 index 0000000..075439e --- /dev/null +++ b/tests/basic_switch.bbb @@ -0,0 +1,26 @@ +E = enum +{ + a, + b, + c, +} + +[export] main = fn [cc(c)] () s32 +{ + >some_enum: E = .a; + switch (some_enum) + { + .a => + { + return 0; + }, + .b => + { + return 1; + }, + .c => + { + return 1; + }, + } +}