Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[webgpu] Update ADD,COMPLEX_MULTIPLY_*,DIV,MUL,SQUARED_DIFFERENCE,SUB #7737

Merged
merged 1 commit into from
Jun 6, 2023
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 21 additions & 14 deletions tfjs-backend-webgpu/src/binary_op_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,15 +41,15 @@ export enum BinaryOpType {
SUB
}

const ADD = 'return a + b;';
const ADD = 'let resultTemp = a + b;';
const ATAN2 = 'var resultTemp = atan2(a, b);';
// (Ar + Ai)(Br + Bi) =
// ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr
// Yr = ArBr - AB
// Yi = ArBi + AiBr
const COMPLEX_MULTIPLY_REAL = 'return areal * breal - aimag * bimag;';
const COMPLEX_MULTIPLY_IMAG = 'return areal * bimag + aimag * breal;';
const DIV = 'return a / b;';
const COMPLEX_MULTIPLY_REAL = 'let resultTemp = areal * breal - aimag * bimag;';
const COMPLEX_MULTIPLY_IMAG = 'let resultTemp = areal * bimag + aimag * breal;';
const DIV = 'let resultTemp = a / b;';
const ELU_DER = 'return select(a * (b + 1.0), a, b >= 0.);';
const ELU_DER_VEC4 =
'return select(a * (b + vec4<f32>(1.0)), a, b >= vec4<f32>(0.));';
Expand Down Expand Up @@ -123,7 +123,7 @@ const MOD_VEC4 = `
resultTemp[3] = (resultTemp[3] + b[3]) % b[3];
}
`;
const MUL = 'return a * b;';
const MUL = 'let resultTemp = a * b;';
const NOT_EQUAL = `
var resultTemp = f32(a != b);
let valueForNaN = 1.0;
Expand Down Expand Up @@ -169,8 +169,8 @@ const PRELU_VEC4 = `
let aLessThanZero = vec4<f32>(a < vec4<f32>(0.0));
return (aLessThanZero * (b * a)) + ((vec4<f32>(1.0) - aLessThanZero) * a);
`;
const SQUARED_DIFFERENCE = 'return (a - b) * (a - b);';
const SUB = 'return a - b;';
const SQUARED_DIFFERENCE = 'let resultTemp = (a - b) * (a - b);';
const SUB = 'let resultTemp = a - b;';

export function getBinaryOpString(
type: BinaryOpType, useVec4?: boolean): string {
Expand Down Expand Up @@ -235,13 +235,17 @@ export function getBinaryOpString(
// Ops without NaN check
switch (type) {
case BinaryOpType.ADD:
return ADD;
doOpSnippet = ADD;
break;
case BinaryOpType.COMPLEX_MULTIPLY_IMAG:
return COMPLEX_MULTIPLY_IMAG;
doOpSnippet = COMPLEX_MULTIPLY_IMAG;
break;
case BinaryOpType.COMPLEX_MULTIPLY_REAL:
return COMPLEX_MULTIPLY_REAL;
doOpSnippet = COMPLEX_MULTIPLY_REAL;
break;
case BinaryOpType.DIV:
return DIV;
doOpSnippet = DIV;
break;
case BinaryOpType.ELU_DER:
return useVec4 ? ELU_DER_VEC4 : ELU_DER;
case BinaryOpType.EQUAL:
Expand All @@ -261,13 +265,16 @@ export function getBinaryOpString(
case BinaryOpType.LOGICAL_OR:
return useVec4 ? LOGICAL_OR_VEC4 : LOGICAL_OR;
case BinaryOpType.MUL:
return MUL;
doOpSnippet = MUL;
break;
case BinaryOpType.PRELU:
return useVec4 ? PRELU_VEC4 : PRELU;
case BinaryOpType.SQUARED_DIFFERENCE:
return SQUARED_DIFFERENCE;
doOpSnippet = SQUARED_DIFFERENCE;
break;
case BinaryOpType.SUB:
return SUB;
doOpSnippet = SUB;
break;
default:
// throw new Error(`BinaryType ${type} is not implemented!`);
}
Expand Down