Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Serialize string tensors as encoded (raw bytes) #1816

Merged
merged 16 commits into from
Jun 27, 2019
14 changes: 7 additions & 7 deletions src/backends/backend.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util';
import {Activation} from '../ops/fused_util';
import {Backend, DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
import {DataType, DataValues, PixelData, Rank, ShapeMap} from '../types';
import {BackendValues, DataType, PixelData, Rank, ShapeMap} from '../types';

export const EPSILON_FLOAT32 = 1e-7;
export const EPSILON_FLOAT16 = 1e-4;
Expand All @@ -31,10 +31,10 @@ export interface BackendTimingInfo {
}

export interface TensorStorage {
read(dataId: DataId): Promise<DataValues>;
readSync(dataId: DataId): DataValues;
read(dataId: DataId): Promise<BackendValues>;
readSync(dataId: DataId): BackendValues;
disposeData(dataId: DataId): void;
write(dataId: DataId, values: DataValues): void;
write(dataId: DataId, values: BackendValues): void;
fromPixels(
pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement|
HTMLVideoElement,
Expand Down Expand Up @@ -92,16 +92,16 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
time(f: () => void): Promise<BackendTimingInfo> {
throw new Error('Not yet implemented.');
}
read(dataId: object): Promise<DataValues> {
read(dataId: object): Promise<BackendValues> {
throw new Error('Not yet implemented.');
}
readSync(dataId: object): DataValues {
readSync(dataId: object): BackendValues {
throw new Error('Not yet implemented.');
}
disposeData(dataId: object): void {
throw new Error('Not yet implemented.');
}
write(dataId: object, values: DataValues): void {
write(dataId: object, values: BackendValues): void {
throw new Error('Not yet implemented.');
}
fromPixels(
Expand Down
22 changes: 16 additions & 6 deletions src/backends/cpu/backend_cpu.ts
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ import * as scatter_nd_util from '../../ops/scatter_nd_util';
import * as selu_util from '../../ops/selu_util';
import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../ops/slice_util';
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor';
import {DataType, DataTypeMap, DataValues, NumericDataType, PixelData, Rank, ShapeMap, TypedArray, upcastType} from '../../types';
import {BackendValues, DataType, DataValues, NumericDataType, PixelData, Rank, ShapeMap, TypedArray, upcastType} from '../../types';
import * as util from '../../util';
import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util';
import {BackendTimingInfo, DataStorage, EPSILON_FLOAT32, KernelBackend} from '../backend';
Expand All @@ -58,7 +58,7 @@ function mapActivation(
}

interface TensorData<D extends DataType> {
values?: DataTypeMap[D];
values?: BackendValues;
dtype: D;
// For complex numbers, the real and imaginary parts are stored as their own
// individual tensors, with a parent joining the two with the
Expand Down Expand Up @@ -116,7 +116,7 @@ export class MathBackendCPU implements KernelBackend {
}
this.data.set(dataId, {dtype});
}
write(dataId: DataId, values: DataValues): void {
write(dataId: DataId, values: BackendValues): void {
if (values == null) {
throw new Error('MathBackendCPU.write(): values can not be null');
}
Expand Down Expand Up @@ -186,10 +186,10 @@ export class MathBackendCPU implements KernelBackend {
[pixels.height, pixels.width, numChannels];
return tensor3d(values, outShape, 'int32');
}
async read(dataId: DataId): Promise<DataValues> {
async read(dataId: DataId): Promise<BackendValues> {
return this.readSync(dataId);
}
readSync(dataId: DataId): DataValues {
readSync(dataId: DataId): BackendValues {
const {dtype, complexTensors} = this.data.get(dataId);
if (dtype === 'complex64') {
const realValues =
Expand All @@ -202,7 +202,17 @@ export class MathBackendCPU implements KernelBackend {
}

private bufferSync<R extends Rank>(t: Tensor<R>): TensorBuffer<R> {
return buffer(t.shape, t.dtype, this.readSync(t.dataId)) as TensorBuffer<R>;
const data = this.readSync(t.dataId);
let decodedData = data as DataValues;
if (t.dtype === 'string') {
try {
// Decode the bytes into string.
decodedData = (data as Uint8Array[]).map(d => util.decodeString(d));
} catch {
throw new Error('Failed to decode encoded string bytes into utf-8');
}
}
return buffer(t.shape, t.dtype, decodedData) as TensorBuffer<R>;
}

disposeData(dataId: DataId): void {
Expand Down
29 changes: 23 additions & 6 deletions src/backends/cpu/backend_cpu_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,21 @@ import * as tf from '../../index';
import {describeWithFlags} from '../../jasmine_util';
import {tensor2d} from '../../ops/ops';
import {expectArraysClose, expectArraysEqual} from '../../test_util';
import {decodeString, encodeString} from '../../util';

import {MathBackendCPU} from './backend_cpu';
import {CPU_ENVS} from './backend_cpu_test_registry';

/** Private test util for encoding array of strings in utf-8. */
function encodeStrings(a: string[]): Uint8Array[] {
return a.map(s => encodeString(s));
}

/** Private test util for decoding array of strings in utf-8. */
function decodeStrings(bytes: Uint8Array[]): string[] {
return bytes.map(b => decodeString(b));
}

describeWithFlags('backendCPU', CPU_ENVS, () => {
let backend: MathBackendCPU;
beforeEach(() => {
Expand All @@ -36,19 +47,25 @@ describeWithFlags('backendCPU', CPU_ENVS, () => {

it('register empty string tensor and write', () => {
const t = tf.Tensor.make([3], {}, 'string');
backend.write(t.dataId, ['c', 'a', 'b']);
expectArraysEqual(backend.readSync(t.dataId), ['c', 'a', 'b']);
backend.write(t.dataId, encodeStrings(['c', 'a', 'b']));
expectArraysEqual(
decodeStrings(backend.readSync(t.dataId) as Uint8Array[]),
['c', 'a', 'b']);
});

it('register string tensor with values', () => {
const t = tf.Tensor.make([3], {values: ['a', 'b', 'c']}, 'string');
expectArraysEqual(backend.readSync(t.dataId), ['a', 'b', 'c']);
expectArraysEqual(
decodeStrings(backend.readSync(t.dataId) as Uint8Array[]),
['a', 'b', 'c']);
});

it('register string tensor with values and overwrite', () => {
const t = tf.Tensor.make([3], {values: ['a', 'b', 'c']}, 'string');
backend.write(t.dataId, ['c', 'a', 'b']);
expectArraysEqual(backend.readSync(t.dataId), ['c', 'a', 'b']);
backend.write(t.dataId, encodeStrings(['c', 'a', 'b']));
expectArraysEqual(
decodeStrings(backend.readSync(t.dataId) as Uint8Array[]),
['c', 'a', 'b']);
});

it('register string tensor with values and mismatched shape', () => {
Expand Down Expand Up @@ -129,7 +146,7 @@ describeWithFlags('memory cpu', CPU_ENVS, () => {
const mem = tf.memory();
expect(mem.numTensors).toBe(2);
expect(mem.numDataBuffers).toBe(2);
expect(mem.numBytes).toBe(6);
expect(mem.numBytes).toBe(5);
expect(mem.unreliable).toBe(true);

const expectedReasonGC =
Expand Down
12 changes: 7 additions & 5 deletions src/backends/webgl/backend_webgl.ts
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../o
import {softmax} from '../../ops/softmax';
import {range, scalar, tensor} from '../../ops/tensor_ops';
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor';
import {DataType, DataTypeMap, DataValues, NumericDataType, PixelData, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../../types';
import {BackendValues, DataType, DataTypeMap, NumericDataType, PixelData, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../../types';
import * as util from '../../util';
import {getArrayFromDType, getTypedArrayFromDType, inferDtype, sizeFromShape} from '../../util';
import {DataStorage, EPSILON_FLOAT16, EPSILON_FLOAT32, KernelBackend} from '../backend';
Expand Down Expand Up @@ -339,7 +339,7 @@ export class MathBackendWebGL implements KernelBackend {
return {dataId, shape, dtype};
}

write(dataId: DataId, values: DataValues): void {
write(dataId: DataId, values: BackendValues): void {
if (values == null) {
throw new Error('MathBackendWebGL.write(): values can not be null');
}
Expand All @@ -366,7 +366,7 @@ export class MathBackendWebGL implements KernelBackend {
texData.values = values;
}

readSync(dataId: DataId): DataValues {
readSync(dataId: DataId): BackendValues {
const texData = this.texData.get(dataId);
const {values, dtype, complexTensors, slice, shape} = texData;
if (slice != null) {
Expand Down Expand Up @@ -403,7 +403,7 @@ export class MathBackendWebGL implements KernelBackend {
return this.convertAndCacheOnCPU(dataId, result);
}

async read(dataId: DataId): Promise<DataValues> {
async read(dataId: DataId): Promise<BackendValues> {
if (this.pendingRead.has(dataId)) {
const subscribers = this.pendingRead.get(dataId);
return new Promise<TypedArray>(resolve => subscribers.push(resolve));
Expand Down Expand Up @@ -961,7 +961,9 @@ export class MathBackendWebGL implements KernelBackend {

tile<T extends Tensor>(x: T, reps: number[]): T {
if (x.dtype === 'string') {
const buf = buffer(x.shape, x.dtype, this.readSync(x.dataId) as string[]);
const data = this.readSync(x.dataId) as Uint8Array[];
const decodedData = data.map(d => util.decodeString(d));
const buf = buffer(x.shape, x.dtype, decodedData);
return tile(buf, reps) as T;
}
const program = new TileProgram(x.shape, reps);
Expand Down
26 changes: 21 additions & 5 deletions src/backends/webgl/backend_webgl_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,19 @@
import * as tf from '../../index';
import {describeWithFlags} from '../../jasmine_util';
import {expectArraysClose, expectArraysEqual} from '../../test_util';
import {decodeString, encodeString} from '../../util';

import {MathBackendWebGL, WebGLMemoryInfo} from './backend_webgl';
import {WEBGL_ENVS} from './backend_webgl_test_registry';

function encodeStrings(a: string[]): Uint8Array[] {
return a.map(s => encodeString(s));
}

function decodeStrings(bytes: Uint8Array[]): string[] {
return bytes.map(b => decodeString(b));
}

describeWithFlags('lazy packing and unpacking', WEBGL_ENVS, () => {
let webglLazilyUnpackFlagSaved: boolean;
let webglCpuForwardFlagSaved: boolean;
Expand Down Expand Up @@ -126,8 +136,10 @@ describeWithFlags('backendWebGL', WEBGL_ENVS, () => {
tf.setBackend('test-storage');

const t = tf.Tensor.make([3], {}, 'string');
backend.write(t.dataId, ['c', 'a', 'b']);
expectArraysEqual(backend.readSync(t.dataId), ['c', 'a', 'b']);
backend.write(t.dataId, encodeStrings(['c', 'a', 'b']));
expectArraysEqual(
decodeStrings(backend.readSync(t.dataId) as Uint8Array[]),
['c', 'a', 'b']);
});

it('register string tensor with values', () => {
Expand All @@ -136,7 +148,9 @@ describeWithFlags('backendWebGL', WEBGL_ENVS, () => {
tf.setBackend('test-storage');

const t = tf.Tensor.make([3], {values: ['a', 'b', 'c']}, 'string');
expectArraysEqual(backend.readSync(t.dataId), ['a', 'b', 'c']);
expectArraysEqual(
decodeStrings(backend.readSync(t.dataId) as Uint8Array[]),
['a', 'b', 'c']);
});

it('register string tensor with values and overwrite', () => {
Expand All @@ -145,8 +159,10 @@ describeWithFlags('backendWebGL', WEBGL_ENVS, () => {
tf.setBackend('test-storage');

const t = tf.Tensor.make([3], {values: ['a', 'b', 'c']}, 'string');
backend.write(t.dataId, ['c', 'a', 'b']);
expectArraysEqual(backend.readSync(t.dataId), ['c', 'a', 'b']);
backend.write(t.dataId, encodeStrings(['c', 'a', 'b']));
expectArraysEqual(
decodeStrings(backend.readSync(t.dataId) as Uint8Array[]),
['c', 'a', 'b']);
});

it('register string tensor with values and wrong shape throws error', () => {
Expand Down
4 changes: 2 additions & 2 deletions src/backends/webgl/tex_util.ts
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
*/

import {DataId, Tensor} from '../../tensor';
import {DataType, DataValues} from '../../types';
import {BackendValues, DataType} from '../../types';
import * as util from '../../util';

export enum TextureUsage {
Expand All @@ -40,7 +40,7 @@ export interface TextureData {
dtype: DataType;

// Optional.
values?: DataValues;
values?: BackendValues;
texture?: WebGLTexture;
// For complex numbers, the real and imaginary parts are stored as their own
// individual tensors, with a parent joining the two with the
Expand Down
11 changes: 6 additions & 5 deletions src/engine.ts
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ import {backpropagateGradients, getFilteredNodesXToY, NamedGradientMap, TapeNode
import {DataId, setTensorTracker, Tensor, Tensor3D, TensorTracker, Variable} from './tensor';
import {GradSaveFunc, NamedTensorMap, NamedVariableMap, TensorContainer} from './tensor_types';
import {getTensorsInContainer} from './tensor_util';
import {DataType, DataValues, PixelData} from './types';
import {BackendValues, DataType, PixelData} from './types';
import * as util from './util';
import {bytesFromStringArray, makeOnesTypedArray, now, sizeFromShape} from './util';

Expand Down Expand Up @@ -830,15 +830,16 @@ export class Engine implements TensorManager, TensorTracker, DataMover {
}

// Forwarding to backend.
write(destBackend: KernelBackend, dataId: DataId, values: DataValues): void {
write(destBackend: KernelBackend, dataId: DataId, values: BackendValues):
void {
const info = this.state.tensorInfo.get(dataId);

const srcBackend = info.backend;
destBackend = destBackend || this.backend;

// Bytes for string tensors are counted when writing.
if (info.dtype === 'string') {
const newBytes = bytesFromStringArray(values as string[]);
const newBytes = bytesFromStringArray(values as Uint8Array[]);
this.state.numBytes += newBytes - info.bytes;
info.bytes = newBytes;
}
Expand All @@ -852,12 +853,12 @@ export class Engine implements TensorManager, TensorTracker, DataMover {
}
destBackend.write(dataId, values);
}
readSync(dataId: DataId): DataValues {
readSync(dataId: DataId): BackendValues {
// Route the read to the correct backend.
const info = this.state.tensorInfo.get(dataId);
return info.backend.readSync(dataId);
}
read(dataId: DataId): Promise<DataValues> {
read(dataId: DataId): Promise<BackendValues> {
// Route the read to the correct backend.
const info = this.state.tensorInfo.get(dataId);
return info.backend.read(dataId);
Expand Down
2 changes: 1 addition & 1 deletion src/engine_test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -349,7 +349,7 @@ describeWithFlags('memory', ALL_ENVS, () => {
const a = tf.tensor([['a', 'bb'], ['c', 'd']]);

expect(tf.memory().numTensors).toBe(1);
expect(tf.memory().numBytes).toBe(10); // 5 letters, each 2 bytes.
expect(tf.memory().numBytes).toBe(5); // 5 letters, each 1 byte in utf8.

a.dispose();

Expand Down
3 changes: 1 addition & 2 deletions src/io/io.ts
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ import {browserHTTPRequest, http, isHTTPScheme} from './http';
import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsInfoForJSON} from './io_utils';
import {fromMemory, withSaveHandler} from './passthrough';
import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry';
import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelJSON, ModelStoreManager, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, StringWeightsManifestEntry, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelJSON, ModelStoreManager, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types';
import {loadWeights, weightsLoaderFactory} from './weights_loader';

export {copyModel, listModels, moveModel, removeModel} from './model_management';
Expand Down Expand Up @@ -54,7 +54,6 @@ export {
SaveConfig,
SaveHandler,
SaveResult,
StringWeightsManifestEntry,
WeightGroup,
weightsLoaderFactory,
WeightsManifestConfig,
Expand Down
Loading