-
Notifications
You must be signed in to change notification settings - Fork 101
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[Feature Request] Support different
randomness settings to train an ensemble of models with TorchOpt
#996
Comments
Thanks for the thorough writeup @Benjamin-eecs! To check a couple things, is this a current bottleneck for your examples? We had been under the assumption that the training would be much more expensive than the training but that may not be true (or it may be fair that we're losing out on performance by not doing this) We're also currently looking at different ways that JAX libraries build neural nets and this is a great axis I hadn't thought of before. It looks like you might be using Flax or Haiku and I was wondering if you had tried this with Equinox at all? cc @zou3519 This seems to be the same thing that the federated learning people were asking for. I forget if we got clear answer for them |
Hi there @samdow , thanks for your quick and detailed feedback.
I think I can call it bottleneck in some way, we can definitely initialize the ensemble of models and optimizers using for-loop. But our TorchOpt example mainly wants to show that we can support
I am not sure I fully understood, the Jax code snippet I showed in the writeup just to present that our TorchOpt example change the functorch example into Jax-style with extra optimizer such as adam other than sgd. |
To be clear, if this is the end goal, it will probably always be easier to write this as a for loop. Most of the hyperparameters are scalar values and right now we can't vmap over lists or tensors of scalar values (1D tensors that we vmap over are going to be treated as scalar tensors instead of scalars). As an example, if we had an ensemble of models like the ones in the TorchOpt PR but where the hidden dimension was being changed: MLP(nn.Module)
def __init__(self, hidden_dim=32, n_classes=2):
...
self.fc1 = nn.Linear(2, self.hidden_dim)
self.fc2 = nn.Linear(self.hidden_dim, self.n_classes)
... we would never be able to vmap over different values for
I see! Thanks for that clarification. I saw |
Motivation
We recently used TorchOpt as a functional optimizer API mentioned in functorch parallel training example to achieve batchable optimization training small neural networks on one GPU with
functorch.vmap
.With TorchOpt, we can mimic the jax implementation to use vmap on the init function:
JAX:
TorchOpt + functorch:
instead of
combine_state_for_ensemble
However, any other
randomness
setting infunctorch.vmap(init_fn)
threw a bug (i.e. ifrandomness='different'
).functorch.vmap(init_fn, randomness='same')
gives identical inits for each net in the ensemble, which is not desirable if we want to train ensembles averaging across random seeds, thereforefunctorch.vmap(init_fn)
supporting different randomness settings is a needed feature in this kind of usage.cc @waterhorse1 @JieRen98 @XuehaiPan
Solution
metaopt/torchopt#32 can be runned with
functorch.vmap(init_fn, randomness='different')
.Resource
Checklist
combine_state_for_ensemble
for initialization of an ensemble of models and related issue Train ensemble models with vmap #782 to ask for implemention for this usage but my request is more on giving an specific usage that requires this feature.The text was updated successfully, but these errors were encountered: