Skip to content

Commit d657b6c

Browse files
committed
sema: support reinterpreting extern/packed unions at comptime via field access
My previous change for reading / writing to unions at comptime did not handle union field read/writes correctly in all cases. Previously, if a field was written to a union, it would overwrite the entire value. This is problematic when a field of a larger size is subsequently read, because the value would not be long enough, causing a panic. Additionally, the writing behaviour itself was incorrect. Writing to a field of a packed or extern union should only overwrite the bits corresponding to that field, allowing for memory reintepretation via field writes / reads. I addressed these problems as follows: Add the concept of a "backing type" for extern / packed unions (`Type.unionBackingType`). For extern unions, this is a `u8` array, for packed unions it's an integer matching the `bitSize` of the union. Whenever union memory is read at comptime, it's read as this type. When union memory is written at comptime, the tag may still be known. If so, the memory is written using the tagged type. If the tag is unknown (because this union had previously been read from memory), it's simply written back out as the backing type. I added `write_packed` to the `reinterpret` field of `ComptimePtrMutationKit`. This causes writes of the operand to be packed - which is necessary when writing to a field of a packed union. Without this, writing a value to a `u1` field would overwrite the entire byte it occupied. The final case to address was reading a different (potentially larger) field from a union when it was written with a known tag. To handle this, a new kind of bitcast was introduced (`bitCastUnionFieldVal`) which supports reading a larger field by using a backing buffer that has the unwritten bits set to undefined. The reason to support this (vs always just writing the union as it's backing type), is that no reads to larger fields ever occur at comptime, it would be strictly worse to have spent time writing the full backing type.
1 parent 53775b0 commit d657b6c

File tree

6 files changed

+271
-73
lines changed

6 files changed

+271
-73
lines changed

src/Module.zig

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6607,6 +6607,7 @@ pub fn unionFieldNormalAlignment(mod: *Module, u: InternPool.UnionType, field_in
66076607
return field_ty.abiAlignment(mod);
66086608
}
66096609

