Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Introduce Expression::Literal. #2333

Merged
merged 4 commits into from
May 11, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions src/back/dot/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,7 @@ fn write_function_expressions(
for (handle, expression) in fun.expressions.iter() {
use crate::Expression as E;
let (label, color_id) = match *expression {
E::Literal(_) => ("Literal".into(), 2),
E::Constant(_) => ("Constant".into(), 2),
E::ZeroValue(_) => ("ZeroValue".into(), 2),
E::Access { base, index } => {
Expand Down
15 changes: 15 additions & 0 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2270,6 +2270,21 @@ impl<'a, W: Write> Writer<'a, W> {
Expression::ZeroValue(ty) => {
self.write_zero_init_value(ty)?;
}
Expression::Literal(literal) => {
match literal {
// Floats are written using `Debug` instead of `Display` because it always appends the
// decimal part even it's zero which is needed for a valid glsl float constant
crate::Literal::F64(value) => write!(self.out, "{:?}LF", value)?,
crate::Literal::F32(value) => write!(self.out, "{:?}", value)?,
// Unsigned integers need a `u` at the end
//
// While `core` doesn't necessarily need it, it's allowed and since `es` needs it we
// always write it as the extra branch wouldn't have any benefit in readability
crate::Literal::U32(value) => write!(self.out, "{}u", value)?,
crate::Literal::I32(value) => write!(self.out, "{}", value)?,
crate::Literal::Bool(value) => write!(self.out, "{}", value)?,
}
}
// `Splat` needs to actually write down a vector, it's not always inferred in GLSL.
Expression::Splat { size: _, value } => {
let resolved = ctx.info[expr].ty.inner_with(&self.module.types);
Expand Down
9 changes: 9 additions & 0 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2058,6 +2058,15 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
match *expression {
Expression::Constant(constant) => self.write_constant(module, constant)?,
Expression::ZeroValue(ty) => self.write_default_init(module, ty)?,
Expression::Literal(literal) => match literal {
// Floats are written using `Debug` instead of `Display` because it always appends the
// decimal part even it's zero
crate::Literal::F64(value) => write!(self.out, "{value:?}L")?,
crate::Literal::F32(value) => write!(self.out, "{value:?}")?,
crate::Literal::U32(value) => write!(self.out, "{}u", value)?,
crate::Literal::I32(value) => write!(self.out, "{}", value)?,
crate::Literal::Bool(value) => write!(self.out, "{}", value)?,
},
Expression::Compose { ty, ref components } => {
match module.types[ty].inner {
TypeInner::Struct { .. } | TypeInner::Array { .. } => {
Expand Down
25 changes: 25 additions & 0 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1344,6 +1344,31 @@ impl<W: Write> Writer<W> {
};
write!(self.out, "{ty_name} {{}}")?;
}
crate::Expression::Literal(literal) => match literal {
crate::Literal::F64(_) => {
return Err(Error::CapabilityNotSupported(valid::Capabilities::FLOAT64))
}
crate::Literal::F32(value) => {
if value.is_infinite() {
let sign = if value.is_sign_negative() { "-" } else { "" };
write!(self.out, "{sign}INFINITY")?;
} else if value.is_nan() {
write!(self.out, "NAN")?;
} else {
let suffix = if value.fract() == 0.0 { ".0" } else { "" };
write!(self.out, "{value}{suffix}")?;
}
}
crate::Literal::U32(value) => {
write!(self.out, "{value}u")?;
}
crate::Literal::I32(value) => {
write!(self.out, "{value}")?;
}
crate::Literal::Bool(value) => {
write!(self.out, "{value}")?;
}
},
crate::Expression::Splat { size, value } => {
let scalar_kind = match *context.resolve_type(value) {
crate::TypeInner::Scalar { kind, .. } => kind,
Expand Down
84 changes: 36 additions & 48 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -118,8 +118,8 @@ impl Writer {
width: 4,
pointer_space: None,
}));
let value0_id = self.get_constant_scalar(crate::ScalarValue::Float(0.0), 4);
let value1_id = self.get_constant_scalar(crate::ScalarValue::Float(1.0), 4);
let zero_scalar_id = self.get_constant_scalar(crate::Literal::F32(0.0));
let one_scalar_id = self.get_constant_scalar(crate::Literal::F32(1.0));

let original_id = self.id_gen.next();
body.push(Instruction::load(
Expand All @@ -135,7 +135,7 @@ impl Writer {
spirv::GLOp::FClamp,
float_type_id,
clamp_id,
&[original_id, value0_id, value1_id],
&[original_id, zero_scalar_id, one_scalar_id],
));

body.push(Instruction::store(frag_depth_id, clamp_id, None));
Expand Down Expand Up @@ -359,6 +359,7 @@ impl<'w> BlockContext<'w> {
}
crate::Expression::Constant(handle) => self.writer.constant_ids[handle.index()],
crate::Expression::ZeroValue(_) => self.writer.write_constant_null(result_type_id),
crate::Expression::Literal(literal) => self.writer.get_constant_scalar(literal),
crate::Expression::Splat { size, value } => {
let value_id = self.cached[value];
let components = [value_id; 4];
Expand Down Expand Up @@ -705,18 +706,14 @@ impl<'w> BlockContext<'w> {
crate::TypeInner::Scalar { width, .. } => (None, width),
ref other => unimplemented!("Unexpected saturate({:?})", other),
};

let mut arg1_id = self
.writer
.get_constant_scalar(crate::ScalarValue::Float(0.0), width);
let mut arg2_id = self
.writer
.get_constant_scalar(crate::ScalarValue::Float(1.0), width);
let kind = crate::ScalarKind::Float;
let mut arg1_id = self.writer.get_constant_scalar_with(0, kind, width)?;
let mut arg2_id = self.writer.get_constant_scalar_with(1, kind, width)?;

if let Some(size) = maybe_size {
let ty = LocalType::Value {
vector_size: Some(size),
kind: crate::ScalarKind::Float,
kind,
width,
pointer_space: None,
}
Expand Down Expand Up @@ -878,12 +875,13 @@ impl<'w> BlockContext<'w> {
arg0_id,
)),
Mf::CountTrailingZeros => {
let uint = crate::ScalarValue::Uint(32);
let kind = crate::ScalarKind::Uint;

let uint_id = match *arg_ty {
crate::TypeInner::Vector { size, width, .. } => {
let ty = LocalType::Value {
vector_size: Some(size),
kind: crate::ScalarKind::Uint,
kind,
width,
pointer_space: None,
}
Expand All @@ -892,13 +890,13 @@ impl<'w> BlockContext<'w> {
self.temp_list.clear();
self.temp_list.resize(
size as _,
self.writer.get_constant_scalar(uint, width),
self.writer.get_constant_scalar_with(32, kind, width)?,
);

self.writer.get_constant_composite(ty, &self.temp_list)
}
crate::TypeInner::Scalar { width, .. } => {
self.writer.get_constant_scalar(uint, width)
self.writer.get_constant_scalar_with(32, kind, width)?
}
_ => unreachable!(),
};
Expand All @@ -921,21 +919,23 @@ impl<'w> BlockContext<'w> {
))
}
Mf::CountLeadingZeros => {
let int = crate::ScalarValue::Sint(31);
let kind = crate::ScalarKind::Sint;

let (int_type_id, int_id) = match *arg_ty {
crate::TypeInner::Vector { size, width, .. } => {
let ty = LocalType::Value {
vector_size: Some(size),
kind: crate::ScalarKind::Sint,
kind,
width,
pointer_space: None,
}
.into();

self.temp_list.clear();
self.temp_list
.resize(size as _, self.writer.get_constant_scalar(int, width));
self.temp_list.resize(
size as _,
self.writer.get_constant_scalar_with(31, kind, width)?,
);

(
self.get_type_id(ty),
Expand All @@ -945,11 +945,11 @@ impl<'w> BlockContext<'w> {
crate::TypeInner::Scalar { width, .. } => (
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
kind: crate::ScalarKind::Sint,
kind,
width,
pointer_space: None,
})),
self.writer.get_constant_scalar(int, width),
self.writer.get_constant_scalar_with(31, kind, width)?,
),
_ => unreachable!(),
};
Expand Down Expand Up @@ -1134,15 +1134,14 @@ impl<'w> BlockContext<'w> {
(_, _, None) => Cast::Unary(spirv::Op::Bitcast),
// casting to a bool - generate `OpXxxNotEqual`
(_, Sk::Bool, Some(_)) => {
let (op, value) = match src_kind {
Sk::Sint => (spirv::Op::INotEqual, crate::ScalarValue::Sint(0)),
Sk::Uint => (spirv::Op::INotEqual, crate::ScalarValue::Uint(0)),
Sk::Float => {
(spirv::Op::FUnordNotEqual, crate::ScalarValue::Float(0.0))
}
let op = match src_kind {
Sk::Sint | Sk::Uint => spirv::Op::INotEqual,
Sk::Float => spirv::Op::FUnordNotEqual,
Sk::Bool => unreachable!(),
};
let zero_scalar_id = self.writer.get_constant_scalar(value, src_width);
let zero_scalar_id = self
.writer
.get_constant_scalar_with(0, src_kind, src_width)?;
let zero_id = match src_size {
Some(size) => {
let ty = LocalType::Value {
Expand All @@ -1165,21 +1164,10 @@ impl<'w> BlockContext<'w> {
}
// casting from a bool - generate `OpSelect`
(Sk::Bool, _, Some(dst_width)) => {
let (val0, val1) = match kind {
Sk::Sint => {
(crate::ScalarValue::Sint(0), crate::ScalarValue::Sint(1))
}
Sk::Uint => {
(crate::ScalarValue::Uint(0), crate::ScalarValue::Uint(1))
}
Sk::Float => (
crate::ScalarValue::Float(0.0),
crate::ScalarValue::Float(1.0),
),
Sk::Bool => unreachable!(),
};
let scalar0_id = self.writer.get_constant_scalar(val0, dst_width);
let scalar1_id = self.writer.get_constant_scalar(val1, dst_width);
let zero_scalar_id =
self.writer.get_constant_scalar_with(0, kind, dst_width)?;
let one_scalar_id =
self.writer.get_constant_scalar_with(1, kind, dst_width)?;
let (accept_id, reject_id) = match src_size {
Some(size) => {
let ty = LocalType::Value {
Expand All @@ -1191,19 +1179,19 @@ impl<'w> BlockContext<'w> {
.into();

self.temp_list.clear();
self.temp_list.resize(size as _, scalar0_id);
self.temp_list.resize(size as _, zero_scalar_id);

let vec0_id =
self.writer.get_constant_composite(ty, &self.temp_list);

self.temp_list.fill(scalar1_id);
self.temp_list.fill(one_scalar_id);

let vec1_id =
self.writer.get_constant_composite(ty, &self.temp_list);

(vec1_id, vec0_id)
}
None => (scalar1_id, scalar0_id),
None => (one_scalar_id, zero_scalar_id),
};

Cast::Ternary(spirv::Op::Select, accept_id, reject_id)
Expand Down Expand Up @@ -1460,8 +1448,8 @@ impl<'w> BlockContext<'w> {
BoundsCheckResult::KnownInBounds(known_index) => {
// Even if the index is known, `OpAccessIndex`
// requires expression operands, not literals.
let scalar = crate::ScalarValue::Uint(known_index as u64);
self.writer.get_constant_scalar(scalar, 4)
let scalar = crate::Literal::U32(known_index);
self.writer.get_constant_scalar(scalar)
}
BoundsCheckResult::Computed(computed_index_id) => computed_index_id,
BoundsCheckResult::Conditional(comparison_id) => {
Expand Down
4 changes: 1 addition & 3 deletions src/back/spv/image.rs
Original file line number Diff line number Diff line change
Expand Up @@ -901,9 +901,7 @@ impl<'w> BlockContext<'w> {
depth_id,
);

let zero_id = self
.writer
.get_constant_scalar(crate::ScalarValue::Float(0.0), 4);
let zero_id = self.writer.get_constant_scalar(crate::Literal::F32(0.0));

mask |= spirv::ImageOperands::LOD;
inst.add_operand(mask.bits());
Expand Down
8 changes: 8 additions & 0 deletions src/back/spv/instructions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -343,6 +343,14 @@ impl super::Instruction {
instruction
}

pub(super) fn constant_32bit(result_type_id: Word, id: Word, value: Word) -> Self {
Self::constant(result_type_id, id, &[value])
}

pub(super) fn constant_64bit(result_type_id: Word, id: Word, low: Word, high: Word) -> Self {
Self::constant(result_type_id, id, &[low, high])
}

pub(super) fn constant(result_type_id: Word, id: Word, values: &[Word]) -> Self {
let mut instruction = Self::new(Op::Constant);
instruction.set_type(result_type_id);
Expand Down
14 changes: 5 additions & 9 deletions src/back/spv/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -295,12 +295,12 @@ enum LocalType {
/// [`BindingArray`]: crate::TypeInner::BindingArray
PointerToBindingArray {
base: Handle<crate::Type>,
size: u64,
size: u32,
space: crate::AddressSpace,
},
BindingArray {
base: Handle<crate::Type>,
size: u64,
size: u32,
},
AccelerationStructure,
RayQuery,
Expand Down Expand Up @@ -454,10 +454,7 @@ impl recyclable::Recyclable for CachedExpressions {

#[derive(Eq, Hash, PartialEq)]
enum CachedConstant {
Scalar {
value: crate::ScalarValue,
width: crate::Bytes,
},
Literal(crate::Literal),
Composite {
ty: LookupType,
constituent_ids: Vec<Word>,
Expand Down Expand Up @@ -568,13 +565,12 @@ impl BlockContext<'_> {
}

fn get_index_constant(&mut self, index: Word) -> Word {
self.writer
.get_constant_scalar(crate::ScalarValue::Uint(index as _), 4)
self.writer.get_constant_scalar(crate::Literal::U32(index))
}

fn get_scope_constant(&mut self, scope: Word) -> Word {
self.writer
.get_constant_scalar(crate::ScalarValue::Sint(scope as _), 4)
.get_constant_scalar(crate::Literal::I32(scope as _))
}
}

Expand Down
9 changes: 3 additions & 6 deletions src/back/spv/ray.rs
Original file line number Diff line number Diff line change
Expand Up @@ -117,12 +117,9 @@ impl<'w> BlockContext<'w> {
) -> spirv::Word {
let width = 4;
let query_id = self.cached[query];
let intersection_id = self.writer.get_constant_scalar(
crate::ScalarValue::Uint(
spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _,
),
width,
);
let intersection_id = self.writer.get_constant_scalar(crate::Literal::U32(
spirv::RayQueryIntersection::RayQueryCommittedIntersectionKHR as _,
));

let flag_type_id = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
Expand Down
Loading