Skip to content

Commit 1b16826

Browse files
committed
Update pytorch.js (#842)
1 parent a021521 commit 1b16826

File tree

1 file changed

+137
-20
lines changed

1 file changed

+137
-20
lines changed

source/python.js

+137-20
Original file line numberDiff line numberDiff line change
@@ -5460,7 +5460,15 @@ python.Execution = class {
54605460
const kind = lhs.kindOf(name);
54615461
switch (kind) {
54625462
case 'i':
5463-
case 's': {
5463+
case 'f':
5464+
case 's':
5465+
case 't': {
5466+
if (lhs[kind](name) !== rhs[kind](name)) {
5467+
return false;
5468+
}
5469+
break;
5470+
}
5471+
case 'ival': {
54645472
if (lhs[kind](name) !== rhs[kind](name)) {
54655473
return false;
54665474
}
@@ -5473,20 +5481,31 @@ python.Execution = class {
54735481
}
54745482
return true;
54755483
});
5484+
this.registerFunction('torch._C.get_hash', (...args) => {
5485+
let hash = 0;
5486+
for (const value of args) {
5487+
if (typeof value === 'number') {
5488+
hash += (value | 0);
5489+
} else if (typeof value === 'string') {
5490+
hash += (value.length | 0);
5491+
} else if (Array.isArray(value)) {
5492+
for (const item of value) {
5493+
hash += torch._C.get_hash(item);
5494+
}
5495+
}
5496+
}
5497+
return hash;
5498+
});
54765499
this.registerFunction('torch._C.HashNode', (k) => {
54775500
torch._C.AT_ASSERT(k !== null);
54785501
let constant_hash = 0;
54795502
if (k.kind() === 'prim::Constant') {
54805503
const type = k.output().type();
5481-
if (type.isSubtypeOf(torch.NumberType.get()) &&
5482-
k.kindOf('value') === 'i') {
5504+
if (type.isSubtypeOf(torch.NumberType.get()) && k.kindOf('value') === 'i') {
54835505
constant_hash = k.i('value');
5484-
} else if (type.isSubtypeOf(torch.NumberType.get()) &&
5485-
k.kindOf('value') === 'f') {
5506+
} else if (type.isSubtypeOf(torch.NumberType.get()) && k.kindOf('value') === 'f') {
54865507
constant_hash = k.f('value');
5487-
} else if (
5488-
type.isSubtypeOf(torch.NumberType.get()) &&
5489-
k.kindOf('value') === 'c') {
5508+
} else if (type.isSubtypeOf(torch.NumberType.get()) && k.kindOf('value') === 'c') {
54905509
constant_hash = k.c('value');
54915510
} else if (type.isSubtypeOf(torch.BoolType.get())) {
54925511
constant_hash = k.i('value');
@@ -5874,6 +5893,18 @@ python.Execution = class {
58745893
}
58755894
return false;
58765895
}
5896+
safeToChangeAliasingRelationship(a, b) {
5897+
if (torch._C.hasWriters(a) || torch._C.hasWriters(b)) {
5898+
return false;
5899+
}
5900+
return !(torch._C.escapesScope(a) && torch._C.escapesScope(b));
5901+
}
5902+
});
5903+
this.registerFunction('torch._C.hasWriters', () => {
5904+
5905+
});
5906+
this.registerFunction('torch._C.escapesScope', () => {
5907+
58775908
});
58785909
this.registerFunction('torch._C.TORCH_INTERNAL_ASSERT', (cond) => {
58795910
if (!cond) {
@@ -5910,12 +5941,11 @@ python.Execution = class {
59105941
if (args.length === 1 && args[0] instanceof torch.Graph) {
59115942
const [graph] = args;
59125943
const aliasDb = new torch._C.AliasDb(graph);
5913-
const constants = new Set();
5944+
const constants = new torch._C.NodeSet();
59145945
torch._C.ConstantPooling(graph.block(), constants, aliasDb);
59155946
} else if (args.length === 3 && args[0] instanceof torch.Block) {
59165947
const [block, constants, aliasDb] = args;
59175948
for (const node of block.nodes()) {
5918-
// const it = node.next;
59195949
if (node.blocks().length > 0) {
59205950
for (const block of node.blocks()) {
59215951
torch._C.ConstantPooling(block, constants, aliasDb);
@@ -5937,7 +5967,7 @@ python.Execution = class {
59375967
node.destroy();
59385968
continue;
59395969
} else {
5940-
constants.add(node);
5970+
constants.insert(node);
59415971
}
59425972
const [first_node] = node.owningGraph().block().nodes();
59435973
if (node !== first_node) {
@@ -5948,13 +5978,30 @@ python.Execution = class {
59485978
throw new python.Error('Not implemented.');
59495979
}
59505980
});
5981+
this.registerFunction('torch._C.handleBlock', () =>{
5982+
//
5983+
});
5984+
this.registerFunction('torch._C.autocastEnabled', () => {
5985+
return true;
5986+
});
5987+
this.registerFunction('torch._C.Autocast', (graph) => {
5988+
if (torch._C.autocastEnabled()) {
5989+
const init = null;
5990+
/* AutocastContext init = {
5991+
at::autocast::is_autocast_enabled(at::kCUDA),
5992+
at::autocast::is_autocast_enabled(at::kCPU),
5993+
at::autocast::get_autocast_dtype(at::kCUDA),
5994+
at::autocast::get_autocast_dtype(at::kCPU)}; */
5995+
torch._C.handleBlock(graph.block(), init);
5996+
}
5997+
});
59515998
this.registerFunction('torch._C.preoptimizeGraph', (graph, disable_autocast) => {
59525999
disable_autocast = disable_autocast || false;
59536000
torch._C.Inline(graph);
59546001
// torch._C.PeepholeOptimize(graph, true);
59556002
torch._C.ConstantPropagationImmutableTypes(graph);
59566003
if (!disable_autocast) {
5957-
// torch._C.Autocast(graph);
6004+
torch._C.Autocast(graph);
59586005
}
59596006
torch._C.ConstantPooling(graph);
59606007
});
@@ -8355,7 +8402,8 @@ python.Execution = class {
83558402
} else if (rhs instanceof torch.UnionType) {
83568403
throw new python.Error('Not implemented.');
83578404
}
8358-
return super.isSubtypeOf(rhs);
8405+
// return super.isSubtypeOf(rhs);
8406+
return torch.Type.prototype.isSubtypeOf.call(this, rhs);
83598407
}
83608408
containedTypes() {
83618409
return [this._contained];
@@ -10135,9 +10183,7 @@ python.Execution = class {
1013510183
return this._block.addInput(name);
1013610184
}
1013710185
insertNode(node) {
10138-
if (!this._insert_before.inBlockList()) {
10139-
throw new python.Error('Invalid insert point.');
10140-
}
10186+
torch._C.AT_ASSERT(this._insert_before.inBlockList());
1014110187
return node.insertBefore(this._insert_before);
1014210188
}
1014310189
insertConstant(val, loc, scope) {
@@ -10707,6 +10753,13 @@ python.Execution = class {
1070710753
}
1070810754
this._graph.freeNode(this);
1070910755
}
10756+
replaceAllUsesWith(n) {
10757+
torch._C.AT_ASSERT(this.outputs().length === n.outputs().length);
10758+
const nOutputs = this.outputs().length;
10759+
for (let i = 0; i < nOutputs; i++) {
10760+
this.outputs()[i].replaceAllUsesWith(n.outputs()[i]);
10761+
}
10762+
}
1071010763
s_(name, value) {
1071110764
this._values.set(name, [value, 's']);
1071210765
return this;
@@ -10793,6 +10846,17 @@ python.Execution = class {
1079310846
}
1079410847
out.write(']');
1079510848
}
10849+
printTypeList(out, items) {
10850+
out.write('[');
10851+
for (let i = 0; i < items.length; i++) {
10852+
const item = items[i];
10853+
if (i++ > 0) {
10854+
out.write(', ');
10855+
}
10856+
out.write(item.str());
10857+
}
10858+
out.write(']');
10859+
}
1079610860
printAttrValue(out, name) {
1079710861
const kind = this.kindOf(name);
1079810862
switch (kind) {
@@ -10809,6 +10873,7 @@ python.Execution = class {
1080910873
case 'ts': out.write('[<Tensors>]'); break;
1081010874
case 'g': out.write('[<Graph>]'); break;
1081110875
case 'gs': out.write('[<Graphs>]'); break;
10876+
case 'tys': this.printTypeList(out, this.tys(name)); break;
1081210877
default: throw new python.Error(`Unknown attribute kind '${kind}'.`);
1081310878
}
1081410879
}
@@ -11020,6 +11085,9 @@ python.Execution = class {
1102011085
throw new python.Error('Unsupported type.');
1102111086
}
1102211087
}
11088+
isNone() {
11089+
return this.tag === 'None';
11090+
}
1102311091
isBool() {
1102411092
return this.tag === 'Bool';
1102511093
}
@@ -11055,6 +11123,29 @@ python.Execution = class {
1105511123
}
1105611124
throw new python.Error('Expected int.');
1105711125
}
11126+
isString() {
11127+
return this.tag === 'String';
11128+
}
11129+
toString() {
11130+
return this.value;
11131+
}
11132+
equals(rhs) {
11133+
const lhs = this;
11134+
switch (lhs.tag) {
11135+
case 'None': return rhs.isNone();
11136+
case 'Bool': return rhs.isBool() && lhs.toBool() === rhs.toBool();
11137+
case 'Int': return rhs.isInt() && lhs.toInt() === rhs.toInt();
11138+
case 'Double': return rhs.isDouble() && lhs.toDouble() === rhs.toDouble();
11139+
case 'String': return rhs.isString() && lhs.toString() === rhs.toString();
11140+
case 'Tensor': return rhs.isTensor() && lhs.toTensor() === rhs.toTensor();
11141+
case 'Object': return rhs.isObject() && lhs.toObject() === rhs.toObject();
11142+
default: throw new python.Error(`IValue.equals() not implemented for '${lhs.tag}.`);
11143+
}
11144+
}
11145+
is(rhs) {
11146+
const lhs = this;
11147+
return lhs.equals(rhs);
11148+
}
1105811149
});
1105911150
this.registerFunction('torch._C.indent', (out, level) => {
1106011151
for (let i = 0; i < level; i++) {
@@ -12945,7 +13036,7 @@ python.Execution = class {
1294513036
const node = v.node();
1294613037
const type = v.type();
1294713038
if (type.isSubtypeOf(torch.TensorType.get())) {
12948-
return node.t('value');
13039+
return new torch._C.IValue(node.t('value'), 'Tensor');
1294913040
} else if (type.isSubtypeOf(torch.BoolType.get())) {
1295013041
return new torch._C.IValue(Boolean(node.i('value'), 'Bool'));
1295113042
} else if (type.isSubtypeOf(torch.NumberType.get()) && node.kindOf('value') === 'i') {
@@ -12988,7 +13079,7 @@ python.Execution = class {
1298813079
return enum_val;
1298913080
} else if (type instanceof torch.ClassType && !type.is_module()) {
1299013081
const class_val = node.ival('value');
12991-
return class_val;
13082+
return new torch._C.IValue(class_val, 'Object');
1299213083
}
1299313084
throw new python.Error('Unsupported constant literal.');
1299413085
});
@@ -14103,6 +14194,9 @@ python.Execution = class {
1410314194
} else {
1410414195
throw new python.Error(`Unrecognized statement kind '${stmt.__class__.__name__}'.`);
1410514196
}
14197+
if (this.exit_blocks.has(this.environment_stack.block())) {
14198+
return;
14199+
}
1410614200
}
1410714201
}
1410814202
emitWith(stmt) {
@@ -16025,11 +16119,34 @@ python.Execution = class {
1602516119
torch._C.inlineConsecutiveIfs(graph.block());
1602616120
torch._C.convertWithBlocksToEnterExitNodes(graph);
1602716121
});
16028-
this.registerFunction('torch._C.normalizeRSub', (/* iter */) => {
16122+
this.registerFunction('torch._C.normalizeRSub', (iter) => {
16123+
if (iter.kind() === 'aten::rsub' && iter.schema() && iter.schema().overload === 'Tensor') {
16124+
const args = iter.inputs();
16125+
const newSub = iter.replaceWithNewSymbol('aten::sub');
16126+
newSub.replaceInput(0, args[1]);
16127+
newSub.replaceInput(1, args[0]);
16128+
iter.destroyCurrent();
16129+
return true;
16130+
}
16131+
return false;
1602916132
});
1603016133
this.registerFunction('torch._C.normalizeOpAliases', (/* iter */) => {
1603116134
});
16032-
this.registerFunction('torch._C.normalizeIsBool', (/* iter */) => {
16135+
this.registerFunction('torch._C.normalizeIsBool', (iter) => {
16136+
const args = iter.inputs();
16137+
if (args.length === 2 && args[0].type() === torch.BoolType.get() && args[1].type() === torch.BoolType.get()) {
16138+
if (iter.kind() === 'aten::__is__') {
16139+
iter.replaceWithNewSymbol('aten::eq');
16140+
iter.destroyCurrent();
16141+
return true;
16142+
}
16143+
if (iter.kind() === 'aten::__isnot__') {
16144+
iter.replaceWithNewSymbol('aten::ne');
16145+
iter.destroyCurrent();
16146+
return true;
16147+
}
16148+
}
16149+
return false;
1603316150
});
1603416151
this.registerFunction('torch._C.NormalizeOps', (block) => {
1603516152
for (const it of block.nodes()) {

0 commit comments

Comments
 (0)