diff --git a/src/back/glsl/mod.rs b/src/back/glsl/mod.rs index cba3b3ca0a..18ba9f29b8 100644 --- a/src/back/glsl/mod.rs +++ b/src/back/glsl/mod.rs @@ -2808,32 +2808,59 @@ impl<'a, W: Write> Writer<'a, W> { let extract_bits = fun == Mf::ExtractBits; let insert_bits = fun == Mf::InsertBits; - // we might need to cast to unsigned integers since - // GLSL's findLSB / findMSB always return signed integers - let need_extra_paren = { - (fun == Mf::FindLsb || fun == Mf::FindMsb || fun == Mf::CountOneBits) - && match *ctx.info[arg].ty.inner_with(&self.module.types) { - crate::TypeInner::Scalar { - kind: crate::ScalarKind::Uint, - .. - } => { - write!(self.out, "uint(")?; - true - } - crate::TypeInner::Vector { - kind: crate::ScalarKind::Uint, - size, - .. - } => { - write!(self.out, "uvec{}(", size as u8)?; - true - } - _ => false, - } + // Some GLSL functions always return signed integers (like findMSB), + // so they need to be cast to uint if the argument is also an uint. + let ret_might_need_int_to_uint = + matches!(fun, Mf::FindLsb | Mf::FindMsb | Mf::CountOneBits | Mf::Abs); + + // Some GLSL functions only accept signed integers (like abs), + // so they need their argument cast from uint to int. + let arg_might_need_uint_to_int = matches!(fun, Mf::Abs); + + // Check if the argument is an unsigned integer and return the vector size + // in case it's a vector + let maybe_uint_size = match *ctx.info[arg].ty.inner_with(&self.module.types) { + crate::TypeInner::Scalar { + kind: crate::ScalarKind::Uint, + .. + } => Some(None), + crate::TypeInner::Vector { + kind: crate::ScalarKind::Uint, + size, + .. + } => Some(Some(size)), + _ => None, }; + // Cast to uint if the function needs it + if ret_might_need_int_to_uint { + if let Some(maybe_size) = maybe_uint_size { + match maybe_size { + Some(size) => write!(self.out, "uvec{}(", size as u8)?, + None => write!(self.out, "uint(")?, + } + } + } + write!(self.out, "{}(", fun_name)?; + + // Cast to int if the function needs it + if arg_might_need_uint_to_int { + if let Some(maybe_size) = maybe_uint_size { + match maybe_size { + Some(size) => write!(self.out, "ivec{}(", size as u8)?, + None => write!(self.out, "int(")?, + } + } + } + self.write_expr(arg, ctx)?; + + // Close the cast from uint to int + if arg_might_need_uint_to_int && maybe_uint_size.is_some() { + write!(self.out, ")")? + } + if let Some(arg) = arg1 { write!(self.out, ", ")?; if extract_bits { @@ -2866,7 +2893,8 @@ impl<'a, W: Write> Writer<'a, W> { } write!(self.out, ")")?; - if need_extra_paren { + // Close the cast from int to uint + if ret_might_need_int_to_uint && maybe_uint_size.is_some() { write!(self.out, ")")? } } diff --git a/tests/in/math-functions.wgsl b/tests/in/math-functions.wgsl index 38e9acee94..a2e097a04d 100644 --- a/tests/in/math-functions.wgsl +++ b/tests/in/math-functions.wgsl @@ -7,4 +7,5 @@ fn main() { let c = degrees(v); let d = radians(v); let const_dot = dot(vec2(), vec2()); + let first_leading_bit_abs = firstLeadingBit(abs(0u)); } diff --git a/tests/out/glsl/math-functions.main.Vertex.glsl b/tests/out/glsl/math-functions.main.Vertex.glsl index 02a94ea69d..8d72868299 100644 --- a/tests/out/glsl/math-functions.main.Vertex.glsl +++ b/tests/out/glsl/math-functions.main.Vertex.glsl @@ -11,5 +11,6 @@ void main() { vec4 c = degrees(v); vec4 d = radians(v); int const_dot = ( + ivec2(0, 0).x * ivec2(0, 0).x + ivec2(0, 0).y * ivec2(0, 0).y); + uint first_leading_bit_abs = uint(findMSB(uint(abs(int(0u))))); } diff --git a/tests/out/hlsl/math-functions.hlsl b/tests/out/hlsl/math-functions.hlsl index b17a32cf41..179cfee66c 100644 --- a/tests/out/hlsl/math-functions.hlsl +++ b/tests/out/hlsl/math-functions.hlsl @@ -7,4 +7,5 @@ void main() float4 c = degrees(v); float4 d = radians(v); int const_dot = dot(int2(0, 0), int2(0, 0)); + uint first_leading_bit_abs = firstbithigh(abs(0u)); } diff --git a/tests/out/msl/math-functions.msl b/tests/out/msl/math-functions.msl index e08894037e..87de8418ce 100644 --- a/tests/out/msl/math-functions.msl +++ b/tests/out/msl/math-functions.msl @@ -14,4 +14,5 @@ vertex void main_( metal::float4 c = ((v) * 57.295779513082322865); metal::float4 d = ((v) * 0.017453292519943295474); int const_dot = ( + const_type.x * const_type.x + const_type.y * const_type.y); + uint first_leading_bit_abs = (((metal::clz(metal::abs(0u)) + 1) % 33) - 1); } diff --git a/tests/out/spv/math-functions.spvasm b/tests/out/spv/math-functions.spvasm index 03b4a04234..07a16f754c 100644 --- a/tests/out/spv/math-functions.spvasm +++ b/tests/out/spv/math-functions.spvasm @@ -1,38 +1,42 @@ ; SPIR-V ; Version: 1.1 ; Generator: rspirv -; Bound: 29 +; Bound: 33 OpCapability Shader %1 = OpExtInstImport "GLSL.std.450" OpMemoryModel Logical GLSL450 -OpEntryPoint Vertex %11 "main" +OpEntryPoint Vertex %13 "main" %2 = OpTypeVoid %4 = OpTypeFloat 32 %3 = OpConstant %4 1.0 %5 = OpConstant %4 0.0 %7 = OpTypeInt 32 1 %6 = OpConstant %7 0 -%8 = OpTypeVector %7 2 -%9 = OpConstantComposite %8 %6 %6 -%12 = OpTypeFunction %2 -%14 = OpTypeVector %4 4 -%21 = OpConstantNull %7 -%11 = OpFunction %2 None %12 -%10 = OpLabel -OpBranch %13 -%13 = OpLabel -%15 = OpCompositeConstruct %14 %5 %5 %5 %5 -%16 = OpExtInst %4 %1 Degrees %3 -%17 = OpExtInst %4 %1 Radians %3 -%18 = OpExtInst %14 %1 Degrees %15 -%19 = OpExtInst %14 %1 Radians %15 -%22 = OpCompositeExtract %7 %9 0 -%23 = OpCompositeExtract %7 %9 0 -%24 = OpIMul %7 %22 %23 -%25 = OpIAdd %7 %21 %24 -%26 = OpCompositeExtract %7 %9 1 -%27 = OpCompositeExtract %7 %9 1 -%28 = OpIMul %7 %26 %27 -%20 = OpIAdd %7 %25 %28 +%9 = OpTypeInt 32 0 +%8 = OpConstant %9 0 +%10 = OpTypeVector %7 2 +%11 = OpConstantComposite %10 %6 %6 +%14 = OpTypeFunction %2 +%16 = OpTypeVector %4 4 +%23 = OpConstantNull %7 +%13 = OpFunction %2 None %14 +%12 = OpLabel +OpBranch %15 +%15 = OpLabel +%17 = OpCompositeConstruct %16 %5 %5 %5 %5 +%18 = OpExtInst %4 %1 Degrees %3 +%19 = OpExtInst %4 %1 Radians %3 +%20 = OpExtInst %16 %1 Degrees %17 +%21 = OpExtInst %16 %1 Radians %17 +%24 = OpCompositeExtract %7 %11 0 +%25 = OpCompositeExtract %7 %11 0 +%26 = OpIMul %7 %24 %25 +%27 = OpIAdd %7 %23 %26 +%28 = OpCompositeExtract %7 %11 1 +%29 = OpCompositeExtract %7 %11 1 +%30 = OpIMul %7 %28 %29 +%22 = OpIAdd %7 %27 %30 +%31 = OpCopyObject %9 %8 +%32 = OpExtInst %9 %1 FindUMsb %31 OpReturn OpFunctionEnd \ No newline at end of file diff --git a/tests/out/wgsl/math-functions.wgsl b/tests/out/wgsl/math-functions.wgsl index 2be33d5647..9543168794 100644 --- a/tests/out/wgsl/math-functions.wgsl +++ b/tests/out/wgsl/math-functions.wgsl @@ -6,4 +6,5 @@ fn main() { let c = degrees(v); let d = radians(v); let const_dot = dot(vec2(0, 0), vec2(0, 0)); + let first_leading_bit_abs = firstLeadingBit(abs(0u)); }