Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

dtype assertion bug in train (possible JAX-version issue) #2

Open
xmax1 opened this issue Nov 1, 2022 · 1 comment
Open

dtype assertion bug in train (possible JAX-version issue) #2

xmax1 opened this issue Nov 1, 2022 · 1 comment

Comments

@xmax1
Copy link

xmax1 commented Nov 1, 2022

Hi, thanks for the great work!

There is an assertion error when checking the dataset, which is confusing because as far as I understand it should fail for anyone.

Possibly a version issue (maybe some version of jax recognises tf types as jnp?).

AssertionError                            Traceback (most recent call last)
/home/amawi/projects/denoising-diffusion-flax/denoising_diffusion_flax/ddpm_flax_oxford102_end_to_end.ipynb Cell 5 in <cell line: 2>()
      [1](vscode-notebook-cell://ssh-remote%2Btitan08.compute.dtu.dk/home/amawi/projects/denoising-diffusion-flax/denoising_diffusion_flax/ddpm_flax_oxford102_end_to_end.ipynb#Y131sdnNjb2RlLXJlbW90ZQ%3D%3D?line=0) work_dir = './fashion_mnist'
----> [2](vscode-notebook-cell://ssh-remote%2Btitan08.compute.dtu.dk/home/amawi/projects/denoising-diffusion-flax/denoising_diffusion_flax/ddpm_flax_oxford102_end_to_end.ipynb#Y131sdnNjb2RlLXJlbW90ZQ%3D%3D?line=1) state = train.train(my_config, work_dir)

File ~/projects/denoising-diffusion-flax/denoising_diffusion_flax/train.py:436, in train(config, workdir, wandb_artifact)
    434 rng, *train_step_rng = jax.random.split(rng, num=jax.local_device_count() + 1)
    435 train_step_rng = jnp.asarray(train_step_rng)
--> 436 state, metrics = p_train_step(train_step_rng, state, batch)
    437 for h in hooks:
    438     h(step)

    [... skipping hidden 17 frame]

File ~/projects/denoising-diffusion-flax/denoising_diffusion_flax/train.py:252, in p_loss(rng, state, batch, ddpm_params, loss_fn, self_condition, is_pred_x0, pmap_axis)
    248 def p_loss(rng, state, batch, ddpm_params, loss_fn, self_condition=False, is_pred_x0=False, pmap_axis='batch'):
    249     
    250     # run the forward diffusion process to generate noisy image x_t at timestep t
    251     x = batch['image']
--> 252     assert x.dtype in [jnp.float32, jnp.float64]
    254     # create batched timesteps: t with shape (B,)
    255     B, H, W, C = x.shape

AssertionError:

get_dataset shown below with fixing lines commented out

def get_dataset(rng, config):
    
    if config.data.batch_size % jax.device_count() > 0:
        raise ValueError('Batch size must be divisible by the number of devices')
    
    batch_size = config.data.batch_size //jax.process_count()

    platform = jax.local_devices()[0].platform
    if config.training.half_precision:
        if platform == 'tpu':
            # input_dtype = tf.bfloat16
            input_dtype = jnp.bfloat16
        else:
            # input_dtype = tf.float16
            input_dtype = jnp.float16
    else: 
        input_dtype = tf.float32
        # input_dtype = jnp.float32

For anyone reading I'm using 0.3.21 CUDA (not TPU).

@yiyixuxu
Copy link
Owner

yiyixuxu commented Nov 1, 2022

Hi @xmax1

Jax doesn't recognise tf types as jnp - this line in get_dataset prefetch your batch to the devices and shard it for you - so you shouldn't need to change the input_dtype if the prefetch_to_device function runs correctly (I just run the notebook with GPU on colab and it seems fine)

it = jax_utils.prefetch_to_device(it, 2)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants