Skip to content

Commit ae902e5

Browse files
Use node's util.types.isUint8Array etc for isTypedArray (#7181)
isTypedArray is implemented with `instanceof`, which does not work in jest (jestjs/jest#11864). Instead, use node's builtin `util.types.isUint8Array`, `util.types.isFloat32Array`, etc to perform this check. Fixes #7175. This may also address #7064, but it does not fix the root cause.
1 parent d8b08c9 commit ae902e5

File tree

10 files changed

+145
-65
lines changed

10 files changed

+145
-65
lines changed

package.json

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
"@types/js-yaml": "^4.0.5",
2020
"@types/long": "4.0.1",
2121
"@types/mkdirp": "^0.5.2",
22-
"@types/node": "^12.7.5",
22+
"@types/node": "^18.11.15",
2323
"@types/node-fetch": "~2.1.2",
2424
"@types/offscreencanvas": "^2019.7.0",
2525
"@types/rollup-plugin-visualizer": "^4.2.1",

tfjs-core/src/BUILD.bazel

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ TEST_ENTRYPOINTS = [
3131
"setup_test.ts",
3232
"worker_test.ts",
3333
"worker_node_test.ts",
34+
"platforms/platform_node_test.ts",
3435
"ops/from_pixels_worker_test.ts",
3536
]
3637

@@ -185,6 +186,26 @@ jasmine_node_test(
185186
],
186187
)
187188

189+
ts_library(
190+
name = "platform_node_test_lib",
191+
srcs = [
192+
"platforms/platform_node_test.ts",
193+
],
194+
deps = [
195+
":tfjs-core_lib",
196+
":tfjs-core_src_lib",
197+
"//tfjs-backend-cpu/src:tfjs-backend-cpu_lib",
198+
"@npm//@types/node",
199+
],
200+
)
201+
202+
jasmine_node_test(
203+
name = "platform_node_test",
204+
deps = [
205+
":platform_node_test_lib",
206+
],
207+
)
208+
188209
ts_library(
189210
name = "worker_test_lib",
190211
srcs = [

tfjs-core/src/platforms/platform.ts

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,7 @@ export interface Platform {
4848
decode(bytes: Uint8Array, encoding: string): string;
4949

5050
setTimeoutCustom?(functionRef: Function, delay: number): void;
51+
52+
isTypedArray(a: unknown): a is Float32Array|Int32Array|Uint8Array|
53+
Uint8ClampedArray;
5154
}

tfjs-core/src/platforms/platform_browser.ts

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -90,6 +90,12 @@ export class PlatformBrowser implements Platform {
9090
}, true);
9191
}
9292
}
93+
94+
isTypedArray(a: unknown): a is Uint8Array | Float32Array | Int32Array
95+
| Uint8ClampedArray {
96+
return a instanceof Float32Array || a instanceof Int32Array ||
97+
a instanceof Uint8Array || a instanceof Uint8ClampedArray;
98+
}
9399
}
94100

95101
if (env().get('IS_BROWSER')) {

tfjs-core/src/platforms/platform_browser_test.ts

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -147,4 +147,19 @@ describeWithFlags('setTimeout', BROWSER_ENVS, () => {
147147
env().platform.setTimeoutCustom(_testSetTimeoutCustom, 0);
148148
}
149149
});
150+
151+
it('isTypedArray returns false if not a typed array', () => {
152+
const platform = new PlatformBrowser();
153+
expect(platform.isTypedArray([1, 2, 3])).toBeFalse();
154+
});
155+
156+
for (const typedArrayConstructor of [Float32Array, Int32Array, Uint8Array,
157+
Uint8ClampedArray]) {
158+
it(`isTypedArray returns true if it is a ${typedArrayConstructor.name}`,
159+
() => {
160+
const platform = new PlatformBrowser();
161+
const array = new typedArrayConstructor([1,2,3]);
162+
expect(platform.isTypedArray(array)).toBeTrue();
163+
});
164+
}
150165
});

tfjs-core/src/platforms/platform_node.ts

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -79,6 +79,13 @@ export class PlatformNode implements Platform {
7979
}
8080
return new this.util.TextDecoder(encoding).decode(bytes);
8181
}
82+
isTypedArray(a: unknown): a is Float32Array | Int32Array | Uint8Array
83+
| Uint8ClampedArray {
84+
return this.util.types.isFloat32Array(a)
85+
|| this.util.types.isInt32Array(a)
86+
|| this.util.types.isUint8Array(a)
87+
|| this.util.types.isUint8ClampedArray(a);
88+
}
8289
}
8390

