Skip to content

robert1003/jax-flax-examples

Repository files navigation

Contents

  • xor-classifier: a simple classifier in flax that illustrate the basic operations of flax, from dataset creation, model building, optimizer and loss construction, the training and evaluation loop, and checkpointing.
  • activation: implementation of common activation function.
  • initialization: implementation of common initialization function..
  • optimization: implementation of common optimizer.
  • flow-based-model: implementation of flow model (on MNIST generation)

Jax / Flax Tips

  • Preventing JAX from using all GPU mem
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'
  • Pytorch dataloader to Numpy dataloader (ref)
def NumpyDataLoader(dataset, **kwargs):
    def numpy_collate(batch):
        if isinstance(batch[0], np.ndarray):
            return np.stack(batch)
        elif isinstance(batch[0], (tuple, list)):
            transposed = zip(*batch)
            return [numpy_collate(samples) for samples in transposed]
        else:
            return np.array(batch)

    return DataLoader(dataset, collate_fn=numpy_collate, **kwargs)
  • Flax save/load model: ref

Reference

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages