Skip to content

Commit

Permalink
compiler: remove destination type from cast builtins
Browse files Browse the repository at this point in the history
Resolves: #5909
  • Loading branch information
mlugg authored and andrewrk committed Jun 24, 2023
1 parent 13853be commit be0c699
Show file tree
Hide file tree
Showing 7 changed files with 680 additions and 356 deletions.
215 changes: 162 additions & 53 deletions src/AstGen.zig
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,32 @@ const ResultInfo = struct {
},
}
}

/// Find the result type for a cast builtin given the result location.
/// If the location does not have a known result type, emits an error on
/// the given node.
fn resultType(rl: Loc, gz: *GenZir, node: Ast.Node.Index, builtin_name: []const u8) !Zir.Inst.Ref {
const astgen = gz.astgen;
switch (rl) {
.discard, .none, .ref, .inferred_ptr => {},
.ty, .coerced_ty => |ty_ref| return ty_ref,
.ptr => |ptr| {
const ptr_ty = try gz.addUnNode(.typeof, ptr.inst, node);
return gz.addUnNode(.elem_type, ptr_ty, node);
},
.block_ptr => |block_scope| {
if (block_scope.rl_ty_inst != .none) return block_scope.rl_ty_inst;
if (block_scope.break_result_info.rl == .ptr) {
const ptr_ty = try gz.addUnNode(.typeof, block_scope.break_result_info.rl.ptr.inst, node);
return gz.addUnNode(.elem_type, ptr_ty, node);
}
},
}

return astgen.failNodeNotes(node, "{s} must have a known result type", .{builtin_name}, &.{
try astgen.errNoteNode(node, "use @as to provide explicit result type", .{}),
});
}
};

