diff --git a/lib/std/zig/AstGen.zig b/lib/std/zig/AstGen.zig index 9c27d1e036bb..601a5393e756 100644 --- a/lib/std/zig/AstGen.zig +++ b/lib/std/zig/AstGen.zig @@ -2914,6 +2914,8 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As .validate_array_init_result_ty, .validate_ptr_array_init, .validate_ref_ty, + .try_operand_ty, + .try_ref_operand_ty, => break :b true, .@"defer" => unreachable, @@ -5887,9 +5889,18 @@ fn tryExpr( } const try_lc = LineColumn{ astgen.source_line - parent_gz.decl_line, astgen.source_column }; - const operand_ri: ResultInfo = switch (ri.rl) { - .ref, .ref_coerced_ty => .{ .rl = .ref, .ctx = .error_handling_expr }, - else => .{ .rl = .none, .ctx = .error_handling_expr }, + const operand_ri: ResultInfo = .{ + .rl = switch (ri.rl) { + .ref => .ref, + .ref_coerced_ty => |payload_ptr_ty| .{ + .ref_coerced_ty = try parent_gz.addUnNode(.try_ref_operand_ty, payload_ptr_ty, node), + }, + else => if (try ri.rl.resultType(parent_gz, node)) |payload_ty| .{ + // `coerced_ty` is OK due to the `rvalue` call below + .coerced_ty = try parent_gz.addUnNode(.try_operand_ty, payload_ty, node), + } else .none, + }, + .ctx = .error_handling_expr, }; // This could be a pointer or value depending on the `ri` parameter. const operand = try reachableExpr(parent_gz, scope, operand_ri, operand_node, node); diff --git a/lib/std/zig/Zir.zig b/lib/std/zig/Zir.zig index af4ddaad6afc..abc36b0b604e 100644 --- a/lib/std/zig/Zir.zig +++ b/lib/std/zig/Zir.zig @@ -684,6 +684,14 @@ pub const Inst = struct { /// operator. Emit a compile error if not. /// Uses the `un_tok` union field. Token is the `&` operator. Operand is the type. validate_ref_ty, + /// Given a type `T`, construct the type `E!T`, where `E` is this function's error set, to be used + /// as the result type of a `try` operand. Generic poison is propagated. + /// Uses the `un_node` union field. Node is the `try` expression. Operand is the type `T`. + try_operand_ty, + /// Given a type `*T`, construct the type `*E!T`, where `E` is this function's error set, to be used + /// as the result type of a `try` operand whose address is taken with `&`. Generic poison is propagated. + /// Uses the `un_node` union field. Node is the `try` expression. Operand is the type `*T`. + try_ref_operand_ty, // The following tags all relate to struct initialization expressions. @@ -1254,6 +1262,8 @@ pub const Inst = struct { .array_init_elem_type, .array_init_elem_ptr, .validate_ref_ty, + .try_operand_ty, + .try_ref_operand_ty, .restore_err_ret_index_unconditional, .restore_err_ret_index_fn_entry, => false, @@ -1324,6 +1334,8 @@ pub const Inst = struct { .validate_array_init_result_ty, .validate_ptr_array_init, .validate_ref_ty, + .try_operand_ty, + .try_ref_operand_ty, => true, .param, @@ -1698,6 +1710,8 @@ pub const Inst = struct { .opt_eu_base_ptr_init = .un_node, .coerce_ptr_elem_ty = .pl_node, .validate_ref_ty = .un_tok, + .try_operand_ty = .un_node, + .try_ref_operand_ty = .un_node, .int_from_ptr = .un_node, .compile_error = .un_node, @@ -3834,6 +3848,8 @@ fn findDeclsInner( .opt_eu_base_ptr_init, .coerce_ptr_elem_ty, .validate_ref_ty, + .try_operand_ty, + .try_ref_operand_ty, .struct_init_empty, .struct_init_empty_result, .struct_init_empty_ref_result, diff --git a/src/Sema.zig b/src/Sema.zig index a72c749f2ed1..4cd02db0e190 100644 --- a/src/Sema.zig +++ b/src/Sema.zig @@ -1177,6 +1177,8 @@ fn analyzeBodyInner( .validate_array_init_ref_ty => try sema.zirValidateArrayInitRefTy(block, inst), .opt_eu_base_ptr_init => try sema.zirOptEuBasePtrInit(block, inst), .coerce_ptr_elem_ty => try sema.zirCoercePtrElemTy(block, inst), + .try_operand_ty => try sema.zirTryOperandTy(block, inst, false), + .try_ref_operand_ty => try sema.zirTryOperandTy(block, inst, true), .clz => try sema.zirBitCount(block, inst, .clz, Value.clz), .ctz => try sema.zirBitCount(block, inst, .ctz, Value.ctz), @@ -2024,6 +2026,22 @@ fn genericPoisonReason(sema: *Sema, block: *Block, ref: Zir.Inst.Ref) GenericPoi const un_node = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node; cur = un_node.operand; }, + .try_operand_ty => { + // Either the input type was itself poison, or it was a slice, which we cannot translate + // to an overall result type. + const un_node = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node; + const operand_ref = sema.resolveInst(un_node.operand) catch |err| switch (err) { + error.GenericPoison => unreachable, // this is a type, not a value + }; + if (operand_ref == .generic_poison_type) { + // The input was poison -- keep looking. + cur = un_node.operand; + continue; + } + // We got a poison because the result type was a slice. This is a tricky case -- let's just + // not bother explaining it to the user for now... + return .unknown; + }, .struct_init_field_type => { const pl_node = sema.code.instructions.items(.data)[@intFromEnum(inst)].pl_node; const extra = sema.code.extraData(Zir.Inst.FieldType, pl_node.payload_index).data; @@ -4423,6 +4441,59 @@ fn zirCoercePtrElemTy(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileE } } +fn zirTryOperandTy(sema: *Sema, block: *Block, inst: Zir.Inst.Index, is_ref: bool) CompileError!Air.Inst.Ref { + const pt = sema.pt; + const zcu = pt.zcu; + const un_node = sema.code.instructions.items(.data)[@intFromEnum(inst)].un_node; + const src = block.nodeOffset(un_node.src_node); + + const operand_ty = sema.resolveType(block, src, un_node.operand) catch |err| switch (err) { + error.GenericPoison => return .generic_poison_type, + else => |e| return e, + }; + + const payload_ty = if (is_ref) ty: { + if (!operand_ty.isSinglePointer(zcu)) { + return .generic_poison_type; // we can't get a meaningful result type here, since it will be `*E![n]T`, and we don't know `n`. + } + break :ty operand_ty.childType(zcu); + } else operand_ty; + + const err_set_ty = err_set: { + // There are awkward cases, like `?E`. Our strategy is to repeatedly unwrap optionals + // until we hit an error union or set. + var cur_ty = sema.fn_ret_ty; + while (true) { + switch (cur_ty.zigTypeTag(zcu)) { + .error_set => break :err_set cur_ty, + .error_union => break :err_set cur_ty.errorUnionSet(zcu), + .optional => cur_ty = cur_ty.optionalChild(zcu), + else => return sema.failWithOwnedErrorMsg(block, msg: { + const msg = try sema.errMsg(src, "expected '{}', found error set", .{sema.fn_ret_ty.fmt(pt)}); + errdefer msg.destroy(sema.gpa); + const ret_ty_src: LazySrcLoc = .{ + .base_node_inst = sema.getOwnerFuncDeclInst(), + .offset = .{ .node_offset_fn_type_ret_ty = 0 }, + }; + try sema.errNote(ret_ty_src, msg, "function cannot return an error", .{}); + break :msg msg; + }), + } + } + }; + + const eu_ty = try pt.errorUnionType(err_set_ty, payload_ty); + + if (is_ref) { + var ptr_info = operand_ty.ptrInfo(zcu); + ptr_info.child = eu_ty.toIntern(); + const eu_ptr_ty = try pt.ptrTypeSema(ptr_info); + return Air.internedToRef(eu_ptr_ty.toIntern()); + } else { + return Air.internedToRef(eu_ty.toIntern()); + } +} + fn zirValidateRefTy(sema: *Sema, block: *Block, inst: Zir.Inst.Index) CompileError!void { const pt = sema.pt; const zcu = pt.zcu; diff --git a/src/print_zir.zig b/src/print_zir.zig index 8d70af5f3cd1..3e61953cf52f 100644 --- a/src/print_zir.zig +++ b/src/print_zir.zig @@ -277,6 +277,8 @@ const Writer = struct { .opt_eu_base_ptr_init, .restore_err_ret_index_unconditional, .restore_err_ret_index_fn_entry, + .try_operand_ty, + .try_ref_operand_ty, => try self.writeUnNode(stream, inst), .ref, diff --git a/test/behavior/try.zig b/test/behavior/try.zig index 53fdc4893457..f17133fabee3 100644 --- a/test/behavior/try.zig +++ b/test/behavior/try.zig @@ -67,3 +67,22 @@ test "`try`ing an if/else expression" { try std.testing.expectError(error.Test, S.getError2()); } + +test "try forwards result location" { + if (builtin.zig_backend == .stage2_x86) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_aarch64) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_arm) return error.SkipZigTest; // TODO + if (builtin.zig_backend == .stage2_spirv64) return error.SkipZigTest; + if (builtin.zig_backend == .stage2_riscv64) return error.SkipZigTest; + + const S = struct { + fn foo(err: bool) error{Foo}!u32 { + const result: error{ Foo, Bar }!u32 = if (err) error.Foo else 123; + const res_int: u32 = try @errorCast(result); + return res_int; + } + }; + + try expect((S.foo(false) catch return error.TestUnexpectedResult) == 123); + try std.testing.expectError(error.Foo, S.foo(true)); +}