Skip to content

Commit

Permalink
Merge branch 'master' into V3ScriptForMapping
Browse files Browse the repository at this point in the history
  • Loading branch information
fengwuyao authored Jun 12, 2023
2 parents 4751e92 + af8fb59 commit 382a5bf
Show file tree
Hide file tree
Showing 14 changed files with 358 additions and 329 deletions.
9 changes: 4 additions & 5 deletions tfjs-backend-webgpu/src/argminmax_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -46,11 +46,10 @@ export class ArgMinMaxProgram implements WebGPUProgram {
this.dispatchLayout = flatDispatchLayout(this.outputShape);
// The shared algorithm is mainly used for large reduce size. It fully
// utilizes the threads in one workgroup to do the reduction. However,
// when the reduce size is very small or the output shape is too large. It's
// better to use the plain algorithm to reduce the number of workgroups to
// speedup. The threthold can be further tuned.
if (util.sizeFromShape(reduceShape) < 32 ||
util.sizeFromShape(outputShape) > 1000) {
// when the reduce size is very small, it's better to use the plain
// algorithm to reduce the number of workgroups to speedup. The threthold
// can be further tuned.
if (util.sizeFromShape(reduceShape) < 32) {
this.type = 'plain';
this.dispatch = computeDispatch(
this.dispatchLayout, this.outputShape, this.workgroupSize);
Expand Down
246 changes: 0 additions & 246 deletions tfjs-backend-webgpu/src/benchmark_ops_test.ts

This file was deleted.

13 changes: 6 additions & 7 deletions tfjs-backend-webgpu/src/binary_op_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -42,17 +42,15 @@ export enum BinaryOpType {
}

const ADD = 'let resultTemp = a + b;';
const ATAN2 = 'var resultTemp = atan2(a, b);';
const ATAN2 = 'let resultTemp = atan2(a, b);';
// (Ar + Ai)(Br + Bi) =
// ArBr + ArBi + AiBr + AiBi = ArBr - AB + ArBi + AiBr
// Yr = ArBr - AB
// Yi = ArBi + AiBr
const COMPLEX_MULTIPLY_REAL = 'let resultTemp = areal * breal - aimag * bimag;';
const COMPLEX_MULTIPLY_IMAG = 'let resultTemp = areal * bimag + aimag * breal;';
const DIV = 'let resultTemp = a / b;';
const ELU_DER = 'return select(a * (b + 1.0), a, b >= 0.);';
const ELU_DER_VEC4 =
'return select(a * (b + vec4<f32>(1.0)), a, b >= vec4<f32>(0.));';
const ELU_DER = 'let resultTemp = select(a * (b + 1.0), a, b >= b - b);';
const EQUAL = 'return f32(a == b);';
const EQUAL_VEC4 = 'return vec4<f32>(a == b);';
const GREATER = 'return f32(a > b);';
Expand Down Expand Up @@ -99,8 +97,8 @@ const LOGICAL_AND_VEC4 = `return (vec4<f32>(a >= vec4<f32>(1.0)) *
const LOGICAL_OR = 'return f32(a >= 1.0 || b >= 1.0);';
const LOGICAL_OR_VEC4 = `return min(vec4<f32>(a >= vec4<f32>(1.0)) +
vec4<f32>(b >= vec4<f32>(1.0)), vec4<f32>(1.0));`;
const MAX = 'var resultTemp = max(a, b);';
const MIN = 'var resultTemp = min(a, b);';
const MAX = 'let resultTemp = max(a, b);';
const MIN = 'let resultTemp = min(a, b);';
const MOD = `
let isNaN = b == 0.;
var resultTemp = a % b;
Expand Down Expand Up @@ -247,7 +245,8 @@ export function getBinaryOpString(
doOpSnippet = DIV;
break;
case BinaryOpType.ELU_DER:
return useVec4 ? ELU_DER_VEC4 : ELU_DER;
doOpSnippet = ELU_DER;
break;
case BinaryOpType.EQUAL:
return useVec4 ? EQUAL_VEC4 : EQUAL;
case BinaryOpType.GREATER:
Expand Down
15 changes: 5 additions & 10 deletions tfjs-backend-webgpu/src/conv2d_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,9 @@ function conv2dCommonSnippet(
const getWSnippet = (innerElementSize: number) => {
switch (innerElementSize) {
case 1:
return 'return W[row * uniforms.wShape[3] + colIn];';
return 'return W[row * uniforms.wShape[3] + col];';
case 4:
return 'return W[row * uniforms.wShape[3] / 4 + colIn];';
return 'return W[(row * uniforms.wShape[3] + col) / 4];';
default:
throw new Error(
`innerElementSize ${innerElementSize} is not supported.`);
Expand Down Expand Up @@ -101,19 +101,15 @@ function conv2dCommonSnippet(
return resData;`;

const sampleX = isChannelsLast ? (fitAOuter && fitInner ? `
let col = colIn * ${innerElementSizeX};
${readXSnippet}` :
`
let col = colIn * ${innerElementSizeX};
if (row < uniforms.dimAOuter && col < uniforms.dimInner) {
${readXSnippet}
}
return ${typeSnippet(innerElementSizeX)}(0.0);`) :
(fitInner && fitBOuter ? `
let col = colIn * ${innerElementSizeX};
${readXSnippet}` :
`
let col = colIn * ${innerElementSizeX};
if (row < uniforms.dimInner && col < uniforms.dimBOuter) {
${readXSnippet}
}
Expand All @@ -130,16 +126,15 @@ function conv2dCommonSnippet(
${
activationFnSnippet(
activation, hasPreluActivationWeights, innerElementSize === 4, 4)}
fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${aType} {
fn mm_readA(batch: i32, row : i32, col : i32) -> ${aType} {
${isChannelsLast ? sampleX : sampleW}
}
fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${bType} {
fn mm_readB(batch: i32, row : i32, col : i32) -> ${bType} {
${isChannelsLast ? sampleW : sampleX}
}
fn mm_write(batch: i32, row : i32, colIn : i32, valueIn : ${resType}) {
let col = colIn * ${innerElementSize};
fn mm_write(batch: i32, row : i32, col : i32, valueIn : ${resType}) {
if (row < uniforms.dimAOuter && col < uniforms.dimBOuter)
{
var value = valueIn;
Expand Down
12 changes: 4 additions & 8 deletions tfjs-backend-webgpu/src/conv_backprop_mm_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -71,15 +71,13 @@ function conv2dTransposeCommonSnippet(innerElementSize = 4) {
return ${typeSnippet(innerElementSize)}(0.0);`;

const userCode = `
fn mm_readA(batch: i32, row : i32, colIn : i32) -> ${
fn mm_readA(batch: i32, row : i32, col : i32) -> ${
typeSnippet(innerElementSize)} {
let col = colIn * ${innerElementSize};
${sampleA}
}
fn mm_readB(batch: i32, row : i32, colIn : i32) -> ${
fn mm_readB(batch: i32, row : i32, col : i32) -> ${
typeSnippet(innerElementSize)} {
let col = colIn * ${innerElementSize};
let coordX = uniforms.filterDims.x - 1 -
row / (uniforms.filterDims[1] * uniforms.outBackprop[3]);
let coordY = uniforms.filterDims.y - 1 -
Expand All @@ -93,11 +91,9 @@ function conv2dTransposeCommonSnippet(innerElementSize = 4) {
return ${typeSnippet(innerElementSize)}(0.0);
}
fn mm_write(batch: i32, row : i32, colIn : i32, valueInput : ${
fn mm_write(batch: i32, row : i32, col : i32, valueInput : ${
typeSnippet(innerElementSize)}) {
let col = colIn * ${innerElementSize};
if (row < uniforms.dimAOuter && (col + ${
innerElementSize - 1}) < uniforms.dimBOuter) {
if (row < uniforms.dimAOuter && col < uniforms.dimBOuter) {
var value = valueInput;
let outCoord = vec4<i32>(
batch,
Expand Down
Loading

0 comments on commit 382a5bf

Please sign in to comment.