diff --git a/tfjs-backend-webgpu/src/backend_webgpu.ts b/tfjs-backend-webgpu/src/backend_webgpu.ts index ba3ffecd2a6..1b8da1e7ee7 100644 --- a/tfjs-backend-webgpu/src/backend_webgpu.ts +++ b/tfjs-backend-webgpu/src/backend_webgpu.ts @@ -60,7 +60,7 @@ export interface WebGPUTimingInfo extends TimingInfo { downloadWaitMs: number; } -type ProgramUniform = Array<{type: string; data: number[]}>; +export type ProgramUniform = Array<{type: string; data: number[]}>; // Empirically determined constant used to determine size threshold for handing // off execution to the CPU. @@ -882,8 +882,8 @@ export class WebGPUBackend extends KernelBackend { }; }); - program.shaderKey = - webgpu_program.makeShaderKey(program, inputsData, output); + program.shaderKey = webgpu_program.makeShaderKey( + program, inputsData, output, programDefinedUniform); const parallelCompilation = env().getBool('WEBGPU_ENGINE_COMPILE_ONLY'); if (!(program.shaderKey in this.pipelineCache)) { diff --git a/tfjs-backend-webgpu/src/webgpu_program.ts b/tfjs-backend-webgpu/src/webgpu_program.ts index ddaeddc6b9a..f2fe599c344 100644 --- a/tfjs-backend-webgpu/src/webgpu_program.ts +++ b/tfjs-backend-webgpu/src/webgpu_program.ts @@ -17,6 +17,7 @@ import {backend_util, DataType, DataTypeMap, env, Rank, TensorInfo, util} from '@tensorflow/tfjs-core'; +import {ProgramUniform} from './backend_webgpu'; import {symbolicallyComputeStrides} from './shader_util'; export enum PixelsOpType { @@ -348,8 +349,8 @@ function makeShader( } export function makeShaderKey( - program: WebGPUProgram, inputsData: InputInfo[], - output: TensorInfo): string { + program: WebGPUProgram, inputsData: InputInfo[], output: TensorInfo, + programDefinedUniform?: ProgramUniform): string { let key = program.shaderKey; if (program.pixelsOpType != null) { return key; @@ -361,6 +362,13 @@ export function makeShaderKey( shapes.push(element.shape); types.push(element.dtype); }); + const uniformLengths: number[] = []; + if (programDefinedUniform != null) { + programDefinedUniform.forEach(element => { + uniformLengths.push(element.data.length); + types.push(element.type); + }); + } shapes.push(output.shape); types.push(output.dtype); @@ -373,8 +381,8 @@ export function makeShaderKey( const flatDispatchString = isFlatDispatch(program) ? 'flatDispatch' : ''; key += '_' + (program.workgroupSize ? program.workgroupSize.join(',') : '') + - shapes.map(shape => shape.length).join(',') + types.join(',') + - program.variableNames.join(',') + broadcastDimsKey + + shapes.map(shape => shape.length).join(',') + uniformLengths.join(',') + + types.join(',') + program.variableNames.join(',') + broadcastDimsKey + inputShapesEqualsOutShape + flatDispatchString; return key;