Skip to content
This repository has been archived by the owner on Jan 29, 2025. It is now read-only.

Add countTrailingZeros #2243

Merged
merged 16 commits into from
Feb 20, 2023
75 changes: 74 additions & 1 deletion src/back/glsl/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1149,7 +1149,8 @@ impl<'a, W: Write> Writer<'a, W> {
}
}
}
crate::MathFunction::CountLeadingZeros => {
crate::MathFunction::CountTrailingZeros
| crate::MathFunction::CountLeadingZeros => {
if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() {
self.need_bake_expressions.insert(arg);
}
Expand Down Expand Up @@ -2960,6 +2961,78 @@ impl<'a, W: Write> Writer<'a, W> {
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
// bits
Mf::CountTrailingZeros => {
if self.options.version.supports_integer_functions() {
match *ctx.info[arg].ty.inner_with(&self.module.types) {
crate::TypeInner::Vector { size, kind, .. } => {
let s = back::vector_size_str(size);

if let crate::ScalarKind::Uint = kind {
write!(self.out, "uvec{s}(findLSB(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ") - ivec{s}(1))")?;
} else {
write!(self.out, "mix(findLSB(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ") - ivec{s}(1), ivec{s}(0), lessThan(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ", ivec{s}(0)))")?;
}
}
crate::TypeInner::Scalar { kind, .. } => {
if let crate::ScalarKind::Uint = kind {
write!(self.out, "uint(findLSB(")?;
} else {
write!(self.out, "(")?;
self.write_expr(arg, ctx)?;
write!(self.out, " == 0 ? -1 : findLSB(")?;
}

self.write_expr(arg, ctx)?;
write!(self.out, ") - 1)")?;
}
_ => unreachable!(),
};
} else {
match *ctx.info[arg].ty.inner_with(&self.module.types) {
crate::TypeInner::Vector { size, kind, .. } => {
let s = back::vector_size_str(size);

if let crate::ScalarKind::Uint = kind {
write!(self.out, "uvec{s}(")?;
write!(self.out, "floor(log2(vec{s}(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ") + 0.5)) - vec{s}(1.0))")?;
} else {
write!(self.out, "ivec{s}(")?;
write!(self.out, "mix(floor(log2(vec{s}(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ") + 0.5)) - vec{s}(1.0), ")?;
write!(self.out, "vec{s}(0.0), lessThan(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ", ivec{s}(0u))))")?;
}
}
crate::TypeInner::Scalar { kind, .. } => {
if let crate::ScalarKind::Uint = kind {
write!(self.out, "uint(floor(log2(float(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ") + 0.5)) - 1.0)")?;
} else {
write!(self.out, "(")?;
self.write_expr(arg, ctx)?;
write!(self.out, " == 0 ? -1 : int(")?;
write!(self.out, "floor(log2(float(")?;
self.write_expr(arg, ctx)?;
write!(self.out, ") + 0.5))) - 1.0)")?;
}
}
_ => unreachable!(),
};
}

