diff --git a/source/pytorch.js b/source/pytorch.js index 11d713dd2f..8bfe8d2587 100644 --- a/source/pytorch.js +++ b/source/pytorch.js @@ -3532,13 +3532,18 @@ pytorch.Utility = class { if (value && value.__class__ && value.__class__.__module__ === 'datetime' && value.__class__.__name__ === 'datetime') { continue; } + if (value && Number.isInteger(value.epoch) && value.state_dict) { + target[key] = value.state_dict; + continue; + } if ((key.startsWith('dico_') && Object(value) === value) || (key.startsWith('best_metrics') && Object(value) === value) || (key === 'args' && Object(value) === value) || (key.startsWith('params') && Object(value) === value && (value.id2lang || value.lang2id)) || (key.startsWith('spk_dict_') && Object(value) === value && Object.keys(value).length === 0) || (key === 'blk_det') || - (key === 'random_state')) { + (key === 'random_state') || + (key === 'train_cfg' || key === 'test_cfg' || key === '_is_full_backward_hook')) { continue; } target[key] = value; @@ -3588,22 +3593,22 @@ pytorch.Utility = class { if (Object(obj) !== obj) { return null; } - const map = new Map(Object.keys(obj).map((key) => [ key, obj[key] ])); + const map = new Map(Object.entries(obj).map((entry) => [ entry[0], entry[1] ])); if (validate(map)) { return map; } - map.clear(); - for (const key of Object.keys(obj)) { - const value = flatten(obj[key]); + const target = new Map(); + for (const entry of map) { + const value = flatten(entry[1]); if (value && value instanceof Map) { for (const pair of value) { - map.set(key + '.' + pair[0], pair[1]); + target.set(entry[0] + '.' + pair[0], pair[1]); } continue; } return null; } - return map; + return target; }; if (!obj) { return null;