Skip to content

Commit

Permalink
Debugged UNet upsampling bug and network loading #13 #14
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Aug 10, 2022
1 parent 95865d6 commit 1a8dbf9
Show file tree
Hide file tree
Showing 2 changed files with 2 additions and 3 deletions.
2 changes: 1 addition & 1 deletion raygun/jax/networks/UNet.py
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ def __call__(self, f_left, g_out):
else:
g_cropped = g_up

f_cropped = self.crop(f_left, g_cropped.size()[-self.dims:])
f_cropped = self.crop(f_left, g_cropped.shape[-self.dims:])

return jax.lax.concatenate((f_cropped, g_cropped), dimension=1)

Expand Down
3 changes: 1 addition & 2 deletions raygun/jax/tests/network_test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def forward(self, inputs):
def train_step(self, inputs, pmapped):
raise RuntimeError("Unimplemented")

#%%

class Model(GenericJaxModel):

def __init__(self):
Expand Down Expand Up @@ -144,7 +144,6 @@ def initialize(self, rng_key, inputs, is_training=True):
else:
loss_scale = jmp.NoOpLossScale()
return Params(weight, opt_state, loss_scale)
#%%

def split(arr, n_devices):
"""Splits the first axis of `arr` evenly across the number of devices."""
Expand Down

0 comments on commit 1a8dbf9

Please sign in to comment.