Skip to content

Commit

Permalink
Update pytorch.js (#543)
Browse files Browse the repository at this point in the history
  • Loading branch information
lutzroeder committed Apr 2, 2023
1 parent bc6a323 commit 4eb4d21
Showing 1 changed file with 12 additions and 7 deletions.
19 changes: 12 additions & 7 deletions source/pytorch.js
Original file line number Diff line number Diff line change
Expand Up @@ -3532,13 +3532,18 @@ pytorch.Utility = class {
if (value && value.__class__ && value.__class__.__module__ === 'datetime' && value.__class__.__name__ === 'datetime') {
continue;
}
if (value && Number.isInteger(value.epoch) && value.state_dict) {
target[key] = value.state_dict;
continue;
}
if ((key.startsWith('dico_') && Object(value) === value) ||
(key.startsWith('best_metrics') && Object(value) === value) ||
(key === 'args' && Object(value) === value) ||
(key.startsWith('params') && Object(value) === value && (value.id2lang || value.lang2id)) ||
(key.startsWith('spk_dict_') && Object(value) === value && Object.keys(value).length === 0) ||
(key === 'blk_det') ||
(key === 'random_state')) {
(key === 'random_state') ||
(key === 'train_cfg' || key === 'test_cfg' || key === '_is_full_backward_hook')) {
continue;
}
target[key] = value;
Expand Down Expand Up @@ -3588,22 +3593,22 @@ pytorch.Utility = class {
if (Object(obj) !== obj) {
return null;
}
const map = new Map(Object.keys(obj).map((key) => [ key, obj[key] ]));
const map = new Map(Object.entries(obj).map((entry) => [ entry[0], entry[1] ]));
if (validate(map)) {
return map;
}
map.clear();
for (const key of Object.keys(obj)) {
const value = flatten(obj[key]);
const target = new Map();
for (const entry of map) {
const value = flatten(entry[1]);
if (value && value instanceof Map) {
for (const pair of value) {
map.set(key + '.' + pair[0], pair[1]);
target.set(entry[0] + '.' + pair[0], pair[1]);
}
continue;
}
return null;
}
return map;
return target;
};
if (!obj) {
return null;
Expand Down

0 comments on commit 4eb4d21

Please sign in to comment.