Skip to content

Commit

Permalink
Created working torch tester - todo refactor #14
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Aug 15, 2022
1 parent 8487450 commit 92eda0f
Showing 1 changed file with 27 additions and 36 deletions.
63 changes: 27 additions & 36 deletions raygun/torch/tests/network_test_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,44 +3,37 @@
import numpy as np
import torch
from raygun.torch.networks import *
from raygun.torch.networks.ResNet import *
from skimage import data
import matplotlib.pyplot as plt
from tqdm import trange
torch.cuda.set_device(1)

# %%

class TorchTest():
def __init__(self,
net=None,
activation=None,
norm=None,
size=24,
seed=42,
noise_factor=3,
img='astronaut',
ind=31,
name='',
**network_kwargs):
name=''):

torch.manual_seed(seed)
if net is None:
if norm is None:
norm = partial(torch.nn.InstanceNorm2d, track_running_stats=True, momentum=0.01)
if activation is None:
activation = torch.nn.ReLU
self.net = torch.nn.Sequential(
ResNet(1, 1, 32, norm, n_blocks=4, activation=activation),
torch.nn.Tanh()
).to('cuda')
if net is None: # FIX PARAM MODE
self.net = ResNet(2)

else:
self.net = net(network_kwargs)
self.net = net.to('cuda')

self.size = size
self.mode = 'train'
self.ind = ind
self.name = name
self.noise_factor = noise_factor
# self.loss_fn = torch.nn.MSELoss() JAX
# self.optim = torch.optim.Adam(self.net.parameters(), lr=1e-5) JAX
self.loss_fn = torch.nn.MSELoss()
self.optim = torch.optim.Adam(self.net.parameters(), lr=1e-5)
if img is not None:
self.data = getattr(data, img)()
if len(self.data.shape) > 2:
Expand Down Expand Up @@ -117,6 +110,21 @@ def step(self, show=False, patches=None, gt=None):
self.show()
return loss.item()


def training_loop(model=ResNet(2), name='ResNet', steps=1000, show_every=200, **network_kwargs):
losses = {}
losses[name] = np.zeros((steps,))
ticker = trange(steps)
model = TorchTest(net=model)
data_src = TorchTest()
for step in ticker:
ticker_postfix = {}
patches, gt, is_face = data_src.get_data()
losses[name][step] = model.step((step % show_every)==0, patches=patches, gt=gt)
ticker_postfix[name] = losses[name][step]
ticker.set_postfix(ticker_postfix)


def eval_models(data_src, models):
outs = {}
patches, gt, is_face = data_src.get_data()
Expand All @@ -133,26 +141,9 @@ def eval_models(data_src, models):
ax.imshow(outs[name], cmap='gray', vmin=-1, vmax=1)
mse = torch.mean((gt - outs[name])**2)
ax.set_title(f'{name}: MSE={mse}')


#%%
model = TorchTest()
patches, gt, out, is_face = model.forward()
model.show()
model.step(True)

#%%
def training_loop(model=UNet, steps=100, show_every=200):
losses = np.zeros((steps,))

ticker = trange(steps)
model = TorchTest(net=model)
data_src = TorchTest()
for step in ticker:
ticker_postfix = {}
patches, gt, is_face = data_src.get_data()
losses[step] = model.step((step % show_every)==0, patches=patches, gt=gt)
ticker_postfix = losses[step]
ticker.set_postfix(ticker_postfix)
def eval_plot(data_src, model, losses):
eval_models((data_src, model))
plt.figure(figsize=(15,10))
for name, loss in losses.items():
Expand Down

0 comments on commit 92eda0f

Please sign in to comment.