Skip to content

Commit

Permalink
Remove shape checks
Browse files Browse the repository at this point in the history
  • Loading branch information
mathpluscode committed Dec 31, 2023
1 parent 86e2c88 commit 277eee4
Show file tree
Hide file tree
Showing 2 changed files with 0 additions and 8 deletions.
4 changes: 0 additions & 4 deletions imgx/task/diffusion_segmentation/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -387,10 +387,6 @@ def train_init(
aug_rng = jax.random.PRNGKey(self.config["seed"])
batch = aug_fn(aug_rng, batch)

# check image size
image_shape = self.dataset_info.image_spatial_shape
chex.assert_equal(batch[IMAGE].shape[1:-1], image_shape)

# init train state on cpu first
dtype = get_half_precision_dtype(self.config.half_precision)
model = instantiate(self.config.task.model, dtype=dtype)
Expand Down
4 changes: 0 additions & 4 deletions imgx/task/segmentation/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,10 +261,6 @@ def train_init(
aug_rng = jax.random.PRNGKey(self.config["seed"])
batch = aug_fn(aug_rng, batch)

# check image size
image_shape = self.dataset_info.image_spatial_shape
chex.assert_equal(batch[IMAGE].shape[1:-1], image_shape)

# init train state on cpu first
dtype = get_half_precision_dtype(self.config.half_precision)
model = instantiate(self.config.task.model, dtype=dtype)
Expand Down

0 comments on commit 277eee4

Please sign in to comment.