-
Notifications
You must be signed in to change notification settings - Fork 8
/
tensor.js
141 lines (130 loc) · 3.24 KB
/
tensor.js
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
'use strict';
/**
* Compute the number of elements given a shape.
* @param {Array} shape
* @return {Number}
*/
export function sizeOfShape(shape) {
return shape.reduce(
(accumulator, currentValue) => accumulator * currentValue, 1);
}
/**
* Tensor: the multidimensional array.
*/
export class Tensor {
/**
* Construct a Tensor object
* @param {Array} shape
* @param {Array} [data]
*/
constructor(shape, data = undefined) {
const size = sizeOfShape(shape);
if (data !== undefined) {
if (size !== data.length) {
throw new Error(`The length of data ${data.length} is invalid, expected ${size}.`);
}
// Copy the data.
this.data = data.slice();
} else {
this.data = new Array(size).fill(0);
}
// Copy the shape.
this.shape = shape.slice();
// Calculate the strides.
this.strides = new Array(this.rank);
this.strides[this.rank - 1] = 1;
for (let i = this.rank - 2; i >= 0; --i) {
this.strides[i] = this.strides[i + 1] * this.shape[i + 1];
}
}
get rank() {
return this.shape.length;
}
get size() {
return this.data.length;
}
/**
* Get index in the flat array given the location.
* @param {Array} location
* @return {Number}
*/
indexFromLocation(location) {
if (location.length !== this.rank) {
throw new Error(`The location length ${location.length} is not equal to rank ${this.rank}.`);
}
let index = 0;
for (let i = 0; i < this.rank; ++i) {
if (location[i] >= this.shape[i]) {
throw new Error(`The location value ${location[i]} at axis ${i} is invalid.`);
}
index += this.strides[i] * location[i];
}
return index;
}
/**
* Get location from the index of the flat array.
* @param {Number} index
* @return {Array}
*/
locationFromIndex(index) {
if (index >= this.size) {
throw new Error('The index is invalid.');
}
const location = new Array(this.rank);
for (let i = 0; i < location.length; ++i) {
location[i] = Math.floor(index / this.strides[i]);
index -= location[i] * this.strides[i];
}
return location;
}
/**
* Set value given the location.
* @param {Array} location
* @param {Number} value
*/
setValueByLocation(location, value) {
this.data[this.indexFromLocation(location)] = value;
}
/**
* Get value given the location.
* @param {Array} location
* @return {Number}
*/
getValueByLocation(location) {
return this.data[this.indexFromLocation(location)];
}
/**
* Set value given the index.
* @param {Number} index
* @param {Number} value
*/
setValueByIndex(index, value) {
if (index >= this.size) {
throw new Error('The index is invalid.');
}
this.data[index] = value;
}
/**
* Get value given the index.
* @param {Number} index
* @return {Number}
*/
getValueByIndex(index) {
if (index >= this.size) {
throw new Error('The index is invalid.');
}
return this.data[index];
}
}
/**
* Scalar: a helper class to create a Tensor with a single value.
*/
export class Scalar extends Tensor {
/**
* Construct a Tensor with a single value.
* @param {Number} value
*/
constructor(value) {
super([1], [value]);
}
}