Skip to content

Commit

Permalink
Use node's util.types.isUint8Array etc for isTypedArray (#7181)
Browse files Browse the repository at this point in the history
isTypedArray is implemented with `instanceof`, which does not work in jest (jestjs/jest#11864). Instead, use node's builtin `util.types.isUint8Array`, `util.types.isFloat32Array`, etc to perform this check.

Fixes #7175.
This may also address #7064, but it does not fix the root cause.
  • Loading branch information
mattsoulanille committed Dec 16, 2022
1 parent d8b08c9 commit ae902e5
Show file tree
Hide file tree
Showing 10 changed files with 145 additions and 65 deletions.
2 changes: 1 addition & 1 deletion package.json
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
"@types/js-yaml": "^4.0.5",
"@types/long": "4.0.1",
"@types/mkdirp": "^0.5.2",
"@types/node": "^12.7.5",
"@types/node": "^18.11.15",
"@types/node-fetch": "~2.1.2",
"@types/offscreencanvas": "^2019.7.0",
"@types/rollup-plugin-visualizer": "^4.2.1",
Expand Down
21 changes: 21 additions & 0 deletions tfjs-core/src/BUILD.bazel
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ TEST_ENTRYPOINTS = [
"setup_test.ts",
"worker_test.ts",
"worker_node_test.ts",
"platforms/platform_node_test.ts",
"ops/from_pixels_worker_test.ts",
]

Expand Down Expand Up @@ -185,6 +186,26 @@ jasmine_node_test(
],
)

ts_library(
name = "platform_node_test_lib",
srcs = [
"platforms/platform_node_test.ts",
],
deps = [
":tfjs-core_lib",
":tfjs-core_src_lib",
"//tfjs-backend-cpu/src:tfjs-backend-cpu_lib",
"@npm//@types/node",
],
)

jasmine_node_test(
name = "platform_node_test",
deps = [
":platform_node_test_lib",
],
)

ts_library(
name = "worker_test_lib",
srcs = [
Expand Down
3 changes: 3 additions & 0 deletions tfjs-core/src/platforms/platform.ts
Original file line number Diff line number Diff line change
Expand Up @@ -48,4 +48,7 @@ export interface Platform {
decode(bytes: Uint8Array, encoding: string): string;

setTimeoutCustom?(functionRef: Function, delay: number): void;

isTypedArray(a: unknown): a is Float32Array|Int32Array|Uint8Array|
Uint8ClampedArray;
}
6 changes: 6 additions & 0 deletions tfjs-core/src/platforms/platform_browser.ts
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,12 @@ export class PlatformBrowser implements Platform {
}, true);
}
}

isTypedArray(a: unknown): a is Uint8Array | Float32Array | Int32Array
| Uint8ClampedArray {
return a instanceof Float32Array || a instanceof Int32Array ||
a instanceof Uint8Array || a instanceof Uint8ClampedArray;
}
}

if (env().get('IS_BROWSER')) {
Expand Down
15 changes: 15 additions & 0 deletions tfjs-core/src/platforms/platform_browser_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -147,4 +147,19 @@ describeWithFlags('setTimeout', BROWSER_ENVS, () => {
env().platform.setTimeoutCustom(_testSetTimeoutCustom, 0);
}
});

it('isTypedArray returns false if not a typed array', () => {
const platform = new PlatformBrowser();
expect(platform.isTypedArray([1, 2, 3])).toBeFalse();
});

for (const typedArrayConstructor of [Float32Array, Int32Array, Uint8Array,
Uint8ClampedArray]) {
it(`isTypedArray returns true if it is a ${typedArrayConstructor.name}`,
() => {
const platform = new PlatformBrowser();
const array = new typedArrayConstructor([1,2,3]);
expect(platform.isTypedArray(array)).toBeTrue();
});
}
});
7 changes: 7 additions & 0 deletions tfjs-core/src/platforms/platform_node.ts
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,13 @@ export class PlatformNode implements Platform {
}
return new this.util.TextDecoder(encoding).decode(bytes);
}
isTypedArray(a: unknown): a is Float32Array | Int32Array | Uint8Array
| Uint8ClampedArray {
return this.util.types.isFloat32Array(a)
|| this.util.types.isInt32Array(a)
|| this.util.types.isUint8Array(a)
|| this.util.types.isUint8ClampedArray(a);
}
}

