Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add countTrailingZeros #2243

Merged
merged 16 commits into from
Feb 20, 2023
29 changes: 29 additions & 0 deletions src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2996,6 +2996,35 @@ impl<'a, W: Write> Writer<'a, W> {
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
// bits
Mf::CountTrailingZeros => {
match *ctx.info[arg].ty.inner_with(&self.module.types) {
crate::TypeInner::Vector { size, kind, .. } => {
let s = back::vector_size_str(size);
if let crate::ScalarKind::Uint = kind {
write!(self.out, "min(uvec{s}(findLSB(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ")), uvec{s}(32u))")?;
} else {
write!(self.out, "ivec{s}(min(uvec{s}(findLSB(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ")), uvec{s}(32u)))")?;
}
}
crate::TypeInner::Scalar { kind, .. } => {
if let crate::ScalarKind::Uint = kind {
write!(self.out, "min(uint(findLSB(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ")), 32u)")?;
} else {
write!(self.out, "int(min(uint(findLSB(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ")), 32u))")?;
}
}
_ => unreachable!(),
};
return Ok(());
}
Mf::CountLeadingZeros => {
if self.options.version.supports_integer_functions() {
match *ctx.info[arg].ty.inner_with(&self.module.types) {
Expand Down
37 changes: 37 additions & 0 deletions src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2551,6 +2551,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Unpack2x16float,
Regular(&'static str),
MissingIntOverload(&'static str),
CountTrailingZeros,
CountLeadingZeros,
}

Expand Down Expand Up @@ -2614,6 +2615,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Mf::Transpose => Function::Regular("transpose"),
Mf::Determinant => Function::Regular("determinant"),
// bits
Mf::CountTrailingZeros => Function::CountTrailingZeros,
Mf::CountLeadingZeros => Function::CountLeadingZeros,
Mf::CountOneBits => Function::MissingIntOverload("countbits"),
Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
Expand Down Expand Up @@ -2682,6 +2684,41 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, ")")?;
}
}
Function::CountTrailingZeros => {
match *func_ctx.info[arg].ty.inner_with(&module.types) {
TypeInner::Vector { size, kind, .. } => {
let s = match size {
crate::VectorSize::Bi => ".xx",
crate::VectorSize::Tri => ".xxx",
crate::VectorSize::Quad => ".xxxx",
};

if let ScalarKind::Uint = kind {
write!(self.out, "min((32u){s}, firstbitlow(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
} else {
write!(self.out, "asint(min((32u){s}, asuint(firstbitlow(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))))")?;
}
}
TypeInner::Scalar { kind, .. } => {
if let ScalarKind::Uint = kind {
write!(self.out, "min(32u, firstbitlow(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))")?;
} else {
write!(self.out, "asint(min(32u, asuint(firstbitlow(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, "))))")?;
}
}
_ => unreachable!(),
}

