Skip to content

Commit

Permalink
spirv: make constant handle float, errorset, errorunion
Browse files Browse the repository at this point in the history
This is in preparation of removing indirect lowering again. Also
modifies constant() to accept a repr so that both direct as well
as indirect representations can be generated. Indirect is not yet
used, but will be used for globals.
  • Loading branch information
Snektron committed May 20, 2023
1 parent 65157d3 commit c92cc57
Show file tree
Hide file tree
Showing 2 changed files with 76 additions and 13 deletions.
79 changes: 66 additions & 13 deletions src/codegen/spirv.zig
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ pub const DeclGen = struct {
return self.spv.declPtr(spv_decl_index).result_id;
}

return try self.constant(ty, val);
return try self.constant(ty, val, .direct);
}
const index = Air.refToIndex(inst).?;
return self.inst_results.get(index).?; // Assertion means instruction does not dominate usage.
Expand Down Expand Up @@ -1021,14 +1021,16 @@ pub const DeclGen = struct {
/// the constant is more complicated however, it needs to be lowered to an indirect constant, which
/// is then loaded using OpLoad. Such values are loaded into the UniformConstant storage class by default.
/// This function should only be called during function code generation.
fn constant(self: *DeclGen, ty: Type, val: Value) !IdRef {
fn constant(self: *DeclGen, ty: Type, val: Value, repr: Repr) !IdRef {
const target = self.getTarget();
const section = &self.spv.sections.types_globals_constants;
const result_ty_ref = try self.resolveType(ty, .direct);
const result_ty_ref = try self.resolveType(ty, repr);
const result_ty_id = self.typeId(result_ty_ref);
const result_id = self.spv.allocId();

log.debug("constant: ty = {}, val = {}", .{ ty.fmt(self.module), val.fmtValue(ty, self.module) });

if (val.isUndef()) {
const result_id = self.spv.allocId();
try section.emit(self.spv.gpa, .OpUndef, .{
.id_result_type = result_ty_id,
.id_result = result_id,
Expand All @@ -1039,24 +1041,76 @@ pub const DeclGen = struct {
switch (ty.zigTypeTag()) {
.Int => {
if (ty.isSignedInt()) {
try self.genConstInt(result_ty_ref, result_id, val.toSignedInt(target));
return try self.constInt(result_ty_ref, val.toSignedInt(target));
} else {
try self.genConstInt(result_ty_ref, result_id, val.toUnsignedInt(target));
return try self.constInt(result_ty_ref, val.toUnsignedInt(target));
}
},
.Bool => switch (repr) {
.direct => {
const result_id = self.spv.allocId();
const operands = .{ .id_result_type = result_ty_id, .id_result = result_id };
if (val.toBool()) {
try section.emit(self.spv.gpa, .OpConstantTrue, operands);
} else {
try section.emit(self.spv.gpa, .OpConstantFalse, operands);
}
return result_id;
},
.indirect => return try self.constInt(result_ty_ref, @boolToInt(val.toBool())),
},
.Float => {
const result_id = self.spv.allocId();
switch (ty.floatBits(target)) {
16 => try self.spv.emitConstant(result_ty_id, result_id, .{ .float32 = val.toFloat(f16) }),
32 => try self.spv.emitConstant(result_ty_id, result_id, .{ .float32 = val.toFloat(f32) }),
64 => try self.spv.emitConstant(result_ty_id, result_id, .{ .float64 = val.toFloat(f64) }),
80, 128 => unreachable, // TODO
else => unreachable,
}
return result_id;
},
.Bool => {
const operands = .{ .id_result_type = result_ty_id, .id_result = result_id };
if (val.toBool()) {
try section.emit(self.spv.gpa, .OpConstantTrue, operands);
.ErrorSet => {
const value = switch (val.tag()) {
.@"error" => blk: {
const err_name = val.castTag(.@"error").?.data.name;
const kv = try self.module.getErrorValue(err_name);
break :blk @intCast(u16, kv.value);
},
.zero => 0,
else => unreachable,
};

return try self.constInt(result_ty_ref, value);
},
.ErrorUnion => {
const payload_ty = ty.errorUnionPayload();
const is_pl = val.errorUnionIsPayload();
const error_val = if (!is_pl) val else Value.initTag(.zero);

const eu_layout = self.errorUnionLayout(payload_ty);
if (!eu_layout.payload_has_bits) {
return try self.constant(Type.anyerror, error_val, repr);
}

const payload_val = if (val.castTag(.eu_payload)) |pl| pl.data else Value.initTag(.undef);

var members: [2]IdRef = undefined;
if (eu_layout.error_first) {
members[0] = try self.constant(Type.anyerror, error_val, .indirect);
members[1] = try self.constant(payload_ty, payload_val, .indirect);
} else {
try section.emit(self.spv.gpa, .OpConstantFalse, operands);
members[0] = try self.constant(payload_ty, payload_val, .indirect);
members[1] = try self.constant(Type.anyerror, error_val, .indirect);
}
return try self.spv.constComposite(result_ty_ref, &members);
},
// TODO: We can handle most pointers here (decl refs etc), because now they emit an extra
// OpVariable that is not really required.
else => {
// The value cannot be generated directly, so generate it as an indirect constant,
// and then perform an OpLoad.
const result_id = self.spv.allocId();
const alignment = ty.abiAlignment(target);
const spv_decl_index = try self.spv.allocDecl(.global);

Expand All @@ -1078,10 +1132,9 @@ pub const DeclGen = struct {
});
// TODO: Convert bools? This logic should hook into `load`. It should be a dead
// path though considering .Bool is handled above.
return result_id;
},
}

return result_id;
}

/// Turn a Zig type into a SPIR-V Type, and return its type result-id.
Expand Down
10 changes: 10 additions & 0 deletions src/codegen/spirv/Module.zig
Original file line number Diff line number Diff line change
Expand Up @@ -774,6 +774,16 @@ pub fn changePtrStorageClass(self: *Module, ptr_ty_ref: Type.Ref, new_storage_cl
return try self.resolveType(Type.initPayload(&payload.base));
}

pub fn constComposite(self: *Module, ty_ref: Type.Ref, members: []const IdRef) !IdRef {
const result_id = self.allocId();
try self.sections.types_globals_constants.emit(self.gpa, .OpSpecConstantComposite, .{
.id_result_type = self.typeId(ty_ref),
.id_result = result_id,
.constituents = members,
});
return result_id;
}

pub fn emitConstant(
self: *Module,
ty_id: IdRef,
Expand Down

0 comments on commit c92cc57

Please sign in to comment.