Skip to content
This repository has been archived by the owner on Oct 17, 2021. It is now read-only.

Add layer: tf.layers.dot #330

Merged
merged 13 commits into from
Oct 4, 2018
15 changes: 14 additions & 1 deletion src/exports_layers.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ import {Conv1D, Conv2D, Conv2DTranspose, ConvLayerConfig, Cropping2D, Cropping2D
import {DepthwiseConv2D, DepthwiseConv2DLayerConfig} from './layers/convolutional_depthwise';
import {Activation, ActivationLayerConfig, Dense, DenseLayerConfig, Dropout, DropoutLayerConfig, Flatten, Permute, PermuteLayerConfig, RepeatVector, RepeatVectorLayerConfig, Reshape, ReshapeLayerConfig} from './layers/core';
import {Embedding, EmbeddingLayerConfig} from './layers/embeddings';
import {Add, Average, Concatenate, ConcatenateLayerConfig, Maximum, Minimum, Multiply} from './layers/merge';
import {Add, Average, Concatenate, ConcatenateLayerConfig, Dot, DotLayerConfig, Maximum, Minimum, Multiply} from './layers/merge';
import {BatchNormalization, BatchNormalizationLayerConfig} from './layers/normalization';
import {ZeroPadding2D, ZeroPadding2DLayerConfig} from './layers/padding';
import {AveragePooling1D, AveragePooling2D, GlobalAveragePooling1D, GlobalAveragePooling2D, GlobalMaxPooling1D, GlobalMaxPooling2D, GlobalPooling2DLayerConfig, MaxPooling1D, MaxPooling2D, Pooling1DLayerConfig, Pooling2DLayerConfig} from './layers/pooling';
Expand Down Expand Up @@ -392,6 +392,19 @@ export function multiply(config?: LayerConfig): Layer {
return new Multiply(config);
}

/**
* @doc {
* heading: 'Layers',
* subheading: 'Merge',
* namespace: 'layers',
* useDocsFrom: 'Dot',
* configParamIndices: [0]
* }
*/
export function dot(config: DotLayerConfig): Layer {
return new Dot(config);
}

// Normalization Layers.

/**
Expand Down
267 changes: 264 additions & 3 deletions src/layers/merge.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import {getScalar} from '../backend/state';
import * as K from '../backend/tfjs_backend';
import {Layer, LayerConfig, SymbolicTensor} from '../engine/topology';
import {NotImplementedError, ValueError} from '../errors';
import {l2Normalize} from '../losses';
import {Kwargs, Shape} from '../types';
import * as generic_utils from '../utils/generic_utils';
import * as mathUtils from '../utils/math_utils';
Expand Down Expand Up @@ -227,7 +228,11 @@ export abstract class Merge extends Layer {
return outputShape;
}

// TODO(cais): Implement computeMask();
computeMask(inputs: Tensor|Tensor[], mask?: Tensor|Tensor[]): Tensor {
// TODO(cais): Implement computeMask();
throw new NotImplementedError(
'computeMask has not been implemented for Merge yet');
}
}

/**
Expand Down Expand Up @@ -799,7 +804,11 @@ export class Concatenate extends Merge {
return outputShape;
}

// TODO(cais): Implement computeMask();
computeMask(inputs: Tensor|Tensor[], mask?: Tensor|Tensor[]): Tensor {
// TODO(cais): Implement computeMask();
throw new NotImplementedError(
'computeMask has not been implemented for Concatenate yet');
}

getConfig(): serialization.ConfigDict {
const config: serialization.ConfigDict = {
Expand Down Expand Up @@ -872,6 +881,258 @@ export function concatenate(config?: SymbolicTensor[]|Tensor[]|
}
}

// TODO(cais): Add class Dot.
export interface DotLayerConfig extends LayerConfig {
/**
* Axis or axes along which the dot product will be taken.
*
* Integer or an Array of integers.
*/
axes: number|[number, number];

/**
* Whether to L2-normalize samples along the dot product axis
* before taking the dot product.
*
* If set to `true`, the output of the dot product isthe cosine
* proximity between the two samples.
*/
normalize?: boolean;
}

