Skip to content

Commit

Permalink
[tfjs-node] fixed range issue for int32 tensor with size larger than …
Browse files Browse the repository at this point in the history
…2 ^ 24 (#7931)

* fixed range issue when creating int32 tensor with larger than 2 ^ 24 size

* reduce the size of the test tensor

* remove use of max op

* fixed test

* fixed the test
  • Loading branch information
pyu10055 authored Aug 28, 2023
1 parent a0115ea commit f4271a5
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 6 deletions.
9 changes: 9 additions & 0 deletions tfjs-core/src/ops/range_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -148,4 +148,13 @@ describeWithFlags('range', ALL_ENVS, () => {
expect(a.dtype).toEqual('int32');
expect(a.shape).toEqual([3]);
});

it('should support large number for int32 dtype', async () => {
const length = Math.pow(2, 24) + 10;
const a = tf.range(1, length, 1, 'int32');
const data = await a.data();
expect(data[length - 2]).toEqual(length - 1);
expect(a.dtype).toEqual('int32');
expect(a.shape).toEqual([length - 1]);
});
});
10 changes: 4 additions & 6 deletions tfjs-node/src/kernels/Range.ts
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,16 @@ export const rangeConfig: KernelConfig = {
}

const opAttrs = [createTensorsTypeOpAttr('Tidx', dtype)];
const startTensor = scalar(start);
const stopTensor = scalar(stop);
const stepTensor = scalar(step);
const startTensor = scalar(start, dtype);
const stopTensor = scalar(stop, dtype);
const stepTensor = scalar(step, dtype);
const res = backend.executeSingleOutput(
Range, opAttrs, [startTensor, stopTensor, stepTensor]);
const castedRes = res.cast(dtype);

startTensor.dispose();
stopTensor.dispose();
stepTensor.dispose();
res.dispose();

return castedRes;
return res;
}
};

0 comments on commit f4271a5

Please sign in to comment.