From 14f065266faf1a44f3cc16daf65b07c946c50938 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Wed, 20 Jul 2022 02:06:47 +0800 Subject: [PATCH 1/9] feat(examples): add functorch vmap support examples --- examples/functorch/parallel_train.py | 166 +++++++++++++++++ examples/functorch/parallel_train_torchopt.py | 175 ++++++++++++++++++ 2 files changed, 341 insertions(+) create mode 100644 examples/functorch/parallel_train.py create mode 100644 examples/functorch/parallel_train_torchopt.py diff --git a/examples/functorch/parallel_train.py b/examples/functorch/parallel_train.py new file mode 100644 index 00000000..42fea035 --- /dev/null +++ b/examples/functorch/parallel_train.py @@ -0,0 +1,166 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +# Adapted from http://willwhitney.com/parallel-training-jax.html , which is a +# tutorial on Model Ensembling with JAX by Will Whitney. +# +# The original code comes with the following citation: +# @misc{Whitney2021Parallelizing, +# author = {William F. Whitney}, +# title = { {Parallelizing neural networks on one GPU with JAX} }, +# year = {2021}, +# url = {http://willwhitney.com/parallel-training-jax.html}, +# } + +# GOAL: Demonstrate that it is possible to use eager-mode vmap +# to parallelize training over models. + +import argparse +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functorch import combine_state_for_ensemble, grad_and_value, make_functional, vmap + + +parser = argparse.ArgumentParser(description="Functorch Ensembled Models") +parser.add_argument( + "--device", + type=str, + default="cpu", + help="CPU or GPU ID for this process (default: 'cpu')", +) +args = parser.parse_args() + +DEVICE = args.device + +# Step 1: Make some spirals + + +def make_spirals(n_samples, noise_std=0., rotations=1.): + ts = torch.linspace(0, 1, n_samples, device=DEVICE) + rs = ts**0.5 + thetas = rs * rotations * 2 * math.pi + signs = torch.randint(0, 2, (n_samples,), device=DEVICE) * 2 - 1 + labels = (signs > 0).to(torch.long).to(DEVICE) + + xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std + ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std + points = torch.stack([xs, ys], dim=1) + return points, labels + + +points, labels = make_spirals(100, noise_std=0.05) + + +# Step 2: Define two-layer MLP and loss function +class MLPClassifier(nn.Module): + + def __init__(self, hidden_dim=32, n_classes=2): + super().__init__() + self.hidden_dim = hidden_dim + self.n_classes = n_classes + + self.fc1 = nn.Linear(2, self.hidden_dim) + self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) + + def forward(self, x): + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + x = F.log_softmax(x, -1) + return x + + +loss_fn = nn.NLLLoss() + +# Step 3: Make the model functional(!!) and define a training function. +# NB: this mechanism doesn't exist in PyTorch today, but we want it to: +# https://github.com/pytorch/pytorch/issues/49171 +func_model, weights = make_functional(MLPClassifier().to(DEVICE)) + + +def train_step_fn(weights, batch, targets, lr=0.2): + + def compute_loss(weights, batch, targets): + output = func_model(weights, batch) + loss = loss_fn(output, targets) + return loss + + grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets) + + # NB: PyTorch is missing a "functional optimizer API" (possibly coming soon) + # so we are going to re-implement SGD here. + new_weights = [] + with torch.no_grad(): + for grad_weight, weight in zip(grad_weights, weights): + new_weights.append(weight - grad_weight * lr) + + return loss, new_weights + + +# Step 4: Let's verify this actually trains. +# We should see the loss decrease. +def step4(): + global weights + for i in range(2000): + loss, weights = train_step_fn(weights, points, labels) + if i % 100 == 0: + print(loss) + + +step4() + +# Step 5: We're ready for multiple models. Let's define an init_fn +# that, given a number of models, returns to us all of the weights. + + +def init_fn(num_models): + models = [MLPClassifier().to(DEVICE) for _ in range(num_models)] + _, params, _ = combine_state_for_ensemble(models) + return params + + +# Step 6: Now, can we try multiple models at the same time? +# The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps +# on decreasing + + +def step6(): + parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None)) + batched_weights = init_fn(num_models=2) + for i in range(2000): + loss, batched_weights = parallel_train_step_fn(batched_weights, points, labels) + if i % 200 == 0: + print(loss) + + +step6() + +# Step 7: Now, the flaw with step 6 is that we were training on the same exact +# data. This can lead to all of the models in the ensemble overfitting in the +# same way. The solution that http://willwhitney.com/parallel-training-jax.html +# applies is to randomly subset the data in a way that the models do not recieve +# exactly the same data in each training step! +# Because the goal of this doc is to show that we can use eager-mode vmap to +# achieve similar things as JAX, the rest of this is left as an exercise to the reader. + +# In conclusion, to achieve what http://willwhitney.com/parallel-training-jax.html +# does, we used the following additional items that PyTorch does not have: +# 1. NN module functional API that turns a module into a (state, state_less_fn) pair +# 2. Functional optimizers +# 3. A "functional" grad API (that effectively wraps autograd.grad) +# 4. Composability between the functional grad API and torch.vmap. diff --git a/examples/functorch/parallel_train_torchopt.py b/examples/functorch/parallel_train_torchopt.py new file mode 100644 index 00000000..eef3e5d4 --- /dev/null +++ b/examples/functorch/parallel_train_torchopt.py @@ -0,0 +1,175 @@ +# Copyright 2022 MetaOPT Team. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +import argparse +import math +from collections import namedtuple +from typing import Any, NamedTuple + +import torch +import torch.nn as nn +import torch.nn.functional as F +from functorch import combine_state_for_ensemble, grad_and_value, make_functional, vmap + +import torchopt + + +def make_spirals(n_samples, noise_std=0., rotations=1.): + ts = torch.linspace(0, 1, n_samples, device=DEVICE) + rs = ts**0.5 + thetas = rs * rotations * 2 * math.pi + signs = torch.randint(0, 2, (n_samples,), device=DEVICE) * 2 - 1 + labels = (signs > 0).to(torch.long).to(DEVICE) + + xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std + ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std + points = torch.stack([xs, ys], dim=1) + return points, labels + + +class MLPClassifier(nn.Module): + + def __init__(self, hidden_dim=32, n_classes=2): + super().__init__() + self.hidden_dim = hidden_dim + self.n_classes = n_classes + + self.fc1 = nn.Linear(2, self.hidden_dim) + self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) + + def forward(self, x): + x = self.fc1(x) + x = F.relu(x) + x = self.fc2(x) + x = F.log_softmax(x, -1) + return x + + +class Net(nn.Module): + + def __init__(self, dim): + super().__init__() + self.fc = nn.Linear(dim, 1, bias=True) + nn.init.ones_(self.fc.weight) + nn.init.zeros_(self.fc.bias) + + def forward(self, x): + return self.fc(x) + + +def train_step_fn(training_state, batch, targets): + weights, opt_state = training_state + + def compute_loss(weights, batch, targets): + output = func_model(weights, batch) + loss = loss_fn(output, targets) + return loss + + grads, loss = grad_and_value(compute_loss)(weights, batch, targets) + + # functional optimizer API is here now + # new_opt_state0 = opt_state[0]._asdict() + # for k, v in new_opt_state0.items(): + # if type(v) is tuple: + # new_opt_state0[k] = tuple(v_el.clone() for v_el in v) + # new_opt_state = (opt_state[0]._make(new_opt_state0.values()), opt_state[1]) + + updates, new_opt_state = optimizer.update(grads, opt_state) + new_weights = torchopt.apply_updates(weights, updates) + # Default `inplace=True` gave me an error + # weights = torchopt.apply_updates(weights, updates, inplace=False) + return loss, (new_weights, new_opt_state) + + +def step4(weights, opt_state): + for i in range(2000): + loss, (weights, opt_state) = train_step_fn((weights, opt_state), points, labels) + if i % 100 == 0: + print(loss) + + +def init_fn(model_idx): + print(model_idx) + # models = [MLPClassifier().to(DEVICE) for _ in range(model_idx)] + # print(len(models)) + # print(models) + # _, weights, _ = combine_state_for_ensemble(models) + #print(weights) + _, weights = make_functional(Net(4).to(DEVICE)) + opt_state = optimizer.init(weights) + print(weights) + #print(opt_state) + print(opt_state) + return weights, opt_state + + +def step6(num_models): + parallel_init_fn = vmap(init_fn, randomness='same') + parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None)) + weights, opt_state = parallel_init_fn(torch.ones(num_models, 1)) + for i in range(2000): + loss, (weights, opt_states) = parallel_train_step_fn((weights, opt_state), points, labels) + if i % 200 == 0: + print(loss) + + +if __name__ == '__main__': + # Adapted from http://willwhitney.com/parallel-training-jax.html , which is a + # tutorial on Model Ensembling with JAX by Will Whitney. + # + # The original code comes with the following citation: + # @misc{Whitney2021Parallelizing, + # author = {William F. Whitney}, + # title = { {Parallelizing neural networks on one GPU with JAX} }, + # year = {2021}, + # url = {http://willwhitney.com/parallel-training-jax.html}, + # } + + # GOAL: Demonstrate that it is possible to use eager-mode vmap + # to parallelize training over models. + parser = argparse.ArgumentParser(description="Functorch Ensembled Models with TorchOpt") + parser.add_argument( + "--device", + type=str, + default="cpu", + help="CPU or GPU ID for this process (default: 'cpu')", + ) + args = parser.parse_args() + + DEVICE = args.device + # Step 1: Make some spirals + points, labels = make_spirals(100, noise_std=0.05) + # Step 2: Define two-layer MLP and loss function + loss_fn = nn.NLLLoss() + # Step 3: Make the model functional(!!) and define a training function. + func_model, weights = make_functional(MLPClassifier().to(DEVICE)) + optimizer = torchopt.adam(lr=0.2) + opt_state = optimizer.init(weights) + # Step 4: Let's verify this actually trains. + # We should see the loss decrease. + step4(weights, opt_state) + # Step 5: We're ready for multiple models. Let's define an init_fn + # that, given a number of models, returns to us all of the weights. + # Step 6: Now, can we try multiple models at the same time? + # The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps + # on decreasing + step6(5) + # Step 7: Now, the flaw with step 6 is that we were training on the same exact + # data. This can lead to all of the models in the ensemble overfitting in the + # same way. The solution that http://willwhitney.com/parallel-training-jax.html + # applies is to randomly subset the data in a way that the models do not recieve + # exactly the same data in each training step! + # Because the goal of this doc is to show that we can use eager-mode vmap to + # achieve similar things as JAX, the rest of this is left as an exercise to the reader. From 4ebcf38608d5385f247c4ca3d23ca796939065b7 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Wed, 20 Jul 2022 02:18:51 +0800 Subject: [PATCH 2/9] feat(examples): add functorch vmap support examples --- examples/functorch/parallel_train.py | 10 ++++------ examples/functorch/parallel_train_torchopt.py | 14 ++++++-------- 2 files changed, 10 insertions(+), 14 deletions(-) diff --git a/examples/functorch/parallel_train.py b/examples/functorch/parallel_train.py index 42fea035..46037df3 100644 --- a/examples/functorch/parallel_train.py +++ b/examples/functorch/parallel_train.py @@ -36,11 +36,11 @@ from functorch import combine_state_for_ensemble, grad_and_value, make_functional, vmap -parser = argparse.ArgumentParser(description="Functorch Ensembled Models") +parser = argparse.ArgumentParser(description='Functorch Ensembled Models') parser.add_argument( - "--device", + '--device', type=str, - default="cpu", + default='cpu', help="CPU or GPU ID for this process (default: 'cpu')", ) args = parser.parse_args() @@ -50,7 +50,7 @@ # Step 1: Make some spirals -def make_spirals(n_samples, noise_std=0., rotations=1.): +def make_spirals(n_samples, noise_std=0.0, rotations=1.0): ts = torch.linspace(0, 1, n_samples, device=DEVICE) rs = ts**0.5 thetas = rs * rotations * 2 * math.pi @@ -68,7 +68,6 @@ def make_spirals(n_samples, noise_std=0., rotations=1.): # Step 2: Define two-layer MLP and loss function class MLPClassifier(nn.Module): - def __init__(self, hidden_dim=32, n_classes=2): super().__init__() self.hidden_dim = hidden_dim @@ -94,7 +93,6 @@ def forward(self, x): def train_step_fn(weights, batch, targets, lr=0.2): - def compute_loss(weights, batch, targets): output = func_model(weights, batch) loss = loss_fn(output, targets) diff --git a/examples/functorch/parallel_train_torchopt.py b/examples/functorch/parallel_train_torchopt.py index eef3e5d4..90913e8d 100644 --- a/examples/functorch/parallel_train_torchopt.py +++ b/examples/functorch/parallel_train_torchopt.py @@ -26,7 +26,7 @@ import torchopt -def make_spirals(n_samples, noise_std=0., rotations=1.): +def make_spirals(n_samples, noise_std=0.0, rotations=1.0): ts = torch.linspace(0, 1, n_samples, device=DEVICE) rs = ts**0.5 thetas = rs * rotations * 2 * math.pi @@ -40,7 +40,6 @@ def make_spirals(n_samples, noise_std=0., rotations=1.): class MLPClassifier(nn.Module): - def __init__(self, hidden_dim=32, n_classes=2): super().__init__() self.hidden_dim = hidden_dim @@ -58,7 +57,6 @@ def forward(self, x): class Net(nn.Module): - def __init__(self, dim): super().__init__() self.fc = nn.Linear(dim, 1, bias=True) @@ -106,11 +104,11 @@ def init_fn(model_idx): # print(len(models)) # print(models) # _, weights, _ = combine_state_for_ensemble(models) - #print(weights) + # print(weights) _, weights = make_functional(Net(4).to(DEVICE)) opt_state = optimizer.init(weights) print(weights) - #print(opt_state) + # print(opt_state) print(opt_state) return weights, opt_state @@ -139,11 +137,11 @@ def step6(num_models): # GOAL: Demonstrate that it is possible to use eager-mode vmap # to parallelize training over models. - parser = argparse.ArgumentParser(description="Functorch Ensembled Models with TorchOpt") + parser = argparse.ArgumentParser(description='Functorch Ensembled Models with TorchOpt') parser.add_argument( - "--device", + '--device', type=str, - default="cpu", + default='cpu', help="CPU or GPU ID for this process (default: 'cpu')", ) args = parser.parse_args() From e959bc0fb7e23015dfc3e1f7708c8734979dbb8a Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 21 Jul 2022 21:05:53 +0800 Subject: [PATCH 3/9] feat: working functorch + torchopt parallel training example --- examples/functorch/parallel_train.py | 164 ----------------- examples/functorch/parallel_train_torchopt.py | 171 +++++++++++------- torchopt/_src/transform.py | 37 +++- 3 files changed, 129 insertions(+), 243 deletions(-) delete mode 100644 examples/functorch/parallel_train.py diff --git a/examples/functorch/parallel_train.py b/examples/functorch/parallel_train.py deleted file mode 100644 index 46037df3..00000000 --- a/examples/functorch/parallel_train.py +++ /dev/null @@ -1,164 +0,0 @@ -# Copyright 2022 MetaOPT Team. All Rights Reserved. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# ============================================================================== - -# Adapted from http://willwhitney.com/parallel-training-jax.html , which is a -# tutorial on Model Ensembling with JAX by Will Whitney. -# -# The original code comes with the following citation: -# @misc{Whitney2021Parallelizing, -# author = {William F. Whitney}, -# title = { {Parallelizing neural networks on one GPU with JAX} }, -# year = {2021}, -# url = {http://willwhitney.com/parallel-training-jax.html}, -# } - -# GOAL: Demonstrate that it is possible to use eager-mode vmap -# to parallelize training over models. - -import argparse -import math - -import torch -import torch.nn as nn -import torch.nn.functional as F -from functorch import combine_state_for_ensemble, grad_and_value, make_functional, vmap - - -parser = argparse.ArgumentParser(description='Functorch Ensembled Models') -parser.add_argument( - '--device', - type=str, - default='cpu', - help="CPU or GPU ID for this process (default: 'cpu')", -) -args = parser.parse_args() - -DEVICE = args.device - -# Step 1: Make some spirals - - -def make_spirals(n_samples, noise_std=0.0, rotations=1.0): - ts = torch.linspace(0, 1, n_samples, device=DEVICE) - rs = ts**0.5 - thetas = rs * rotations * 2 * math.pi - signs = torch.randint(0, 2, (n_samples,), device=DEVICE) * 2 - 1 - labels = (signs > 0).to(torch.long).to(DEVICE) - - xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std - ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std - points = torch.stack([xs, ys], dim=1) - return points, labels - - -points, labels = make_spirals(100, noise_std=0.05) - - -# Step 2: Define two-layer MLP and loss function -class MLPClassifier(nn.Module): - def __init__(self, hidden_dim=32, n_classes=2): - super().__init__() - self.hidden_dim = hidden_dim - self.n_classes = n_classes - - self.fc1 = nn.Linear(2, self.hidden_dim) - self.fc2 = nn.Linear(self.hidden_dim, self.n_classes) - - def forward(self, x): - x = self.fc1(x) - x = F.relu(x) - x = self.fc2(x) - x = F.log_softmax(x, -1) - return x - - -loss_fn = nn.NLLLoss() - -# Step 3: Make the model functional(!!) and define a training function. -# NB: this mechanism doesn't exist in PyTorch today, but we want it to: -# https://github.com/pytorch/pytorch/issues/49171 -func_model, weights = make_functional(MLPClassifier().to(DEVICE)) - - -def train_step_fn(weights, batch, targets, lr=0.2): - def compute_loss(weights, batch, targets): - output = func_model(weights, batch) - loss = loss_fn(output, targets) - return loss - - grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets) - - # NB: PyTorch is missing a "functional optimizer API" (possibly coming soon) - # so we are going to re-implement SGD here. - new_weights = [] - with torch.no_grad(): - for grad_weight, weight in zip(grad_weights, weights): - new_weights.append(weight - grad_weight * lr) - - return loss, new_weights - - -# Step 4: Let's verify this actually trains. -# We should see the loss decrease. -def step4(): - global weights - for i in range(2000): - loss, weights = train_step_fn(weights, points, labels) - if i % 100 == 0: - print(loss) - - -step4() - -# Step 5: We're ready for multiple models. Let's define an init_fn -# that, given a number of models, returns to us all of the weights. - - -def init_fn(num_models): - models = [MLPClassifier().to(DEVICE) for _ in range(num_models)] - _, params, _ = combine_state_for_ensemble(models) - return params - - -# Step 6: Now, can we try multiple models at the same time? -# The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps -# on decreasing - - -def step6(): - parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None)) - batched_weights = init_fn(num_models=2) - for i in range(2000): - loss, batched_weights = parallel_train_step_fn(batched_weights, points, labels) - if i % 200 == 0: - print(loss) - - -step6() - -# Step 7: Now, the flaw with step 6 is that we were training on the same exact -# data. This can lead to all of the models in the ensemble overfitting in the -# same way. The solution that http://willwhitney.com/parallel-training-jax.html -# applies is to randomly subset the data in a way that the models do not recieve -# exactly the same data in each training step! -# Because the goal of this doc is to show that we can use eager-mode vmap to -# achieve similar things as JAX, the rest of this is left as an exercise to the reader. - -# In conclusion, to achieve what http://willwhitney.com/parallel-training-jax.html -# does, we used the following additional items that PyTorch does not have: -# 1. NN module functional API that turns a module into a (state, state_less_fn) pair -# 2. Functional optimizers -# 3. A "functional" grad API (that effectively wraps autograd.grad) -# 4. Composability between the functional grad API and torch.vmap. diff --git a/examples/functorch/parallel_train_torchopt.py b/examples/functorch/parallel_train_torchopt.py index 90913e8d..203198cd 100644 --- a/examples/functorch/parallel_train_torchopt.py +++ b/examples/functorch/parallel_train_torchopt.py @@ -18,6 +18,7 @@ from collections import namedtuple from typing import Any, NamedTuple +import functorch import torch import torch.nn as nn import torch.nn.functional as F @@ -56,71 +57,89 @@ def forward(self, x): return x -class Net(nn.Module): - def __init__(self, dim): - super().__init__() - self.fc = nn.Linear(dim, 1, bias=True) - nn.init.ones_(self.fc.weight) - nn.init.zeros_(self.fc.bias) - - def forward(self, x): - return self.fc(x) - - -def train_step_fn(training_state, batch, targets): - weights, opt_state = training_state - - def compute_loss(weights, batch, targets): - output = func_model(weights, batch) - loss = loss_fn(output, targets) - return loss - - grads, loss = grad_and_value(compute_loss)(weights, batch, targets) - - # functional optimizer API is here now - # new_opt_state0 = opt_state[0]._asdict() - # for k, v in new_opt_state0.items(): - # if type(v) is tuple: - # new_opt_state0[k] = tuple(v_el.clone() for v_el in v) - # new_opt_state = (opt_state[0]._make(new_opt_state0.values()), opt_state[1]) - - updates, new_opt_state = optimizer.update(grads, opt_state) - new_weights = torchopt.apply_updates(weights, updates) - # Default `inplace=True` gave me an error - # weights = torchopt.apply_updates(weights, updates, inplace=False) - return loss, (new_weights, new_opt_state) - - -def step4(weights, opt_state): - for i in range(2000): - loss, (weights, opt_state) = train_step_fn((weights, opt_state), points, labels) - if i % 100 == 0: - print(loss) - - -def init_fn(model_idx): - print(model_idx) - # models = [MLPClassifier().to(DEVICE) for _ in range(model_idx)] - # print(len(models)) - # print(models) - # _, weights, _ = combine_state_for_ensemble(models) - # print(weights) - _, weights = make_functional(Net(4).to(DEVICE)) - opt_state = optimizer.init(weights) - print(weights) - # print(opt_state) - print(opt_state) - return weights, opt_state - - -def step6(num_models): - parallel_init_fn = vmap(init_fn, randomness='same') - parallel_train_step_fn = vmap(train_step_fn, in_dims=(0, None, None)) - weights, opt_state = parallel_init_fn(torch.ones(num_models, 1)) - for i in range(2000): - loss, (weights, opt_states) = parallel_train_step_fn((weights, opt_state), points, labels) - if i % 200 == 0: - print(loss) +class ParallelTrainFunctorchOriginal: + def __init__(self, loss_fn, lr): + self.loss_fn = loss_fn + self.lr = lr + self.func_model, _ = make_functional(MLPClassifier().to(DEVICE)) + + def init_fn(self, num_models): + models = [MLPClassifier().to(DEVICE) for _ in range(num_models)] + _, batched_weights, _ = combine_state_for_ensemble(models) + return batched_weights + + def train_step_fn(self, weights, batch, targets): + def compute_loss(weights, batch, targets): + output = self.func_model(weights, batch) + loss = self.loss_fn(output, targets) + return loss + + grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets) + # NB: PyTorch is missing a "functional optimizer API" (possibly coming soon) + # so we are going to re-implement SGD here. + new_weights = [] + with torch.no_grad(): + for grad_weight, weight in zip(grad_weights, weights): + new_weights.append(weight - grad_weight * self.lr) + + return loss, new_weights + + def test_train_step_fn(self, weights, points, labels): + for i in range(2000): + loss, weights = self.train_step_fn(weights, points, labels) + if i % 100 == 0: + print(loss) + + def test_parallel_train_step_fn(self, num_models): + parallel_train_step_fn = vmap(self.train_step_fn, in_dims=(0, None, None)) + batched_weights = self.init_fn(num_models=num_models) + for i in range(2000): + loss, batched_weights = parallel_train_step_fn(batched_weights, points, labels) + if i % 200 == 0: + print(loss) + + +class ParallelTrainFunctorchTorchOpt: + def __init__(self, loss_fn, optimizer): + self.loss_fn = loss_fn + self.optimizer = optimizer + self.func_model, _ = make_functional(MLPClassifier().to(DEVICE)) + + def init_fn(self, model_idx): + _, weights = make_functional(MLPClassifier().to(DEVICE)) + opt_state = self.optimizer.init(weights) + return weights, opt_state + + def train_step_fn(self, training_state, batch, targets): + weights, opt_state = training_state + + def compute_loss(weights, batch, targets): + output = self.func_model(weights, batch) + loss = self.loss_fn(output, targets) + return loss + + grads, loss = grad_and_value(compute_loss)(weights, batch, targets) + # functional optimizer API is here now + updates, new_opt_state = optimizer.update(grads, opt_state, inplace=False) + new_weights = torchopt.apply_updates(weights, updates, inplace=False) + return loss, (new_weights, new_opt_state) + + def test_train_step_fn(self, weights, opt_state, points, labels): + for i in range(2000): + loss, (weights, opt_state) = self.train_step_fn((weights, opt_state), points, labels) + if i % 100 == 0: + print(loss) + + def test_parallel_train_step_fn(self, num_models): + parallel_init_fn = vmap(self.init_fn, randomness='same') + parallel_train_step_fn = vmap(self.train_step_fn, in_dims=(0, None, None)) + weights, opt_state = parallel_init_fn(torch.ones(num_models, 1)) + for i in range(2000): + loss, (weights, opt_states) = parallel_train_step_fn( + (weights, opt_state), points, labels + ) + if i % 200 == 0: + print(loss) if __name__ == '__main__': @@ -136,7 +155,7 @@ def step6(num_models): # } # GOAL: Demonstrate that it is possible to use eager-mode vmap - # to parallelize training over models. + parser = argparse.ArgumentParser(description='Functorch Ensembled Models with TorchOpt') parser.add_argument( '--device', @@ -153,17 +172,29 @@ def step6(num_models): loss_fn = nn.NLLLoss() # Step 3: Make the model functional(!!) and define a training function. func_model, weights = make_functional(MLPClassifier().to(DEVICE)) + + # original functorch implementation + functorch_original = ParallelTrainFunctorchOriginal(loss_fn=loss_fn, lr=0.2) + # Step 4: Let's verify this actually trains. + # We should see the loss decrease. + functorch_original.test_train_step_fn(weights, points, labels) + # Step 6: Now, can we try multiple models at the same time? + # The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps + # on decreasing + functorch_original.test_parallel_train_step_fn(num_models=2) + + # functorch + torchopt implementation optimizer = torchopt.adam(lr=0.2) opt_state = optimizer.init(weights) + functorch_original = ParallelTrainFunctorchTorchOpt(loss_fn=loss_fn, optimizer=optimizer) # Step 4: Let's verify this actually trains. # We should see the loss decrease. - step4(weights, opt_state) - # Step 5: We're ready for multiple models. Let's define an init_fn - # that, given a number of models, returns to us all of the weights. + functorch_original.test_train_step_fn(weights, opt_state, points, labels) # Step 6: Now, can we try multiple models at the same time? # The answer is: yes! `loss` is a 2-tuple, and we can see that the value keeps # on decreasing - step6(5) + functorch_original.test_parallel_train_step_fn(num_models=2) + # Step 7: Now, the flaw with step 6 is that we were training on the same exact # data. This can lead to all of the models in the ensemble overfitting in the # same way. The solution that http://willwhitney.com/parallel-training-jax.html diff --git a/torchopt/_src/transform.py b/torchopt/_src/transform.py index a04d49b5..2d196cb3 100644 --- a/torchopt/_src/transform.py +++ b/torchopt/_src/transform.py @@ -38,17 +38,27 @@ import torch from torchopt._src import base -from torchopt._src.typing import Schedule +from torchopt._src.typing import Numeric, Schedule ScaleState = base.EmptyState -def inc_count(updates, count: Tuple[int]) -> Tuple[int]: - """Increments int counter by one.""" +def inc_count(updates, count: Tuple[Numeric, ...]) -> Tuple[Numeric, ...]: + """Increments int counter by one. + + Returns: + A counter incremeted by one, or max_int if the maximum precision is reached. + """ + max_int32_value = torch.iinfo(torch.int32).max + one = torch.ones(1, dtype=torch.int32, device=count[0].device) def f(c, g): - return c + 1 if g is not None else c + return ( + c + (1 - torch.div(c, max_int32_value, rounding_mode='trunc')) * one + if g is not None + else c + ) return jax.tree_map(f, count, updates) @@ -87,7 +97,7 @@ def f(g): class ScaleByScheduleState(NamedTuple): """Maintains count for scale scheduling.""" - count: Tuple[int, ...] # type: ignore + count: Tuple[Numeric, ...] # type: ignore def scale_by_schedule(step_size_fn: Schedule) -> base.GradientTransformation: @@ -103,7 +113,10 @@ def scale_by_schedule(step_size_fn: Schedule) -> base.GradientTransformation: """ def init_fn(params): - return ScaleByScheduleState(count=tuple(0 for _ in range(len(params)))) + zero = jax.tree_map( # First moment + lambda t: torch.zeros(1, dtype=torch.int32, device=t.device), params + ) + return ScaleByScheduleState(count=tuple(zero)) def update_fn(updates, state, inplace=True): step_size = step_size_fn(state.count) @@ -149,7 +162,7 @@ def f(g, t): class ScaleByAdamState(NamedTuple): """State for the Adam algorithm.""" - count: Tuple[int, ...] # type: ignore + count: Tuple[Numeric, ...] # type: ignore mu: base.Updates nu: base.Updates @@ -199,13 +212,16 @@ def scale_by_adam( """ def init_fn(params): + zero = jax.tree_map( # First moment + lambda t: torch.zeros(1, dtype=torch.int32, device=t.device), params + ) mu = jax.tree_map( # First moment lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params ) nu = jax.tree_map( # Second moment lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params ) - return ScaleByAdamState(count=tuple(0 for _ in range(len(mu))), mu=tuple(mu), nu=tuple(nu)) + return ScaleByAdamState(count=tuple(zero), mu=tuple(mu), nu=tuple(nu)) def update_fn(updates, state, inplace=True): mu = _update_moment(updates, state.mu, b1, 1, inplace) @@ -262,13 +278,16 @@ def scale_by_accelerated_adam( from torchopt._src.accelerated_op import AdamOp # pylint: disable=import-outside-toplevel def init_fn(params): + zero = jax.tree_map( # First moment + lambda t: torch.zeros(1, dtype=torch.int32, device=t.device), params + ) mu = jax.tree_map( # First moment lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params ) nu = jax.tree_map( # Second moment lambda t: torch.zeros_like(t, requires_grad=moment_requires_grad), params ) - return ScaleByAdamState(count=tuple(0 for _ in range(len(params))), mu=mu, nu=nu) + return ScaleByAdamState(count=tuple(zero), mu=mu, nu=nu) def update_fn(updates, state, inplace=True): count_inc = inc_count(updates, state.count) From c9ecd25454ea2221471721369c5d12f385c59f61 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 21 Jul 2022 21:29:48 +0800 Subject: [PATCH 4/9] fix: pass lint --- torchopt/_src/transform.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/torchopt/_src/transform.py b/torchopt/_src/transform.py index 2d196cb3..45779860 100644 --- a/torchopt/_src/transform.py +++ b/torchopt/_src/transform.py @@ -38,13 +38,13 @@ import torch from torchopt._src import base -from torchopt._src.typing import Numeric, Schedule +from torchopt._src.typing import Schedule ScaleState = base.EmptyState -def inc_count(updates, count: Tuple[Numeric, ...]) -> Tuple[Numeric, ...]: +def inc_count(updates, count: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: """Increments int counter by one. Returns: @@ -97,7 +97,7 @@ def f(g): class ScaleByScheduleState(NamedTuple): """Maintains count for scale scheduling.""" - count: Tuple[Numeric, ...] # type: ignore + count: Tuple[torch.Tensor, ...] # type: ignore def scale_by_schedule(step_size_fn: Schedule) -> base.GradientTransformation: @@ -162,7 +162,7 @@ def f(g, t): class ScaleByAdamState(NamedTuple): """State for the Adam algorithm.""" - count: Tuple[Numeric, ...] # type: ignore + count: Tuple[torch.Tensor, ...] # type: ignore mu: base.Updates nu: base.Updates From 05960bc50725534162a2f12cb6ecdd17fc359fa9 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Sat, 6 Aug 2022 17:02:14 +0800 Subject: [PATCH 5/9] fix: pass lint --- examples/functorch/parallel_train_torchopt.py | 4 ++-- tests/unit/low_level/test_low_level_inplace.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/examples/functorch/parallel_train_torchopt.py b/examples/functorch/parallel_train_torchopt.py index 203198cd..06287f70 100644 --- a/examples/functorch/parallel_train_torchopt.py +++ b/examples/functorch/parallel_train_torchopt.py @@ -18,13 +18,13 @@ from collections import namedtuple from typing import Any, NamedTuple -import functorch import torch import torch.nn as nn import torch.nn.functional as F -from functorch import combine_state_for_ensemble, grad_and_value, make_functional, vmap +import functorch import torchopt +from functorch import combine_state_for_ensemble, grad_and_value, make_functional, vmap def make_spirals(n_samples, noise_std=0.0, rotations=1.0): diff --git a/tests/unit/low_level/test_low_level_inplace.py b/tests/unit/low_level/test_low_level_inplace.py index 09f39ec9..aae337ec 100644 --- a/tests/unit/low_level/test_low_level_inplace.py +++ b/tests/unit/low_level/test_low_level_inplace.py @@ -16,13 +16,13 @@ import copy import unittest -import functorch import pytest import torch import torch.nn.functional as F from torch.utils import data from torchvision import models +import functorch import torchopt From acba2e1f8e768b080f51e473efb2f56b54ffe819 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Wed, 10 Aug 2022 23:51:13 +0800 Subject: [PATCH 6/9] fix: update comment --- torchopt/_src/transform.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/torchopt/_src/transform.py b/torchopt/_src/transform.py index 0a1c542c..d4496206 100644 --- a/torchopt/_src/transform.py +++ b/torchopt/_src/transform.py @@ -113,7 +113,7 @@ def scale_by_schedule(step_size_fn: Schedule) -> base.GradientTransformation: """ def init_fn(params): - zero = pytree.tree_map( # First moment + zero = pytree.tree_map( # Count init lambda t: torch.zeros(1, dtype=torch.int32, device=t.device), params ) return ScaleByScheduleState(count=tuple(zero)) @@ -212,7 +212,7 @@ def scale_by_adam( """ def init_fn(params): - zero = pytree.tree_map( # First moment + zero = pytree.tree_map( # Count init lambda t: torch.zeros(1, dtype=torch.int32, device=t.device), params ) mu = pytree.tree_map( # First moment @@ -278,7 +278,7 @@ def scale_by_accelerated_adam( from torchopt._src.accelerated_op import AdamOp # pylint: disable=import-outside-toplevel def init_fn(params): - zero = pytree.tree_map( # First moment + zero = pytree.tree_map( # Count init lambda t: torch.zeros(1, dtype=torch.int32, device=t.device), params ) mu = pytree.tree_map( # First moment From fc679adfac03cf9ab440d4b1159aa23814d801f3 Mon Sep 17 00:00:00 2001 From: Benjamin-eecs Date: Thu, 11 Aug 2022 16:57:39 +0800 Subject: [PATCH 7/9] fix: resolve comments --- CHANGELOG.md | 2 + .../parallel_train_torchopt.py | 49 ++++++++++--------- .../unit/low_level/test_low_level_inplace.py | 2 +- torchopt/_src/transform.py | 4 +- 4 files changed, 31 insertions(+), 26 deletions(-) rename examples/{functorch => FuncTorch}/parallel_train_torchopt.py (81%) diff --git a/CHANGELOG.md b/CHANGELOG.md index a7a35900..7e6de847 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -14,6 +14,8 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 ### Added - Add question/help/support issue template [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#43](https://github.com/metaopt/TorchOpt/pull/43). +- Add parallel training on one GPU using functorch.vmap example [@Benjamin-eecs](https://github.com/Benjamin-eecs) in [#32](https://github.com/metaopt/TorchOpt/pull/32). + ### Changed diff --git a/examples/functorch/parallel_train_torchopt.py b/examples/FuncTorch/parallel_train_torchopt.py similarity index 81% rename from examples/functorch/parallel_train_torchopt.py rename to examples/FuncTorch/parallel_train_torchopt.py index 06287f70..640763cb 100644 --- a/examples/functorch/parallel_train_torchopt.py +++ b/examples/FuncTorch/parallel_train_torchopt.py @@ -18,24 +18,23 @@ from collections import namedtuple from typing import Any, NamedTuple +import functorch import torch import torch.nn as nn import torch.nn.functional as F -import functorch import torchopt -from functorch import combine_state_for_ensemble, grad_and_value, make_functional, vmap -def make_spirals(n_samples, noise_std=0.0, rotations=1.0): - ts = torch.linspace(0, 1, n_samples, device=DEVICE) +def make_spirals(n_samples, noise_std=0.0, rotations=1.0, device='cpu'): + ts = torch.linspace(0, 1, n_samples, device=device) rs = ts**0.5 thetas = rs * rotations * 2 * math.pi - signs = torch.randint(0, 2, (n_samples,), device=DEVICE) * 2 - 1 - labels = (signs > 0).to(torch.long).to(DEVICE) + signs = torch.randint(0, 2, (n_samples,), device=device) * 2 - 1 + labels = (signs > 0).to(torch.long).to(device) - xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std - ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples, device=DEVICE) * noise_std + xs = rs * signs * torch.cos(thetas) + torch.randn(n_samples, device=device) * noise_std + ys = rs * signs * torch.sin(thetas) + torch.randn(n_samples, device=device) * noise_std points = torch.stack([xs, ys], dim=1) return points, labels @@ -58,14 +57,15 @@ def forward(self, x): class ParallelTrainFunctorchOriginal: - def __init__(self, loss_fn, lr): + def __init__(self, loss_fn, lr, device): + self.device = device self.loss_fn = loss_fn self.lr = lr - self.func_model, _ = make_functional(MLPClassifier().to(DEVICE)) + self.func_model, _ = functorch.make_functional(MLPClassifier().to(self.device)) def init_fn(self, num_models): - models = [MLPClassifier().to(DEVICE) for _ in range(num_models)] - _, batched_weights, _ = combine_state_for_ensemble(models) + models = [MLPClassifier().to(self.device) for _ in range(num_models)] + _, batched_weights, _ = functorch.combine_state_for_ensemble(models) return batched_weights def train_step_fn(self, weights, batch, targets): @@ -74,7 +74,7 @@ def compute_loss(weights, batch, targets): loss = self.loss_fn(output, targets) return loss - grad_weights, loss = grad_and_value(compute_loss)(weights, batch, targets) + grad_weights, loss = functorch.grad_and_value(compute_loss)(weights, batch, targets) # NB: PyTorch is missing a "functional optimizer API" (possibly coming soon) # so we are going to re-implement SGD here. new_weights = [] @@ -91,7 +91,7 @@ def test_train_step_fn(self, weights, points, labels): print(loss) def test_parallel_train_step_fn(self, num_models): - parallel_train_step_fn = vmap(self.train_step_fn, in_dims=(0, None, None)) + parallel_train_step_fn = functorch.vmap(self.train_step_fn, in_dims=(0, None, None)) batched_weights = self.init_fn(num_models=num_models) for i in range(2000): loss, batched_weights = parallel_train_step_fn(batched_weights, points, labels) @@ -100,13 +100,14 @@ def test_parallel_train_step_fn(self, num_models): class ParallelTrainFunctorchTorchOpt: - def __init__(self, loss_fn, optimizer): + def __init__(self, loss_fn, optimizer, device): + self.device = device self.loss_fn = loss_fn self.optimizer = optimizer - self.func_model, _ = make_functional(MLPClassifier().to(DEVICE)) + self.func_model, _ = functorch.make_functional(MLPClassifier().to(self.device)) def init_fn(self, model_idx): - _, weights = make_functional(MLPClassifier().to(DEVICE)) + _, weights = functorch.make_functional(MLPClassifier().to(self.device)) opt_state = self.optimizer.init(weights) return weights, opt_state @@ -118,7 +119,7 @@ def compute_loss(weights, batch, targets): loss = self.loss_fn(output, targets) return loss - grads, loss = grad_and_value(compute_loss)(weights, batch, targets) + grads, loss = functorch.grad_and_value(compute_loss)(weights, batch, targets) # functional optimizer API is here now updates, new_opt_state = optimizer.update(grads, opt_state, inplace=False) new_weights = torchopt.apply_updates(weights, updates, inplace=False) @@ -131,8 +132,8 @@ def test_train_step_fn(self, weights, opt_state, points, labels): print(loss) def test_parallel_train_step_fn(self, num_models): - parallel_init_fn = vmap(self.init_fn, randomness='same') - parallel_train_step_fn = vmap(self.train_step_fn, in_dims=(0, None, None)) + parallel_init_fn = functorch.vmap(self.init_fn, randomness='same') + parallel_train_step_fn = functorch.vmap(self.train_step_fn, in_dims=(0, None, None)) weights, opt_state = parallel_init_fn(torch.ones(num_models, 1)) for i in range(2000): loss, (weights, opt_states) = parallel_train_step_fn( @@ -171,10 +172,10 @@ def test_parallel_train_step_fn(self, num_models): # Step 2: Define two-layer MLP and loss function loss_fn = nn.NLLLoss() # Step 3: Make the model functional(!!) and define a training function. - func_model, weights = make_functional(MLPClassifier().to(DEVICE)) + func_model, weights = functorch.make_functional(MLPClassifier().to(DEVICE)) # original functorch implementation - functorch_original = ParallelTrainFunctorchOriginal(loss_fn=loss_fn, lr=0.2) + functorch_original = ParallelTrainFunctorchOriginal(loss_fn=loss_fn, lr=0.2, device=DEVICE) # Step 4: Let's verify this actually trains. # We should see the loss decrease. functorch_original.test_train_step_fn(weights, points, labels) @@ -186,7 +187,9 @@ def test_parallel_train_step_fn(self, num_models): # functorch + torchopt implementation optimizer = torchopt.adam(lr=0.2) opt_state = optimizer.init(weights) - functorch_original = ParallelTrainFunctorchTorchOpt(loss_fn=loss_fn, optimizer=optimizer) + functorch_original = ParallelTrainFunctorchTorchOpt( + loss_fn=loss_fn, optimizer=optimizer, device=DEVICE + ) # Step 4: Let's verify this actually trains. # We should see the loss decrease. functorch_original.test_train_step_fn(weights, opt_state, points, labels) diff --git a/tests/unit/low_level/test_low_level_inplace.py b/tests/unit/low_level/test_low_level_inplace.py index aae337ec..09f39ec9 100644 --- a/tests/unit/low_level/test_low_level_inplace.py +++ b/tests/unit/low_level/test_low_level_inplace.py @@ -16,13 +16,13 @@ import copy import unittest +import functorch import pytest import torch import torch.nn.functional as F from torch.utils import data from torchvision import models -import functorch import torchopt diff --git a/torchopt/_src/transform.py b/torchopt/_src/transform.py index d4496206..bbecad60 100644 --- a/torchopt/_src/transform.py +++ b/torchopt/_src/transform.py @@ -42,6 +42,7 @@ ScaleState = base.EmptyState +MaxInt32Value = torch.iinfo(torch.int32).max def inc_count(updates, count: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: @@ -50,12 +51,11 @@ def inc_count(updates, count: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, . Returns: A counter incremeted by one, or max_int if the maximum precision is reached. """ - max_int32_value = torch.iinfo(torch.int32).max one = torch.ones(1, dtype=torch.int32, device=count[0].device) def f(c, g): return ( - c + (1 - torch.div(c, max_int32_value, rounding_mode='trunc')) * one + c + (1 - torch.div(c, MaxInt32Value, rounding_mode='trunc')) * one if g is not None else c ) From 84615d2f62029fab52c478008908f54b1e6f4a98 Mon Sep 17 00:00:00 2001 From: Bo Liu Date: Thu, 11 Aug 2022 17:09:18 +0800 Subject: [PATCH 8/9] fix: update torchopt/_src/transform.py Co-authored-by: Xuehai Pan --- torchopt/_src/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchopt/_src/transform.py b/torchopt/_src/transform.py index bbecad60..e53b0c39 100644 --- a/torchopt/_src/transform.py +++ b/torchopt/_src/transform.py @@ -42,7 +42,7 @@ ScaleState = base.EmptyState -MaxInt32Value = torch.iinfo(torch.int32).max +INT32_MAX = torch.iinfo(torch.int32).max def inc_count(updates, count: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, ...]: From aa8efd467cead2f27f8e048ba021177ca53fe921 Mon Sep 17 00:00:00 2001 From: Bo Liu Date: Thu, 11 Aug 2022 17:09:33 +0800 Subject: [PATCH 9/9] fix: update torchopt/_src/transform.py Co-authored-by: Xuehai Pan --- torchopt/_src/transform.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/torchopt/_src/transform.py b/torchopt/_src/transform.py index e53b0c39..1fb3b8f7 100644 --- a/torchopt/_src/transform.py +++ b/torchopt/_src/transform.py @@ -55,7 +55,7 @@ def inc_count(updates, count: Tuple[torch.Tensor, ...]) -> Tuple[torch.Tensor, . def f(c, g): return ( - c + (1 - torch.div(c, MaxInt32Value, rounding_mode='trunc')) * one + c + (1 - torch.div(c, INT32_MAX, rounding_mode='trunc')) * one if g is not None else c )