/**
* Interpretable potentially negative axis index.
*
* For example, given axis = -1, and dim = 3, this function will return 2.
*
* @param axis The axis index, may be a positive, zero or negative integer.
* @param dim Total number of dimensions, a positive integer.
* @returns A non-negative axis index equivalent to the input `axis`.
*/
function interpretAxis(axis: number, dim: number): number {
while (axis < 0) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, it seems possible that this results in an infinite loop if the user tries to dot product a scalar. Can we throw an error message instead?

axis += dim;
}
return axis;
}

function batchDot(x: Tensor, y: Tensor, axes: number|[number, number]): Tensor {
if (x.shape.length > 3 || y.shape.length > 3) {
throw new NotImplementedError(
'batchDot is not implemented for tensors of 4D or higher rank yet');
}
tfc.util.assert(
x.shape.length >= 2,
`batchDot requires the rank of x to be >= 2, ` +
`but got ${x.shape.length}`);
tfc.util.assert(
x.shape.length >= 2,
`batchDot requires the rank of y to be >= 2, ` +
`but got ${y.shape.length}`);

if (typeof axes === 'number') {
axes = [axes, axes];
}

if (x.dtype === 'complex64' || y.dtype === 'complex64') {
throw new NotImplementedError(
'batchDot is not implemented for complex64-type Tensors yet.');
}

const xNDim = x.shape.length;
const yNDim = y.shape.length;
if (axes == null) {
// Behave like batchMatmul by default.
axes = [xNDim - 1, yNDim - 2];
}
const axesArray = axes as [number, number];

return tfc.tidy(() => {
let diff: number;
if (xNDim > yNDim) {
diff = xNDim - yNDim;
const diffShape: Shape = [];
for (let i = 0; i < diff; ++i) {
diffShape.push(1);
}
y = y.reshape(y.shape.concat(diffShape));
} else if (yNDim > xNDim) {
diff = yNDim - xNDim;
const diffShape: Shape = [];
for (let i = 0; i < diff; ++i) {
diffShape.push(1);
}
x = x.reshape(x.shape.concat(diffShape));
} else {
diff = 0;
}

let out: Tensor;
if (x.shape.length === 2 && y.shape.length === 2) {
if (axesArray[0] === axesArray[1]) {
out = x.mulStrict(y).sum(axesArray[0]);
} else {
out = x.transpose([1, 0]).mulStrict(y).sum(axesArray[1]);
}
} else {
const adjX = axesArray[0] === x.shape.length - 1 ? null : true;
const adjY = axesArray[1] === y.shape.length - 1 ? true : null;
out = x.matMul(y, adjX, adjY);
}

if (diff > 0) {
let idx: number;
if (xNDim > yNDim) {
idx = xNDim + yNDim - 3;
} else {
idx = xNDim - 1;
}
const squeezeAxes: number[] = [];
for (let i = idx; i < idx + diff; ++i) {
squeezeAxes.push(i);
}
out = out.squeeze(squeezeAxes);
}
if (out.shape.length === 1) {
out = out.expandDims(1);
}
return out;
});
}

