Skip to content

Commit

Permalink
Update jax test to point of rank issue #14 #16
Browse files Browse the repository at this point in the history
  • Loading branch information
brianreicher committed Aug 15, 2022
1 parent 1a8dbf9 commit d37b79f
Show file tree
Hide file tree
Showing 3 changed files with 3 additions and 232 deletions.
18 changes: 3 additions & 15 deletions raygun/jax/tests/network_test_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,21 +7,8 @@
import optax
import jmp
import time

# from funlib.learn.jax.models import UNet, ConvPass

from typing import Tuple, Any, NamedTuple, Dict

'''To test model with some dummy input and output, run with command
`CUDA_VISIBLE_DEVICES=0 python unet_example.py`
for single device training, or
`CUDA_VISIBLE_DEVICES=0,1,2,3 python unet_example.py`
for multi-device training
'''

# PARAMETERS
mp_training = True # mixed-precision training using `jmp`
Expand Down Expand Up @@ -153,7 +140,7 @@ def create_network():
# returns a model that Gunpowder `Predict` and `Train` node can use
return Model()

#%%

my_model = Model()

n_devices = jax.local_device_count()
Expand Down Expand Up @@ -188,8 +175,9 @@ def create_network():
#%%
# test forward
y = jit(my_model.forward)(model_params, {'raw': raw})
assert y['affs'].shape == (batch_size, 3, 40, 40, 40)

assert y['affs'].shape == (batch_size, 3, 40, 40, 40)
#%%
# test train loop
for _ in range(10):
t0 = time.time()
Expand Down
217 changes: 0 additions & 217 deletions raygun/jax/tests/network_test_tri.py

This file was deleted.

File renamed without changes.

0 comments on commit d37b79f

Please sign in to comment.