From ba7130f4605aa80c6f9380595b9d6f2bf695e581 Mon Sep 17 00:00:00 2001 From: Jiajie Hu Date: Sat, 10 Jun 2023 11:59:45 +0800 Subject: [PATCH] [webgpu] Update EQUAL,GREATER,GREATER_EQUAL,LESS,LESS_EQUAL --- tfjs-backend-webgpu/src/binary_op_util.ts | 50 ++++++++++++++++------- 1 file changed, 35 insertions(+), 15 deletions(-) diff --git a/tfjs-backend-webgpu/src/binary_op_util.ts b/tfjs-backend-webgpu/src/binary_op_util.ts index 31a4e502697..9cae252514d 100644 --- a/tfjs-backend-webgpu/src/binary_op_util.ts +++ b/tfjs-backend-webgpu/src/binary_op_util.ts @@ -51,12 +51,21 @@ const COMPLEX_MULTIPLY_REAL = 'let resultTemp = areal * breal - aimag * bimag;'; const COMPLEX_MULTIPLY_IMAG = 'let resultTemp = areal * bimag + aimag * breal;'; const DIV = 'let resultTemp = a / b;'; const ELU_DER = 'let resultTemp = select(a * (b + 1.0), a, b >= b - b);'; -const EQUAL = 'return f32(a == b);'; -const EQUAL_VEC4 = 'return vec4(a == b);'; -const GREATER = 'return f32(a > b);'; -const GREATER_VEC4 = 'return vec4(a > b);'; -const GREATER_EQUAL = 'return f32(a >= b);'; -const GREATER_EQUAL_VEC4 = 'return vec4(a >= b);'; +const EQUAL = ` + let zero = sign(a) * 0 + 0; + let one = sign(b) * 0 + 1; + let resultTemp = select(zero, one, a == b); +`; +const GREATER = ` + let zero = sign(a) * 0 + 0; + let one = sign(b) * 0 + 1; + let resultTemp = select(zero, one, a > b); +`; +const GREATER_EQUAL = ` + let zero = sign(a) * 0 + 0; + let one = sign(b) * 0 + 1; + let resultTemp = select(zero, one, a >= b); +`; const INT_DIV = ` let s = sign(a) * sign(b); @@ -87,10 +96,16 @@ const INT_DIV_VEC4 = ` return vec4(resultTemp); `; -const LESS = 'return f32(a < b);'; -const LESS_VEC4 = 'return vec4(a < b);'; -const LESS_EQUAL = 'return f32(a <= b);'; -const LESS_EQUAL_VEC4 = 'return vec4(a <= b);'; +const LESS = ` + let zero = sign(a) * 0 + 0; + let one = sign(b) * 0 + 1; + let resultTemp = select(zero, one, a < b); +`; +const LESS_EQUAL = ` + let zero = sign(a) * 0 + 0; + let one = sign(b) * 0 + 1; + let resultTemp = select(zero, one, a <= b); +`; const LOGICAL_AND = 'return f32(a >= 1.0 && b >= 1.0);'; const LOGICAL_AND_VEC4 = `return (vec4(a >= vec4(1.0)) * vec4(b >= vec4(1.0)));`; @@ -248,17 +263,22 @@ export function getBinaryOpString( doOpSnippet = ELU_DER; break; case BinaryOpType.EQUAL: - return useVec4 ? EQUAL_VEC4 : EQUAL; + doOpSnippet = EQUAL; + break; case BinaryOpType.GREATER: - return useVec4 ? GREATER_VEC4 : GREATER; + doOpSnippet = GREATER; + break; case BinaryOpType.GREATER_EQUAL: - return useVec4 ? GREATER_EQUAL_VEC4 : GREATER_EQUAL; + doOpSnippet = GREATER_EQUAL; + break; case BinaryOpType.INT_DIV: return useVec4 ? INT_DIV_VEC4 : INT_DIV; case BinaryOpType.LESS: - return useVec4 ? LESS_VEC4 : LESS; + doOpSnippet = LESS; + break; case BinaryOpType.LESS_EQUAL: - return useVec4 ? LESS_EQUAL_VEC4 : LESS_EQUAL; + doOpSnippet = LESS_EQUAL; + break; case BinaryOpType.LOGICAL_AND: return useVec4 ? LOGICAL_AND_VEC4 : LOGICAL_AND; case BinaryOpType.LOGICAL_OR: