diff --git a/tfjs-converter/docs/supported_ops.md b/tfjs-converter/docs/supported_ops.md index e070eb28458..35e5122072c 100644 --- a/tfjs-converter/docs/supported_ops.md +++ b/tfjs-converter/docs/supported_ops.md @@ -146,6 +146,7 @@ |OneHot|oneHot| |Ones|ones| |OnesLike|onesLike| +|RandomStandardNormal|RandomStandardNormal| |RandomUniform|RandomUniform| |Range|range| |TruncatedNormal|truncatedNormal| diff --git a/tfjs-converter/python/tensorflowjs/op_list/creation.json b/tfjs-converter/python/tensorflowjs/op_list/creation.json index 9c97df9d0bb..9ad1996f6bc 100644 --- a/tfjs-converter/python/tensorflowjs/op_list/creation.json +++ b/tfjs-converter/python/tensorflowjs/op_list/creation.json @@ -129,6 +129,43 @@ } ] }, + { + "tfOpName": "RandomStandardNormal", + "category": "creation", + "inputs": [ + { + "start": 0, + "name": "shape", + "type": "number[]" + } + ], + "attrs": [ + { + "tfName": "seed", + "name": "seed", + "type": "number", + "defaultValue": 0 + }, + { + "tfName": "seed2", + "name": "seed2", + "type": "number", + "defaultValue": 0, + "notSupported": true + }, + { + "tfName": "dtype", + "name": "dtype", + "type": "dtype" + }, + { + "tfName": "T", + "name": "T", + "type": "number", + "notSupported": true + } + ] + }, { "tfOpName": "RandomUniform", "category": "creation", diff --git a/tfjs-converter/src/operations/executors/creation_executor.ts b/tfjs-converter/src/operations/executors/creation_executor.ts index be6e9ed2428..596384b6f42 100644 --- a/tfjs-converter/src/operations/executors/creation_executor.ts +++ b/tfjs-converter/src/operations/executors/creation_executor.ts @@ -75,6 +75,13 @@ export const executeOp: InternalOpExecutor = return [tfOps.onesLike( getParamValue('x', node, tensorMap, context) as Tensor)]; } + case 'RandomStandardNormal': { + return [tfOps.randomStandardNormal( + getParamValue('shape', node, tensorMap, context) as number[], + getParamValue('dtype', node, tensorMap, context) as 'float32' | + 'int32', + getParamValue('seed', node, tensorMap, context) as number)]; + } case 'RandomUniform': { return [tfOps.randomUniform( // tslint:disable-next-line:no-any diff --git a/tfjs-converter/src/operations/executors/creation_executor_test.ts b/tfjs-converter/src/operations/executors/creation_executor_test.ts index cfe30817011..7e6b3d84667 100644 --- a/tfjs-converter/src/operations/executors/creation_executor_test.ts +++ b/tfjs-converter/src/operations/executors/creation_executor_test.ts @@ -173,6 +173,30 @@ describe('creation', () => { expect(validateParam(node, creation.json)).toBeTruthy(); }); }); + describe('RandomStandardNormal', () => { + it('should call tfOps.randomStandardNormal', () => { + spyOn(tfOps, 'randomStandardNormal'); + node.op = 'RandomStandardNormal'; + node.inputParams['shape'] = createNumericArrayAttrFromIndex(0); + node.inputNames = ['input1']; + node.attrParams['dtype'] = createDtypeAttr('float32'); + node.attrParams['seed'] = createNumberAttr(0); + + executeOp(node, {input1}, context); + + expect(tfOps.randomStandardNormal) + .toHaveBeenCalledWith([1, 2, 3], 'float32', 0); + }); + it('should match json def', () => { + node.op = 'RandomStandardNormal'; + node.inputParams['shape'] = createNumericArrayAttrFromIndex(0); + node.inputNames = ['input1']; + node.attrParams['dtype'] = createDtypeAttr('float32'); + node.attrParams['seed'] = createNumberAttr(0); + + expect(validateParam(node, creation.json)).toBeTruthy(); + }); + }); describe('RandomUniform', () => { it('should call tfOps.randomUniform', () => { spyOn(tfOps, 'randomUniform'); diff --git a/tfjs-core/src/ops/ops.ts b/tfjs-core/src/ops/ops.ts index faa7e609eea..2d93f720250 100644 --- a/tfjs-core/src/ops/ops.ts +++ b/tfjs-core/src/ops/ops.ts @@ -138,6 +138,7 @@ export {prod} from './prod'; export {rand} from './rand'; export {randomGamma} from './random_gamma'; export {randomNormal} from './random_normal'; +export {randomStandardNormal} from './random_standard_normal'; export {randomUniform} from './random_uniform'; export {range} from './range'; export {real} from './real'; diff --git a/tfjs-core/src/ops/random_standard_normal.ts b/tfjs-core/src/ops/random_standard_normal.ts new file mode 100644 index 00000000000..adf31bfbd52 --- /dev/null +++ b/tfjs-core/src/ops/random_standard_normal.ts @@ -0,0 +1,47 @@ +/** + * @license + * Copyright 2022 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {Tensor} from '../tensor'; +import {DataType, Rank, ShapeMap} from '../types'; + +import {op} from './operation'; +import {randomNormal} from './random_normal'; + +/** + * Creates a `tf.Tensor` with values sampled from a normal distribution. + * + * The generated values will have mean 0 and standard deviation 1. + * + * ```js + * tf.randomStandardNormal([2, 2]).print(); + * ``` + * + * @param shape An array of integers defining the output tensor shape. + * @param dtype The data type of the output. + * @param seed The seed for the random number generator. + * + * @doc {heading: 'Tensors', subheading: 'Random'} + */ +function randomStandardNormal_( + shape: ShapeMap[R], dtype?: 'float32'|'int32', seed?: number): Tensor { + if (dtype != null && (dtype as DataType) === 'bool') { + throw new Error(`Unsupported data type ${dtype}`); + } + return randomNormal(shape, 0, 1, dtype, seed); +} + +export const randomStandardNormal = op({randomStandardNormal_}); diff --git a/tfjs-core/src/ops/random_standard_normal_test.ts b/tfjs-core/src/ops/random_standard_normal_test.ts new file mode 100644 index 00000000000..189a783f2d4 --- /dev/null +++ b/tfjs-core/src/ops/random_standard_normal_test.ts @@ -0,0 +1,146 @@ + +/** + * @license + * Copyright 2022 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import * as tf from '../index'; +import {ALL_ENVS, describeWithFlags} from '../jasmine_util'; + +import {expectArrayInMeanStdRange, jarqueBeraNormalityTest} from './rand_util'; + +describeWithFlags('randomStandardNormal', ALL_ENVS, () => { + const SEED = 42; + const EPSILON = 0.05; + + it('should return a float32 1D of random standard normal values', + async () => { + const SAMPLES = 10000; + + // Ensure defaults to float32. + let result = tf.randomStandardNormal([SAMPLES], null, SEED); + expect(result.dtype).toBe('float32'); + expect(result.shape).toEqual([SAMPLES]); + jarqueBeraNormalityTest(await result.data()); + expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON); + + result = tf.randomStandardNormal([SAMPLES], 'float32', SEED); + expect(result.dtype).toBe('float32'); + expect(result.shape).toEqual([SAMPLES]); + jarqueBeraNormalityTest(await result.data()); + expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON); + }); + + it('should return a int32 1D of random standard normal values', async () => { + const SAMPLES = 10000; + const result = tf.randomStandardNormal([SAMPLES], 'int32', SEED); + expect(result.dtype).toBe('int32'); + expect(result.shape).toEqual([SAMPLES]); + jarqueBeraNormalityTest(await result.data()); + expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON); + }); + + it('should return a float32 2D of random standard normal values', + async () => { + const SAMPLES = 100; + + // Ensure defaults to float32. + let result = tf.randomStandardNormal([SAMPLES, SAMPLES], null, SEED); + expect(result.dtype).toBe('float32'); + expect(result.shape).toEqual([SAMPLES, SAMPLES]); + jarqueBeraNormalityTest(await result.data()); + expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON); + + result = tf.randomStandardNormal([SAMPLES, SAMPLES], 'float32', SEED); + expect(result.dtype).toBe('float32'); + expect(result.shape).toEqual([SAMPLES, SAMPLES]); + jarqueBeraNormalityTest(await result.data()); + expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON); + }); + + it('should return a int32 2D of random standard normal values', async () => { + const SAMPLES = 100; + const result = tf.randomStandardNormal([SAMPLES, SAMPLES], 'int32', SEED); + expect(result.dtype).toBe('int32'); + expect(result.shape).toEqual([SAMPLES, SAMPLES]); + jarqueBeraNormalityTest(await result.data()); + expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON); + }); + + it('should return a float32 3D of random standard normal values', + async () => { + const SAMPLES_SHAPE = [20, 20, 20]; + + // Ensure defaults to float32. + let result = tf.randomStandardNormal(SAMPLES_SHAPE, null, SEED); + expect(result.dtype).toBe('float32'); + expect(result.shape).toEqual(SAMPLES_SHAPE); + jarqueBeraNormalityTest(await result.data()); + expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON); + + result = tf.randomStandardNormal(SAMPLES_SHAPE, 'float32', SEED); + expect(result.dtype).toBe('float32'); + expect(result.shape).toEqual(SAMPLES_SHAPE); + jarqueBeraNormalityTest(await result.data()); + expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON); + }); + + it('should return a int32 3D of random standard normal values', async () => { + const SAMPLES_SHAPE = [20, 20, 20]; + const result = tf.randomStandardNormal(SAMPLES_SHAPE, 'int32', SEED); + expect(result.dtype).toBe('int32'); + expect(result.shape).toEqual(SAMPLES_SHAPE); + jarqueBeraNormalityTest(await result.data()); + expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON); + }); + + it('should return a float32 4D of random standard normal values', + async () => { + const SAMPLES_SHAPE = [10, 10, 10, 10]; + + // Ensure defaults to float32. + let result = tf.randomStandardNormal(SAMPLES_SHAPE, null, SEED); + expect(result.dtype).toBe('float32'); + expect(result.shape).toEqual(SAMPLES_SHAPE); + jarqueBeraNormalityTest(await result.data()); + expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON); + + result = tf.randomStandardNormal(SAMPLES_SHAPE, 'float32', SEED); + expect(result.dtype).toBe('float32'); + expect(result.shape).toEqual(SAMPLES_SHAPE); + jarqueBeraNormalityTest(await result.data()); + expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON); + }); + + it('should return a int32 4D of random standard normal values', async () => { + const SAMPLES_SHAPE = [10, 10, 10, 10]; + + const result = tf.randomStandardNormal(SAMPLES_SHAPE, 'int32', SEED); + expect(result.dtype).toBe('int32'); + expect(result.shape).toEqual(SAMPLES_SHAPE); + jarqueBeraNormalityTest(await result.data()); + expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON); + }); + + it('should return a int32 5D of random standard normal values', async () => { + const SAMPLES_SHAPE = [10, 10, 10, 10, 10]; + + const result = tf.randomStandardNormal(SAMPLES_SHAPE, 'int32', SEED); + expect(result.dtype).toBe('int32'); + expect(result.shape).toEqual(SAMPLES_SHAPE); + jarqueBeraNormalityTest(await result.data()); + expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON); + }); +});