-
Notifications
You must be signed in to change notification settings - Fork 14
/
Copy pathgraph.py
138 lines (120 loc) · 5.51 KB
/
graph.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import os, sys, math, random, itertools, heapq
from collections import namedtuple, defaultdict
from functools import partial, reduce
import numpy as np
import IPython
import torch
import torch.nn as nn
import torch.nn.functional as F
from utils import *
from models import TrainableModel, WrapperModel
from datasets import TaskDataset
from task_configs import get_task, task_map, tasks, get_model, RealityTask
from transfers import Transfer, RealityTransfer, get_transfer_name
#from modules.gan_dis import GanDisNet
import pdb
class TaskGraph(TrainableModel):
"""Basic graph that encapsulates set of edge constraints. Can be saved and loaded
from directories."""
def __init__(
self, tasks=tasks, edges=None, edges_exclude=None,
pretrained=True, finetuned=False,
reality=[], task_filter=[tasks.segment_semantic],
freeze_list=[], lazy=False, initialize_from_transfer=True,
):
super().__init__()
self.tasks = list(set(tasks) - set(task_filter))
self.tasks += [task.base for task in self.tasks if hasattr(task, "base")]
self.edge_list, self.edge_list_exclude = edges, edges_exclude
self.pretrained, self.finetuned = pretrained, finetuned
self.edges, self.adj, self.in_adj = [], defaultdict(list), defaultdict(list)
self.edge_map, self.reality = {}, reality
self.initialize_from_transfer = initialize_from_transfer
print('Creating graph with tasks:', self.tasks)
self.params = {}
# construct transfer graph
for src_task, dest_task in itertools.product(self.tasks, self.tasks):
key = (src_task, dest_task)
if edges is not None and key not in edges: continue
if edges_exclude is not None and key in edges_exclude: continue
if src_task == dest_task: continue
if isinstance(dest_task, RealityTask): continue
# print (src_task, dest_task)
transfer = None
if isinstance(src_task, RealityTask):
if dest_task not in src_task.tasks: continue
transfer = RealityTransfer(src_task, dest_task)
else:
transfer = Transfer(src_task, dest_task,
pretrained=pretrained, finetuned=finetuned
)
transfer.name = get_transfer_name(transfer)
if not self.initialize_from_transfer:
transfer.path = None
if transfer.model_type is None:
continue
# print ("Added transfer", transfer)
self.edges += [transfer]
self.adj[src_task.name] += [transfer]
self.in_adj[dest_task.name] += [transfer]
self.edge_map[str((src_task.name, dest_task.name))] = transfer
if isinstance(transfer, nn.Module):
if str((src_task.name, dest_task.name)) not in freeze_list:
self.params[str((src_task.name, dest_task.name))] = transfer
else:
print("Setting link: " + str((src_task.name, dest_task.name)) + " not trainable.")
try:
if not lazy: transfer.load_model()
except Exception as e:
print(e)
IPython.embed()
self.params = nn.ModuleDict(self.params)
def edge(self, src_task, dest_task):
key1 = str((src_task.name, dest_task.name))
key2 = str((src_task.kind, dest_task.kind))
if key1 in self.edge_map: return self.edge_map[key1]
return self.edge_map[key2]
def sample_path(self, path, reality=None, use_cache=False, cache={}):
path = [reality or self.reality[0]] + path
x = None
for i in range(1, len(path)):
try:
# if x is not None: print (x.shape)
# print (self.edge(path[i-1], path[i]))
x = cache.get(tuple(path[0:(i+1)]),
self.edge(path[i-1], path[i])(x)
)
except KeyError:
return None
except Exception as e:
print(e)
IPython.embed()
if use_cache: cache[tuple(path[0:(i+1)])] = x
return x
def save(self, weights_file=None, weights_dir=None):
### TODO: save optimizers here too
if weights_file:
torch.save({
key: model.state_dict() for key, model in self.edge_map.items() \
if not isinstance(model, RealityTransfer)
}, weights_file)
if weights_dir:
os.makedirs(weights_dir, exist_ok=True)
for key, model in self.edge_map.items():
if isinstance(model, RealityTransfer): continue
if not isinstance(model.model, TrainableModel): continue
model.model.save(f"{weights_dir}/{model.name}.pth")
torch.save(self.optimizer, f"{weights_dir}/optimizer.pth")
# def load_weights(self, weights_file=None):
# for key, state_dict in torch.load(weights_file).items():
# if key in self.edge_map:
# self.edge_map[key].load_state_dict(state_dict)
def load_weights(self, weights_file=None):
loaded_something = False
for key, state_dict in torch.load(weights_file).items():
if key in self.edge_map:
loaded_something = True
self.edge_map[key].load_model()
self.edge_map[key].load_state_dict(state_dict)
if not loaded_something:
raise RuntimeError(f"No edges loaded from file: {weights_file}")