From 81fd0aebfd1865578ada88ed72934f2268bac957 Mon Sep 17 00:00:00 2001 From: David Gonzalez Martin Date: Fri, 13 Jun 2025 22:06:04 -0600 Subject: [PATCH] Switch statement --- src/compiler.bbb | 336 ++++++++++++++++++++++++++++++++++++++++++++++- src/emitter.cpp | 30 +++++ src/parser.cpp | 10 +- 3 files changed, 370 insertions(+), 6 deletions(-) diff --git a/src/compiler.bbb b/src/compiler.bbb index 28f7798..447cc4c 100644 --- a/src/compiler.bbb +++ b/src/compiler.bbb @@ -2498,6 +2498,9 @@ llvm_create_global_variable = fn (module: &LLVMModule, type: &LLVMType, is_const [extern] LLVMBuildBr = fn [cc(c)] (builder: &LLVMBuilder, target_block: &LLVMBasicBlock) &LLVMValue; [extern] LLVMBuildCondBr = fn [cc(c)] (builder: &LLVMBuilder, condition: &LLVMValue, taken_block: &LLVMBasicBlock, not_taken_block: &LLVMBasicBlock) &LLVMValue; +[extern] LLVMBuildSwitch = fn [cc(c)] (builder: &LLVMBuilder, value: &LLVMValue, else: &LLVMBasicBlock, case_count: u32) &LLVMValue; +[extern] LLVMAddCase = fn [cc(c)] (switch_instruction: &LLVMValue, value: &LLVMValue, block: &LLVMBasicBlock) void; + [extern] LLVMBuildPhi = fn [cc(c)] (builder: &LLVMBuilder, type: &LLVMType, name: &u8) &LLVMValue; [extern] LLVMAddIncoming = fn [cc(c)] (phi: &LLVMValue, values: &&LLVMValue, blocks: &&LLVMBasicBlock, count: u32) void; @@ -2738,10 +2741,29 @@ StatementWhile = struct block: &Block, } +StatementSwitchDiscriminantId = enum +{ + single, + range, +} + +StatementSwitchDiscriminantContent = union +{ + single: &Value, + range: [2]&Value, +} + +StatementSwitchDiscriminant = struct +{ + content: StatementSwitchDiscriminantContent, + id: StatementSwitchDiscriminantId, +} + StatementSwitchClause = struct { - values: []&Value, + values: []StatementSwitchDiscriminant, block: &Block, + basic_block: &LLVMBasicBlock, } StatementSwitch = struct @@ -6014,7 +6036,149 @@ parse_statement = fn (module: &Module, scope: &Scope) &Statement }, .switch => { - #trap(); + skip_space(module); + expect_character(module, left_parenthesis); + skip_space(module); + + >discriminant = parse_value(module, scope, zero); + + skip_space(module); + expect_character(module, right_parenthesis); + + skip_space(module); + expect_character(module, left_brace); + + >clause_buffer: [64]StatementSwitchClause = undefined; + >clause_count: u64 = 0; + + while (1) + { + skip_space(module); + + >is_else: u1 = 0; + if (is_identifier_start(module.content[module.offset])) + { + >else_checkpoint = get_checkpoint(module); + >i = parse_identifier(module); + is_else = string_equal(i, "else"); + + if (!is_else) + { + set_checkpoint(module, else_checkpoint); + } + } + + >clause_values: []StatementSwitchDiscriminant = zero; + + if (is_else) + { + skip_space(module); + expect_character(module, '='); + expect_character(module, '>'); + } + else + { + >case_buffer: [64]StatementSwitchDiscriminant = undefined; + >case_count: u64 = 0; + + while (1) + { + >first_case_value = parse_value(module, scope, zero); + + skip_space(module); + + >checkpoint = get_checkpoint(module); + >token = tokenize(module); + + >clause_discriminant: StatementSwitchDiscriminant = undefined; + + switch (token.id) + { + .triple_dot => + { + >last_case_value = parse_value(module, scope, zero); + clause_discriminant = { + .content = { + .range = [ first_case_value, last_case_value ], + }, + .id = .range, + }; + }, + else => + { + if (token.id != .comma) + { + set_checkpoint(module, checkpoint); + } + + clause_discriminant = { + .content = { + .single = first_case_value, + }, + .id = .single, + }; + }, + } + + switch (clause_discriminant.id) + { + .single => { assert(clause_discriminant.content.single != zero); }, + .range => + { + assert(clause_discriminant.content.range[0] != zero); + assert(clause_discriminant.content.range[1] != zero); + }, + } + + case_buffer[case_count] = clause_discriminant; + case_count += 1; + + skip_space(module); + + if (consume_character_if_match(module, '=')) + { + expect_character(module, '>'); + break; + } + } + + clause_values = arena_allocate_slice[StatementSwitchDiscriminant](module.arena, case_count); + memcpy(#pointer_cast(clause_values.pointer), #pointer_cast(&case_buffer), case_count * #byte_size(StatementSwitchDiscriminant)); + } + + skip_space(module); + + >clause_block = parse_block(module, scope); + + clause_buffer[clause_count] = { + .values = clause_values, + .block = clause_block, + zero, + }; + clause_count += 1; + + consume_character_if_match(module, ','); + + skip_space(module); + + if (consume_character_if_match(module, right_brace)) + { + break; + } + } + + >clauses = arena_allocate_slice[StatementSwitchClause](module.arena, clause_count); + memcpy(#pointer_cast(clauses.pointer), #pointer_cast(&clause_buffer), clause_count * #byte_size(StatementSwitchClause)); + + require_semicolon = 0; + + statement.content = { + .switch = { + .discriminant = discriminant, + .clauses = clauses, + }, + }; + statement.id = .switch; }, .break => { @@ -13013,6 +13177,173 @@ analyze_statement = fn (module: &Module, scope: &Scope, statement: &Statement) v }, } }, + .switch => + { + >discriminant = statement.content.switch.discriminant; + >clauses = statement.content.switch.clauses; + + >exit_block = LLVMAppendBasicBlockInContext(module.llvm.context, llvm_function, "switch.exit"); + + analyze_value(module, discriminant, zero, .abi, 0); + + >discriminant_type = discriminant.type; + + >invalid_clause_index: u64 = ~0; + >else_clause_index = invalid_clause_index; + >discriminant_case_count: u64 = 0; + + // TODO: more analysis + switch (discriminant_type.id) + { + .enum, .integer => {}, + else => { report_error(); }, + } + + for (i: 0..clauses.length) + { + >clause = &clauses[i]; + clause.basic_block = LLVMAppendBasicBlockInContext(module.llvm.context, llvm_function, #select(clause.values.length == 0, "switch.else_case_block", "switch.case_block")); + discriminant_case_count += clause.values.length; + + if (clause.values.length == 0) + { + if (else_clause_index != invalid_clause_index) + { + // Double else + report_error(); + } + + else_clause_index = i; + } + else + { + for (&value: clause.values) + { + switch (value.id) + { + .single => + { + >v = value.content.single; + assert(v != zero); + analyze_value(module, v, discriminant_type, .abi, 1); + }, + .range => + { + >start = value.content.range[0]; + >end = value.content.range[0]; + + for (v: value.content.range) + { + analyze_value(module, v, discriminant_type, .abi, 1); + } + + if (start.id != end.id) + { + report_error(); + } + + switch (start.id) + { + .constant_integer => + { + if (start.content.constant_integer.value >= end.content.constant_integer.value) + { + report_error(); + } + }, + else => { report_error(); }, + } + }, + } + } + } + } + + >else_block: &LLVMBasicBlock = undefined; + if (else_clause_index != invalid_clause_index) + { + else_block = clauses[else_clause_index].basic_block; + } + else + { + else_block = LLVMAppendBasicBlockInContext(module.llvm.context, llvm_function, "switch.else_case_block"); + } + + >switch_instruction = LLVMBuildSwitch(module.llvm.builder, discriminant.llvm, else_block, #truncate(discriminant_case_count)); + >all_blocks_terminated: u1 = 1; + + for (&clause: clauses) + { + for (&value: clause.values) + { + switch (value.id) + { + .single => + { + LLVMAddCase(switch_instruction, value.content.single.llvm, clause.basic_block); + }, + .range => + { + >start = value.content.range[0]; + >end = value.content.range[1]; + + LLVMAddCase(switch_instruction, start.llvm, clause.basic_block); + + assert(start.id == end.id); + + switch (start.id) + { + .constant_integer => + { + >start_value = start.content.constant_integer.value; + >end_value = end.content.constant_integer.value; + + for (i: start_value + 1 .. end_value) + { + LLVMAddCase(switch_instruction, LLVMConstInt(start.type.llvm.abi, i, 0), clause.basic_block); + } + }, + else => { unreachable; }, + } + + LLVMAddCase(switch_instruction, end.llvm, clause.basic_block); + }, + } + } + + LLVMPositionBuilderAtEnd(module.llvm.builder, clause.basic_block); + + analyze_block(module, clause.block); + + if (LLVMGetInsertBlock(module.llvm.builder)) + { + all_blocks_terminated = 0; + LLVMBuildBr(module.llvm.builder, exit_block); + LLVMClearInsertionPosition(module.llvm.builder); + } + } + + if (else_clause_index == invalid_clause_index) + { + LLVMPositionBuilderAtEnd(module.llvm.builder, else_block); + + if (module.has_debug_info and !build_mode_is_optimized(module.build_mode)) + { + emit_intrinsic_call(module, ."llvm.trap", zero, zero); + } + + LLVMBuildUnreachable(module.llvm.builder); + LLVMClearInsertionPosition(module.llvm.builder); + } + + LLVMPositionBuilderAtEnd(module.llvm.builder, exit_block); + + if (all_blocks_terminated) + { + LLVMBuildUnreachable(module.llvm.builder); + LLVMClearInsertionPosition(module.llvm.builder); + } + }, else => { #trap(); @@ -14457,6 +14788,7 @@ names: [_][]u8 = "string_to_enum", "empty_if", "else_if", + "else_if_complicated", ]; [export] main = fn [cc(c)] (argument_count: u32, argv: &&u8, envp: &&u8) s32 diff --git a/src/emitter.cpp b/src/emitter.cpp index 32fa326..d00f223 100644 --- a/src/emitter.cpp +++ b/src/emitter.cpp @@ -5209,6 +5209,22 @@ fn void invalidate_analysis(Module* module, Value* value) { invalidate_analysis(module, value->unary.value); } break; + case ValueId::slice_expression: + { + invalidate_analysis(module, value->slice_expression.array_like); + auto start = value->slice_expression.start; + auto end = value->slice_expression.end; + + if (start) + { + invalidate_analysis(module, start); + } + + if (end) + { + invalidate_analysis(module, end); + } + } break; default: trap(); } @@ -8606,6 +8622,10 @@ fn void analyze_statement(Module* module, Scope* scope, Statement* statement, u3 if (else_clause_index == invalid_clause_index) { LLVMPositionBuilderAtEnd(module->llvm.builder, else_block); + if (module->has_debug_info && !build_mode_is_optimized(module->build_mode)) + { + emit_intrinsic_call(module, IntrinsicIndex::trap, {}, {}); + } LLVMBuildUnreachable(module->llvm.builder); LLVMClearInsertionPosition(module->llvm.builder); } @@ -8669,6 +8689,16 @@ fn void analyze_statement(Module* module, Scope* scope, Statement* statement, u3 Type* aggregate_type = 0; + if (right->kind == ValueKind::left && right->type->id != TypeId::pointer) + { + if (!type_is_slice(right->type)) + { + report_error(); + } + + right->kind = ValueKind::right; + } + switch (right->kind) { case ValueKind::right: diff --git a/src/parser.cpp b/src/parser.cpp index a8655b8..020e74e 100644 --- a/src/parser.cpp +++ b/src/parser.cpp @@ -2823,7 +2823,7 @@ fn Statement* parse_statement(Module* module, Scope* scope) Value* right_value_buffer[64]; u64 right_value_count = 0; - right_value_buffer[right_value_count] = parse_value(module, scope, { .kind = ValueKind::left }); + right_value_buffer[right_value_count] = parse_value(module, scope, {}); right_value_count += 1; skip_space(module); @@ -2840,15 +2840,17 @@ fn Statement* parse_statement(Module* module, Scope* scope) report_error(); } - right_value_buffer[0]->kind = ValueKind::right; - right_value_buffer[right_value_count] = parse_value(module, scope, {}); right_value_count += 1; expect_character(module, right_parenthesis); kind = ForEachKind::range; } break; - case TokenId::right_parenthesis: kind = ForEachKind::slice; break; + case TokenId::right_parenthesis: + { + right_value_buffer[0]->kind = ValueKind::left; + kind = ForEachKind::slice; + } break; default: report_error(); }