8491
if (env().get('IS_NODE') && !env().get('IS_BROWSER')) {

tfjs-core/src/platforms/platform_node_test.ts

Lines changed: 31 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,11 +16,11 @@
1616
*/
1717

1818
import * as tf from '../index';
19-
import {describeWithFlags, NODE_ENVS} from '../jasmine_util';
2019
import * as platform_node from './platform_node';
2120
import {PlatformNode} from './platform_node';
21+
import * as vm from 'node:vm';
2222

23-
describeWithFlags('PlatformNode', NODE_ENVS, () => {
23+
describe('PlatformNode', () => {
2424
it('fetch should use global.fetch if defined', async () => {
2525
const globalFetch = tf.env().global.fetch;
2626

@@ -125,4 +125,33 @@ describeWithFlags('PlatformNode', NODE_ENVS, () => {
125125
expect(s.length).toBe(6);
126126
expect(s).toEqual('Здраво');
127127
});
128+
129+
describe('isTypedArray', () => {
130+
let platform: PlatformNode;
131+
beforeEach(() => {
132+
platform = new PlatformNode();
133+
});
134+
135+
it('returns false if not a typed array', () => {
136+
expect(platform.isTypedArray([1, 2, 3])).toBeFalse();
137+
});
138+
139+
for (const typedArrayConstructor of [Float32Array, Int32Array, Uint8Array,
140+
Uint8ClampedArray]) {
141+
it(`returns true if it is a ${typedArrayConstructor.name}`,
142+
() => {
143+
const array = new typedArrayConstructor([1,2,3]);
144+
expect(platform.isTypedArray(array)).toBeTrue();
145+
});
146+
}
147+
148+
it('works on values created in a new node context', async () => {
149+
const array = await new Promise((resolve) => {
150+
const code = `resolve(new Uint8Array([1, 2, 3]));`;
151+
vm.runInNewContext(code, {resolve});
152+
});
153+
154+
expect(platform.isTypedArray(array)).toBeTrue();
155+
});
156+
});
128157
});

tfjs-core/src/util.ts

Lines changed: 56 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717

1818
import {env} from './environment';
19-
import {BackendValues, DataType, TensorLike, TypedArray} from './types';
19+
import {BackendValues, DataType, RecursiveArray, TensorLike, TypedArray} from './types';
2020
import * as base from './util_base';
2121
export * from './util_base';
2222
export * from './hash_util';
@@ -44,7 +44,7 @@ export function toTypedArray(a: TensorLike, dtype: DataType): TypedArray {
4444
throw new Error('Cannot convert a string[] to a TypedArray');
4545
}
4646
if (Array.isArray(a)) {
47-
a = base.flatten(a);
47+
a = flatten(a);
4848
}
4949

5050
if (env().getBool('DEBUG')) {
@@ -131,3 +131,57 @@ export function decodeString(bytes: Uint8Array, encoding = 'utf-8'): string {
131131
encoding = encoding || 'utf-8';
132132
return env().platform.decode(bytes, encoding);
133133
}
134+
135+
export function isTypedArray(a: {}): a is Float32Array|Int32Array|Uint8Array|
136+
Uint8ClampedArray {
137+
return env().platform.isTypedArray(a);
138+
}
139+
140+
// NOTE: We explicitly type out what T extends instead of any so that
141+
// util.flatten on a nested array of number doesn't try to infer T as a
142+
// number[][], causing us to explicitly type util.flatten<number>().
143+
/**
144+
* Flattens an arbitrarily nested array.
145+
*
146+
* ```js
147+
* const a = [[1, 2], [3, 4], [5, [6, [7]]]];
148+
* const flat = tf.util.flatten(a);
149+
* console.log(flat);
150+
* ```
151+
*
152+
* @param arr The nested array to flatten.
153+
* @param result The destination array which holds the elements.
154+
* @param skipTypedArray If true, avoids flattening the typed arrays. Defaults
155+
* to false.
156+
*
157+
* @doc {heading: 'Util', namespace: 'util'}
158+
*/
159+
export function
160+
flatten<T extends number|boolean|string|Promise<number>|TypedArray>(
161+
arr: T|RecursiveArray<T>, result: T[] = [], skipTypedArray = false): T[] {
162+
if (result == null) {
163+
result = [];
164+
}
165+
if (typeof arr === 'boolean' || typeof arr === 'number' ||
166+
typeof arr === 'string' || base.isPromise(arr) || arr == null ||
167+
isTypedArray(arr) && skipTypedArray) {
168+
result.push(arr as T);
169+
} else if (Array.isArray(arr) || isTypedArray(arr)) {
170+
for (let i = 0; i < arr.length; ++i) {
171+
flatten(arr[i], result, skipTypedArray);
172+
}
173+
} else {
174+
let maxIndex = -1;
175+
for (const key of Object.keys(arr)) {
176+
// 0 or positive integer.
177+
if (/^([1-9]+[0-9]*|0)$/.test(key)) {
178+
maxIndex = Math.max(maxIndex, Number(key));
179+
}
180+
}
181+
for (let i = 0; i <= maxIndex; i++) {
182+
// tslint:disable-next-line: no-unnecessary-type-assertion
183+
flatten((arr as RecursiveArray<T>)[i], result, skipTypedArray);
184+
}
185+
}
186+
return result;
187+
}

tfjs-core/src/util_base.ts

Lines changed: 1 addition & 56 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
* =============================================================================
1616
*/
1717

18-
import {BackendValues, DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray, WebGLData, WebGPUData} from './types';
18+
import {BackendValues, DataType, DataTypeMap, FlatVector, NumericDataType, TensorLike, TypedArray, WebGLData, WebGPUData} from './types';
1919

2020
/**
2121
* Shuffles the array in-place using Fisher-Yates algorithm.
@@ -167,55 +167,6 @@ export function assertNonNull(a: TensorLike): void {
167167
() => `The input to the tensor constructor must be a non-null value.`);
168168
}
169169

170-
// NOTE: We explicitly type out what T extends instead of any so that
171-
// util.flatten on a nested array of number doesn't try to infer T as a
172-
// number[][], causing us to explicitly type util.flatten<number>().
173-
/**
174-
* Flattens an arbitrarily nested array.
175-
*
176-
* ```js
177-
* const a = [[1, 2], [3, 4], [5, [6, [7]]]];
178-
* const flat = tf.util.flatten(a);
179-
* console.log(flat);
180-
* ```
181-
*
182-
* @param arr The nested array to flatten.
183-
* @param result The destination array which holds the elements.
184-
* @param skipTypedArray If true, avoids flattening the typed arrays. Defaults
185-
* to false.
186-
*
187-
* @doc {heading: 'Util', namespace: 'util'}
188-
*/
189-
export function
190-
flatten<T extends number|boolean|string|Promise<number>|TypedArray>(
191-
arr: T|RecursiveArray<T>, result: T[] = [], skipTypedArray = false): T[] {
192-
if (result == null) {
193-
result = [];
194-
}
195-
if (typeof arr === 'boolean' || typeof arr === 'number' ||
196-
typeof arr === 'string' || isPromise(arr) || arr == null ||
197-
isTypedArray(arr) && skipTypedArray) {
198-
result.push(arr as T);
199-
} else if (Array.isArray(arr) || isTypedArray(arr)) {
200-
for (let i = 0; i < arr.length; ++i) {
201-
flatten(arr[i], result, skipTypedArray);
202-
}
203-
} else {
204-
let maxIndex = -1;
205-
for (const key of Object.keys(arr)) {
206-
// 0 or positive integer.
207-
if (/^([1-9]+[0-9]*|0)$/.test(key)) {
208-
maxIndex = Math.max(maxIndex, Number(key));
209-
}
210-
}
211-
for (let i = 0; i <= maxIndex; i++) {
212-
// tslint:disable-next-line: no-unnecessary-type-assertion
213-
flatten((arr as RecursiveArray<T>)[i], result, skipTypedArray);
214-
}
215-
}
216-
return result;
217-
}
218-
219170
/**
220171
* Returns the size (number of elements) of the tensor given its shape.
221172
*
@@ -527,12 +478,6 @@ export function hasEncodingLoss(oldType: DataType, newType: DataType): boolean {
527478
return true;
528479
}
529480

530-
export function isTypedArray(a: {}): a is Float32Array|Int32Array|Uint8Array|
531-
Uint8ClampedArray {
532-
return a instanceof Float32Array || a instanceof Int32Array ||
533-
a instanceof Uint8Array || a instanceof Uint8ClampedArray;
534-
}
535-
536481
export function bytesPerElement(dtype: DataType): number {
537482
if (dtype === 'float32' || dtype === 'int32') {
538483
return 4;

yarn.lock

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -384,10 +384,10 @@
384384
resolved "https://registry.yarnpkg.com/@types/node/-/node-10.17.60.tgz#35f3d6213daed95da7f0f73e75bcc6980e90597b"
385385
integrity sha512-F0KIgDJfy2nA3zMLmWGKxcH2ZVEtCZXHHdOQs2gSaQ27+lNeEfGxzkIw90aXswATX7AZ33tahPbzy6KAfUreVw==
386386

387-
"@types/node@^12.7.5":
388-
version "12.20.28"
389-
resolved "https://registry.yarnpkg.com/@types/node/-/node-12.20.28.tgz#4b20048c6052b5f51a8d5e0d2acbf63d5a17e1e2"
390-
integrity sha512-cBw8gzxUPYX+/5lugXIPksioBSbE42k0fZ39p+4yRzfYjN6++eq9kAPdlY9qm+MXyfbk9EmvCYAYRn380sF46w==
387+
"@types/node@^18.11.15":
388+
version "18.11.15"
389+
resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.15.tgz#de0e1fbd2b22b962d45971431e2ae696643d3f5d"
390+
integrity sha512-VkhBbVo2+2oozlkdHXLrb3zjsRkpdnaU2bXmX8Wgle3PUi569eLRaHGlgETQHR7lLL1w7GiG3h9SnePhxNDecw==
391391

392392
"@types/offscreencanvas@^2019.7.0":
393393
version "2019.7.0"

0 commit comments

Comments
 (0)