From 72f887c707cbe4df384f2425de309c8e923d474f Mon Sep 17 00:00:00 2001 From: David Gonzalez Martin Date: Thu, 17 Apr 2025 22:01:14 -0600 Subject: [PATCH] Shortcircuiting if --- src/bootstrap.zig | 237 +++++++++++++++++++++++++---------- src/compiler.bbb | 44 ++++--- src/main.zig | 1 + tests/shortcircuiting_if.bbb | 16 +++ 4 files changed, 219 insertions(+), 79 deletions(-) create mode 100644 tests/shortcircuiting_if.bbb diff --git a/src/bootstrap.zig b/src/bootstrap.zig index c25248b..fc9804f 100644 --- a/src/bootstrap.zig +++ b/src/bootstrap.zig @@ -706,6 +706,10 @@ const Binary = struct { @"<", @">=", @"<=", + @"and", + @"or", + @"and?", + @"or?", fn is_boolean(id: Binary.Id) bool { return switch (id) { @@ -719,6 +723,10 @@ const Binary = struct { else => false, }; } + + fn is_shortcircuiting(id: Binary.Id) bool { + return id == .@"and?" or id == .@"or?"; + } }; }; @@ -1895,6 +1903,11 @@ pub const Module = struct { break :blk r; }; + const OperatorKeyword = enum { + @"and", + @"or", + }; + fn tokenize(module: *Module) Token { module.skip_space(); @@ -1912,7 +1925,24 @@ pub const Module = struct { 'a'...'z', 'A'...'Z', '_' => blk: { assert(is_identifier_start_ch(start_character)); const identifier = module.parse_identifier(); - const token: Token = if (lib.string.to_enum(Value.Keyword, identifier)) |value_keyword| .{ .value_keyword = value_keyword } else .{ .identifier = identifier }; + const token: Token = if (lib.string.to_enum(Value.Keyword, identifier)) |value_keyword| + .{ .value_keyword = value_keyword } + else if (lib.string.to_enum(OperatorKeyword, identifier)) |operator_keyword| switch (operator_keyword) { + .@"and" => switch (module.content[module.offset]) { + '?' => b: { + module.offset += 1; + break :b .@"and?"; + }, + else => .@"and", + }, + .@"or" => switch (module.content[module.offset]) { + '?' => b: { + module.offset += 1; + break :b .@"or?"; + }, + else => .@"or", + }, + } else .{ .identifier = identifier }; break :blk token; }, '#' => if (is_identifier_start_ch(module.content[module.offset + 1])) blk: { @@ -2980,6 +3010,10 @@ pub const Module = struct { .@"<=" => .@"<=", .@">" => .@">", .@"<" => .@"<", + .@"and" => .@"and", + .@"and?" => .@"and?", + .@"or" => .@"or", + .@"or?" => .@"or?", else => @trap(), }; @@ -4887,7 +4921,7 @@ pub const Module = struct { } if (function_type.return_abi.semantic_type == module.noreturn_type or global.variable.storage.?.bb.function.attributes.naked) { - @trap(); + _ = module.llvm.builder.create_unreachable(); } else if (function_type.return_abi.semantic_type == module.void_type) { module.llvm.builder.create_ret_void(); } else { @@ -6084,70 +6118,134 @@ pub const Module = struct { else => @trap(), }, .binary => |binary| blk: { - const left = if (binary.left.llvm) |left_llvm| left_llvm else b: { - module.emit_value(function, binary.left, .abi); - break :b binary.left.llvm orelse unreachable; - }; - const right = if (binary.right.llvm) |right_llvm| right_llvm else b: { - module.emit_value(function, binary.right, .abi); - break :b binary.right.llvm orelse unreachable; - }; - const result = switch (value_type.bb) { - .integer => |integer| switch (binary.id) { - .@"+" => module.llvm.builder.create_add(left, right), - .@"-" => module.llvm.builder.create_sub(left, right), - .@"*" => module.llvm.builder.create_mul(left, right), - .@"/" => switch (integer.signed) { - true => module.llvm.builder.create_sdiv(left, right), - false => module.llvm.builder.create_udiv(left, right), + if (binary.id.is_shortcircuiting()) { + const ShortcircuitingOperation = enum { + @"and", + @"or", + }; + const op: ShortcircuitingOperation = switch (binary.id) { + .@"and?" => .@"and", + .@"or?" => .@"or", + else => unreachable, + }; + const left = if (binary.left.llvm) |left_llvm| left_llvm else b: { + module.emit_value(function, binary.left, .abi); + break :b binary.left.llvm orelse unreachable; + }; + const left_condition = switch (binary.left.type.?.bb) { + .integer => |integer| switch (integer.bit_count) { + 1 => left, + else => @trap(), }, - .@"%" => switch (integer.signed) { - true => module.llvm.builder.create_srem(left, right), - false => module.llvm.builder.create_urem(left, right), + else => @trap(), + }; + const llvm_function = function.?.variable.storage.?.llvm.?.to_function(); + const current_bb = module.llvm.builder.get_insert_block().?; + const right_block = module.llvm.context.create_basic_block(switch (op) { + inline else => |o| @tagName(o) ++ ".right", + }, llvm_function); + const end_block = module.llvm.context.create_basic_block(switch (op) { + inline else => |o| @tagName(o) ++ ".end", + }, llvm_function); + _ = module.llvm.builder.create_conditional_branch(left_condition, switch (op) { + .@"and" => right_block, + .@"or" => end_block, + }, switch (op) { + .@"and" => end_block, + .@"or" => right_block, + }); + + module.llvm.builder.position_at_end(right_block); + const right = if (binary.right.llvm) |right_llvm| right_llvm else b: { + module.emit_value(function, binary.right, .abi); + break :b binary.right.llvm orelse unreachable; + }; + const right_condition = switch (binary.left.type.?.bb) { + .integer => |integer| switch (integer.bit_count) { + 1 => right, + else => @trap(), }, - .@"&" => module.llvm.builder.create_and(left, right), - .@"|" => module.llvm.builder.create_or(left, right), - .@"^" => module.llvm.builder.create_xor(left, right), - .@"<<" => module.llvm.builder.create_shl(left, right), - .@">>" => switch (integer.signed) { - true => module.llvm.builder.create_ashr(left, right), - false => module.llvm.builder.create_lshr(left, right), + else => @trap(), + }; + _ = module.llvm.builder.create_branch(end_block); + module.llvm.builder.position_at_end(end_block); + const boolean_type = module.integer_type(1, false).llvm.abi.?; + const phi = module.llvm.builder.create_phi(boolean_type); + phi.add_incoming(&.{ switch (op) { + .@"and" => boolean_type.get_zero().to_value(), + .@"or" => boolean_type.to_integer().get_constant(1, 0).to_value(), + }, right_condition }, &.{ current_bb, right_block }); + break :blk switch (type_kind) { + .abi => phi.to_value(), + .memory => @trap(), + }; + } else { + const left = if (binary.left.llvm) |left_llvm| left_llvm else b: { + module.emit_value(function, binary.left, .abi); + break :b binary.left.llvm orelse unreachable; + }; + const right = if (binary.right.llvm) |right_llvm| right_llvm else b: { + module.emit_value(function, binary.right, .abi); + break :b binary.right.llvm orelse unreachable; + }; + const result = switch (value_type.bb) { + .integer => |integer| switch (binary.id) { + .@"+" => module.llvm.builder.create_add(left, right), + .@"-" => module.llvm.builder.create_sub(left, right), + .@"*" => module.llvm.builder.create_mul(left, right), + .@"/" => switch (integer.signed) { + true => module.llvm.builder.create_sdiv(left, right), + false => module.llvm.builder.create_udiv(left, right), + }, + .@"%" => switch (integer.signed) { + true => module.llvm.builder.create_srem(left, right), + false => module.llvm.builder.create_urem(left, right), + }, + .@"&" => module.llvm.builder.create_and(left, right), + .@"|" => module.llvm.builder.create_or(left, right), + .@"^" => module.llvm.builder.create_xor(left, right), + .@"<<" => module.llvm.builder.create_shl(left, right), + .@">>" => switch (integer.signed) { + true => module.llvm.builder.create_ashr(left, right), + false => module.llvm.builder.create_lshr(left, right), + }, + .@"==" => module.llvm.builder.create_integer_compare(.eq, left, right), + .@"!=" => module.llvm.builder.create_integer_compare(.ne, left, right), + .@">" => switch (integer.signed) { + true => module.llvm.builder.create_integer_compare(.sgt, left, right), + false => module.llvm.builder.create_integer_compare(.ugt, left, right), + }, + .@"<" => switch (integer.signed) { + true => module.llvm.builder.create_integer_compare(.slt, left, right), + false => module.llvm.builder.create_integer_compare(.ult, left, right), + }, + .@">=" => switch (integer.signed) { + true => module.llvm.builder.create_integer_compare(.sge, left, right), + false => module.llvm.builder.create_integer_compare(.uge, left, right), + }, + .@"<=" => switch (integer.signed) { + true => module.llvm.builder.create_integer_compare(.sle, left, right), + false => module.llvm.builder.create_integer_compare(.ule, left, right), + }, + else => module.report_error(), }, - .@"==" => module.llvm.builder.create_integer_compare(.eq, left, right), - .@"!=" => module.llvm.builder.create_integer_compare(.ne, left, right), - .@">" => switch (integer.signed) { - true => module.llvm.builder.create_integer_compare(.sgt, left, right), - false => module.llvm.builder.create_integer_compare(.ugt, left, right), + .pointer => |pointer| switch (binary.id) { + .@"+" => module.llvm.builder.create_gep(.{ + .type = pointer.type.llvm.abi.?, + .aggregate = left, + .indices = &.{right}, + }), + .@"-" => module.llvm.builder.create_gep(.{ + .type = pointer.type.llvm.abi.?, + .aggregate = left, + .indices = &.{module.negate_llvm_value(right, binary.right.is_constant())}, + }), + else => module.report_error(), }, - .@"<" => switch (integer.signed) { - true => module.llvm.builder.create_integer_compare(.slt, left, right), - false => module.llvm.builder.create_integer_compare(.ult, left, right), - }, - .@">=" => switch (integer.signed) { - true => module.llvm.builder.create_integer_compare(.sge, left, right), - false => module.llvm.builder.create_integer_compare(.uge, left, right), - }, - .@"<=" => switch (integer.signed) { - true => module.llvm.builder.create_integer_compare(.sle, left, right), - false => module.llvm.builder.create_integer_compare(.ule, left, right), - }, - }, - .pointer => |pointer| switch (binary.id) { - .@"+" => module.llvm.builder.create_gep(.{ - .type = pointer.type.llvm.abi.?, - .aggregate = left, - .indices = &.{right}, - }), - .@"-" => module.llvm.builder.create_gep(.{ - .type = pointer.type.llvm.abi.?, - .aggregate = left, - .indices = &.{module.negate_llvm_value(right, binary.right.is_constant())}, - }), - else => module.report_error(), - }, - else => @trap(), - }; - break :blk result; + else => @trap(), + }; + break :blk result; + } }, .variable_reference => |variable| switch (value.kind) { .left => switch (variable.storage.?.type == value_type) { @@ -6826,7 +6924,7 @@ pub const Module = struct { const not_taken_block = module.llvm.context.create_basic_block("if.false", llvm_function); const exit_block = module.llvm.context.create_basic_block("if.end", llvm_function); - module.analyze(function, if_statement.condition, .{}, .memory); + module.analyze(function, if_statement.condition, .{}, .abi); const llvm_condition = switch (if_statement.condition.type.?.bb) { .integer => |integer| if (integer.bit_count != 1) module.llvm.builder.create_integer_compare(.ne, if_statement.condition.llvm.?, if_statement.condition.type.?.llvm.abi.?.get_zero().to_value()) else if_statement.condition.llvm.?, .pointer => module.llvm.builder.create_integer_compare(.ne, if_statement.condition.llvm.?, if_statement.condition.type.?.llvm.abi.?.get_zero().to_value()), @@ -7211,6 +7309,19 @@ pub const Module = struct { _ = module.llvm.builder.create_memcpy(left_llvm, pointer_type.bb.pointer.alignment, variable.storage.?.llvm.?, variable.storage.?.type.?.bb.pointer.alignment, uint64.llvm.abi.?.to_integer().get_constant(value_type.get_byte_size(), @intFromBool(false)).to_value()); }, }, + .field_access => |field_access| { + const struct_type = field_access.aggregate.type.?.bb.pointer.type; + const fields = struct_type.bb.structure.fields; + const field_index: u32 = for (fields, 0..) |*field, field_index| { + if (lib.string.equal(field_access.field, field.name)) { + break @intCast(field_index); + } + } else module.report_error(); + module.emit_value(function, field_access.aggregate, .memory); + const gep = module.llvm.builder.create_struct_gep(struct_type.llvm.abi.?.to_struct(), field_access.aggregate.llvm.?, field_index); + const uint64 = module.integer_type(64, false); + _ = module.llvm.builder.create_memcpy(left_llvm, pointer_type.bb.pointer.alignment, gep, value_type.get_byte_alignment(), uint64.llvm.abi.?.to_integer().get_constant(value_type.get_byte_size(), @intFromBool(false)).to_value()); + }, else => @trap(), }, .complex => @trap(), diff --git a/src/compiler.bbb b/src/compiler.bbb index 0bc445d..90f25fa 100644 --- a/src/compiler.bbb +++ b/src/compiler.bbb @@ -1,4 +1,5 @@ [extern] memcmp = fn [cc(c)] (a: &u8, b: &u8, byte_count: u64) s32; +[extern] exit = fn [cc(c)] (exit_code: s32) noreturn; string_no_match = #integer_max(u64); @@ -225,6 +226,11 @@ global_state_initialize = fn () void }; } +fail = fn () noreturn +{ + exit(1); +} + CompilerCommand = enum { compile, @@ -253,6 +259,28 @@ CompileFile = struct compile_file = fn (arena: &Arena, compile: CompileFile) void { + >relative_file_path = compile.relative_file_path; + if (relative_file_path.length < 5) + { + fail(); + } + + >extension_start = string_last_character(relative_file_path, '.'); + if (extension_start == string_no_match) + { + fail(); + } + + if (!string_equal(relative_file_path[extension_start..], ".bbb")) + { + fail(); + } + + >separator_index = string_last_character(relative_file_path, '/'); + if (separator_index == string_no_match) + { + separator_index = 0; + } } [export] main = fn [cc(c)] (argument_count: u32, argv: &&u8) s32 @@ -321,22 +349,6 @@ compile_file = fn (arena: &Arena, compile: CompileFile) void >relative_file_path = c_string_to_slice(relative_file_path_pointer); - if (relative_file_path.length < 5) - { - return 1; - } - - >extension_start = string_last_character(relative_file_path, '.'); - if (extension_start == string_no_match) - { - return 1; - } - - if (!string_equal(relative_file_path[extension_start..], ".bbb")) - { - return 1; - } - compile_file(arena, { .relative_file_path = relative_file_path, .build_mode = build_mode, diff --git a/src/main.zig b/src/main.zig index 1ef8e2d..584190d 100644 --- a/src/main.zig +++ b/src/main.zig @@ -314,4 +314,5 @@ const names = &[_][]const u8{ "empty_if", "else_if", "else_if_complicated", + "shortcircuiting_if", }; diff --git a/tests/shortcircuiting_if.bbb b/tests/shortcircuiting_if.bbb new file mode 100644 index 0000000..8a8b1c7 --- /dev/null +++ b/tests/shortcircuiting_if.bbb @@ -0,0 +1,16 @@ +[export] main = fn [cc(c)] (argument_count: u32) s32 +{ + >a: s32 = 0; + if (argument_count != 0 and? argument_count != 2 and? argument_count != 3 or? argument_count != 1) + { + return 0; + } + else if (argument_count == 5 or? a == 0) + { + return 45; + } + else + { + return 1; + } +}