@@ -5460,7 +5460,15 @@ python.Execution = class {
5460
5460
const kind = lhs.kindOf(name);
5461
5461
switch (kind) {
5462
5462
case 'i':
5463
- case 's': {
5463
+ case 'f':
5464
+ case 's':
5465
+ case 't': {
5466
+ if (lhs[kind](name) !== rhs[kind](name)) {
5467
+ return false;
5468
+ }
5469
+ break;
5470
+ }
5471
+ case 'ival': {
5464
5472
if (lhs[kind](name) !== rhs[kind](name)) {
5465
5473
return false;
5466
5474
}
@@ -5473,20 +5481,31 @@ python.Execution = class {
5473
5481
}
5474
5482
return true;
5475
5483
});
5484
+ this.registerFunction('torch._C.get_hash', (...args) => {
5485
+ let hash = 0;
5486
+ for (const value of args) {
5487
+ if (typeof value === 'number') {
5488
+ hash += (value | 0);
5489
+ } else if (typeof value === 'string') {
5490
+ hash += (value.length | 0);
5491
+ } else if (Array.isArray(value)) {
5492
+ for (const item of value) {
5493
+ hash += torch._C.get_hash(item);
5494
+ }
5495
+ }
5496
+ }
5497
+ return hash;
5498
+ });
5476
5499
this.registerFunction('torch._C.HashNode', (k) => {
5477
5500
torch._C.AT_ASSERT(k !== null);
5478
5501
let constant_hash = 0;
5479
5502
if (k.kind() === 'prim::Constant') {
5480
5503
const type = k.output().type();
5481
- if (type.isSubtypeOf(torch.NumberType.get()) &&
5482
- k.kindOf('value') === 'i') {
5504
+ if (type.isSubtypeOf(torch.NumberType.get()) && k.kindOf('value') === 'i') {
5483
5505
constant_hash = k.i('value');
5484
- } else if (type.isSubtypeOf(torch.NumberType.get()) &&
5485
- k.kindOf('value') === 'f') {
5506
+ } else if (type.isSubtypeOf(torch.NumberType.get()) && k.kindOf('value') === 'f') {
5486
5507
constant_hash = k.f('value');
5487
- } else if (
5488
- type.isSubtypeOf(torch.NumberType.get()) &&
5489
- k.kindOf('value') === 'c') {
5508
+ } else if (type.isSubtypeOf(torch.NumberType.get()) && k.kindOf('value') === 'c') {
5490
5509
constant_hash = k.c('value');
5491
5510
} else if (type.isSubtypeOf(torch.BoolType.get())) {
5492
5511
constant_hash = k.i('value');
@@ -5874,6 +5893,18 @@ python.Execution = class {
5874
5893
}
5875
5894
return false;
5876
5895
}
5896
+ safeToChangeAliasingRelationship(a, b) {
5897
+ if (torch._C.hasWriters(a) || torch._C.hasWriters(b)) {
5898
+ return false;
5899
+ }
5900
+ return !(torch._C.escapesScope(a) && torch._C.escapesScope(b));
5901
+ }
5902
+ });
5903
+ this.registerFunction('torch._C.hasWriters', () => {
5904
+
5905
+ });
5906
+ this.registerFunction('torch._C.escapesScope', () => {
5907
+
5877
5908
});
5878
5909
this.registerFunction('torch._C.TORCH_INTERNAL_ASSERT', (cond) => {
5879
5910
if (!cond) {
@@ -5910,12 +5941,11 @@ python.Execution = class {
5910
5941
if (args.length === 1 && args[0] instanceof torch.Graph) {
5911
5942
const [graph] = args;
5912
5943
const aliasDb = new torch._C.AliasDb(graph);
5913
- const constants = new Set ();
5944
+ const constants = new torch._C.NodeSet ();
5914
5945
torch._C.ConstantPooling(graph.block(), constants, aliasDb);
5915
5946
} else if (args.length === 3 && args[0] instanceof torch.Block) {
5916
5947
const [block, constants, aliasDb] = args;
5917
5948
for (const node of block.nodes()) {
5918
- // const it = node.next;
5919
5949
if (node.blocks().length > 0) {
5920
5950
for (const block of node.blocks()) {
5921
5951
torch._C.ConstantPooling(block, constants, aliasDb);
@@ -5937,7 +5967,7 @@ python.Execution = class {
5937
5967
node.destroy();
5938
5968
continue;
5939
5969
} else {
5940
- constants.add (node);
5970
+ constants.insert (node);
5941
5971
}
5942
5972
const [first_node] = node.owningGraph().block().nodes();
5943
5973
if (node !== first_node) {
@@ -5948,13 +5978,30 @@ python.Execution = class {
5948
5978
throw new python.Error('Not implemented.');
5949
5979
}
5950
5980
});
5981
+ this.registerFunction('torch._C.handleBlock', () =>{
5982
+ //
5983
+ });
5984
+ this.registerFunction('torch._C.autocastEnabled', () => {
5985
+ return true;
5986
+ });
5987
+ this.registerFunction('torch._C.Autocast', (graph) => {
5988
+ if (torch._C.autocastEnabled()) {
5989
+ const init = null;
5990
+ /* AutocastContext init = {
5991
+ at::autocast::is_autocast_enabled(at::kCUDA),
5992
+ at::autocast::is_autocast_enabled(at::kCPU),
5993
+ at::autocast::get_autocast_dtype(at::kCUDA),
5994
+ at::autocast::get_autocast_dtype(at::kCPU)}; */
5995
+ torch._C.handleBlock(graph.block(), init);
5996
+ }
5997
+ });
5951
5998
this.registerFunction('torch._C.preoptimizeGraph', (graph, disable_autocast) => {
5952
5999
disable_autocast = disable_autocast || false;
5953
6000
torch._C.Inline(graph);
5954
6001
// torch._C.PeepholeOptimize(graph, true);
5955
6002
torch._C.ConstantPropagationImmutableTypes(graph);
5956
6003
if (!disable_autocast) {
5957
- // torch._C.Autocast(graph);
6004
+ torch._C.Autocast(graph);
5958
6005
}
5959
6006
torch._C.ConstantPooling(graph);
5960
6007
});
@@ -8355,7 +8402,8 @@ python.Execution = class {
8355
8402
} else if (rhs instanceof torch.UnionType) {
8356
8403
throw new python.Error('Not implemented.');
8357
8404
}
8358
- return super.isSubtypeOf(rhs);
8405
+ // return super.isSubtypeOf(rhs);
8406
+ return torch.Type.prototype.isSubtypeOf.call(this, rhs);
8359
8407
}
8360
8408
containedTypes() {
8361
8409
return [this._contained];
@@ -10135,9 +10183,7 @@ python.Execution = class {
10135
10183
return this._block.addInput(name);
10136
10184
}
10137
10185
insertNode(node) {
10138
- if (!this._insert_before.inBlockList()) {
10139
- throw new python.Error('Invalid insert point.');
10140
- }
10186
+ torch._C.AT_ASSERT(this._insert_before.inBlockList());
10141
10187
return node.insertBefore(this._insert_before);
10142
10188
}
10143
10189
insertConstant(val, loc, scope) {
@@ -10707,6 +10753,13 @@ python.Execution = class {
10707
10753
}
10708
10754
this._graph.freeNode(this);
10709
10755
}
10756
+ replaceAllUsesWith(n) {
10757
+ torch._C.AT_ASSERT(this.outputs().length === n.outputs().length);
10758
+ const nOutputs = this.outputs().length;
10759
+ for (let i = 0; i < nOutputs; i++) {
10760
+ this.outputs()[i].replaceAllUsesWith(n.outputs()[i]);
10761
+ }
10762
+ }
10710
10763
s_(name, value) {
10711
10764
this._values.set(name, [value, 's']);
10712
10765
return this;
@@ -10793,6 +10846,17 @@ python.Execution = class {
10793
10846
}
10794
10847
out.write(']');
10795
10848
}
10849
+ printTypeList(out, items) {
10850
+ out.write('[');
10851
+ for (let i = 0; i < items.length; i++) {
10852
+ const item = items[i];
10853
+ if (i++ > 0) {
10854
+ out.write(', ');
10855
+ }
10856
+ out.write(item.str());
10857
+ }
10858
+ out.write(']');
10859
+ }
10796
10860
printAttrValue(out, name) {
10797
10861
const kind = this.kindOf(name);
10798
10862
switch (kind) {
@@ -10809,6 +10873,7 @@ python.Execution = class {
10809
10873
case 'ts': out.write('[<Tensors>]'); break;
10810
10874
case 'g': out.write('[<Graph>]'); break;
10811
10875
case 'gs': out.write('[<Graphs>]'); break;
10876
+ case 'tys': this.printTypeList(out, this.tys(name)); break;
10812
10877
default: throw new python.Error(`Unknown attribute kind '${kind}'.`);
10813
10878
}
10814
10879
}
@@ -11020,6 +11085,9 @@ python.Execution = class {
11020
11085
throw new python.Error('Unsupported type.');
11021
11086
}
11022
11087
}
11088
+ isNone() {
11089
+ return this.tag === 'None';
11090
+ }
11023
11091
isBool() {
11024
11092
return this.tag === 'Bool';
11025
11093
}
@@ -11055,6 +11123,29 @@ python.Execution = class {
11055
11123
}
11056
11124
throw new python.Error('Expected int.');
11057
11125
}
11126
+ isString() {
11127
+ return this.tag === 'String';
11128
+ }
11129
+ toString() {
11130
+ return this.value;
11131
+ }
11132
+ equals(rhs) {
11133
+ const lhs = this;
11134
+ switch (lhs.tag) {
11135
+ case 'None': return rhs.isNone();
11136
+ case 'Bool': return rhs.isBool() && lhs.toBool() === rhs.toBool();
11137
+ case 'Int': return rhs.isInt() && lhs.toInt() === rhs.toInt();
11138
+ case 'Double': return rhs.isDouble() && lhs.toDouble() === rhs.toDouble();
11139
+ case 'String': return rhs.isString() && lhs.toString() === rhs.toString();
11140
+ case 'Tensor': return rhs.isTensor() && lhs.toTensor() === rhs.toTensor();
11141
+ case 'Object': return rhs.isObject() && lhs.toObject() === rhs.toObject();
11142
+ default: throw new python.Error(`IValue.equals() not implemented for '${lhs.tag}.`);
11143
+ }
11144
+ }
11145
+ is(rhs) {
11146
+ const lhs = this;
11147
+ return lhs.equals(rhs);
11148
+ }
11058
11149
});
11059
11150
this.registerFunction('torch._C.indent', (out, level) => {
11060
11151
for (let i = 0; i < level; i++) {
@@ -12945,7 +13036,7 @@ python.Execution = class {
12945
13036
const node = v.node();
12946
13037
const type = v.type();
12947
13038
if (type.isSubtypeOf(torch.TensorType.get())) {
12948
- return node.t('value');
13039
+ return new torch._C.IValue( node.t('value'), 'Tensor ');
12949
13040
} else if (type.isSubtypeOf(torch.BoolType.get())) {
12950
13041
return new torch._C.IValue(Boolean(node.i('value'), 'Bool'));
12951
13042
} else if (type.isSubtypeOf(torch.NumberType.get()) && node.kindOf('value') === 'i') {
@@ -12988,7 +13079,7 @@ python.Execution = class {
12988
13079
return enum_val;
12989
13080
} else if (type instanceof torch.ClassType && !type.is_module()) {
12990
13081
const class_val = node.ival('value');
12991
- return class_val;
13082
+ return new torch._C.IValue( class_val, 'Object') ;
12992
13083
}
12993
13084
throw new python.Error('Unsupported constant literal.');
12994
13085
});
@@ -14103,6 +14194,9 @@ python.Execution = class {
14103
14194
} else {
14104
14195
throw new python.Error(`Unrecognized statement kind '${stmt.__class__.__name__}'.`);
14105
14196
}
14197
+ if (this.exit_blocks.has(this.environment_stack.block())) {
14198
+ return;
14199
+ }
14106
14200
}
14107
14201
}
14108
14202
emitWith(stmt) {
@@ -16025,11 +16119,34 @@ python.Execution = class {
16025
16119
torch._C.inlineConsecutiveIfs(graph.block());
16026
16120
torch._C.convertWithBlocksToEnterExitNodes(graph);
16027
16121
});
16028
- this.registerFunction('torch._C.normalizeRSub', (/* iter */) => {
16122
+ this.registerFunction('torch._C.normalizeRSub', (iter) => {
16123
+ if (iter.kind() === 'aten::rsub' && iter.schema() && iter.schema().overload === 'Tensor') {
16124
+ const args = iter.inputs();
16125
+ const newSub = iter.replaceWithNewSymbol('aten::sub');
16126
+ newSub.replaceInput(0, args[1]);
16127
+ newSub.replaceInput(1, args[0]);
16128
+ iter.destroyCurrent();
16129
+ return true;
16130
+ }
16131
+ return false;
16029
16132
});
16030
16133
this.registerFunction('torch._C.normalizeOpAliases', (/* iter */) => {
16031
16134
});
16032
- this.registerFunction('torch._C.normalizeIsBool', (/* iter */) => {
16135
+ this.registerFunction('torch._C.normalizeIsBool', (iter) => {
16136
+ const args = iter.inputs();
16137
+ if (args.length === 2 && args[0].type() === torch.BoolType.get() && args[1].type() === torch.BoolType.get()) {
16138
+ if (iter.kind() === 'aten::__is__') {
16139
+ iter.replaceWithNewSymbol('aten::eq');
16140
+ iter.destroyCurrent();
16141
+ return true;
16142
+ }
16143
+ if (iter.kind() === 'aten::__isnot__') {
16144
+ iter.replaceWithNewSymbol('aten::ne');
16145
+ iter.destroyCurrent();
16146
+ return true;
16147
+ }
16148
+ }
16149
+ return false;
16033
16150
});
16034
16151
this.registerFunction('torch._C.NormalizeOps', (block) => {
16035
16152
for (const it of block.nodes()) {
0 commit comments