return Ok(());
}
Mf::CountLeadingZeros => {
if self.options.version.supports_integer_functions() {
match *ctx.info[arg].ty.inner_with(&self.module.types) {
Expand Down
44 changes: 43 additions & 1 deletion src/back/hlsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
| crate::MathFunction::Unpack2x16float => {
self.need_bake_expressions.insert(arg);
}
crate::MathFunction::CountLeadingZeros => {
crate::MathFunction::CountTrailingZeros
| crate::MathFunction::CountLeadingZeros => {
let inner = info[fun_handle].ty.inner_with(&module.types);
if let Some(crate::ScalarKind::Sint) = inner.scalar_kind() {
self.need_bake_expressions.insert(arg);
Expand Down Expand Up @@ -2551,6 +2552,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Unpack2x16float,
Regular(&'static str),
MissingIntOverload(&'static str),
CountTrailingZeros,
CountLeadingZeros,
}

Expand Down Expand Up @@ -2614,6 +2616,7 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
Mf::Transpose => Function::Regular("transpose"),
Mf::Determinant => Function::Regular("determinant"),
// bits
Mf::CountTrailingZeros => Function::CountTrailingZeros,
Mf::CountLeadingZeros => Function::CountLeadingZeros,
Mf::CountOneBits => Function::MissingIntOverload("countbits"),
Mf::ReverseBits => Function::MissingIntOverload("reversebits"),
Expand Down Expand Up @@ -2682,6 +2685,45 @@ impl<'a, W: fmt::Write> super::Writer<'a, W> {
write!(self.out, ")")?;
}
}
Function::CountTrailingZeros => {
match *func_ctx.info[arg].ty.inner_with(&module.types) {
TypeInner::Vector { size, kind, .. } => {
let s = match size {
crate::VectorSize::Bi => ".xx",
crate::VectorSize::Tri => ".xxx",
crate::VectorSize::Quad => ".xxxx",
};

if let ScalarKind::Uint = kind {
write!(self.out, "asuint(firstbitlow(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ") - (1){s})")?;
} else {
write!(self.out, "(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " == (0){s} ? (-1){s} : firstbitlow(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ") - (1){s})")?;
}
}
TypeInner::Scalar { kind, .. } => {
if let ScalarKind::Uint = kind {
write!(self.out, "asuint(firstbitlow(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ") - 1)")?;
} else {
write!(self.out, "(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, " == 0 ? -1 : firstbitlow(")?;
self.write_expr(module, arg, func_ctx)?;
write!(self.out, ") - 1)")?;
}
}
_ => unreachable!(),
}

return Ok(());
}
Function::CountLeadingZeros => {
match *func_ctx.info[arg].ty.inner_with(&module.types) {
TypeInner::Vector { size, kind, .. } => {
Expand Down
1 change: 1 addition & 0 deletions src/back/msl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1689,6 +1689,7 @@ impl<W: Write> Writer<W> {
Mf::Transpose => "transpose",
Mf::Determinant => "determinant",
// bits
Mf::CountTrailingZeros => "ctz",
Mf::CountLeadingZeros => "clz",
Mf::CountOneBits => "popcount",
Mf::ReverseBits => "reverse_bits",
Expand Down
65 changes: 65 additions & 0 deletions src/back/spv/block.rs
Original file line number Diff line number Diff line change
Expand Up @@ -888,6 +888,71 @@ impl<'w> BlockContext<'w> {
id,
arg0_id,
)),
Mf::CountTrailingZeros => {
let int = crate::ScalarValue::Sint(1);

let (int_type_id, int_id) = match *arg_ty {
crate::TypeInner::Vector { size, width, .. } => {
let ty = self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: Some(size),
kind: crate::ScalarKind::Sint,
width,
pointer_space: None,
}));

self.temp_list.clear();
self.temp_list
.resize(size as _, self.writer.get_constant_scalar(int, width));

let id = self.gen_id();
block.body.push(Instruction::constant_composite(
ty,
id,
&self.temp_list,
));

(ty, id)
}
crate::TypeInner::Scalar { width, .. } => (
self.get_type_id(LookupType::Local(LocalType::Value {
vector_size: None,
kind: crate::ScalarKind::Sint,
width,
pointer_space: None,
})),
self.writer.get_constant_scalar(int, width),
),
_ => unreachable!(),
};

block.body.push(Instruction::ext_inst(
self.writer.gl450_ext_inst_id,
spirv::GLOp::FindILsb,
int_type_id,
id,
&[arg0_id],
));

let sub_id = self.gen_id();
block.body.push(Instruction::binary(
spirv::Op::ISub,
int_type_id,
sub_id,
id,
int_id,
));

if let Some(crate::ScalarKind::Uint) = arg_scalar_kind {
block.body.push(Instruction::unary(
spirv::Op::Bitcast,
result_type_id,
self.gen_id(),
sub_id,
));
}

return Ok(());
}
Mf::CountLeadingZeros => {
let int = crate::ScalarValue::Sint(31);

Expand Down
1 change: 1 addition & 0 deletions src/back/wgsl/writer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1578,6 +1578,7 @@ impl<W: Write> Writer<W> {
Mf::Transpose => Function::Regular("transpose"),
Mf::Determinant => Function::Regular("determinant"),
// bits
Mf::CountTrailingZeros => Function::Regular("countTrailingZeros"),
Mf::CountLeadingZeros => Function::Regular("countLeadingZeros"),
Mf::CountOneBits => Function::Regular("countOneBits"),
Mf::ReverseBits => Function::Regular("reverseBits"),
Expand Down
1 change: 1 addition & 0 deletions src/front/wgsl/parse/conv.rs
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,7 @@ pub fn map_standard_fun(word: &str) -> Option<crate::MathFunction> {
"transpose" => Mf::Transpose,
"determinant" => Mf::Determinant,
// bits
"countTrailingZeros" => Mf::CountTrailingZeros,
"countLeadingZeros" => Mf::CountLeadingZeros,
"countOneBits" => Mf::CountOneBits,
"reverseBits" => Mf::ReverseBits,
Expand Down
1 change: 1 addition & 0 deletions src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1066,6 +1066,7 @@ pub enum MathFunction {
Transpose,
Determinant,
// bits
CountTrailingZeros,
CountLeadingZeros,
CountOneBits,
ReverseBits,
Expand Down
1 change: 1 addition & 0 deletions src/proc/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,7 @@ impl super::MathFunction {
Self::Transpose => 1,
Self::Determinant => 1,
// bits
Self::CountTrailingZeros => 1,
Self::CountLeadingZeros => 1,
Self::CountOneBits => 1,
Self::ReverseBits => 1,
Expand Down
1 change: 1 addition & 0 deletions src/proc/typifier.rs
Original file line number Diff line number Diff line change
Expand Up @@ -793,6 +793,7 @@ impl<'a> ResolveContext<'a> {
)),
},
// bits
Mf::CountTrailingZeros |
Mf::CountLeadingZeros |
Mf::CountOneBits |
Mf::ReverseBits |
Expand Down
3 changes: 2 additions & 1 deletion src/valid/expression.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1223,7 +1223,8 @@ impl super::Validator {
_ => return Err(ExpressionError::InvalidArgumentType(fun, 0, arg)),
}
}
Mf::CountLeadingZeros
Mf::CountTrailingZeros
| Mf::CountLeadingZeros
| Mf::CountOneBits
| Mf::ReverseBits
| Mf::FindLsb
Expand Down
6 changes: 6 additions & 0 deletions tests/in/math-functions.wgsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,12 @@ fn main() {
let g = refract(v, v, f);
let const_dot = dot(vec2<i32>(), vec2<i32>());
let first_leading_bit_abs = firstLeadingBit(abs(0u));
let ctz_a = countTrailingZeros(-1);
let ctz_b = countTrailingZeros(1u);
let ctz_c = countTrailingZeros(vec2(-1));
let ctz_d = countTrailingZeros(vec2(1u));
let ctz_e = countTrailingZeros(0);
let ctz_f = countTrailingZeros(0u);
let clz_a = countLeadingZeros(-1);
let clz_b = countLeadingZeros(1u);
let clz_c = countLeadingZeros(vec2(-1));
Expand Down
11 changes: 9 additions & 2 deletions tests/out/glsl/math-functions.main.Vertex.glsl
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,17 @@ void main() {
vec4 g = refract(v, v, 1.0);
int const_dot = ( + ivec2(0, 0).x * ivec2(0, 0).x + ivec2(0, 0).y * ivec2(0, 0).y);
uint first_leading_bit_abs = uint(findMSB(uint(abs(int(0u)))));
int ctz_a = (-1 == 0 ? -1 : findLSB(-1) - 1);
uint ctz_b = uint(findLSB(1u) - 1);
ivec2 _e20 = ivec2(-1);
ivec2 ctz_c = mix(findLSB(_e20) - ivec2(1), ivec2(0), lessThan(_e20, ivec2(0)));
uvec2 ctz_d = uvec2(findLSB(uvec2(1u)) - ivec2(1));
int ctz_e = (0 == 0 ? -1 : findLSB(0) - 1);
uint ctz_f = uint(findLSB(0u) - 1);
int clz_a = (-1 < 0 ? 0 : 31 - findMSB(-1));
uint clz_b = uint(31 - findMSB(1u));
ivec2 _e20 = ivec2(-1);
ivec2 clz_c = mix(ivec2(31) - findMSB(_e20), ivec2(0), lessThan(_e20, ivec2(0)));
ivec2 _e34 = ivec2(-1);
ivec2 clz_c = mix(ivec2(31) - findMSB(_e34), ivec2(0), lessThan(_e34, ivec2(0)));
uvec2 clz_d = uvec2(ivec2(31) - findMSB(uvec2(1u)));
}

11 changes: 9 additions & 2 deletions tests/out/hlsl/math-functions.hlsl
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,16 @@ void main()
float4 g = refract(v, v, 1.0);
int const_dot = dot(int2(0, 0), int2(0, 0));
uint first_leading_bit_abs = firstbithigh(abs(0u));
int ctz_a = (-1 == 0 ? -1 : firstbitlow(-1) - 1);
uint ctz_b = asuint(firstbitlow(1u) - 1);
int2 _expr20 = (-1).xx;
int2 ctz_c = (_expr20 == (0).xx ? (-1).xx : firstbitlow(_expr20) - (1).xx);
uint2 ctz_d = asuint(firstbitlow((1u).xx) - (1).xx);
int ctz_e = (0 == 0 ? -1 : firstbitlow(0) - 1);
uint ctz_f = asuint(firstbitlow(0u) - 1);
int clz_a = (-1 < 0 ? 0 : 31 - firstbithigh(-1));
uint clz_b = asuint(31 - firstbithigh(1u));
int2 _expr20 = (-1).xx;
int2 clz_c = (_expr20 < (0).xx ? (0).xx : (31).xx - firstbithigh(_expr20));
int2 _expr34 = (-1).xx;
int2 clz_c = (_expr34 < (0).xx ? (0).xx : (31).xx - firstbithigh(_expr34));
uint2 clz_d = asuint((31).xx - firstbithigh((1u).xx));
}
6 changes: 6 additions & 0 deletions tests/out/msl/math-functions.msl
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,12 @@ vertex void main_(
int const_dot = ( + const_type_1_.x * const_type_1_.x + const_type_1_.y * const_type_1_.y);
uint _e13 = metal::abs(0u);
uint first_leading_bit_abs = metal::select(31 - metal::clz(_e13), uint(-1), _e13 == 0 || _e13 == -1);
int ctz_a = metal::ctz(-1);
uint ctz_b = metal::ctz(1u);
metal::int2 ctz_c = metal::ctz(metal::int2(-1));
metal::uint2 ctz_d = metal::ctz(metal::uint2(1u));
int ctz_e = metal::ctz(0);
uint ctz_f = metal::ctz(0u);
int clz_a = metal::clz(-1);
uint clz_b = metal::clz(1u);
metal::int2 clz_c = metal::clz(metal::int2(-1));
Expand Down
Loading