diff --git a/tfjs-backend-webgpu/src/binary_op_util.ts b/tfjs-backend-webgpu/src/binary_op_util.ts index d03f6de416c..c41303cab70 100644 --- a/tfjs-backend-webgpu/src/binary_op_util.ts +++ b/tfjs-backend-webgpu/src/binary_op_util.ts @@ -41,15 +41,15 @@ export enum BinaryOpType { SUB } -const ADD = 'return a + b;'; +const ADD = 'let resultTemp = a + b;'; const ATAN2 = 'var resultTemp = atan2(a, b);'; // (Ar + Ai)(Br + Bi) = // ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr // Yr = ArBr - AB // Yi = ArBi + AiBr -const COMPLEX_MULTIPLY_REAL = 'return areal * breal - aimag * bimag;'; -const COMPLEX_MULTIPLY_IMAG = 'return areal * bimag + aimag * breal;'; -const DIV = 'return a / b;'; +const COMPLEX_MULTIPLY_REAL = 'let resultTemp = areal * breal - aimag * bimag;'; +const COMPLEX_MULTIPLY_IMAG = 'let resultTemp = areal * bimag + aimag * breal;'; +const DIV = 'let resultTemp = a / b;'; const ELU_DER = 'return select(a * (b + 1.0), a, b >= 0.);'; const ELU_DER_VEC4 = 'return select(a * (b + vec4(1.0)), a, b >= vec4(0.));'; @@ -123,7 +123,7 @@ const MOD_VEC4 = ` resultTemp[3] = (resultTemp[3] + b[3]) % b[3]; } `; -const MUL = 'return a * b;'; +const MUL = 'let resultTemp = a * b;'; const NOT_EQUAL = ` var resultTemp = f32(a != b); let valueForNaN = 1.0; @@ -169,8 +169,8 @@ const PRELU_VEC4 = ` let aLessThanZero = vec4(a < vec4(0.0)); return (aLessThanZero * (b * a)) + ((vec4(1.0) - aLessThanZero) * a); `; -const SQUARED_DIFFERENCE = 'return (a - b) * (a - b);'; -const SUB = 'return a - b;'; +const SQUARED_DIFFERENCE = 'let resultTemp = (a - b) * (a - b);'; +const SUB = 'let resultTemp = a - b;'; export function getBinaryOpString( type: BinaryOpType, useVec4?: boolean): string { @@ -235,13 +235,17 @@ export function getBinaryOpString( // Ops without NaN check switch (type) { case BinaryOpType.ADD: - return ADD; + doOpSnippet = ADD; + break; case BinaryOpType.COMPLEX_MULTIPLY_IMAG: - return COMPLEX_MULTIPLY_IMAG; + doOpSnippet = COMPLEX_MULTIPLY_IMAG; + break; case BinaryOpType.COMPLEX_MULTIPLY_REAL: - return COMPLEX_MULTIPLY_REAL; + doOpSnippet = COMPLEX_MULTIPLY_REAL; + break; case BinaryOpType.DIV: - return DIV; + doOpSnippet = DIV; + break; case BinaryOpType.ELU_DER: return useVec4 ? ELU_DER_VEC4 : ELU_DER; case BinaryOpType.EQUAL: @@ -261,13 +265,16 @@ export function getBinaryOpString( case BinaryOpType.LOGICAL_OR: return useVec4 ? LOGICAL_OR_VEC4 : LOGICAL_OR; case BinaryOpType.MUL: - return MUL; + doOpSnippet = MUL; + break; case BinaryOpType.PRELU: return useVec4 ? PRELU_VEC4 : PRELU; case BinaryOpType.SQUARED_DIFFERENCE: - return SQUARED_DIFFERENCE; + doOpSnippet = SQUARED_DIFFERENCE; + break; case BinaryOpType.SUB: - return SUB; + doOpSnippet = SUB; + break; default: // throw new Error(`BinaryType ${type} is not implemented!`); }