const Context = enum {
Expand Down Expand Up @@ -2521,6 +2547,7 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As
.array_type,
.array_type_sentinel,
.elem_type_index,
.elem_type,
.vector_type,
.indexable_ptr_len,
.anyframe_type,
Expand Down Expand Up @@ -2662,7 +2689,6 @@ fn addEnsureResult(gz: *GenZir, maybe_unused_result: Zir.Inst.Ref, statement: As
.int_cast,
.ptr_cast,
.truncate,
.align_cast,
.has_decl,
.has_field,
.clz,
Expand Down Expand Up @@ -7924,18 +7950,127 @@ fn bitCast(
scope: *Scope,
ri: ResultInfo,
node: Ast.Node.Index,
lhs: Ast.Node.Index,
rhs: Ast.Node.Index,
operand_node: Ast.Node.Index,
) InnerError!Zir.Inst.Ref {
const dest_type = try reachableTypeExpr(gz, scope, lhs, node);
const operand = try reachableExpr(gz, scope, .{ .rl = .none }, rhs, node);
const dest_type = try ri.rl.resultType(gz, node, "@bitCast");
const operand = try reachableExpr(gz, scope, .{ .rl = .none }, operand_node, node);
const result = try gz.addPlNode(.bitcast, node, Zir.Inst.Bin{
.lhs = dest_type,
.rhs = operand,
});
return rvalue(gz, ri, result, node);
}

/// Handle one or more nested pointer cast builtins:
/// * @ptrCast
/// * @alignCast
/// * @addrSpaceCast
/// * @constCast
/// * @volatileCast
/// Any sequence of such builtins is treated as a single operation. This allowed
/// for sequences like `@ptrCast(@alignCast(ptr))` to work correctly despite the
/// intermediate result type being unknown.
fn ptrCast(
gz: *GenZir,
scope: *Scope,
ri: ResultInfo,
root_node: Ast.Node.Index,
) InnerError!Zir.Inst.Ref {
const astgen = gz.astgen;
const tree = astgen.tree;
const main_tokens = tree.nodes.items(.main_token);
const node_datas = tree.nodes.items(.data);
const node_tags = tree.nodes.items(.tag);

var flags: Zir.Inst.FullPtrCastFlags = .{};

// Note that all pointer cast builtins have one parameter, so we only need
// to handle `builtin_call_two`.
var node = root_node;
while (true) {
switch (node_tags[node]) {
.builtin_call_two, .builtin_call_two_comma => {},
.grouped_expression => {
// Handle the chaining even with redundant parentheses
node = node_datas[node].lhs;
continue;
},
else => break,
}

if (node_datas[node].lhs == 0) break; // 0 args
if (node_datas[node].rhs != 0) break; // 2 args

const builtin_token = main_tokens[node];
const builtin_name = tree.tokenSlice(builtin_token);
const info = BuiltinFn.list.get(builtin_name) orelse break;
if (info.param_count != 1) break;

switch (info.tag) {
else => break,
inline .ptr_cast,
.align_cast,
.addrspace_cast,
.const_cast,
.volatile_cast,
=> |tag| {
if (@field(flags, @tagName(tag))) {
return astgen.failNode(node, "redundant {s}", .{builtin_name});
}
@field(flags, @tagName(tag)) = true;
},
}

node = node_datas[node].lhs;
}

const flags_i = @bitCast(u5, flags);
assert(flags_i != 0);

const ptr_only: Zir.Inst.FullPtrCastFlags = .{ .ptr_cast = true };
if (flags_i == @bitCast(u5, ptr_only)) {
// Special case: simpler representation
return typeCast(gz, scope, ri, root_node, node, .ptr_cast, "@ptrCast");
}

const no_result_ty_flags: Zir.Inst.FullPtrCastFlags = .{
.const_cast = true,
.volatile_cast = true,
};
if ((flags_i & ~@bitCast(u5, no_result_ty_flags)) == 0) {
// Result type not needed
const cursor = maybeAdvanceSourceCursorToMainToken(gz, root_node);
const operand = try expr(gz, scope, .{ .rl = .none }, node);
try emitDbgStmt(gz, cursor);
const result = try gz.addExtendedPayloadSmall(.ptr_cast_no_dest, flags_i, Zir.Inst.UnNode{
.node = gz.nodeIndexToRelative(root_node),
.operand = operand,
});
return rvalue(gz, ri, result, root_node);
}

// Full cast including result type
const need_result_type_builtin = if (flags.ptr_cast)
"@ptrCast"
else if (flags.align_cast)
"@alignCast"
else if (flags.addrspace_cast)
"@addrSpaceCast"
else
unreachable;

const cursor = maybeAdvanceSourceCursorToMainToken(gz, root_node);
const result_type = try ri.rl.resultType(gz, root_node, need_result_type_builtin);
const operand = try expr(gz, scope, .{ .rl = .none }, node);
try emitDbgStmt(gz, cursor);
const result = try gz.addExtendedPayloadSmall(.ptr_cast_full, flags_i, Zir.Inst.BinNode{
.node = gz.nodeIndexToRelative(root_node),
.lhs = result_type,
.rhs = operand,
});
return rvalue(gz, ri, result, root_node);
}

fn typeOf(
gz: *GenZir,
scope: *Scope,
Expand Down Expand Up @@ -8123,7 +8258,7 @@ fn builtinCall(

// zig fmt: off
.as => return as( gz, scope, ri, node, params[0], params[1]),
.bit_cast => return bitCast( gz, scope, ri, node, params[0], params[1]),
.bit_cast => return bitCast( gz, scope, ri, node, params[0]),
.TypeOf => return typeOf( gz, scope, ri, node, params),
.union_init => return unionInit(gz, scope, ri, node, params),
.c_import => return cImport( gz, scope, node, params[0]),
Expand Down Expand Up @@ -8308,14 +8443,13 @@ fn builtinCall(
.Frame => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, params[0], .frame_type),
.frame_size => return simpleUnOp(gz, scope, ri, node, .{ .rl = .none }, params[0], .frame_size),

.int_from_float => return typeCast(gz, scope, ri, node, params[0], params[1], .int_from_float),
.float_from_int => return typeCast(gz, scope, ri, node, params[0], params[1], .float_from_int),
.ptr_from_int => return typeCast(gz, scope, ri, node, params[0], params[1], .ptr_from_int),
.enum_from_int => return typeCast(gz, scope, ri, node, params[0], params[1], .enum_from_int),
.float_cast => return typeCast(gz, scope, ri, node, params[0], params[1], .float_cast),
.int_cast => return typeCast(gz, scope, ri, node, params[0], params[1], .int_cast),
.ptr_cast => return typeCast(gz, scope, ri, node, params[0], params[1], .ptr_cast),
.truncate => return typeCast(gz, scope, ri, node, params[0], params[1], .truncate),
.int_from_float => return typeCast(gz, scope, ri, node, params[0], .int_from_float, builtin_name),
.float_from_int => return typeCast(gz, scope, ri, node, params[0], .float_from_int, builtin_name),
.ptr_from_int => return typeCast(gz, scope, ri, node, params[0], .ptr_from_int, builtin_name),
.enum_from_int => return typeCast(gz, scope, ri, node, params[0], .enum_from_int, builtin_name),
.float_cast => return typeCast(gz, scope, ri, node, params[0], .float_cast, builtin_name),
.int_cast => return typeCast(gz, scope, ri, node, params[0], .int_cast, builtin_name),
.truncate => return typeCast(gz, scope, ri, node, params[0], .truncate, builtin_name),
// zig fmt: on

.Type => {
Expand Down Expand Up @@ -8368,49 +8502,22 @@ fn builtinCall(
});
return rvalue(gz, ri, result, node);
},
.align_cast => {
const dest_align = try comptimeExpr(gz, scope, align_ri, params[0]);
const rhs = try expr(gz, scope, .{ .rl = .none }, params[1]);
const result = try gz.addPlNode(.align_cast, node, Zir.Inst.Bin{
.lhs = dest_align,
.rhs = rhs,
});
return rvalue(gz, ri, result, node);
},
.err_set_cast => {
try emitDbgNode(gz, node);

const result = try gz.addExtendedPayload(.err_set_cast, Zir.Inst.BinNode{
.lhs = try typeExpr(gz, scope, params[0]),
.rhs = try expr(gz, scope, .{ .rl = .none }, params[1]),
.lhs = try ri.rl.resultType(gz, node, "@errSetCast"),
.rhs = try expr(gz, scope, .{ .rl = .none }, params[0]),
.node = gz.nodeIndexToRelative(node),
});
return rvalue(gz, ri, result, node);
},
.addrspace_cast => {
const result = try gz.addExtendedPayload(.addrspace_cast, Zir.Inst.BinNode{
.lhs = try comptimeExpr(gz, scope, .{ .rl = .{ .ty = .address_space_type } }, params[0]),
.rhs = try expr(gz, scope, .{ .rl = .none }, params[1]),
.node = gz.nodeIndexToRelative(node),
});
return rvalue(gz, ri, result, node);
},
.const_cast => {
const operand = try expr(gz, scope, .{ .rl = .none }, params[0]);
const result = try gz.addExtendedPayload(.const_cast, Zir.Inst.UnNode{
.node = gz.nodeIndexToRelative(node),
.operand = operand,
});
return rvalue(gz, ri, result, node);
},
.volatile_cast => {
const operand = try expr(gz, scope, .{ .rl = .none }, params[0]);
const result = try gz.addExtendedPayload(.volatile_cast, Zir.Inst.UnNode{
.node = gz.nodeIndexToRelative(node),
.operand = operand,
});
return rvalue(gz, ri, result, node);
},
.ptr_cast,
.align_cast,
.addrspace_cast,
.const_cast,
.volatile_cast,
=> return ptrCast(gz, scope, ri, node),

// zig fmt: off
.has_decl => return hasDeclOrField(gz, scope, ri, node, params[0], params[1], .has_decl),
Expand Down Expand Up @@ -8725,13 +8832,13 @@ fn typeCast(
scope: *Scope,
ri: ResultInfo,
node: Ast.Node.Index,
lhs_node: Ast.Node.Index,
rhs_node: Ast.Node.Index,
operand_node: Ast.Node.Index,
tag: Zir.Inst.Tag,
builtin_name: []const u8,
) InnerError!Zir.Inst.Ref {
const cursor = maybeAdvanceSourceCursorToMainToken(gz, node);
const result_type = try typeExpr(gz, scope, lhs_node);
const operand = try expr(gz, scope, .{ .rl = .none }, rhs_node);
const result_type = try ri.rl.resultType(gz, node, builtin_name);
const operand = try expr(gz, scope, .{ .rl = .none }, operand_node);

try emitDbgStmt(gz, cursor);
const result = try gz.addPlNode(tag, node, Zir.Inst.Bin{
Expand Down Expand Up @@ -9432,6 +9539,7 @@ fn nodeMayNeedMemoryLocation(tree: *const Ast, start_node: Ast.Node.Index, have_
switch (builtin_info.needs_mem_loc) {
.never => return false,
.always => return true,
.forward0 => node = node_datas[node].lhs,
.forward1 => node = node_datas[node].rhs,
}
// Missing builtin arg is not a parsing error, expect an error later.
Expand All @@ -9448,6 +9556,7 @@ fn nodeMayNeedMemoryLocation(tree: *const Ast, start_node: Ast.Node.Index, have_
switch (builtin_info.needs_mem_loc) {
.never => return false,
.always => return true,
.forward0 => node = params[0],
.forward1 => node = params[1],
}
// Missing builtin arg is not a parsing error, expect an error later.
Expand Down
3 changes: 0 additions & 3 deletions src/Autodoc.zig
Original file line number Diff line number Diff line change
Expand Up @@ -1529,7 +1529,6 @@ fn walkInstruction(
.int_cast,
.ptr_cast,
.truncate,
.align_cast,
.has_decl,
.has_field,
.div_exact,
Expand Down Expand Up @@ -3024,8 +3023,6 @@ fn walkInstruction(
.int_from_error,
.error_from_int,
.reify,
.const_cast,
.volatile_cast,
=> {
const extra = file.zir.extraData(Zir.Inst.UnNode, extended.operand).data;
const bin_index = self.exprs.items.len;
Expand Down
Loading

0 comments on commit be0c699

Please sign in to comment.