Skip to content

Commit

Permalink
Add randomStandardNormal op (#6533)
Browse files Browse the repository at this point in the history
Fixes #4156

Co-authored-by: Matthew Soulanille <msoulanille@google.com>
  • Loading branch information
kon72 and mattsoulanille authored Jun 21, 2022
1 parent 2d5755c commit 9206578
Show file tree
Hide file tree
Showing 7 changed files with 263 additions and 0 deletions.
1 change: 1 addition & 0 deletions tfjs-converter/docs/supported_ops.md
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,7 @@
|OneHot|oneHot|
|Ones|ones|
|OnesLike|onesLike|
|RandomStandardNormal|RandomStandardNormal|
|RandomUniform|RandomUniform|
|Range|range|
|TruncatedNormal|truncatedNormal|
Expand Down
37 changes: 37 additions & 0 deletions tfjs-converter/python/tensorflowjs/op_list/creation.json
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,43 @@
}
]
},
{
"tfOpName": "RandomStandardNormal",
"category": "creation",
"inputs": [
{
"start": 0,
"name": "shape",
"type": "number[]"
}
],
"attrs": [
{
"tfName": "seed",
"name": "seed",
"type": "number",
"defaultValue": 0
},
{
"tfName": "seed2",
"name": "seed2",
"type": "number",
"defaultValue": 0,
"notSupported": true
},
{
"tfName": "dtype",
"name": "dtype",
"type": "dtype"
},
{
"tfName": "T",
"name": "T",
"type": "number",
"notSupported": true
}
]
},
{
"tfOpName": "RandomUniform",
"category": "creation",
Expand Down
7 changes: 7 additions & 0 deletions tfjs-converter/src/operations/executors/creation_executor.ts
Original file line number Diff line number Diff line change
Expand Up @@ -75,6 +75,13 @@ export const executeOp: InternalOpExecutor =
return [tfOps.onesLike(
getParamValue('x', node, tensorMap, context) as Tensor)];
}
case 'RandomStandardNormal': {
return [tfOps.randomStandardNormal(
getParamValue('shape', node, tensorMap, context) as number[],
getParamValue('dtype', node, tensorMap, context) as 'float32' |
'int32',
getParamValue('seed', node, tensorMap, context) as number)];
}
case 'RandomUniform': {
return [tfOps.randomUniform(
// tslint:disable-next-line:no-any
Expand Down
24 changes: 24 additions & 0 deletions tfjs-converter/src/operations/executors/creation_executor_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,30 @@ describe('creation', () => {
expect(validateParam(node, creation.json)).toBeTruthy();
});
});
describe('RandomStandardNormal', () => {
it('should call tfOps.randomStandardNormal', () => {
spyOn(tfOps, 'randomStandardNormal');
node.op = 'RandomStandardNormal';
node.inputParams['shape'] = createNumericArrayAttrFromIndex(0);
node.inputNames = ['input1'];
node.attrParams['dtype'] = createDtypeAttr('float32');
node.attrParams['seed'] = createNumberAttr(0);

executeOp(node, {input1}, context);

expect(tfOps.randomStandardNormal)
.toHaveBeenCalledWith([1, 2, 3], 'float32', 0);
});
it('should match json def', () => {
node.op = 'RandomStandardNormal';
node.inputParams['shape'] = createNumericArrayAttrFromIndex(0);
node.inputNames = ['input1'];
node.attrParams['dtype'] = createDtypeAttr('float32');
node.attrParams['seed'] = createNumberAttr(0);

expect(validateParam(node, creation.json)).toBeTruthy();
});
});
describe('RandomUniform', () => {
it('should call tfOps.randomUniform', () => {
spyOn(tfOps, 'randomUniform');
Expand Down
1 change: 1 addition & 0 deletions tfjs-core/src/ops/ops.ts
Original file line number Diff line number Diff line change
Expand Up @@ -138,6 +138,7 @@ export {prod} from './prod';
export {rand} from './rand';
export {randomGamma} from './random_gamma';
export {randomNormal} from './random_normal';
export {randomStandardNormal} from './random_standard_normal';
export {randomUniform} from './random_uniform';
export {range} from './range';
export {real} from './real';
Expand Down
47 changes: 47 additions & 0 deletions tfjs-core/src/ops/random_standard_normal.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/**
* @license
* Copyright 2022 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import {Tensor} from '../tensor';
import {DataType, Rank, ShapeMap} from '../types';

import {op} from './operation';
import {randomNormal} from './random_normal';

/**
* Creates a `tf.Tensor` with values sampled from a normal distribution.
*
* The generated values will have mean 0 and standard deviation 1.
*
* ```js
* tf.randomStandardNormal([2, 2]).print();
* ```
*
* @param shape An array of integers defining the output tensor shape.
* @param dtype The data type of the output.
* @param seed The seed for the random number generator.
*
* @doc {heading: 'Tensors', subheading: 'Random'}
*/
function randomStandardNormal_<R extends Rank>(
shape: ShapeMap[R], dtype?: 'float32'|'int32', seed?: number): Tensor<R> {
if (dtype != null && (dtype as DataType) === 'bool') {
throw new Error(`Unsupported data type ${dtype}`);
}
return randomNormal(shape, 0, 1, dtype, seed);
}

export const randomStandardNormal = op({randomStandardNormal_});
146 changes: 146 additions & 0 deletions tfjs-core/src/ops/random_standard_normal_test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,146 @@

/**
* @license
* Copyright 2022 Google LLC. All Rights Reserved.
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
* =============================================================================
*/

import * as tf from '../index';
import {ALL_ENVS, describeWithFlags} from '../jasmine_util';

import {expectArrayInMeanStdRange, jarqueBeraNormalityTest} from './rand_util';

describeWithFlags('randomStandardNormal', ALL_ENVS, () => {
const SEED = 42;
const EPSILON = 0.05;

it('should return a float32 1D of random standard normal values',
async () => {
const SAMPLES = 10000;

// Ensure defaults to float32.
let result = tf.randomStandardNormal([SAMPLES], null, SEED);
expect(result.dtype).toBe('float32');
expect(result.shape).toEqual([SAMPLES]);
jarqueBeraNormalityTest(await result.data());
expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON);

result = tf.randomStandardNormal([SAMPLES], 'float32', SEED);
expect(result.dtype).toBe('float32');
expect(result.shape).toEqual([SAMPLES]);
jarqueBeraNormalityTest(await result.data());
expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON);
});

it('should return a int32 1D of random standard normal values', async () => {
const SAMPLES = 10000;
const result = tf.randomStandardNormal([SAMPLES], 'int32', SEED);
expect(result.dtype).toBe('int32');
expect(result.shape).toEqual([SAMPLES]);
jarqueBeraNormalityTest(await result.data());
expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON);
});

it('should return a float32 2D of random standard normal values',
async () => {
const SAMPLES = 100;

// Ensure defaults to float32.
let result = tf.randomStandardNormal([SAMPLES, SAMPLES], null, SEED);
expect(result.dtype).toBe('float32');
expect(result.shape).toEqual([SAMPLES, SAMPLES]);
jarqueBeraNormalityTest(await result.data());
expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON);

result = tf.randomStandardNormal([SAMPLES, SAMPLES], 'float32', SEED);
expect(result.dtype).toBe('float32');
expect(result.shape).toEqual([SAMPLES, SAMPLES]);
jarqueBeraNormalityTest(await result.data());
expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON);
});

it('should return a int32 2D of random standard normal values', async () => {
const SAMPLES = 100;
const result = tf.randomStandardNormal([SAMPLES, SAMPLES], 'int32', SEED);
expect(result.dtype).toBe('int32');
expect(result.shape).toEqual([SAMPLES, SAMPLES]);
jarqueBeraNormalityTest(await result.data());
expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON);
});

it('should return a float32 3D of random standard normal values',
async () => {
const SAMPLES_SHAPE = [20, 20, 20];

// Ensure defaults to float32.
let result = tf.randomStandardNormal(SAMPLES_SHAPE, null, SEED);
expect(result.dtype).toBe('float32');
expect(result.shape).toEqual(SAMPLES_SHAPE);
jarqueBeraNormalityTest(await result.data());
expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON);

result = tf.randomStandardNormal(SAMPLES_SHAPE, 'float32', SEED);
expect(result.dtype).toBe('float32');
expect(result.shape).toEqual(SAMPLES_SHAPE);
jarqueBeraNormalityTest(await result.data());
expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON);
});

