Skip to content

Commit

Permalink
Add uniforms
Browse files Browse the repository at this point in the history
  • Loading branch information
Linchenn committed Aug 23, 2023
1 parent 20ceb6f commit 645802f
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 7 deletions.
6 changes: 3 additions & 3 deletions tfjs-backend-webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,7 @@ export interface WebGPUTimingInfo extends TimingInfo {
downloadWaitMs: number;
}

type ProgramUniform = Array<{type: string; data: number[]}>;
export type ProgramUniform = Array<{type: string; data: number[]}>;

// Empirically determined constant used to determine size threshold for handing
// off execution to the CPU.
Expand Down Expand Up @@ -882,8 +882,8 @@ export class WebGPUBackend extends KernelBackend {
};
});

program.shaderKey =
webgpu_program.makeShaderKey(program, inputsData, output);
program.shaderKey = webgpu_program.makeShaderKey(
program, inputsData, output, programDefinedUniform);

const parallelCompilation = env().getBool('WEBGPU_ENGINE_COMPILE_ONLY');
if (!(program.shaderKey in this.pipelineCache)) {
Expand Down
16 changes: 12 additions & 4 deletions tfjs-backend-webgpu/src/webgpu_program.ts
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

import {backend_util, DataType, DataTypeMap, env, Rank, TensorInfo, util} from '@tensorflow/tfjs-core';

import {ProgramUniform} from './backend_webgpu';
import {symbolicallyComputeStrides} from './shader_util';

export enum PixelsOpType {
Expand Down Expand Up @@ -348,8 +349,8 @@ function makeShader(
}

export function makeShaderKey<R extends Rank>(
program: WebGPUProgram, inputsData: InputInfo[],
output: TensorInfo): string {
program: WebGPUProgram, inputsData: InputInfo[], output: TensorInfo,
programDefinedUniform?: ProgramUniform): string {
let key = program.shaderKey;
if (program.pixelsOpType != null) {
return key;
Expand All @@ -361,6 +362,13 @@ export function makeShaderKey<R extends Rank>(
shapes.push(element.shape);
types.push(element.dtype);
});
const uniformLengths: number[] = [];
if (programDefinedUniform != null) {
programDefinedUniform.forEach(element => {
uniformLengths.push(element.data.length);
types.push(element.type);
});
}
shapes.push(output.shape);
types.push(output.dtype);

Expand All @@ -373,8 +381,8 @@ export function makeShaderKey<R extends Rank>(
const flatDispatchString = isFlatDispatch(program) ? 'flatDispatch' : '';

key += '_' + (program.workgroupSize ? program.workgroupSize.join(',') : '') +
shapes.map(shape => shape.length).join(',') + types.join(',') +
program.variableNames.join(',') + broadcastDimsKey +
shapes.map(shape => shape.length).join(',') + uniformLengths.join(',') +
types.join(',') + program.variableNames.join(',') + broadcastDimsKey +
inputShapesEqualsOutShape + flatDispatchString;

return key;
Expand Down

0 comments on commit 645802f

Please sign in to comment.