From 555fb9a449f7eb25da89b0aba23737afbbd3b957 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Fri, 28 Sep 2018 17:19:58 -0400 Subject: [PATCH 01/12] WIP --- src/exports_layers.ts | 15 ++- src/layers/merge.ts | 212 ++++++++++++++++++++++++++++++++++++++- src/layers/merge_test.ts | 24 +++++ 3 files changed, 249 insertions(+), 2 deletions(-) diff --git a/src/exports_layers.ts b/src/exports_layers.ts index ff6be2625..139bcf7ee 100644 --- a/src/exports_layers.ts +++ b/src/exports_layers.ts @@ -16,7 +16,7 @@ import {Conv1D, Conv2D, Conv2DTranspose, ConvLayerConfig, Cropping2D, Cropping2D import {DepthwiseConv2D, DepthwiseConv2DLayerConfig} from './layers/convolutional_depthwise'; import {Activation, ActivationLayerConfig, Dense, DenseLayerConfig, Dropout, DropoutLayerConfig, Flatten, Permute, PermuteLayerConfig, RepeatVector, RepeatVectorLayerConfig, Reshape, ReshapeLayerConfig} from './layers/core'; import {Embedding, EmbeddingLayerConfig} from './layers/embeddings'; -import {Add, Average, Concatenate, ConcatenateLayerConfig, Maximum, Minimum, Multiply} from './layers/merge'; +import {Add, Average, Concatenate, ConcatenateLayerConfig, Dot, DotLayerConfig, Maximum, Minimum, Multiply} from './layers/merge'; import {BatchNormalization, BatchNormalizationLayerConfig} from './layers/normalization'; import {ZeroPadding2D, ZeroPadding2DLayerConfig} from './layers/padding'; import {AveragePooling1D, AveragePooling2D, GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalMaxPooling1D, GlobalMaxPooling2D, GlobalPooling2DLayerConfig, MaxPooling1D, MaxPooling2D, Pooling1DLayerConfig, Pooling2DLayerConfig} from './layers/pooling'; @@ -392,6 +392,19 @@ export function multiply(config?: LayerConfig): Layer { return new Multiply(config); } +/** + * @doc { + * heading: 'Layers', + * subheading: 'Merge', + * namespace: 'layers', + * useDocsFrom: 'Dot', + * configParamIndices: [0] + * } + */ +export function dot(config?: DotLayerConfig): Layer { + return new Dot(config); +} + // Normalization Layers. /** diff --git a/src/layers/merge.ts b/src/layers/merge.ts index 6eea618b9..545eee506 100644 --- a/src/layers/merge.ts +++ b/src/layers/merge.ts @@ -19,6 +19,7 @@ import {getScalar} from '../backend/state'; import * as K from '../backend/tfjs_backend'; import {Layer, LayerConfig, SymbolicTensor} from '../engine/topology'; import {NotImplementedError, ValueError} from '../errors'; +import {l2Normalize} from '../losses'; import {Kwargs, Shape} from '../types'; import * as generic_utils from '../utils/generic_utils'; import * as mathUtils from '../utils/math_utils'; @@ -872,6 +873,215 @@ export function concatenate(config?: SymbolicTensor[]|Tensor[]| } } -// TODO(cais): Add class Dot. +export interface DotLayerConfig extends LayerConfig { + /** + * Axis or axes along which the dot product will be taken. + * + * Integer or an Array of integers. + */ + axes: number|number[]; + + /** + * Whether to L2-normalize samples along the dot product axis + * before taking the dot product. + * + * If set to `true`, the output of the dot product isthe cosine + * proximity between the two samples. + */ + normalize?: boolean; +} + +function normalizeAxis(axis: number, dim: number): number { + while (axis < 0) { + axis += dim; + } + return axis; +} + +function batchDot(x: Tensor, y: Tensor, axes: number|number[]): Tensor { + if (typeof axes === 'number') { + axes = [axes, axes]; + } + + if (x.dtype === 'complex64' || y.dtype === 'complex64') { + throw new NotImplementedError( + 'batchDot is not implemented for complex64-type Tensors yet.'); + } + + const xNDim = x.shape.length; + const yNDim = y.shape.length; + if (axes == null) { + axes = [xNDim - 1, yNDim - 2]; + } + let diff: number; + if (xNDim > yNDim) { + diff = xNDim - yNDim; + const diffShape: Shape = []; + for (let i = 0; i < diff; ++i) { + diffShape.push(1); + } + y = y.reshape(y.shape.concat(diffShape)); + } else if (yNDim > xNDim) { + diff = yNDim - xNDim; + const diffShape: Shape = []; + for (let i = 0; i < diff; ++i) { + diffShape.push(1); + } + x = x.reshape(x.shape.concat(diffShape)); + } else { + diff = 0; + } + + let out: Tensor; + if (x.shape.length === 2 && y.shape.length === 2) { + if (axes[0] === axes[1]) { + out = x.mulStrict(y).sum(axes[0]); + } else { + out = x.transpose([1, 0]).mulStrict(y).sum(axes[1]); + } + } else { + const adjX = axes[0] === x.shape.length - 1 ? null : true; + const adjY = axes[1] === y.shape.length - 1 ? true : null; + out = x.matMul(y, adjX, adjY); + } + + if (diff > 0) { + let idx: number; + if (xNDim > yNDim) { + idx = xNDim + yNDim - 3; + } else { + idx = xNDim - 1; + } + const squeezeAxes: number[] = []; + for (let i = idx; i < idx + diff; ++i) { + squeezeAxes.push(i); + } + out = out.squeeze(squeezeAxes); + } + if (out.shape.length === 1) { + out = out.expandDims(1); + } + return out; +} + +/** + * Layer that computes a dot product between samples in two tensors. + * + * E.g., if applied to a list of two tensors `a` and `b` both of shape + * `[batchSize, n]`, the output will be a tensor of shape `[batchSize, 1]`, + * where each entry at index `[i, 0]` will be the dot product between + * `a[i, :]` and `b[i, :]`. + * + * TODO(cais): Add code snippet. + * + */ +export class Dot extends Merge { + static className = 'Add'; + + private axes: number|number[]; + private normalize: boolean; + + constructor(config: DotLayerConfig) { + super(config); + this.axes = config.axes; + this.normalize = config.normalize == null ? false : config.normalize; + this.supportsMasking = true; + this.reshapeRequired = false; + } + + build(inputShape: Shape|Shape[]): void { + tfc.util.assert( + Array.isArray(inputShape) && inputShape.length === 2 && + Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]), + 'A `Dot` layer should be called on a list of exactly 2 inputs.'); + const shape1 = inputShape[0] as Shape; + const shape2 = inputShape[1] as Shape; + let axes: number[]; // TODO(cais): Refactor into a function? + if (!Array.isArray(this.axes)) { + // `this.axes` is a single integer. + axes = [ + normalizeAxis(this.axes, shape1.length), + normalizeAxis(this.axes, shape2.length) + ]; + // `this.axes` is an Array of integers. + } else { + axes = this.axes; + } + if (shape1[axes[0]] !== shape2[axes[1]]) { + throw new ValueError( + `Dimension incompability: ${shape1[axes[0]]} !== ${shape2[axes[1]]}`); + } + } + + protected mergeFunction(inputs: Tensor[]): Tensor { + if (inputs.length !== 2) { + throw new ValueError( + 'A `Dot` layer must be called on exactly 2 inputs, ' + + `but received ${inputs.length} input(s).`); + } + + let x1 = inputs[0]; + let x2 = inputs[1]; + let axes: number[]; + if (!Array.isArray(this.axes)) { + axes = [ + normalizeAxis(this.axes, x1.shape.length), + normalizeAxis(this.axes, x2.shape.length) + ]; + } else { + axes = []; + for (let i = 0; i < this.axes.length; ++i) { + axes.push(normalizeAxis(this.axes[i], inputs[i].shape.length)); + } + } + if (this.normalize) { + x1 = l2Normalize(x1, axes[0]); + x2 = l2Normalize(x2, axes[1]); + } + return batchDot(x1, x2, axes); + } + + computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] { + tfc.util.assert( + Array.isArray(inputShape) && inputShape.length === 2 && + Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]), + 'A `Dot` layer should be called on a list of exactly 2 inputs.'); + const shape1 = inputShape[0] as Shape; + const shape2 = inputShape[1] as Shape; + console.log('shape1:', shape1); // DEBUG + console.log('shape2:', shape2); // DEBUG + let axes: number[]; // TODO(cais): Refactor into a function? + if (!Array.isArray(this.axes)) { + // `this.axes` is a single integer. + axes = [ + normalizeAxis(this.axes, shape1.length), + normalizeAxis(this.axes, shape2.length) + ]; + // `this.axes` is an Array of integers. + } else { + axes = this.axes; + } + shape1.splice(axes[0], 1); + shape2.splice(axes[1], 1); + shape2.splice(0, 1); + const outputShape = shape1.concat(shape2); + if (outputShape.length === 1) { + outputShape.push(1); + } + return outputShape; + } + + // TODO(cais): Implement computeMask(); + + getConfig(): serialization.ConfigDict { + const config: serialization.ConfigDict = { + 'axes': this.axes, + 'normalize': this.normalize + }; + const baseConfig = super.getConfig(); + Object.assign(config, baseConfig); + return config; + } +} // TODO(cais): Add functional interfaces for the merge layers. diff --git a/src/layers/merge_test.ts b/src/layers/merge_test.ts index 91f9dc11e..4d7fec9c4 100644 --- a/src/layers/merge_test.ts +++ b/src/layers/merge_test.ts @@ -551,3 +551,27 @@ describeMathCPU('Deserialize Merge Layers', () => { expect(model.outputs[0].shape).toEqual([null, 8]); }); }); + +describeMathCPU('Dot-Layer: Symbolic', () => { + it('2D x 2D', () => { + const x1 = new tfl.SymbolicTensor('float32', [null, 8], null, [], null); + const x2 = new tfl.SymbolicTensor('float32', [null, 8], null, [], null); + const y1 = tfl.layers.dot({axes: -1}).apply([x1, x2]) as tfl.SymbolicTensor; + expect(y1.shape).toEqual([null, 1]); + const y2 = tfl.layers.dot({axes: 1}).apply([x1, x2]) as tfl.SymbolicTensor; + expect(y2.shape).toEqual([null, 1]); + }); + + it('3D x 3D', () => { + const x1 = new tfl.SymbolicTensor('float32', [null, 2, 3], null, [], null); + const x2 = new tfl.SymbolicTensor('float32', [null, 2, 3], null, [], null); + const y1 = tfl.layers.dot({axes: -1}).apply([x1, x2]) as tfl.SymbolicTensor; + expect(y1.shape).toEqual([null, 2, 2]); + console.log('===='); // DEBUG + const y2 = tfl.layers.dot({axes: 2}).apply([x1, x2]) as tfl.SymbolicTensor; + console.log('y2.shape:', y2.shape); // DEBUG + expect(y2.shape).toEqual([null, 2, 2]); // TODO(cais): Fix test. + }); + + // TODO(cais): Cover incorrect number of inputs. +}); From 6a4f32652edc4c6eedd5052ecfbca364d4347404 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Sat, 29 Sep 2018 00:21:13 -0400 Subject: [PATCH 02/12] WIP --- src/layers/merge.ts | 7 ++- src/layers/merge_test.ts | 113 ++++++++++++++++++++++++++++++++++++--- 2 files changed, 109 insertions(+), 11 deletions(-) diff --git a/src/layers/merge.ts b/src/layers/merge.ts index 545eee506..fc3c6be7e 100644 --- a/src/layers/merge.ts +++ b/src/layers/merge.ts @@ -1009,7 +1009,8 @@ export class Dot extends Merge { } if (shape1[axes[0]] !== shape2[axes[1]]) { throw new ValueError( - `Dimension incompability: ${shape1[axes[0]]} !== ${shape2[axes[1]]}`); + `Dimension incompatibility: ` + + `${shape1[axes[0]]} !== ${shape2[axes[1]]}`); } } @@ -1048,8 +1049,6 @@ export class Dot extends Merge { 'A `Dot` layer should be called on a list of exactly 2 inputs.'); const shape1 = inputShape[0] as Shape; const shape2 = inputShape[1] as Shape; - console.log('shape1:', shape1); // DEBUG - console.log('shape2:', shape2); // DEBUG let axes: number[]; // TODO(cais): Refactor into a function? if (!Array.isArray(this.axes)) { // `this.axes` is a single integer. @@ -1057,8 +1056,8 @@ export class Dot extends Merge { normalizeAxis(this.axes, shape1.length), normalizeAxis(this.axes, shape2.length) ]; - // `this.axes` is an Array of integers. } else { + // `this.axes` is an Array of integers. axes = this.axes; } shape1.splice(axes[0], 1); diff --git a/src/layers/merge_test.ts b/src/layers/merge_test.ts index 4d7fec9c4..7ddc021c5 100644 --- a/src/layers/merge_test.ts +++ b/src/layers/merge_test.ts @@ -553,6 +553,21 @@ describeMathCPU('Deserialize Merge Layers', () => { }); describeMathCPU('Dot-Layer: Symbolic', () => { + // Example refernce Python Keras code: + // + // ```py + // import keras + // + // x1 = keras.Input(shape=[3, 4]) + // x2 = keras.Input(shape=[3]) + // dot_layer = keras.layers.Dot(1) + // y = dot_layer([x1, x2]) + // + // print(x1.shape) + // print(x2.shape) + // print(y.shape) + // ``` + it('2D x 2D', () => { const x1 = new tfl.SymbolicTensor('float32', [null, 8], null, [], null); const x2 = new tfl.SymbolicTensor('float32', [null, 8], null, [], null); @@ -562,16 +577,100 @@ describeMathCPU('Dot-Layer: Symbolic', () => { expect(y2.shape).toEqual([null, 1]); }); - it('3D x 3D', () => { + it('3D x 3D, axes = -1', () => { + const x1 = new tfl.SymbolicTensor('float32', [null, 2, 3], null, [], null); + const x2 = new tfl.SymbolicTensor('float32', [null, 2, 3], null, [], null); + const y = tfl.layers.dot({axes: -1}).apply([x1, x2]) as tfl.SymbolicTensor; + expect(y.shape).toEqual([null, 2, 2]); + }); + + it('3D x 3D, axes = 1', () => { + const x1 = new tfl.SymbolicTensor('float32', [null, 2, 3], null, [], null); + const x2 = new tfl.SymbolicTensor('float32', [null, 2, 3], null, [], null); + const y2 = tfl.layers.dot({axes: 1}).apply([x1, x2]) as tfl.SymbolicTensor; + expect(y2.shape).toEqual([null, 3, 3]); + }); + + it('3D x 3D, axes = 2', () => { const x1 = new tfl.SymbolicTensor('float32', [null, 2, 3], null, [], null); const x2 = new tfl.SymbolicTensor('float32', [null, 2, 3], null, [], null); - const y1 = tfl.layers.dot({axes: -1}).apply([x1, x2]) as tfl.SymbolicTensor; - expect(y1.shape).toEqual([null, 2, 2]); - console.log('===='); // DEBUG const y2 = tfl.layers.dot({axes: 2}).apply([x1, x2]) as tfl.SymbolicTensor; - console.log('y2.shape:', y2.shape); // DEBUG - expect(y2.shape).toEqual([null, 2, 2]); // TODO(cais): Fix test. + expect(y2.shape).toEqual([null, 2, 2]); }); - // TODO(cais): Cover incorrect number of inputs. + it('2D x 3D, axes = -1', () => { + const x1 = new tfl.SymbolicTensor('float32', [null, 3], null, [], null); + const x2 = new tfl.SymbolicTensor('float32', [null, 2, 3], null, [], null); + const y2 = tfl.layers.dot({axes: -1}).apply([x1, x2]) as tfl.SymbolicTensor; + expect(y2.shape).toEqual([null, 2]); + }); + + it('2D x 3D, axes = 1', () => { + const x1 = new tfl.SymbolicTensor('float32', [null, 3], null, [], null); + const x2 = new tfl.SymbolicTensor('float32', [null, 3, 4], null, [], null); + const y2 = tfl.layers.dot({axes: 1}).apply([x1, x2]) as tfl.SymbolicTensor; + expect(y2.shape).toEqual([null, 4]); + }); + + it('3D x 2D, axes = -1', () => { + const x1 = new tfl.SymbolicTensor('float32', [null, 2, 3], null, [], null); + const x2 = new tfl.SymbolicTensor('float32', [null, 3], null, [], null); + const y2 = tfl.layers.dot({axes: -1}).apply([x1, x2]) as tfl.SymbolicTensor; + expect(y2.shape).toEqual([null, 2]); + }); + + it('3D x 2D, axes = -1', () => { + const x1 = new tfl.SymbolicTensor('float32', [null, 3, 4], null, [], null); + const x2 = new tfl.SymbolicTensor('float32', [null, 3], null, [], null); + const y2 = tfl.layers.dot({axes: 1}).apply([x1, x2]) as tfl.SymbolicTensor; + expect(y2.shape).toEqual([null, 4]); + }); + + it('4D x 4D, axes = -1', () => { + const x1 = new tfl.SymbolicTensor( + 'float32', [null, 2, 3, 4], null, [], null); + const x2 = new tfl.SymbolicTensor( + 'float32', [null, 2, 3, 4], null, [], null); + const y = tfl.layers.dot({axes: -1}).apply([x1, x2]) as tfl.SymbolicTensor; + expect(y.shape).toEqual([null, 2, 3, 2, 3]); + }); + + it('Dimension mismatch leads to error', () => { + const x1 = new tfl.SymbolicTensor('float32', [null, 2, 3], null, [], null); + const x2 = new tfl.SymbolicTensor('float32', [null, 4], null, [], null); + expect(() => tfl.layers.dot({axes: -1}).apply([x1, x2])) + .toThrowError('Dimension incompatibility: 3 !== 4'); + }); + + it('Incorrect number of inputs leads to error', () => { + const x1 = new tfl.SymbolicTensor('float32', [null, 2, 3], null, [], null); + const x2 = new tfl.SymbolicTensor('float32', [null, 2, 3], null, [], null); + const x3 = new tfl.SymbolicTensor('float32', [null, 2, 3], null, [], null); + expect(() => tfl.layers.dot({axes: -1}).apply([x1])) + .toThrowError(/should be called on a list of exactly 2 inputs/); + expect(() => tfl.layers.dot({axes: -1}).apply(x1)) + .toThrowError(/should be called on a list of exactly 2 inputs/); + expect(() => tfl.layers.dot({axes: -1}).apply([x1, x2, x3])) + .toThrowError(/should be called on a list of exactly 2 inputs/); + }); + + it('Serialization round trip', () => { + const layer = tfl.layers.dot({axes: -1, normalize: true}); + const pythonicConfig = convertTsToPythonic(layer.getConfig()); + // tslint:disable-next-line:no-any + const tsConfig = convertPythonicToTs(pythonicConfig) as any; + const layerPrime = tfl.layers.dot(tsConfig); + expect(layerPrime.getConfig().axes).toEqual(-1); + expect(layerPrime.getConfig().normalize).toEqual(true); + }); }); + +describeMathCPUAndGPU('Dot-Layer: Tensor', () => { + it('2D x 2D, axis = -1', () => { + const x1 = tensor2d([[10, 20], [30, 40]]); + const x2 = tensor2d([[-1, -2], [-3, -4]]); + const addLayer = tfl.layers.dot({axes: -1}); + const y = addLayer.apply([x1, x2]) as Tensor; + expectTensorsClose(y, tensor2d([[-50], [-250]])); + }); +}); \ No newline at end of file From 46c30e684d192b2dced7be8f199a5362746cdee4 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Sat, 29 Sep 2018 00:23:03 -0400 Subject: [PATCH 03/12] WIP2 --- src/layers/merge_test.ts | 19 +++++++++++++++++++ 1 file changed, 19 insertions(+) diff --git a/src/layers/merge_test.ts b/src/layers/merge_test.ts index 7ddc021c5..bec09fc55 100644 --- a/src/layers/merge_test.ts +++ b/src/layers/merge_test.ts @@ -666,6 +666,25 @@ describeMathCPU('Dot-Layer: Symbolic', () => { }); describeMathCPUAndGPU('Dot-Layer: Tensor', () => { + // Example reference Python Keras code: + // + // ```py + // import keras + // import numpy as np + // + // x1 = keras.Input(shape=[2]) + // x2 = keras.Input(shape=[2]) + // dot_layer = keras.layers.Dot(-11) + // y = dot_layer([x1, x2]) + // + // model = keras.Model([x1, x2], y) + // model.summary() + // + // xs1 = np.array([[10, 20], [30, 40]], dtype=np.float32) + // xs2 = np.array([[-1, -2], [-3, -4]], dtype=np.float32) + // print(model.predict([xs1, xs2])) + // ``` + it('2D x 2D, axis = -1', () => { const x1 = tensor2d([[10, 20], [30, 40]]); const x2 = tensor2d([[-1, -2], [-3, -4]]); From 29508467886f07e3180ec401f92a0abe4e52b1e2 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Sun, 30 Sep 2018 00:06:04 -0400 Subject: [PATCH 04/12] Add tensor-based tests --- src/layers/merge.ts | 43 +++++++++++++------------ src/layers/merge_test.ts | 68 ++++++++++++++++++++++++++++++++++++++-- 2 files changed, 88 insertions(+), 23 deletions(-) diff --git a/src/layers/merge.ts b/src/layers/merge.ts index fc3c6be7e..e1e7cdd16 100644 --- a/src/layers/merge.ts +++ b/src/layers/merge.ts @@ -972,8 +972,16 @@ function batchDot(x: Tensor, y: Tensor, axes: number|number[]): Tensor { * where each entry at index `[i, 0]` will be the dot product between * `a[i, :]` and `b[i, :]`. * - * TODO(cais): Add code snippet. + * Example: * + * ```js + * const dotLayer = tf.layers.dot({axis: -1}); + * const x1 = tf.tensor2d([[10, 20], [30, 40]]); + * const x2 = tf.tensor2d([[-1, -2], [-3, -4]]); + * + * // Invoke the layer's apply() method in eager (imperative) mode. + * const y = dotLayer.apply([x1, x2]); + * ``` */ export class Dot extends Merge { static className = 'Add'; @@ -996,17 +1004,7 @@ export class Dot extends Merge { 'A `Dot` layer should be called on a list of exactly 2 inputs.'); const shape1 = inputShape[0] as Shape; const shape2 = inputShape[1] as Shape; - let axes: number[]; // TODO(cais): Refactor into a function? - if (!Array.isArray(this.axes)) { - // `this.axes` is a single integer. - axes = [ - normalizeAxis(this.axes, shape1.length), - normalizeAxis(this.axes, shape2.length) - ]; - // `this.axes` is an Array of integers. - } else { - axes = this.axes; - } + const axes = this.normalizeAxes(shape1, shape2); if (shape1[axes[0]] !== shape2[axes[1]]) { throw new ValueError( `Dimension incompatibility: ` + @@ -1042,14 +1040,8 @@ export class Dot extends Merge { return batchDot(x1, x2, axes); } - computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] { - tfc.util.assert( - Array.isArray(inputShape) && inputShape.length === 2 && - Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]), - 'A `Dot` layer should be called on a list of exactly 2 inputs.'); - const shape1 = inputShape[0] as Shape; - const shape2 = inputShape[1] as Shape; - let axes: number[]; // TODO(cais): Refactor into a function? + private normalizeAxes(shape1: Shape, shape2: Shape): number[] { + let axes: number[]; if (!Array.isArray(this.axes)) { // `this.axes` is a single integer. axes = [ @@ -1060,6 +1052,17 @@ export class Dot extends Merge { // `this.axes` is an Array of integers. axes = this.axes; } + return axes; + } + + computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] { + tfc.util.assert( + Array.isArray(inputShape) && inputShape.length === 2 && + Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]), + 'A `Dot` layer should be called on a list of exactly 2 inputs.'); + const shape1 = inputShape[0] as Shape; + const shape2 = inputShape[1] as Shape; + const axes = this.normalizeAxes(shape1, shape2); shape1.splice(axes[0], 1); shape2.splice(axes[1], 1); shape2.splice(0, 1); diff --git a/src/layers/merge_test.ts b/src/layers/merge_test.ts index bec09fc55..c4deac906 100644 --- a/src/layers/merge_test.ts +++ b/src/layers/merge_test.ts @@ -688,8 +688,70 @@ describeMathCPUAndGPU('Dot-Layer: Tensor', () => { it('2D x 2D, axis = -1', () => { const x1 = tensor2d([[10, 20], [30, 40]]); const x2 = tensor2d([[-1, -2], [-3, -4]]); - const addLayer = tfl.layers.dot({axes: -1}); - const y = addLayer.apply([x1, x2]) as Tensor; + const dotLayer = tfl.layers.dot({axes: -1}); + const y = dotLayer.apply([x1, x2]) as Tensor; + expectTensorsClose(y, tensor2d([[-50], [-250]])); + }); + + it('2D x 2D, axis = 1', () => { + const x1 = tensor2d([[10, 20], [30, 40]]); + const x2 = tensor2d([[-1, -2], [-3, -4]]); + const dotLayer = tfl.layers.dot({axes: 1}); + const y = dotLayer.apply([x1, x2]) as Tensor; expectTensorsClose(y, tensor2d([[-50], [-250]])); }); -}); \ No newline at end of file + + it('3D x 2D, axis = -1', () => { + const x1 = tensor3d([[[10, 20], [30, 40]], [[4, 3], [2, 1]]]); + const x2 = tensor2d([[-1, -2], [-3, -4]]); + const dotLayer = tfl.layers.dot({axes: -1}); + const y1 = dotLayer.apply([x1, x2]) as Tensor; + expectTensorsClose(y1, tensor2d([[-50, -110], [-24, -10]])); + const x3 = tensor2d([[1, 2], [3, 4]]); + const y2 = dotLayer.apply([x1, x3]) as Tensor; + expectTensorsClose(y2, tensor2d([[50, 110], [24, 10]])); + }); + + it('2D x 3D, axis = -1', () => { + const x1 = tensor2d([[-1, -2], [-3, -4]]); + const x2 = tensor3d([[[10, 20], [30, 40]], [[4, 3], [2, 1]]]); + const dotLayer = tfl.layers.dot({axes: -1}); + const y = dotLayer.apply([x1, x2]) as Tensor; + expectTensorsClose(y, tensor2d([[-50, -110], [-24, -10]])); + }); + + it('2D x 3D, axis = 1', () => { + const x1 = tensor2d([[-1, -2], [-3, -4]]); + const x2 = tensor3d([[[10, 20], [30, 40]], [[4, 3], [2, 1]]]); + const dotLayer = tfl.layers.dot({axes: 1}); + const y = dotLayer.apply([x1, x2]) as Tensor; + expectTensorsClose(y, tensor2d([[-70, -100], [-20, -13]])); + }); + + it('3D x 3D, axis = -1', () => { + const x1 = tensor3d([[[-1, -2], [-3, -4]], [[5, 6], [7, 8]]]); + const x2 = tensor3d([[[10, 20], [30, 40]], [[4, 3], [2, 1]]]); + const dotLayer = tfl.layers.dot({axes: -1}); + const y = dotLayer.apply([x1, x2]) as Tensor; + expectTensorsClose(y, tensor3d( + [[[-50, -110], [-110, -250]], [[38, 16], [52, 22]]])); + }); + + it('3D x 3D, axis = 1', () => { + const x1 = tensor3d([[[-1, -2], [-3, -4]], [[5, 6], [7, 8]]]); + const x2 = tensor3d([[[10, 20], [30, 40]], [[4, 3], [2, 1]]]); + const dotLayer = tfl.layers.dot({axes: 1}); + const y = dotLayer.apply([x1, x2]) as Tensor; + expectTensorsClose(y, tensor3d( + [[[-100, -140], [-140, -200]], [[34, 22], [40, 26]]])); + }); + + it('3D x 3D, axis = [1, 2]', () => { + const x1 = tensor3d([[[-1, -2], [-3, -4]], [[5, 6], [7, 8]]]); + const x2 = tensor3d([[[10, 20], [30, 40]], [[4, 3], [2, 1]]]); + const dotLayer = tfl.layers.dot({axes: [1, 2]}); + const y = dotLayer.apply([x1, x2]) as Tensor; + expectTensorsClose(y, tensor3d( + [[[-70, -150], [-100, -220]], [[41, 17], [48, 20]]])); + }); +}); From a85fba4be0d860af10bd66bd4f4d7f710db8474b Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Sun, 30 Sep 2018 00:13:36 -0400 Subject: [PATCH 05/12] Cleanups --- src/exports_layers.ts | 2 +- src/layers/merge_test.ts | 8 ++++++++ 2 files changed, 9 insertions(+), 1 deletion(-) diff --git a/src/exports_layers.ts b/src/exports_layers.ts index 139bcf7ee..ca2d4d0cd 100644 --- a/src/exports_layers.ts +++ b/src/exports_layers.ts @@ -401,7 +401,7 @@ export function multiply(config?: LayerConfig): Layer { * configParamIndices: [0] * } */ -export function dot(config?: DotLayerConfig): Layer { +export function dot(config: DotLayerConfig): Layer { return new Dot(config); } diff --git a/src/layers/merge_test.ts b/src/layers/merge_test.ts index c4deac906..0e87112ba 100644 --- a/src/layers/merge_test.ts +++ b/src/layers/merge_test.ts @@ -693,6 +693,14 @@ describeMathCPUAndGPU('Dot-Layer: Tensor', () => { expectTensorsClose(y, tensor2d([[-50], [-250]])); }); + it('2D x 2D, axis = -1, normalize = true', () => { + const x1 = tensor2d([[10, 20], [30, 40]]); + const x2 = tensor2d([[-1, -2], [-4, -3]]); + const dotLayer = tfl.layers.dot({axes: -1, normalize: true}); + const y = dotLayer.apply([x1, x2]) as Tensor; + expectTensorsClose(y, tensor2d([[-1], [-0.96]])); + }); + it('2D x 2D, axis = 1', () => { const x1 = tensor2d([[10, 20], [30, 40]]); const x2 = tensor2d([[-1, -2], [-3, -4]]); From 0e56f2a8b2422e80fb62ae7a7877376042d04d69 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Sun, 30 Sep 2018 00:24:26 -0400 Subject: [PATCH 06/12] Throw NotImplementedError from computeMask() --- src/layers/merge.ts | 18 +++++++++++++++--- 1 file changed, 15 insertions(+), 3 deletions(-) diff --git a/src/layers/merge.ts b/src/layers/merge.ts index e1e7cdd16..3110efd3e 100644 --- a/src/layers/merge.ts +++ b/src/layers/merge.ts @@ -228,7 +228,11 @@ export abstract class Merge extends Layer { return outputShape; } - // TODO(cais): Implement computeMask(); + computeMask(inputs: Tensor|Tensor[], mask?: Tensor|Tensor[]): Tensor { + // TODO(cais): Implement computeMask(); + throw new NotImplementedError( + 'computeMask has not been implemented for Merge yet'); + } } /** @@ -800,7 +804,11 @@ export class Concatenate extends Merge { return outputShape; } - // TODO(cais): Implement computeMask(); + computeMask(inputs: Tensor|Tensor[], mask?: Tensor|Tensor[]): Tensor { + // TODO(cais): Implement computeMask(); + throw new NotImplementedError( + 'computeMask has not been implemented for Concatenate yet'); + } getConfig(): serialization.ConfigDict { const config: serialization.ConfigDict = { @@ -1073,7 +1081,11 @@ export class Dot extends Merge { return outputShape; } - // TODO(cais): Implement computeMask(); + computeMask(inputs: Tensor|Tensor[], mask?: Tensor|Tensor[]): Tensor { + // TODO(cais): Implement computeMask(); + throw new NotImplementedError( + 'computeMask has not been implemented for Dot yet'); + } getConfig(): serialization.ConfigDict { const config: serialization.ConfigDict = { From f34445cb3da792d0c2150fb16ce77976a13b84d8 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Wed, 3 Oct 2018 16:50:26 -0400 Subject: [PATCH 07/12] Add guard for tensor inputs of 4+D for Dot --- src/layers/merge.ts | 15 +++++++++++++++ src/layers/merge_test.ts | 19 +++++++++++-------- 2 files changed, 26 insertions(+), 8 deletions(-) diff --git a/src/layers/merge.ts b/src/layers/merge.ts index 3110efd3e..4f135d6ad 100644 --- a/src/layers/merge.ts +++ b/src/layers/merge.ts @@ -907,6 +907,11 @@ function normalizeAxis(axis: number, dim: number): number { } function batchDot(x: Tensor, y: Tensor, axes: number|number[]): Tensor { + if (x.shape.length > 3 || y.shape.length > 3) { + throw new NotImplementedError( + 'batchDot is not implemented for tensors of 4D or higher rank yet'); + } + if (typeof axes === 'number') { axes = [axes, axes]; } @@ -1012,6 +1017,11 @@ export class Dot extends Merge { 'A `Dot` layer should be called on a list of exactly 2 inputs.'); const shape1 = inputShape[0] as Shape; const shape2 = inputShape[1] as Shape; + if (shape1.length > 3 || shape2.length > 3) { + throw new NotImplementedError( + 'Dot layer does not support tensors of 4D or higher rank yet.'); + } + const axes = this.normalizeAxes(shape1, shape2); if (shape1[axes[0]] !== shape2[axes[1]]) { throw new ValueError( @@ -1070,6 +1080,11 @@ export class Dot extends Merge { 'A `Dot` layer should be called on a list of exactly 2 inputs.'); const shape1 = inputShape[0] as Shape; const shape2 = inputShape[1] as Shape; + if (shape1.length > 3 || shape2.length > 3) { + throw new NotImplementedError( + 'Dot layer does not support tensors of 4D or higher rank yet.'); + } + const axes = this.normalizeAxes(shape1, shape2); shape1.splice(axes[0], 1); shape2.splice(axes[1], 1); diff --git a/src/layers/merge_test.ts b/src/layers/merge_test.ts index 0e87112ba..4f74227af 100644 --- a/src/layers/merge_test.ts +++ b/src/layers/merge_test.ts @@ -626,14 +626,17 @@ describeMathCPU('Dot-Layer: Symbolic', () => { expect(y2.shape).toEqual([null, 4]); }); - it('4D x 4D, axes = -1', () => { - const x1 = new tfl.SymbolicTensor( - 'float32', [null, 2, 3, 4], null, [], null); - const x2 = new tfl.SymbolicTensor( - 'float32', [null, 2, 3, 4], null, [], null); - const y = tfl.layers.dot({axes: -1}).apply([x1, x2]) as tfl.SymbolicTensor; - expect(y.shape).toEqual([null, 2, 3, 2, 3]); - }); + // TODO(cais): Uncomment the follow test case when 4D and higher is supported + // by the Dot layer. + // it('4D x 4D, axes = -1', () => { + // const x1 = new tfl.SymbolicTensor( + // 'float32', [null, 2, 3, 4], null, [], null); + // const x2 = new tfl.SymbolicTensor( + // 'float32', [null, 2, 3, 4], null, [], null); + // const y = tfl.layers.dot({axes: -1}).apply([x1, x2]) as + // tfl.SymbolicTensor; + // expect(y.shape).toEqual([null, 2, 3, 2, 3]); + // }); it('Dimension mismatch leads to error', () => { const x1 = new tfl.SymbolicTensor('float32', [null, 2, 3], null, [], null); From 1d3403d553bbebed7c7d8f5bade10f288cfb7a9d Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Wed, 3 Oct 2018 23:53:06 -0400 Subject: [PATCH 08/12] Respond to reviewer comments --- src/layers/merge.ts | 135 ++++++++++++++++++++++++++------------------ 1 file changed, 80 insertions(+), 55 deletions(-) diff --git a/src/layers/merge.ts b/src/layers/merge.ts index 4f135d6ad..ce083a8c6 100644 --- a/src/layers/merge.ts +++ b/src/layers/merge.ts @@ -887,7 +887,7 @@ export interface DotLayerConfig extends LayerConfig { * * Integer or an Array of integers. */ - axes: number|number[]; + axes: number|[number, number]; /** * Whether to L2-normalize samples along the dot product axis @@ -899,22 +899,43 @@ export interface DotLayerConfig extends LayerConfig { normalize?: boolean; } -function normalizeAxis(axis: number, dim: number): number { +/** + * Interpretable potentially negative axis index. + * + * For example, given axis = -1, and dim = 3, this function will return 2. + * + * @param axis The axis index, may be a positive, zero or negative integer. + * @param dim Total number of dimensions, a positive integer. + * @returns A non-negative axis index equivalent to the input `axis`. + */ +function interpretAxis(axis: number, dim: number): number { while (axis < 0) { axis += dim; } return axis; } -function batchDot(x: Tensor, y: Tensor, axes: number|number[]): Tensor { +function batchDot( + x: Tensor, y: Tensor, axes: number|[number]|[number, number]): Tensor { if (x.shape.length > 3 || y.shape.length > 3) { throw new NotImplementedError( 'batchDot is not implemented for tensors of 4D or higher rank yet'); } + tfc.util.assert( + x.shape.length >= 2, + `batchDot requires the rank of x to be >= 2, ` + + `but got ${x.shape.length}`); + tfc.util.assert( + x.shape.length >= 2, + `batchDot requires the rank of y to be >= 2, ` + + `but got ${y.shape.length}`); if (typeof axes === 'number') { axes = [axes, axes]; } + if (Array.isArray(axes) && axes.length === 1) { + axes = [axes[0], axes[0]]; + } if (x.dtype === 'complex64' || y.dtype === 'complex64') { throw new NotImplementedError( @@ -924,57 +945,62 @@ function batchDot(x: Tensor, y: Tensor, axes: number|number[]): Tensor { const xNDim = x.shape.length; const yNDim = y.shape.length; if (axes == null) { + // Behave like batchMatmul by default. axes = [xNDim - 1, yNDim - 2]; } - let diff: number; - if (xNDim > yNDim) { - diff = xNDim - yNDim; - const diffShape: Shape = []; - for (let i = 0; i < diff; ++i) { - diffShape.push(1); - } - y = y.reshape(y.shape.concat(diffShape)); - } else if (yNDim > xNDim) { - diff = yNDim - xNDim; - const diffShape: Shape = []; - for (let i = 0; i < diff; ++i) { - diffShape.push(1); - } - x = x.reshape(x.shape.concat(diffShape)); - } else { - diff = 0; - } + const axesArray = axes as [number, number]; - let out: Tensor; - if (x.shape.length === 2 && y.shape.length === 2) { - if (axes[0] === axes[1]) { - out = x.mulStrict(y).sum(axes[0]); + return tfc.tidy(() => { + let diff: number; + if (xNDim > yNDim) { + diff = xNDim - yNDim; + const diffShape: Shape = []; + for (let i = 0; i < diff; ++i) { + diffShape.push(1); + } + y = y.reshape(y.shape.concat(diffShape)); + } else if (yNDim > xNDim) { + diff = yNDim - xNDim; + const diffShape: Shape = []; + for (let i = 0; i < diff; ++i) { + diffShape.push(1); + } + x = x.reshape(x.shape.concat(diffShape)); } else { - out = x.transpose([1, 0]).mulStrict(y).sum(axes[1]); + diff = 0; } - } else { - const adjX = axes[0] === x.shape.length - 1 ? null : true; - const adjY = axes[1] === y.shape.length - 1 ? true : null; - out = x.matMul(y, adjX, adjY); - } - if (diff > 0) { - let idx: number; - if (xNDim > yNDim) { - idx = xNDim + yNDim - 3; + let out: Tensor; + if (x.shape.length === 2 && y.shape.length === 2) { + if (axesArray[0] === axesArray[1]) { + out = x.mulStrict(y).sum(axesArray[0]); + } else { + out = x.transpose([1, 0]).mulStrict(y).sum(axesArray[1]); + } } else { - idx = xNDim - 1; + const adjX = axesArray[0] === x.shape.length - 1 ? null : true; + const adjY = axesArray[1] === y.shape.length - 1 ? true : null; + out = x.matMul(y, adjX, adjY); } - const squeezeAxes: number[] = []; - for (let i = idx; i < idx + diff; ++i) { - squeezeAxes.push(i); + + if (diff > 0) { + let idx: number; + if (xNDim > yNDim) { + idx = xNDim + yNDim - 3; + } else { + idx = xNDim - 1; + } + const squeezeAxes: number[] = []; + for (let i = idx; i < idx + diff; ++i) { + squeezeAxes.push(i); + } + out = out.squeeze(squeezeAxes); } - out = out.squeeze(squeezeAxes); - } - if (out.shape.length === 1) { - out = out.expandDims(1); - } - return out; + if (out.shape.length === 1) { + out = out.expandDims(1); + } + return out; + }); } /** @@ -999,7 +1025,7 @@ function batchDot(x: Tensor, y: Tensor, axes: number|number[]): Tensor { export class Dot extends Merge { static className = 'Add'; - private axes: number|number[]; + private axes: number|[number, number]; private normalize: boolean; constructor(config: DotLayerConfig) { @@ -1039,17 +1065,16 @@ export class Dot extends Merge { let x1 = inputs[0]; let x2 = inputs[1]; - let axes: number[]; + let axes: [number, number]; if (!Array.isArray(this.axes)) { axes = [ - normalizeAxis(this.axes, x1.shape.length), - normalizeAxis(this.axes, x2.shape.length) + interpretAxis(this.axes, x1.shape.length), + interpretAxis(this.axes, x2.shape.length) ]; } else { - axes = []; - for (let i = 0; i < this.axes.length; ++i) { - axes.push(normalizeAxis(this.axes[i], inputs[i].shape.length)); - } + axes = this.axes.map( + (axis, i) => interpretAxis( + axis, inputs[i].shape.length)) as [number, number]; } if (this.normalize) { x1 = l2Normalize(x1, axes[0]); @@ -1063,8 +1088,8 @@ export class Dot extends Merge { if (!Array.isArray(this.axes)) { // `this.axes` is a single integer. axes = [ - normalizeAxis(this.axes, shape1.length), - normalizeAxis(this.axes, shape2.length) + interpretAxis(this.axes, shape1.length), + interpretAxis(this.axes, shape2.length) ]; } else { // `this.axes` is an Array of integers. From 5fcf292517cdd8aac509697ffec51a1bc69bc4d5 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Wed, 3 Oct 2018 23:54:39 -0400 Subject: [PATCH 09/12] Call registerClass() --- src/layers/merge.ts | 1 + 1 file changed, 1 insertion(+) diff --git a/src/layers/merge.ts b/src/layers/merge.ts index ce083a8c6..786a10618 100644 --- a/src/layers/merge.ts +++ b/src/layers/merge.ts @@ -1137,5 +1137,6 @@ export class Dot extends Merge { return config; } } +serialization.registerClass(Dot); // TODO(cais): Add functional interfaces for the merge layers. From 4aad01fee41d52b87f58401b5466a882c0bf164f Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Wed, 3 Oct 2018 23:55:00 -0400 Subject: [PATCH 10/12] Fix incorrect className --- src/layers/merge.ts | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/layers/merge.ts b/src/layers/merge.ts index 786a10618..aee570ddf 100644 --- a/src/layers/merge.ts +++ b/src/layers/merge.ts @@ -1023,7 +1023,7 @@ function batchDot( * ``` */ export class Dot extends Merge { - static className = 'Add'; + static className = 'Dot'; private axes: number|[number, number]; private normalize: boolean; From 81a929067cc1def4e68d1f37bf14d8a1cb585603 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Wed, 3 Oct 2018 23:56:24 -0400 Subject: [PATCH 11/12] Rename private helper method --- src/layers/merge.ts | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/layers/merge.ts b/src/layers/merge.ts index aee570ddf..74a4e9bb1 100644 --- a/src/layers/merge.ts +++ b/src/layers/merge.ts @@ -1048,7 +1048,7 @@ export class Dot extends Merge { 'Dot layer does not support tensors of 4D or higher rank yet.'); } - const axes = this.normalizeAxes(shape1, shape2); + const axes = this.interpretAxes(shape1, shape2); if (shape1[axes[0]] !== shape2[axes[1]]) { throw new ValueError( `Dimension incompatibility: ` + @@ -1083,7 +1083,7 @@ export class Dot extends Merge { return batchDot(x1, x2, axes); } - private normalizeAxes(shape1: Shape, shape2: Shape): number[] { + private interpretAxes(shape1: Shape, shape2: Shape): number[] { let axes: number[]; if (!Array.isArray(this.axes)) { // `this.axes` is a single integer. @@ -1110,7 +1110,7 @@ export class Dot extends Merge { 'Dot layer does not support tensors of 4D or higher rank yet.'); } - const axes = this.normalizeAxes(shape1, shape2); + const axes = this.interpretAxes(shape1, shape2); shape1.splice(axes[0], 1); shape2.splice(axes[1], 1); shape2.splice(0, 1); From d79104f5211f19091854b043454e390f9957d037 Mon Sep 17 00:00:00 2001 From: Shanqing Cai Date: Thu, 4 Oct 2018 10:11:20 -0400 Subject: [PATCH 12/12] Simplify type signature --- src/layers/merge.ts | 6 +----- src/layers/merge_test.ts | 22 ++++++++++++---------- 2 files changed, 13 insertions(+), 15 deletions(-) diff --git a/src/layers/merge.ts b/src/layers/merge.ts index 74a4e9bb1..f718de671 100644 --- a/src/layers/merge.ts +++ b/src/layers/merge.ts @@ -915,8 +915,7 @@ function interpretAxis(axis: number, dim: number): number { return axis; } -function batchDot( - x: Tensor, y: Tensor, axes: number|[number]|[number, number]): Tensor { +function batchDot(x: Tensor, y: Tensor, axes: number|[number, number]): Tensor { if (x.shape.length > 3 || y.shape.length > 3) { throw new NotImplementedError( 'batchDot is not implemented for tensors of 4D or higher rank yet'); @@ -933,9 +932,6 @@ function batchDot( if (typeof axes === 'number') { axes = [axes, axes]; } - if (Array.isArray(axes) && axes.length === 1) { - axes = [axes[0], axes[0]]; - } if (x.dtype === 'complex64' || y.dtype === 'complex64') { throw new NotImplementedError( diff --git a/src/layers/merge_test.ts b/src/layers/merge_test.ts index 4f74227af..4c5fda9db 100644 --- a/src/layers/merge_test.ts +++ b/src/layers/merge_test.ts @@ -641,8 +641,9 @@ describeMathCPU('Dot-Layer: Symbolic', () => { it('Dimension mismatch leads to error', () => { const x1 = new tfl.SymbolicTensor('float32', [null, 2, 3], null, [], null); const x2 = new tfl.SymbolicTensor('float32', [null, 4], null, [], null); - expect(() => tfl.layers.dot({axes: -1}).apply([x1, x2])) - .toThrowError('Dimension incompatibility: 3 !== 4'); + expect(() => tfl.layers.dot({axes: -1}).apply([ + x1, x2 + ])).toThrowError('Dimension incompatibility: 3 !== 4'); }); it('Incorrect number of inputs leads to error', () => { @@ -653,8 +654,9 @@ describeMathCPU('Dot-Layer: Symbolic', () => { .toThrowError(/should be called on a list of exactly 2 inputs/); expect(() => tfl.layers.dot({axes: -1}).apply(x1)) .toThrowError(/should be called on a list of exactly 2 inputs/); - expect(() => tfl.layers.dot({axes: -1}).apply([x1, x2, x3])) - .toThrowError(/should be called on a list of exactly 2 inputs/); + expect(() => tfl.layers.dot({axes: -1}).apply([ + x1, x2, x3 + ])).toThrowError(/should be called on a list of exactly 2 inputs/); }); it('Serialization round trip', () => { @@ -744,8 +746,8 @@ describeMathCPUAndGPU('Dot-Layer: Tensor', () => { const x2 = tensor3d([[[10, 20], [30, 40]], [[4, 3], [2, 1]]]); const dotLayer = tfl.layers.dot({axes: -1}); const y = dotLayer.apply([x1, x2]) as Tensor; - expectTensorsClose(y, tensor3d( - [[[-50, -110], [-110, -250]], [[38, 16], [52, 22]]])); + expectTensorsClose( + y, tensor3d([[[-50, -110], [-110, -250]], [[38, 16], [52, 22]]])); }); it('3D x 3D, axis = 1', () => { @@ -753,8 +755,8 @@ describeMathCPUAndGPU('Dot-Layer: Tensor', () => { const x2 = tensor3d([[[10, 20], [30, 40]], [[4, 3], [2, 1]]]); const dotLayer = tfl.layers.dot({axes: 1}); const y = dotLayer.apply([x1, x2]) as Tensor; - expectTensorsClose(y, tensor3d( - [[[-100, -140], [-140, -200]], [[34, 22], [40, 26]]])); + expectTensorsClose( + y, tensor3d([[[-100, -140], [-140, -200]], [[34, 22], [40, 26]]])); }); it('3D x 3D, axis = [1, 2]', () => { @@ -762,7 +764,7 @@ describeMathCPUAndGPU('Dot-Layer: Tensor', () => { const x2 = tensor3d([[[10, 20], [30, 40]], [[4, 3], [2, 1]]]); const dotLayer = tfl.layers.dot({axes: [1, 2]}); const y = dotLayer.apply([x1, x2]) as Tensor; - expectTensorsClose(y, tensor3d( - [[[-70, -150], [-100, -220]], [[41, 17], [48, 20]]])); + expectTensorsClose( + y, tensor3d([[[-70, -150], [-100, -220]], [[41, 17], [48, 20]]])); }); });