Skip to content

Commit

Permalink
Unittest/Pytest implementation for testing Add testing modules to use…
Browse files Browse the repository at this point in the history
… with pytest #14
  • Loading branch information
brianreicher committed Aug 15, 2022
1 parent 14a7412 commit adc2134
Showing 1 changed file with 35 additions and 21 deletions.
56 changes: 35 additions & 21 deletions raygun/torch/tests/network_test_torch.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
#%%
from functools import partial
from operator import mod
import unittest
import pytest
import numpy as np
import torch
from raygun.torch.networks import *
Expand Down Expand Up @@ -57,7 +57,7 @@ def batch2im(self, batch): # TODO JAX
batch = batch.detach().cpu().squeeze()
return torch.cat((torch.cat((batch[0], batch[1])), torch.cat((batch[2], batch[3]))), axis=1)

def get_data(self): # TODO JAX
def get_data(self):
if self.data is None:
ind = torch.randint(low=0, high=200, size=(1,))[0]
is_face = ind >= 100
Expand Down Expand Up @@ -112,14 +112,15 @@ def step(self, show=False, patches=None, gt=None):
self.show()
return loss.item()

#%%

class TorchTrainTest():

def __init__(self, net=ResNet(2), name='ResNet', **net_kwargs) -> None:
self.model = TorchBuild(net=net) # TODO fix **net_kwargs
self.data_src = TorchBuild()
self.name = name


def train_network(self, steps=1000, show_every=200):
self.losses = {}
name = self.name + '-loss'
Expand All @@ -134,28 +135,41 @@ def train_network(self, steps=1000, show_every=200):

# TODO fix eval_models()
def eval_models(model, name):
outs = {}
test = TorchBuild()
patches, gt, is_face = test.get_data()
outs[name] = model.eval(show=False, patches=patches, gt=gt)
# num = len(models.keys()) + 2
fig, axs = plt.subplots(1, figsize=(5, 5))
axs[0].imshow(test.batch2im(patches), cmap='gray', vmin=-1, vmax=1)
axs[0].set_title('Input')
gt = test.batch2im(gt)
axs[-1].imshow(gt, cmap='gray', vmin=-1, vmax=1)
axs[-1].set_title('Real')
for ax, name in zip(axs[1:-1]):
ax.imshow(outs[name], cmap='gray', vmin=-1, vmax=1)
mse = torch.mean((gt - outs[name])**2)
ax.set_title(f'{name}: MSE={mse}')
# outs = {}
# test = TorchBuild()
# patches, gt, is_face = test.get_data()
# outs[name] = model.eval(show=False, patches=patches, gt=gt)
# # num = len(models.keys()) + 2
# fig, axs = plt.subplots(1, figsize=(5, 5))
# axs[0].imshow(test.batch2im(patches), cmap='gray', vmin=-1, vmax=1)
# axs[0].set_title('Input')
# gt = test.batch2im(gt)
# axs[-1].imshow(gt, cmap='gray', vmin=-1, vmax=1)
# axs[-1].set_title('Real')
# for ax, name in zip(axs[1:-1]):
# ax.imshow(outs[name], cmap='gray', vmin=-1, vmax=1)
# mse = torch.mean((gt - outs[name])**2)
# ax.set_title(f'{name}: MSE={mse}')
pass


def eval_plot(self):
plt.figure(figsize=(15,10))
for name, loss in self.losses.items():
plt.plot(loss, label=name)
plt.title('Losses')
plt.title('Loss vs. Step')
plt.ylim([0,.1])
plt.legend()
# %%


class TorchUnitTest(unittest.TestCase):

def test_gpu(self):
x = TorchBuild()
net = x.net
assert next(x.net.parameters()).is_cuda == True # Fix params staying on CPU

def test_weight_size(self):
pass


0 comments on commit adc2134

Please sign in to comment.