Skip to content

Commit

Permalink
Add a shape member
Browse files Browse the repository at this point in the history
  • Loading branch information
xhcao committed May 6, 2022
1 parent cd2c998 commit d24b0cc
Showing 1 changed file with 17 additions and 12 deletions.
29 changes: 17 additions & 12 deletions tfjs-backend-webgpu/src/backend_webgpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@ type BufferInfo = {
type TensorBufferInfo = {
values: backend_util.BackendValues,
dtype: DataType,
shape: number[],
bufferInfo: BufferInfo,
refCount: number,
// For complex numbers, the real and imaginary parts are stored as their own
Expand Down Expand Up @@ -315,6 +316,7 @@ export class WebGPUBackend extends KernelBackend {

this.tensorMap.set(dataId, {
dtype,
shape,
values,
bufferInfo: {byteSize, usage: this.defaultGpuBufferUsage()},
refCount: 1
Expand All @@ -335,6 +337,7 @@ export class WebGPUBackend extends KernelBackend {

this.tensorMap.set(dataId, {
dtype,
shape,
values,
bufferInfo: {byteSize, usage: this.defaultGpuBufferUsage()},
refCount
Expand Down Expand Up @@ -480,7 +483,7 @@ export class WebGPUBackend extends KernelBackend {
*/
readToGPU(dataId: DataId, options: DataToGPUWebGPUOption = {}): GPUData {
const srcData = this.tensorMap.get(dataId);
const {values, dtype, bufferInfo} = srcData;
const {values, dtype, shape, bufferInfo} = srcData;

if (dtype === 'complex64') {
throw new Error('Does not support reading buffer for complex64 dtype.');
Expand All @@ -494,31 +497,33 @@ export class WebGPUBackend extends KernelBackend {
}
}

const size =
util.sizeFromShape(shape) * webgpu_util.GPUBytesPerElement(dtype);
if (options.customBufSize != null) {
util.assert(
options.customBufSize >= bufferInfo.byteSize,
() => 'customBufSize should be equal or larger than the buffer size');
options.customBufSize >= size,
() => `customBufSize should be equal or larger than ` +
`the source tensor size ${size} bytes.`);
}

const bufferSize = options.customBufSize != null ? options.customBufSize :
bufferInfo.byteSize;
const bufferSize =
options.customBufSize != null ? options.customBufSize : size;
const resBuffer = this.acquireBuffer(bufferSize);
this.ensureCommandEncoderReady();
this.ensureComputePassEnded();
this.currentCommandEncoder.copyBufferToBuffer(
bufferInfo.buffer, 0, resBuffer, 0, bufferInfo.byteSize);
bufferInfo.buffer, 0, resBuffer, 0, size);
this.submitQueue();

const tensorInfo = this.makeTensorInfo(
[bufferSize / webgpu_util.GPUBytesPerElement(dtype)], dtype);
const tensorInfo = this.makeTensorInfo(shape, dtype);
// Make engine track this tensor, so that we can dispose it later.
const tensorRef = engine().makeTensorFromDataId(
tensorInfo.dataId,
[bufferInfo.byteSize / webgpu_util.GPUBytesPerElement(dtype)],
tensorInfo.dtype);
const tensorRef = engine().makeTensorFromTensorInfo(tensorInfo);

const info = this.tensorMap.get(tensorInfo.dataId);
info.bufferInfo.buffer = resBuffer;
// Explicitly change the buffer size that could release the buffer
// successfully in future.
info.bufferInfo.byteSize = bufferSize;

return {tensorRef, buffer: resBuffer, bufSize: bufferSize};
}
Expand Down

0 comments on commit d24b0cc

Please sign in to comment.