return Ok(());
}
Function::CountLeadingZeros => {
match *func_ctx.info[arg].ty.inner_with(&module.types) {
TypeInner::Vector { size, kind, .. } => {
Expand Down
1 change: 1 addition & 0 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1686,6 +1686,7 @@ impl<W: Write> Writer<W> {
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
// bits
Mf::CountTrailingZeros => "ctz",
Mf::CountLeadingZeros => "clz",
Mf::CountOneBits => "popcount",
Mf::ReverseBits => "reverse_bits",
Expand Down
43 changes: 43 additions & 0 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -874,6 +874,49 @@ impl<'w> BlockContext<'w> {
id,
arg0_id,
)),
Mf::CountTrailingZeros => {
let uint = crate::ScalarValue::Uint(32);
let uint_id = match *arg_ty {
crate::TypeInner::Vector { size, width, .. } => {
let ty = LocalType::Value {
vector_size: Some(size),
kind: crate::ScalarKind::Uint,
width,
pointer_space: None,
}
.into();

self.temp_list.clear();
self.temp_list.resize(
size as _,
self.writer.get_constant_scalar(uint, width),
);

self.writer.get_constant_composite(ty, &self.temp_list)
}
crate::TypeInner::Scalar { width, .. } => {
self.writer.get_constant_scalar(uint, width)
}
_ => unreachable!(),
};

let lsb_id = self.gen_id();
block.body.push(Instruction::ext_inst(
self.writer.gl450_ext_inst_id,
spirv::GLOp::FindILsb,
result_type_id,
lsb_id,
&[arg0_id],
));

MathOp::Custom(Instruction::ext_inst(
self.writer.gl450_ext_inst_id,
spirv::GLOp::UMin,
result_type_id,
id,
&[uint_id, lsb_id],
))
}
Mf::CountLeadingZeros => {
let int = crate::ScalarValue::Sint(31);

Expand Down
1 change: 1 addition & 0 deletions src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1578,6 +1578,7 @@ impl<W: Write> Writer<W> {
Mf::Transpose => Function::Regular("transpose"),
Mf::Determinant => Function::Regular("determinant"),
// bits
Mf::CountTrailingZeros => Function::Regular("countTrailingZeros"),
Mf::CountLeadingZeros => Function::Regular("countLeadingZeros"),
Mf::CountOneBits => Function::Regular("countOneBits"),
Mf::ReverseBits => Function::Regular("reverseBits"),
Expand Down
1 change: 1 addition & 0 deletions src/front/wgsl/parse/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
"transpose" => Mf::Transpose,
"determinant" => Mf::Determinant,
// bits
"countTrailingZeros" => Mf::CountTrailingZeros,
"countLeadingZeros" => Mf::CountLeadingZeros,
"countOneBits" => Mf::CountOneBits,
"reverseBits" => Mf::ReverseBits,
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,7 @@ pub enum MathFunction {
Transpose,
Determinant,
// bits
CountTrailingZeros,
CountLeadingZeros,
CountOneBits,
ReverseBits,
Expand Down
1 change: 1 addition & 0 deletions src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ impl super::MathFunction {
Self::Transpose => 1,
Self::Determinant => 1,
// bits
Self::CountTrailingZeros => 1,
Self::CountLeadingZeros => 1,
Self::CountOneBits => 1,
Self::ReverseBits => 1,
Expand Down
1 change: 1 addition & 0 deletions src/proc/typifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,7 @@ impl<'a> ResolveContext<'a> {
)),
},
// bits
Mf::CountTrailingZeros |
Mf::CountLeadingZeros |
Mf::CountOneBits |
Mf::ReverseBits |
Expand Down
3 changes: 2 additions & 1 deletion src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1223,7 +1223,8 @@ impl super::Validator {
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::CountLeadingZeros
Mf::CountTrailingZeros
| Mf::CountLeadingZeros
| Mf::CountOneBits
| Mf::ReverseBits
| Mf::FindLsb
Expand Down
8 changes: 8 additions & 0 deletions tests/in/math-functions.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ fn main() {
let g = refract(v, v, f);
let const_dot = dot(vec2<i32>(), vec2<i32>());
let first_leading_bit_abs = firstLeadingBit(abs(0u));
let ctz_a = countTrailingZeros(0u);
let ctz_b = countTrailingZeros(0);
let ctz_c = countTrailingZeros(0xFFFFFFFFu);
let ctz_d = countTrailingZeros(-1);
let ctz_e = countTrailingZeros(vec2(0u));
let ctz_f = countTrailingZeros(vec2(0));
let ctz_g = countTrailingZeros(vec2(1u));
let ctz_h = countTrailingZeros(vec2(1));
let clz_a = countLeadingZeros(-1);
let clz_b = countLeadingZeros(1u);
let clz_c = countLeadingZeros(vec2(-1));
Expand Down
12 changes: 10 additions & 2 deletions tests/out/glsl/math-functions.main.Vertex.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,18 @@ void main() {
vec4 g = refract(v, v, 1.0);
int const_dot = ( + ivec2(0, 0).x * ivec2(0, 0).x + ivec2(0, 0).y * ivec2(0, 0).y);
uint first_leading_bit_abs = uint(findMSB(uint(abs(int(0u)))));
uint ctz_a = min(uint(findLSB(0u)), 32u);
int ctz_b = int(min(uint(findLSB(0)), 32u));
uint ctz_c = min(uint(findLSB(4294967295u)), 32u);
int ctz_d = int(min(uint(findLSB(-1)), 32u));
uvec2 ctz_e = min(uvec2(findLSB(uvec2(0u))), uvec2(32u));
ivec2 ctz_f = ivec2(min(uvec2(findLSB(ivec2(0))), uvec2(32u)));
uvec2 ctz_g = min(uvec2(findLSB(uvec2(1u))), uvec2(32u));
ivec2 ctz_h = ivec2(min(uvec2(findLSB(ivec2(1))), uvec2(32u)));
int clz_a = (-1 < 0 ? 0 : 31 - findMSB(-1));
uint clz_b = uint(31 - findMSB(1u));
ivec2 _e20 = ivec2(-1);
ivec2 clz_c = mix(ivec2(31) - findMSB(_e20), ivec2(0), lessThan(_e20, ivec2(0)));
ivec2 _e40 = ivec2(-1);
ivec2 clz_c = mix(ivec2(31) - findMSB(_e40), ivec2(0), lessThan(_e40, ivec2(0)));
uvec2 clz_d = uvec2(ivec2(31) - findMSB(uvec2(1u)));
}

12 changes: 10 additions & 2 deletions tests/out/hlsl/math-functions.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,17 @@ void main()
float4 g = refract(v, v, 1.0);
int const_dot = dot(int2(0, 0), int2(0, 0));
uint first_leading_bit_abs = firstbithigh(abs(0u));
uint ctz_a = min(32u, firstbitlow(0u));
int ctz_b = asint(min(32u, asuint(firstbitlow(0))));
uint ctz_c = min(32u, firstbitlow(4294967295u));
int ctz_d = asint(min(32u, asuint(firstbitlow(-1))));
uint2 ctz_e = min((32u).xx, firstbitlow((0u).xx));
int2 ctz_f = asint(min((32u).xx, asuint(firstbitlow((0).xx))));
uint2 ctz_g = min((32u).xx, firstbitlow((1u).xx));
int2 ctz_h = asint(min((32u).xx, asuint(firstbitlow((1).xx))));
int clz_a = (-1 < 0 ? 0 : 31 - firstbithigh(-1));
uint clz_b = asuint(31 - firstbithigh(1u));
int2 _expr20 = (-1).xx;
int2 clz_c = (_expr20 < (0).xx ? (0).xx : (31).xx - firstbithigh(_expr20));
int2 _expr40 = (-1).xx;
int2 clz_c = (_expr40 < (0).xx ? (0).xx : (31).xx - firstbithigh(_expr40));
uint2 clz_d = asuint((31).xx - firstbithigh((1u).xx));
}
8 changes: 8 additions & 0 deletions tests/out/msl/math-functions.msl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,14 @@ vertex void main_(
int const_dot = ( + const_type_1_.x * const_type_1_.x + const_type_1_.y * const_type_1_.y);
uint _e13 = metal::abs(0u);
uint first_leading_bit_abs = metal::select(31 - metal::clz(_e13), uint(-1), _e13 == 0 || _e13 == -1);
uint ctz_a = metal::ctz(0u);
int ctz_b = metal::ctz(0);
uint ctz_c = metal::ctz(4294967295u);
int ctz_d = metal::ctz(-1);
metal::uint2 ctz_e = metal::ctz(metal::uint2(0u));
metal::int2 ctz_f = metal::ctz(metal::int2(0));
metal::uint2 ctz_g = metal::ctz(metal::uint2(1u));
metal::int2 ctz_h = metal::ctz(metal::int2(1));
int clz_a = metal::clz(-1);
uint clz_b = metal::clz(1u);
metal::int2 clz_c = metal::clz(metal::int2(-1));
Expand Down
Loading