From f4271a5549a1e9bd1388aa67358ff7229f649b23 Mon Sep 17 00:00:00 2001 From: Ping Yu <4018+pyu10055@users.noreply.github.com> Date: Mon, 28 Aug 2023 11:49:15 -0700 Subject: [PATCH] [tfjs-node] fixed range issue for int32 tensor with size larger than 2 ^ 24 (#7931) * fixed range issue when creating int32 tensor with larger than 2 ^ 24 size * reduce the size of the test tensor * remove use of max op * fixed test * fixed the test --- tfjs-core/src/ops/range_test.ts | 9 +++++++++ tfjs-node/src/kernels/Range.ts | 10 ++++------ 2 files changed, 13 insertions(+), 6 deletions(-) diff --git a/tfjs-core/src/ops/range_test.ts b/tfjs-core/src/ops/range_test.ts index dc59533a1cb..3fed0223c82 100644 --- a/tfjs-core/src/ops/range_test.ts +++ b/tfjs-core/src/ops/range_test.ts @@ -148,4 +148,13 @@ describeWithFlags('range', ALL_ENVS, () => { expect(a.dtype).toEqual('int32'); expect(a.shape).toEqual([3]); }); + + it('should support large number for int32 dtype', async () => { + const length = Math.pow(2, 24) + 10; + const a = tf.range(1, length, 1, 'int32'); + const data = await a.data(); + expect(data[length - 2]).toEqual(length - 1); + expect(a.dtype).toEqual('int32'); + expect(a.shape).toEqual([length - 1]); + }); }); diff --git a/tfjs-node/src/kernels/Range.ts b/tfjs-node/src/kernels/Range.ts index b91780b85f0..134d0f1f281 100644 --- a/tfjs-node/src/kernels/Range.ts +++ b/tfjs-node/src/kernels/Range.ts @@ -44,18 +44,16 @@ export const rangeConfig: KernelConfig = { } const opAttrs = [createTensorsTypeOpAttr('Tidx', dtype)]; - const startTensor = scalar(start); - const stopTensor = scalar(stop); - const stepTensor = scalar(step); + const startTensor = scalar(start, dtype); + const stopTensor = scalar(stop, dtype); + const stepTensor = scalar(step, dtype); const res = backend.executeSingleOutput( Range, opAttrs, [startTensor, stopTensor, stepTensor]); - const castedRes = res.cast(dtype); startTensor.dispose(); stopTensor.dispose(); stepTensor.dispose(); - res.dispose(); - return castedRes; + return res; } };