Skip to content
This repository was archived by the owner on Aug 15, 2019. It is now read-only.

Commit b5854a2

Browse files
authored
Serialize string tensors as encoded (raw bytes) (#1816)
FEATURE To align with TensorFlow Python/C++, this PR changes the way we serialize strings in the weights format, and in our engine. - We store the underlying encoded string bytes as `Uint8Array`. Thus a string tensors (which has multiple strings) is backed by `Uint8Array[]`. - To keep backwards compatibility, `tensor.data()` returns `string[]`, which means that we try to utf-8 decode a string. - In the weights format, string bytes are kept unchanged, with their original encoding. Each string is prefixed with 4 bytes denoting the number of bytes in the string. Thus, a string tensor of 3 values will be encoded as `[4 bytes][string1...][4 bytes][string2...][4 bytes][string3....]` - Add `util.encodeString(text: string, encoding?: string)` and `util.decodeString(bytes: Uint8Array, encoding?: string)`, along with the respective `Platform` methods - Add `tensor.bytes()` which gives the underlying bytes of the data - `Uint8Array` for any numeric tensor, and `Uint8Array[]` for string tensors. Corresponding change in tfjs-converter: tensorflow/tfjs-converter#386
1 parent 4cb7fda commit b5854a2

24 files changed

+590
-227
lines changed

src/backends/backend.ts

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import {Conv2DInfo, Conv3DInfo} from '../ops/conv_util';
1919
import {Activation} from '../ops/fused_util';
2020
import {Backend, DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../tensor';
21-
import {DataType, DataValues, PixelData, Rank, ShapeMap} from '../types';
21+
import {BackendValues, DataType, PixelData, Rank, ShapeMap} from '../types';
2222

2323
export const EPSILON_FLOAT32 = 1e-7;
2424
export const EPSILON_FLOAT16 = 1e-4;
@@ -31,10 +31,10 @@ export interface BackendTimingInfo {
3131
}
3232

3333
export interface TensorStorage {
34-
read(dataId: DataId): Promise<DataValues>;
35-
readSync(dataId: DataId): DataValues;
34+
read(dataId: DataId): Promise<BackendValues>;
35+
readSync(dataId: DataId): BackendValues;
3636
disposeData(dataId: DataId): void;
37-
write(dataId: DataId, values: DataValues): void;
37+
write(dataId: DataId, values: BackendValues): void;
3838
fromPixels(
3939
pixels: PixelData|ImageData|HTMLImageElement|HTMLCanvasElement|
4040
HTMLVideoElement,
@@ -92,16 +92,16 @@ export class KernelBackend implements TensorStorage, Backend, BackendTimer {
9292
time(f: () => void): Promise<BackendTimingInfo> {
9393
throw new Error('Not yet implemented.');
9494
}
95-
read(dataId: object): Promise<DataValues> {
95+
read(dataId: object): Promise<BackendValues> {
9696
throw new Error('Not yet implemented.');
9797
}
98-
readSync(dataId: object): DataValues {
98+
readSync(dataId: object): BackendValues {
9999
throw new Error('Not yet implemented.');
100100
}
101101
disposeData(dataId: object): void {
102102
throw new Error('Not yet implemented.');
103103
}
104-
write(dataId: object, values: DataValues): void {
104+
write(dataId: object, values: BackendValues): void {
105105
throw new Error('Not yet implemented.');
106106
}
107107
fromPixels(

src/backends/cpu/backend_cpu.ts

Lines changed: 16 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,7 @@ import * as scatter_nd_util from '../../ops/scatter_nd_util';
3434
import * as selu_util from '../../ops/selu_util';
3535
import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../ops/slice_util';
3636
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D, TensorBuffer} from '../../tensor';
37-
import {DataType, DataTypeMap, DataValues, NumericDataType, PixelData, Rank, ShapeMap, TypedArray, upcastType} from '../../types';
37+
import {BackendValues, DataType, DataValues, NumericDataType, PixelData, Rank, ShapeMap, TypedArray, upcastType} from '../../types';
3838
import * as util from '../../util';
3939
import {getArrayFromDType, inferDtype, now, sizeFromShape} from '../../util';
4040
import {BackendTimingInfo, DataStorage, EPSILON_FLOAT32, KernelBackend} from '../backend';
@@ -58,7 +58,7 @@ function mapActivation(
5858
}
5959

6060
interface TensorData<D extends DataType> {
61-
values?: DataTypeMap[D];
61+
values?: BackendValues;
6262
dtype: D;
6363
// For complex numbers, the real and imaginary parts are stored as their own
6464
// individual tensors, with a parent joining the two with the
@@ -116,7 +116,7 @@ export class MathBackendCPU implements KernelBackend {
116116
}
117117
this.data.set(dataId, {dtype});
118118
}
119-
write(dataId: DataId, values: DataValues): void {
119+
write(dataId: DataId, values: BackendValues): void {
120120
if (values == null) {
121121
throw new Error('MathBackendCPU.write(): values can not be null');
122122
}
@@ -186,10 +186,10 @@ export class MathBackendCPU implements KernelBackend {
186186
[pixels.height, pixels.width, numChannels];
187187
return tensor3d(values, outShape, 'int32');
188188
}
189-
async read(dataId: DataId): Promise<DataValues> {
189+
async read(dataId: DataId): Promise<BackendValues> {
190190
return this.readSync(dataId);
191191
}
192-
readSync(dataId: DataId): DataValues {
192+
readSync(dataId: DataId): BackendValues {
193193
const {dtype, complexTensors} = this.data.get(dataId);
194194
if (dtype === 'complex64') {
195195
const realValues =
@@ -202,7 +202,17 @@ export class MathBackendCPU implements KernelBackend {
202202
}
203203

204204
private bufferSync<R extends Rank>(t: Tensor<R>): TensorBuffer<R> {
205-
return buffer(t.shape, t.dtype, this.readSync(t.dataId)) as TensorBuffer<R>;
205+
const data = this.readSync(t.dataId);
206+
let decodedData = data as DataValues;
207+
if (t.dtype === 'string') {
208+
try {
209+
// Decode the bytes into string.
210+
decodedData = (data as Uint8Array[]).map(d => util.decodeString(d));
211+
} catch {
212+
throw new Error('Failed to decode encoded string bytes into utf-8');
213+
}
214+
}
215+
return buffer(t.shape, t.dtype, decodedData) as TensorBuffer<R>;
206216
}
207217

208218
disposeData(dataId: DataId): void {

src/backends/cpu/backend_cpu_test.ts

Lines changed: 23 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,21 @@ import * as tf from '../../index';
1919
import {describeWithFlags} from '../../jasmine_util';
2020
import {tensor2d} from '../../ops/ops';
2121
import {expectArraysClose, expectArraysEqual} from '../../test_util';
22+
import {decodeString, encodeString} from '../../util';
2223

2324
import {MathBackendCPU} from './backend_cpu';
2425
import {CPU_ENVS} from './backend_cpu_test_registry';
2526

27+
/** Private test util for encoding array of strings in utf-8. */
28+
function encodeStrings(a: string[]): Uint8Array[] {
29+
return a.map(s => encodeString(s));
30+
}
31+
32+
/** Private test util for decoding array of strings in utf-8. */
33+
function decodeStrings(bytes: Uint8Array[]): string[] {
34+
return bytes.map(b => decodeString(b));
35+
}
36+
2637
describeWithFlags('backendCPU', CPU_ENVS, () => {
2738
let backend: MathBackendCPU;
2839
beforeEach(() => {
@@ -36,19 +47,25 @@ describeWithFlags('backendCPU', CPU_ENVS, () => {
3647

3748
it('register empty string tensor and write', () => {
3849
const t = tf.Tensor.make([3], {}, 'string');
39-
backend.write(t.dataId, ['c', 'a', 'b']);
40-
expectArraysEqual(backend.readSync(t.dataId), ['c', 'a', 'b']);
50+
backend.write(t.dataId, encodeStrings(['c', 'a', 'b']));
51+
expectArraysEqual(
52+
decodeStrings(backend.readSync(t.dataId) as Uint8Array[]),
53+
['c', 'a', 'b']);
4154
});
4255

4356
it('register string tensor with values', () => {
4457
const t = tf.Tensor.make([3], {values: ['a', 'b', 'c']}, 'string');
45-
expectArraysEqual(backend.readSync(t.dataId), ['a', 'b', 'c']);
58+
expectArraysEqual(
59+
decodeStrings(backend.readSync(t.dataId) as Uint8Array[]),
60+
['a', 'b', 'c']);
4661
});
4762

4863
it('register string tensor with values and overwrite', () => {
4964
const t = tf.Tensor.make([3], {values: ['a', 'b', 'c']}, 'string');
50-
backend.write(t.dataId, ['c', 'a', 'b']);
51-
expectArraysEqual(backend.readSync(t.dataId), ['c', 'a', 'b']);
65+
backend.write(t.dataId, encodeStrings(['c', 'a', 'b']));
66+
expectArraysEqual(
67+
decodeStrings(backend.readSync(t.dataId) as Uint8Array[]),
68+
['c', 'a', 'b']);
5269
});
5370

5471
it('register string tensor with values and mismatched shape', () => {
@@ -129,7 +146,7 @@ describeWithFlags('memory cpu', CPU_ENVS, () => {
129146
const mem = tf.memory();
130147
expect(mem.numTensors).toBe(2);
131148
expect(mem.numDataBuffers).toBe(2);
132-
expect(mem.numBytes).toBe(6);
149+
expect(mem.numBytes).toBe(5);
133150
expect(mem.unreliable).toBe(true);
134151

135152
const expectedReasonGC =

src/backends/webgl/backend_webgl.ts

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ import {computeFlatOffset, getStridedSlicedInfo, isSliceContinous} from '../../o
3737
import {softmax} from '../../ops/softmax';
3838
import {range, scalar, tensor} from '../../ops/tensor_ops';
3939
import {DataId, Scalar, Tensor, Tensor1D, Tensor2D, Tensor3D, Tensor4D, Tensor5D} from '../../tensor';
40-
import {DataType, DataTypeMap, DataValues, NumericDataType, PixelData, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../../types';
40+
import {BackendValues, DataType, DataTypeMap, NumericDataType, PixelData, Rank, RecursiveArray, ShapeMap, sumOutType, TypedArray, upcastType} from '../../types';
4141
import * as util from '../../util';
4242
import {getArrayFromDType, getTypedArrayFromDType, inferDtype, sizeFromShape} from '../../util';
4343
import {DataStorage, EPSILON_FLOAT16, EPSILON_FLOAT32, KernelBackend} from '../backend';
@@ -339,7 +339,7 @@ export class MathBackendWebGL implements KernelBackend {
339339
return {dataId, shape, dtype};
340340
}
341341

342-
write(dataId: DataId, values: DataValues): void {
342+
write(dataId: DataId, values: BackendValues): void {
343343
if (values == null) {
344344
throw new Error('MathBackendWebGL.write(): values can not be null');
345345
}
@@ -366,7 +366,7 @@ export class MathBackendWebGL implements KernelBackend {
366366
texData.values = values;
367367
}
368368

369-
readSync(dataId: DataId): DataValues {
369+
readSync(dataId: DataId): BackendValues {
370370
const texData = this.texData.get(dataId);
371371
const {values, dtype, complexTensors, slice, shape} = texData;
372372
if (slice != null) {
@@ -403,7 +403,7 @@ export class MathBackendWebGL implements KernelBackend {
403403
return this.convertAndCacheOnCPU(dataId, result);
404404
}
405405

406-
async read(dataId: DataId): Promise<DataValues> {
406+
async read(dataId: DataId): Promise<BackendValues> {
407407
if (this.pendingRead.has(dataId)) {
408408
const subscribers = this.pendingRead.get(dataId);
409409
return new Promise<TypedArray>(resolve => subscribers.push(resolve));
@@ -961,7 +961,9 @@ export class MathBackendWebGL implements KernelBackend {
961961

962962
tile<T extends Tensor>(x: T, reps: number[]): T {
963963
if (x.dtype === 'string') {
964-
const buf = buffer(x.shape, x.dtype, this.readSync(x.dataId) as string[]);
964+
const data = this.readSync(x.dataId) as Uint8Array[];
965+
const decodedData = data.map(d => util.decodeString(d));
966+
const buf = buffer(x.shape, x.dtype, decodedData);
965967
return tile(buf, reps) as T;
966968
}
967969
const program = new TileProgram(x.shape, reps);

src/backends/webgl/backend_webgl_test.ts

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,9 +18,19 @@
1818
import * as tf from '../../index';
1919
import {describeWithFlags} from '../../jasmine_util';
2020
import {expectArraysClose, expectArraysEqual} from '../../test_util';
21+
import {decodeString, encodeString} from '../../util';
22+
2123
import {MathBackendWebGL, WebGLMemoryInfo} from './backend_webgl';
2224
import {WEBGL_ENVS} from './backend_webgl_test_registry';
2325

26+
function encodeStrings(a: string[]): Uint8Array[] {
27+
return a.map(s => encodeString(s));
28+
}
29+
30+
function decodeStrings(bytes: Uint8Array[]): string[] {
31+
return bytes.map(b => decodeString(b));
32+
}
33+
2434
describeWithFlags('lazy packing and unpacking', WEBGL_ENVS, () => {
2535
let webglLazilyUnpackFlagSaved: boolean;
2636
let webglCpuForwardFlagSaved: boolean;
@@ -126,8 +136,10 @@ describeWithFlags('backendWebGL', WEBGL_ENVS, () => {
126136
tf.setBackend('test-storage');
127137

128138
const t = tf.Tensor.make([3], {}, 'string');
129-
backend.write(t.dataId, ['c', 'a', 'b']);
130-
expectArraysEqual(backend.readSync(t.dataId), ['c', 'a', 'b']);
139+
backend.write(t.dataId, encodeStrings(['c', 'a', 'b']));
140+
expectArraysEqual(
141+
decodeStrings(backend.readSync(t.dataId) as Uint8Array[]),
142+
['c', 'a', 'b']);
131143
});
132144

133145
it('register string tensor with values', () => {
@@ -136,7 +148,9 @@ describeWithFlags('backendWebGL', WEBGL_ENVS, () => {
136148
tf.setBackend('test-storage');
137149

138150
const t = tf.Tensor.make([3], {values: ['a', 'b', 'c']}, 'string');
139-
expectArraysEqual(backend.readSync(t.dataId), ['a', 'b', 'c']);
151+
expectArraysEqual(
152+
decodeStrings(backend.readSync(t.dataId) as Uint8Array[]),
153+
['a', 'b', 'c']);
140154
});
141155

142156
it('register string tensor with values and overwrite', () => {
@@ -145,8 +159,10 @@ describeWithFlags('backendWebGL', WEBGL_ENVS, () => {
145159
tf.setBackend('test-storage');
146160

147161
const t = tf.Tensor.make([3], {values: ['a', 'b', 'c']}, 'string');
148-
backend.write(t.dataId, ['c', 'a', 'b']);
149-
expectArraysEqual(backend.readSync(t.dataId), ['c', 'a', 'b']);
162+
backend.write(t.dataId, encodeStrings(['c', 'a', 'b']));
163+
expectArraysEqual(
164+
decodeStrings(backend.readSync(t.dataId) as Uint8Array[]),
165+
['c', 'a', 'b']);
150166
});
151167

152168
it('register string tensor with values and wrong shape throws error', () => {

src/backends/webgl/tex_util.ts

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
*/
1717

1818
import {DataId, Tensor} from '../../tensor';
19-
import {DataType, DataValues} from '../../types';
19+
import {BackendValues, DataType} from '../../types';
2020
import * as util from '../../util';
2121

2222
export enum TextureUsage {
@@ -40,7 +40,7 @@ export interface TextureData {
4040
dtype: DataType;
4141

4242
// Optional.
43-
values?: DataValues;
43+
values?: BackendValues;
4444
texture?: WebGLTexture;
4545
// For complex numbers, the real and imaginary parts are stored as their own
4646
// individual tensors, with a parent joining the two with the

src/engine.ts

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ import {backpropagateGradients, getFilteredNodesXToY, NamedGradientMap, TapeNode
2222
import {DataId, setTensorTracker, Tensor, Tensor3D, TensorTracker, Variable} from './tensor';
2323
import {GradSaveFunc, NamedTensorMap, NamedVariableMap, TensorContainer} from './tensor_types';
2424
import {getTensorsInContainer} from './tensor_util';
25-
import {DataType, DataValues, PixelData} from './types';
25+
import {BackendValues, DataType, PixelData} from './types';
2626
import * as util from './util';
2727
import {bytesFromStringArray, makeOnesTypedArray, now, sizeFromShape} from './util';
2828

@@ -830,15 +830,16 @@ export class Engine implements TensorManager, TensorTracker, DataMover {
830830
}
831831

832832
// Forwarding to backend.
833-
write(destBackend: KernelBackend, dataId: DataId, values: DataValues): void {
833+
write(destBackend: KernelBackend, dataId: DataId, values: BackendValues):
834+
void {
834835
const info = this.state.tensorInfo.get(dataId);
835836

836837
const srcBackend = info.backend;
837838
destBackend = destBackend || this.backend;
838839

839840
// Bytes for string tensors are counted when writing.
840841
if (info.dtype === 'string') {
841-
const newBytes = bytesFromStringArray(values as string[]);
842+
const newBytes = bytesFromStringArray(values as Uint8Array[]);
842843
this.state.numBytes += newBytes - info.bytes;
843844
info.bytes = newBytes;
844845
}
@@ -852,12 +853,12 @@ export class Engine implements TensorManager, TensorTracker, DataMover {
852853
}
853854
destBackend.write(dataId, values);
854855
}
855-
readSync(dataId: DataId): DataValues {
856+
readSync(dataId: DataId): BackendValues {
856857
// Route the read to the correct backend.
857858
const info = this.state.tensorInfo.get(dataId);
858859
return info.backend.readSync(dataId);
859860
}
860-
read(dataId: DataId): Promise<DataValues> {
861+
read(dataId: DataId): Promise<BackendValues> {
861862
// Route the read to the correct backend.
862863
const info = this.state.tensorInfo.get(dataId);
863864
return info.backend.read(dataId);

src/engine_test.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -349,7 +349,7 @@ describeWithFlags('memory', ALL_ENVS, () => {
349349
const a = tf.tensor([['a', 'bb'], ['c', 'd']]);
350350

351351
expect(tf.memory().numTensors).toBe(1);
352-
expect(tf.memory().numBytes).toBe(10); // 5 letters, each 2 bytes.
352+
expect(tf.memory().numBytes).toBe(5); // 5 letters, each 1 byte in utf8.
353353

354354
a.dispose();
355355

src/io/io.ts

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,7 @@ import {browserHTTPRequest, http, isHTTPScheme} from './http';
2525
import {concatenateArrayBuffers, decodeWeights, encodeWeights, getModelArtifactsInfoForJSON} from './io_utils';
2626
import {fromMemory, withSaveHandler} from './passthrough';
2727
import {getLoadHandlers, getSaveHandlers, registerLoadRouter, registerSaveRouter} from './router_registry';
28-
import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelJSON, ModelStoreManager, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, StringWeightsManifestEntry, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types';
28+
import {IOHandler, LoadHandler, LoadOptions, ModelArtifacts, ModelJSON, ModelStoreManager, OnProgressCallback, SaveConfig, SaveHandler, SaveResult, WeightGroup, WeightsManifestConfig, WeightsManifestEntry} from './types';
2929
import {loadWeights, weightsLoaderFactory} from './weights_loader';
3030

3131
export {copyModel, listModels, moveModel, removeModel} from './model_management';
@@ -54,7 +54,6 @@ export {
5454
SaveConfig,
5555
SaveHandler,
5656
SaveResult,
57-
StringWeightsManifestEntry,
5857
WeightGroup,
5958
weightsLoaderFactory,
6059
WeightsManifestConfig,

0 commit comments

Comments
 (0)