/**
* Layer that computes a dot product between samples in two tensors.
*
* E.g., if applied to a list of two tensors `a` and `b` both of shape
* `[batchSize, n]`, the output will be a tensor of shape `[batchSize, 1]`,
* where each entry at index `[i, 0]` will be the dot product between
* `a[i, :]` and `b[i, :]`.
*
* Example:
*
* ```js
* const dotLayer = tf.layers.dot({axis: -1});
* const x1 = tf.tensor2d([[10, 20], [30, 40]]);
* const x2 = tf.tensor2d([[-1, -2], [-3, -4]]);
*
* // Invoke the layer's apply() method in eager (imperative) mode.
* const y = dotLayer.apply([x1, x2]);
* ```
*/
export class Dot extends Merge {
static className = 'Dot';

private axes: number|[number, number];
private normalize: boolean;

constructor(config: DotLayerConfig) {
super(config);
this.axes = config.axes;
this.normalize = config.normalize == null ? false : config.normalize;
this.supportsMasking = true;
this.reshapeRequired = false;
}

build(inputShape: Shape|Shape[]): void {
tfc.util.assert(
Array.isArray(inputShape) && inputShape.length === 2 &&
Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]),
'A `Dot` layer should be called on a list of exactly 2 inputs.');
const shape1 = inputShape[0] as Shape;
const shape2 = inputShape[1] as Shape;
if (shape1.length > 3 || shape2.length > 3) {
throw new NotImplementedError(
'Dot layer does not support tensors of 4D or higher rank yet.');
}

const axes = this.interpretAxes(shape1, shape2);
if (shape1[axes[0]] !== shape2[axes[1]]) {
throw new ValueError(
`Dimension incompatibility: ` +
`${shape1[axes[0]]} !== ${shape2[axes[1]]}`);
}
}

protected mergeFunction(inputs: Tensor[]): Tensor {
if (inputs.length !== 2) {
throw new ValueError(
'A `Dot` layer must be called on exactly 2 inputs, ' +
`but received ${inputs.length} input(s).`);
}

let x1 = inputs[0];
let x2 = inputs[1];
let axes: [number, number];
if (!Array.isArray(this.axes)) {
axes = [
interpretAxis(this.axes, x1.shape.length),
interpretAxis(this.axes, x2.shape.length)
];
} else {
axes = this.axes.map(
(axis, i) => interpretAxis(
axis, inputs[i].shape.length)) as [number, number];
}
if (this.normalize) {
x1 = l2Normalize(x1, axes[0]);
x2 = l2Normalize(x2, axes[1]);
}
return batchDot(x1, x2, axes);
}

private interpretAxes(shape1: Shape, shape2: Shape): number[] {
let axes: number[];
if (!Array.isArray(this.axes)) {
// `this.axes` is a single integer.
axes = [
interpretAxis(this.axes, shape1.length),
interpretAxis(this.axes, shape2.length)
];
} else {
// `this.axes` is an Array of integers.
axes = this.axes;
}
return axes;
}

computeOutputShape(inputShape: Shape|Shape[]): Shape|Shape[] {
tfc.util.assert(
Array.isArray(inputShape) && inputShape.length === 2 &&
Array.isArray(inputShape[0]) && Array.isArray(inputShape[1]),
'A `Dot` layer should be called on a list of exactly 2 inputs.');
const shape1 = inputShape[0] as Shape;
const shape2 = inputShape[1] as Shape;
if (shape1.length > 3 || shape2.length > 3) {
throw new NotImplementedError(
'Dot layer does not support tensors of 4D or higher rank yet.');
}

const axes = this.interpretAxes(shape1, shape2);
shape1.splice(axes[0], 1);
shape2.splice(axes[1], 1);
shape2.splice(0, 1);
const outputShape = shape1.concat(shape2);
if (outputShape.length === 1) {
outputShape.push(1);
}
return outputShape;
}

computeMask(inputs: Tensor|Tensor[], mask?: Tensor|Tensor[]): Tensor {
// TODO(cais): Implement computeMask();
throw new NotImplementedError(
'computeMask has not been implemented for Dot yet');
}

getConfig(): serialization.ConfigDict {
const config: serialization.ConfigDict = {
'axes': this.axes,
'normalize': this.normalize
};
const baseConfig = super.getConfig();
Object.assign(config, baseConfig);
return config;
}
}
serialization.registerClass(Dot);

// TODO(cais): Add functional interfaces for the merge layers.
Loading