6610+
/// Returns the index of the active field, given the current tag value
66106611
pub fn unionTagFieldIndex(mod: *Module, u: InternPool.UnionType, enum_tag: Value) ?u32 {
66116612
const ip = &mod.intern_pool;
66126613
if (enum_tag.toIntern() == .none) return null;

src/Sema.zig

Lines changed: 85 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -27260,7 +27260,7 @@ fn unionFieldVal(
2726027260
else
2726127261
union_ty.unionFieldType(un.tag.toValue(), mod).?;
2726227262

27263-
if (try sema.bitCastVal(block, src, un.val.toValue(), old_ty, field_ty, 0)) |new_val| {
27263+
if (try sema.bitCastUnionFieldVal(block, src, un.val.toValue(), old_ty, field_ty)) |new_val| {
2726427264
return Air.internedToRef(new_val.toIntern());
2726527265
}
2726627266
}
@@ -29781,13 +29781,19 @@ fn storePtrVal(
2978129781
error.IllDefinedMemoryLayout => unreachable, // Sema was supposed to emit a compile error already
2978229782
error.Unimplemented => return sema.fail(block, src, "TODO: implement writeToMemory for type '{}'", .{mut_kit.ty.fmt(mod)}),
2978329783
};
29784-
operand_val.writeToMemory(operand_ty, mod, buffer[reinterpret.byte_offset..]) catch |err| switch (err) {
29785-
error.OutOfMemory => return error.OutOfMemory,
29786-
error.ReinterpretDeclRef => unreachable,
29787-
error.IllDefinedMemoryLayout => unreachable, // Sema was supposed to emit a compile error already
29788-
error.Unimplemented => return sema.fail(block, src, "TODO: implement writeToMemory for type '{}'", .{operand_ty.fmt(mod)}),
29789-
};
29790-
29784+
if (reinterpret.write_packed) {
29785+
operand_val.writeToPackedMemory(operand_ty, mod, buffer[reinterpret.byte_offset..], 0) catch |err| switch (err) {
29786+
error.OutOfMemory => return error.OutOfMemory,
29787+
error.ReinterpretDeclRef => unreachable,
29788+
};
29789+
} else {
29790+
operand_val.writeToMemory(operand_ty, mod, buffer[reinterpret.byte_offset..]) catch |err| switch (err) {
29791+
error.OutOfMemory => return error.OutOfMemory,
29792+
error.ReinterpretDeclRef => unreachable,
29793+
error.IllDefinedMemoryLayout => unreachable, // Sema was supposed to emit a compile error already
29794+
error.Unimplemented => return sema.fail(block, src, "TODO: implement writeToMemory for type '{}'", .{operand_ty.fmt(mod)}),
29795+
};
29796+
}
2979129797
const val = Value.readFromMemory(mut_kit.ty, mod, buffer, sema.arena) catch |err| switch (err) {
2979229798
error.OutOfMemory => return error.OutOfMemory,
2979329799
error.IllDefinedMemoryLayout => unreachable,
@@ -29819,6 +29825,8 @@ const ComptimePtrMutationKit = struct {
2981929825
reinterpret: struct {
2982029826
val_ptr: *Value,
2982129827
byte_offset: usize,
29828+
/// If set, write the operand to packed memory
29829+
write_packed: bool = false,
2982229830
},
2982329831
/// If the root decl could not be used as parent, this means `ty` is the type that
2982429832
/// caused that by not having a well-defined layout.
@@ -30182,21 +30190,43 @@ fn beginComptimePtrMutation(
3018230190
);
3018330191
},
3018430192
.@"union" => {
30185-
// We need to set the active field of the union.
30186-
const union_tag_ty = base_child_ty.unionTagTypeHypothetical(mod);
30187-
3018830193
const payload = &val_ptr.castTag(.@"union").?.data;
30189-
payload.tag = try mod.enumValueFieldIndex(union_tag_ty, field_index);
30194+
const layout = base_child_ty.containerLayout(mod);
3019030195

30191-
return beginComptimePtrMutationInner(
30192-
sema,
30193-
block,
30194-
src,
30195-
parent.ty.structFieldType(field_index, mod),
30196-
&payload.val,
30197-
ptr_elem_ty,
30198-
parent.mut_decl,
30199-
);
30196+
const tag_type = base_child_ty.unionTagTypeHypothetical(mod);
30197+
const hypothetical_tag = try mod.enumValueFieldIndex(tag_type, field_index);
30198+
if (layout == .Auto or (payload.tag != null and hypothetical_tag.eql(payload.tag.?, tag_type, mod))) {
30199+
// We need to set the active field of the union.
30200+
payload.tag = hypothetical_tag;
30201+
30202+
const field_ty = parent.ty.structFieldType(field_index, mod);
30203+
return beginComptimePtrMutationInner(
30204+
sema,
30205+
block,
30206+
src,
30207+
field_ty,
30208+
&payload.val,
30209+
ptr_elem_ty,
30210+
parent.mut_decl,
30211+
);
30212+
} else {
30213+
// Writing to a different field (a different or unknown tag is active) requires reinterpreting
30214+
// memory of the entire union, which requires knowing its abiSize.
30215+
try sema.resolveTypeLayout(parent.ty);
30216+
30217+
// This union value no longer has a well-defined tag type.
30218+
// The reinterpretation will read it back out as .none.
30219+
payload.val = try payload.val.unintern(sema.arena, mod);
30220+
return ComptimePtrMutationKit{
30221+
.mut_decl = parent.mut_decl,
30222+
.pointee = .{ .reinterpret = .{
30223+
.val_ptr = val_ptr,
30224+
.byte_offset = 0,
30225+
.write_packed = layout == .Packed,
30226+
} },
30227+
.ty = parent.ty,
30228+
};
30229+
}
3020030230
},
3020130231
.slice => switch (field_index) {
3020230232
Value.slice_ptr_index => return beginComptimePtrMutationInner(
@@ -30697,6 +30727,7 @@ fn bitCastVal(
3069730727
// For types with well-defined memory layouts, we serialize them a byte buffer,
3069830728
// then deserialize to the new type.
3069930729
const abi_size = try sema.usizeCast(block, src, old_ty.abiSize(mod));
30730+
3070030731
const buffer = try sema.gpa.alloc(u8, abi_size);
3070130732
defer sema.gpa.free(buffer);
3070230733
val.writeToMemory(old_ty, mod, buffer) catch |err| switch (err) {
@@ -30713,6 +30744,39 @@ fn bitCastVal(
3071330744
};
3071430745
}
3071530746

30747+
fn bitCastUnionFieldVal(
30748+
sema: *Sema,
30749+
block: *Block,
30750+
src: LazySrcLoc,
30751+
val: Value,
30752+
old_ty: Type,
30753+
field_ty: Type,
30754+
) !?Value {
30755+
const mod = sema.mod;
30756+
if (old_ty.eql(field_ty, mod)) return val;
30757+
30758+
const old_size = try sema.usizeCast(block, src, old_ty.abiSize(mod));
30759+
const field_size = try sema.usizeCast(block, src, field_ty.abiSize(mod));
30760+
30761+
const buffer = try sema.gpa.alloc(u8, @max(old_size, field_size));
30762+
defer sema.gpa.free(buffer);
30763+
val.writeToMemory(old_ty, mod, buffer) catch |err| switch (err) {
30764+
error.OutOfMemory => return error.OutOfMemory,
30765+
error.ReinterpretDeclRef => return null,
30766+
error.IllDefinedMemoryLayout => unreachable, // Sema was supposed to emit a compile error already
30767+
error.Unimplemented => return sema.fail(block, src, "TODO: implement writeToMemory for type '{}'", .{old_ty.fmt(mod)}),
30768+
};
30769+
30770+
// Reading a larger value means we need to reinterpret from undefined bytes
30771+
if (field_size > old_size) @memset(buffer[old_size..], 0xaa);
30772+
30773+
return Value.readFromMemory(field_ty, mod, buffer[0..], sema.arena) catch |err| switch (err) {
30774+
error.OutOfMemory => return error.OutOfMemory,
30775+
error.IllDefinedMemoryLayout => unreachable,
30776+
error.Unimplemented => return sema.fail(block, src, "TODO: implement readFromMemory for type '{}'", .{field_ty.fmt(mod)}),
30777+
};
30778+
}
30779+
3071630780
fn coerceArrayPtrToSlice(
3071730781
sema: *Sema,
3071830782
block: *Block,

src/TypedValue.zig

Lines changed: 17 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -84,22 +84,27 @@ pub fn print(
8484
if (level == 0) {
8585
return writer.writeAll(".{ ... }");
8686
}
87-
const union_val = val.castTag(.@"union").?.data;
87+
const payload = val.castTag(.@"union").?.data;
8888
try writer.writeAll(".{ ");
8989

90-
if (union_val.tag.toIntern() != .none) {
90+
if (payload.tag) |tag| {
9191
try print(.{
9292
.ty = ip.indexToKey(ty.toIntern()).union_type.enum_tag_ty.toType(),
93-
.val = union_val.tag,
93+
.val = tag,
9494
}, writer, level - 1, mod);
9595
try writer.writeAll(" = ");
96-
const field_ty = ty.unionFieldType(union_val.tag, mod).?;
96+
const field_ty = ty.unionFieldType(tag, mod).?;
9797
try print(.{
9898
.ty = field_ty,
99-
.val = union_val.val,
99+
.val = payload.val,
100100
}, writer, level - 1, mod);
101101
} else {
102-
return writer.writeAll("(unknown tag)");
102+
try writer.writeAll("(unknown tag) = ");
103+
const backing_ty = try ty.unionBackingType(mod);
104+
try print(.{
105+
.ty = backing_ty,
106+
.val = payload.val,
107+
}, writer, level - 1, mod);
103108
}
104109

105110
return writer.writeAll(" }");
@@ -421,7 +426,12 @@ pub fn print(
421426
.val = un.val.toValue(),
422427
}, writer, level - 1, mod);
423428
} else {
424-
try writer.writeAll("(unknown tag)");
429+
try writer.writeAll("(unknown tag) = ");
430+
const backing_ty = try ty.unionBackingType(mod);
431+
try print(.{
432+
.ty = backing_ty,
433+
.val = un.val.toValue(),
434+
}, writer, level - 1, mod);
425435
}
426436
} else try writer.writeAll("...");
427437
return writer.writeAll(" }");

src/type.zig

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1954,6 +1954,16 @@ pub const Type = struct {
19541954
return true;
19551955
}
19561956

1957+
/// Returns the type used for backing storage of this union during comptime operations.
1958+
/// Asserts the type is either an extern or packed union.
1959+
pub fn unionBackingType(ty: Type, mod: *Module) !Type {
1960+
return switch (ty.containerLayout(mod)) {
1961+
.Extern => try mod.arrayType(.{ .len = ty.abiSize(mod), .child = .u8_type }),
1962+
.Packed => try mod.intType(.unsigned, @intCast(ty.bitSize(mod))),
1963+
.Auto => unreachable,
1964+
};
1965+
}
1966+
19571967
pub fn unionGetLayout(ty: Type, mod: *Module) Module.UnionLayout {
19581968
const ip = &mod.intern_pool;
19591969
const union_type = ip.indexToKey(ty.toIntern()).union_type;

src/value.zig

Lines changed: 25 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -327,11 +327,19 @@ pub const Value = struct {
327327
},
328328
.@"union" => {
329329
const pl = val.castTag(.@"union").?.data;
330-
return mod.intern(.{ .un = .{
331-
.ty = ty.toIntern(),
332-
.tag = try pl.tag.intern(ty.unionTagTypeHypothetical(mod), mod),
333-
.val = try pl.val.intern(ty.unionFieldType(pl.tag, mod).?, mod),
334-
} });
330+
if (pl.tag) |pl_tag| {
331+
return mod.intern(.{ .un = .{
332+
.ty = ty.toIntern(),
333+
.tag = try pl_tag.intern(ty.unionTagTypeHypothetical(mod), mod),
334+
.val = try pl.val.intern(ty.unionFieldType(pl_tag, mod).?, mod),
335+
} });
336+
} else {
337+
return mod.intern(.{ .un = .{
338+
.ty = ty.toIntern(),
339+
.tag = .none,
340+
.val = try pl.val.intern(try ty.unionBackingType(mod), mod),
341+
} });
342+
}
335343
},
336344
}
337345
}
@@ -399,10 +407,7 @@ pub const Value = struct {
399407

400408
.un => |un| Tag.@"union".create(arena, .{
401409
// toValue asserts that the value cannot be .none which is valid on unions.
402-
.tag = .{
403-
.ip_index = un.tag,
404-
.legacy = undefined,
405-
},
410+
.tag = if (un.tag == .none) null else un.tag.toValue(),
406411
.val = un.val.toValue(),
407412
}),
408413

@@ -709,21 +714,22 @@ pub const Value = struct {
709714
.Union => switch (ty.containerLayout(mod)) {
710715
.Auto => return error.IllDefinedMemoryLayout, // Sema is supposed to have emitted a compile error already
711716
.Extern => {
712-
const union_obj = mod.typeToUnion(ty).?;
713717
if (val.unionTag(mod)) |union_tag| {
718+
const union_obj = mod.typeToUnion(ty).?;
714719
const field_index = mod.unionTagFieldIndex(union_obj, union_tag).?;
715720
const field_type = union_obj.field_types.get(&mod.intern_pool)[field_index].toType();
716721
const field_val = try val.fieldValue(mod, field_index);
717722
const byte_count = @as(usize, @intCast(field_type.abiSize(mod)));
718723
return writeToMemory(field_val, field_type, mod, buffer[0..byte_count]);
719724
} else {
720-
const union_size = ty.abiSize(mod);
721-
const array_type = try mod.arrayType(.{ .len = union_size, .child = .u8_type });
722-
return writeToMemory(val.unionValue(mod), array_type, mod, buffer[0..@as(usize, @intCast(union_size))]);
725+
const backing_ty = try ty.unionBackingType(mod);
726+
const byte_count: usize = @intCast(backing_ty.abiSize(mod));
727+
return writeToMemory(val.unionValue(mod), backing_ty, mod, buffer[0..byte_count]);
723728
}
724729
},
725730
.Packed => {
726-
const byte_count = (@as(usize, @intCast(ty.bitSize(mod))) + 7) / 8;
731+
const backing_ty = try ty.unionBackingType(mod);
732+
const byte_count: usize = @intCast(backing_ty.abiSize(mod));
727733
return writeToPackedMemory(val, ty, mod, buffer[0..byte_count], 0);
728734
},
729735
},
@@ -842,9 +848,8 @@ pub const Value = struct {
842848
const field_val = try val.fieldValue(mod, field_index);
843849
return field_val.writeToPackedMemory(field_type, mod, buffer, bit_offset);
844850
} else {
845-
const union_bits: u16 = @intCast(ty.bitSize(mod));
846-
const int_ty = try mod.intType(.unsigned, union_bits);
847-
return val.unionValue(mod).writeToPackedMemory(int_ty, mod, buffer, bit_offset);
851+
const backing_ty = try ty.unionBackingType(mod);
852+
return val.unionValue(mod).writeToPackedMemory(backing_ty, mod, buffer, bit_offset);
848853
}
849854
},
850855
}
@@ -1146,10 +1151,8 @@ pub const Value = struct {
11461151
.Union => switch (ty.containerLayout(mod)) {
11471152
.Auto, .Extern => unreachable, // Handled by non-packed readFromMemory
11481153
.Packed => {
1149-
const union_bits: u16 = @intCast(ty.bitSize(mod));
1150-
assert(union_bits != 0);
1151-
const int_ty = try mod.intType(.unsigned, union_bits);
1152-
const val = (try readFromPackedMemory(int_ty, mod, buffer, bit_offset, arena)).toIntern();
1154+
const backing_ty = try ty.unionBackingType(mod);
1155+
const val = (try readFromPackedMemory(backing_ty, mod, buffer, bit_offset, arena)).toIntern();
11531156
return (try mod.intern(.{ .un = .{
11541157
.ty = ty.toIntern(),
11551158
.tag = .none,
@@ -4017,7 +4020,7 @@ pub const Value = struct {
40174020
data: Data,
40184021

40194022
pub const Data = struct {
4020-
tag: Value,
4023+
tag: ?Value,
40214024
val: Value,
40224025
};
40234026
};

0 commit comments

Comments
 (0)