it('should return a int32 3D of random standard normal values', async () => {
const SAMPLES_SHAPE = [20, 20, 20];
const result = tf.randomStandardNormal(SAMPLES_SHAPE, 'int32', SEED);
expect(result.dtype).toBe('int32');
expect(result.shape).toEqual(SAMPLES_SHAPE);
jarqueBeraNormalityTest(await result.data());
expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON);
});

it('should return a float32 4D of random standard normal values',
async () => {
const SAMPLES_SHAPE = [10, 10, 10, 10];

// Ensure defaults to float32.
let result = tf.randomStandardNormal(SAMPLES_SHAPE, null, SEED);
expect(result.dtype).toBe('float32');
expect(result.shape).toEqual(SAMPLES_SHAPE);
jarqueBeraNormalityTest(await result.data());
expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON);

result = tf.randomStandardNormal(SAMPLES_SHAPE, 'float32', SEED);
expect(result.dtype).toBe('float32');
expect(result.shape).toEqual(SAMPLES_SHAPE);
jarqueBeraNormalityTest(await result.data());
expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON);
});

it('should return a int32 4D of random standard normal values', async () => {
const SAMPLES_SHAPE = [10, 10, 10, 10];

const result = tf.randomStandardNormal(SAMPLES_SHAPE, 'int32', SEED);
expect(result.dtype).toBe('int32');
expect(result.shape).toEqual(SAMPLES_SHAPE);
jarqueBeraNormalityTest(await result.data());
expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON);
});

it('should return a int32 5D of random standard normal values', async () => {
const SAMPLES_SHAPE = [10, 10, 10, 10, 10];

const result = tf.randomStandardNormal(SAMPLES_SHAPE, 'int32', SEED);
expect(result.dtype).toBe('int32');
expect(result.shape).toEqual(SAMPLES_SHAPE);
jarqueBeraNormalityTest(await result.data());
expectArrayInMeanStdRange(await result.data(), 0, 1, EPSILON);
});
});

0 comments on commit 9206578

Please sign in to comment.