if (env().get('IS_NODE') && !env().get('IS_BROWSER')) {
Expand Down
33 changes: 31 additions & 2 deletions tfjs-core/src/platforms/platform_node_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@
*/

import * as tf from '../index';
import {describeWithFlags, NODE_ENVS} from '../jasmine_util';
import * as platform_node from './platform_node';
import {PlatformNode} from './platform_node';
import * as vm from 'node:vm';

describeWithFlags('PlatformNode', NODE_ENVS, () => {
describe('PlatformNode', () => {
it('fetch should use global.fetch if defined', async () => {
const globalFetch = tf.env().global.fetch;

Expand Down Expand Up @@ -125,4 +125,33 @@ describeWithFlags('PlatformNode', NODE_ENVS, () => {
expect(s.length).toBe(6);
expect(s).toEqual('Здраво');
});

describe('isTypedArray', () => {
let platform: PlatformNode;
beforeEach(() => {
platform = new PlatformNode();
});

it('returns false if not a typed array', () => {
expect(platform.isTypedArray([1, 2, 3])).toBeFalse();
});

for (const typedArrayConstructor of [Float32Array, Int32Array, Uint8Array,
Uint8ClampedArray]) {
it(`returns true if it is a ${typedArrayConstructor.name}`,
() => {
const array = new typedArrayConstructor([1,2,3]);
expect(platform.isTypedArray(array)).toBeTrue();
});
}

it('works on values created in a new node context', async () => {
const array = await new Promise((resolve) => {
const code = `resolve(new Uint8Array([1, 2, 3]));`;
vm.runInNewContext(code, {resolve});
});

expect(platform.isTypedArray(array)).toBeTrue();
});
});
});
58 changes: 56 additions & 2 deletions tfjs-core/src/util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/

import {env} from './environment';
import {BackendValues, DataType, TensorLike, TypedArray} from './types';
import {BackendValues, DataType, RecursiveArray, TensorLike, TypedArray} from './types';
import * as base from './util_base';
export * from './util_base';
export * from './hash_util';
Expand Down Expand Up @@ -44,7 +44,7 @@ export function toTypedArray(a: TensorLike, dtype: DataType): TypedArray {
throw new Error('Cannot convert a string[] to a TypedArray');
}
if (Array.isArray(a)) {
a = base.flatten(a);
a = flatten(a);
}

if (env().getBool('DEBUG')) {
Expand Down Expand Up @@ -131,3 +131,57 @@ export function decodeString(bytes: Uint8Array, encoding = 'utf-8'): string {
encoding = encoding || 'utf-8';
return env().platform.decode(bytes, encoding);
}

export function isTypedArray(a: {}): a is Float32Array|Int32Array|Uint8Array|
Uint8ClampedArray {
return env().platform.isTypedArray(a);
}

// NOTE: We explicitly type out what T extends instead of any so that
// util.flatten on a nested array of number doesn't try to infer T as a
// number[][], causing us to explicitly type util.flatten<number>().
/**
* Flattens an arbitrarily nested array.
*
* ```js
* const a = [[1, 2], [3, 4], [5, [6, [7]]]];
* const flat = tf.util.flatten(a);
* console.log(flat);
* ```
*
* @param arr The nested array to flatten.
* @param result The destination array which holds the elements.
* @param skipTypedArray If true, avoids flattening the typed arrays. Defaults
* to false.
*
* @doc {heading: 'Util', namespace: 'util'}
*/
export function
flatten<T extends number|boolean|string|Promise<number>|TypedArray>(
arr: T|RecursiveArray<T>, result: T[] = [], skipTypedArray = false): T[] {
if (result == null) {
result = [];
}
if (typeof arr === 'boolean' || typeof arr === 'number' ||
typeof arr === 'string' || base.isPromise(arr) || arr == null ||
isTypedArray(arr) && skipTypedArray) {
result.push(arr as T);
} else if (Array.isArray(arr) || isTypedArray(arr)) {
for (let i = 0; i < arr.length; ++i) {
flatten(arr[i], result, skipTypedArray);
}
} else {
let maxIndex = -1;
for (const key of Object.keys(arr)) {
// 0 or positive integer.
if (/^([1-9]+[0-9]*|0)$/.test(key)) {
maxIndex = Math.max(maxIndex, Number(key));
}
}
for (let i = 0; i <= maxIndex; i++) {
// tslint:disable-next-line: no-unnecessary-type-assertion
flatten((arr as RecursiveArray<T>)[i], result, skipTypedArray);
}
}
return result;
}
57 changes: 1 addition & 56 deletions tfjs-core/src/util_base.ts
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
* =============================================================================
*/

import {BackendValues, DataType, DataTypeMap, FlatVector, NumericDataType, RecursiveArray, TensorLike, TypedArray, WebGLData, WebGPUData} from './types';
import {BackendValues, DataType, DataTypeMap, FlatVector, NumericDataType, TensorLike, TypedArray, WebGLData, WebGPUData} from './types';

/**
* Shuffles the array in-place using Fisher-Yates algorithm.
Expand Down Expand Up @@ -167,55 +167,6 @@ export function assertNonNull(a: TensorLike): void {
() => `The input to the tensor constructor must be a non-null value.`);
}

// NOTE: We explicitly type out what T extends instead of any so that
// util.flatten on a nested array of number doesn't try to infer T as a
// number[][], causing us to explicitly type util.flatten<number>().
/**
* Flattens an arbitrarily nested array.
*
* ```js
* const a = [[1, 2], [3, 4], [5, [6, [7]]]];
* const flat = tf.util.flatten(a);
* console.log(flat);
* ```
*
* @param arr The nested array to flatten.
* @param result The destination array which holds the elements.
* @param skipTypedArray If true, avoids flattening the typed arrays. Defaults
* to false.
*
* @doc {heading: 'Util', namespace: 'util'}
*/
export function
flatten<T extends number|boolean|string|Promise<number>|TypedArray>(
arr: T|RecursiveArray<T>, result: T[] = [], skipTypedArray = false): T[] {
if (result == null) {
result = [];
}
if (typeof arr === 'boolean' || typeof arr === 'number' ||
typeof arr === 'string' || isPromise(arr) || arr == null ||
isTypedArray(arr) && skipTypedArray) {
result.push(arr as T);
} else if (Array.isArray(arr) || isTypedArray(arr)) {
for (let i = 0; i < arr.length; ++i) {
flatten(arr[i], result, skipTypedArray);
}
} else {
let maxIndex = -1;
for (const key of Object.keys(arr)) {
// 0 or positive integer.
if (/^([1-9]+[0-9]*|0)$/.test(key)) {
maxIndex = Math.max(maxIndex, Number(key));
}
}
for (let i = 0; i <= maxIndex; i++) {
// tslint:disable-next-line: no-unnecessary-type-assertion
flatten((arr as RecursiveArray<T>)[i], result, skipTypedArray);
}
}
return result;
}

/**
* Returns the size (number of elements) of the tensor given its shape.
*
Expand Down Expand Up @@ -527,12 +478,6 @@ export function hasEncodingLoss(oldType: DataType, newType: DataType): boolean {
return true;
}

export function isTypedArray(a: {}): a is Float32Array|Int32Array|Uint8Array|
Uint8ClampedArray {
return a instanceof Float32Array || a instanceof Int32Array ||
a instanceof Uint8Array || a instanceof Uint8ClampedArray;
}

export function bytesPerElement(dtype: DataType): number {
if (dtype === 'float32' || dtype === 'int32') {
return 4;
Expand Down
8 changes: 4 additions & 4 deletions yarn.lock
Original file line number Diff line number Diff line change
Expand Up @@ -384,10 +384,10 @@
resolved "https://registry.yarnpkg.com/@types/node/-/node-10.17.60.tgz#35f3d6213daed95da7f0f73e75bcc6980e90597b"
integrity sha512-F0KIgDJfy2nA3zMLmWGKxcH2ZVEtCZXHHdOQs2gSaQ27+lNeEfGxzkIw90aXswATX7AZ33tahPbzy6KAfUreVw==

"@types/node@^12.7.5":
version "12.20.28"
resolved "https://registry.yarnpkg.com/@types/node/-/node-12.20.28.tgz#4b20048c6052b5f51a8d5e0d2acbf63d5a17e1e2"
integrity sha512-cBw8gzxUPYX+/5lugXIPksioBSbE42k0fZ39p+4yRzfYjN6++eq9kAPdlY9qm+MXyfbk9EmvCYAYRn380sF46w==
"@types/node@^18.11.15":
version "18.11.15"
resolved "https://registry.yarnpkg.com/@types/node/-/node-18.11.15.tgz#de0e1fbd2b22b962d45971431e2ae696643d3f5d"
integrity sha512-VkhBbVo2+2oozlkdHXLrb3zjsRkpdnaU2bXmX8Wgle3PUi569eLRaHGlgETQHR7lLL1w7GiG3h9SnePhxNDecw==

"@types/offscreencanvas@^2019.7.0":
version "2019.7.0"
Expand Down

0 comments on commit ae902e5

Please sign in to comment.