-
Notifications
You must be signed in to change notification settings - Fork 558
Description
🐛 Bug
A script that runs fine when a single network is used to calculate output errors when the same computation is performed with the single network accessed from a ModuleList of similarly initialized networks of sufficient size
To Reproduce
-
Create a TPU VM (v3-32) with version
tpu-vm-pt-1.12
-
Run the below code with a resnet 50 outputting features for a simple pseudoloss
python3 -m torch_xla.distributed.xla_dist --tpu=$TPU_NAME --restart-tpuvm-pod-server --env XLA_USE_BF16=0 -- python3 /export/home/baseline.py
from torchvision.models import mobilenet_v3_small, resnet50
from torch import nn
import torch_xla.distributed.xla_multiprocessing as xmp
import torch_xla.core.xla_model as xm
from itertools import chain
import torch
class MultiBinaryVision(nn.Module):
def __init__(self, num_models=64):
super().__init__()
self.model_list = nn.ModuleList([ resnet50(num_classes=num_models) for _ in range(num_models)])
def forward(self, x):
return self.model_list[0](x)
class PseudoLoss(nn.Module):
def __init__(self):
super().__init__()
def forward(self, outputs):
return (outputs @ outputs.T).sum()
def broadcast_xla_master_model_param(model):
"""
Broadcast the model parameters from master process to other processes
"""
parameters_and_buffers = []
is_master = xm.is_master_ordinal(local=False)
for p in chain(model.parameters(), model.buffers()):
# Set all params in non-master devices to zero so that all_reduce is
# equivalent to broadcasting parameters from master to other devices.
scale = 1 if is_master else 0
scale = torch.tensor(scale, dtype=p.data.dtype, device=p.data.device)
p.data.mul_(scale)
parameters_and_buffers.append(p.data)
xm.all_reduce(xm.REDUCE_SUM, parameters_and_buffers)
xm.mark_step()
xm.wait_device_ops()
xm.rendezvous("broadcast_xla_master_model_param")
def main(device_id):
global_bs = 128
embed_dim = 16
device = xm.xla_device()
model = resnet50(num_classes=embed_dim).to(device)
# xm.rendezvous("sync1")
broadcast_xla_master_model_param(model)
loss_fn = PseudoLoss()
optimizer = torch.optim.AdamW(model.parameters(), lr=1e-8)
# xm.rendezvous("sync2")
print("Starting batches")
for batch_i in range(5):
optimizer.zero_grad()
print("forward")
output = model(torch.randn(global_bs//32, 3, 224, 224, device=device))
loss = loss_fn(output)
print("backward")
loss.backward()
print("step")
xm.reduce_gradients(optimizer)
optimizer.step()
tpu_cores_per_node = 8
xmp.spawn(main, nprocs=tpu_cores_per_node,
start_method='fork')
This runs in ~1 minute w/o errors
- Repeat the run with
embed_dim = 2
andmodel = MultiBinaryVision(embed_dim).to(device)
This runs slower but OK
- Repeat above run with
embed_dim = 16
This errors and then hangs
...[normal connection preamble]...
2022-07-13 16:28:18 172.16.96.110 [0] 2022-07-13 16:28:18.016969: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:157] RPC failed with status = "UNAVAILABLE: Socket closed" and grpc_error_string = "{"created":"@1657729698.016766181","description":"Error received from peer ipv4:172.16.96.110:51011","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Socket closed","grpc_status":14}", maybe retrying the RPC
2022-07-13 16:28:18 172.16.96.110 [0] 2022-07-13 16:28:18.017066: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:157] RPC failed with status = "UNAVAILABLE: Socket closed" and grpc_error_string = "{"created":"@1657729698.016886252","description":"Error received from peer ipv4:172.16.96.110:51011","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Socket closed","grpc_status":14}", maybe retrying the RPC
2022-07-13 16:28:18 172.16.96.110 [0] 2022-07-13 16:28:18.017066: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:157] RPC failed with status = "UNAVAILABLE: Socket closed" and grpc_error_string = "{"created":"@1657729698.016885538","description":"Error received from peer ipv4:172.16.96.110:51011","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Socket closed","grpc_status":14}", maybe retrying the RPC
2022-07-13 16:28:18 172.16.96.110 [0] 2022-07-13 16:28:18.017093: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:157] RPC failed with status = "UNAVAILABLE: Socket closed" and grpc_error_string = "{"created":"@1657729698.016951473","description":"Error received from peer ipv4:172.16.96.110:51011","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Socket closed","grpc_status":14}", maybe retrying the RPC
2022-07-13 16:28:18 172.16.96.110 [0] 2022-07-13 16:28:18.017101: W tensorflow/core/distributed_runtime/rpc/grpc_remote_master.cc:157] RPC failed with status = "UNAVAILABLE: Socket closed" and grpc_error_string = "{"created":"@1657729698.016952944","description":"Error received from peer ipv4:172.16.96.110:51011","file":"external/com_github_grpc_grpc/src/core/lib/surface/call.cc","file_line":1056,"grpc_message":"Socket closed","grpc_status":14}", maybe retrying the RPC
....[more of the same]....
Expected behavior
This level of scale shouldn't be an issue, 16 resnet50s total only have 2576 total Parameters and the total memory footprint should be fine capacity wise.
Environment
- torch_xla version: tpu-vm-pt-1.12
Initially detected on
Using tpu-vm-pt-1.10
, needed for my purposes per ronghanghu/moco_v3_tpu#1
NOTE: Using pip3 install https://storage.googleapis.com/cloud-tpu-tpuvm-artifacts/wheels/libtpu-nightly/libtpu_nightly-0.1.dev20211015-py3-none-any.whl
per #3214 (comment)
Additional context
This error does not occur with a list of 16 or 32 ResNet18s but does occur with a list of 64 ResNet18s a list of 32 mobilenet_v3_small
Total parameters & Parameters (parameter tensors)
ResNet50: 23.5M / 161
ResNet18: 11.2M / 62
MobileNetv3-Small: 1.5M / 142
This makes me think that it has do with the maximum parameter issue from #3453 but I'm not getting that error and 16 MobileNets only has 2272 Parameters which is well under the cited limit.
Ultimately I am hoping to use multiple modules from a ModuleList such as the above in computations (similar to an ensembling setup).
In 3453 it's mentioned you can cut the computation with xm.mark_step()
but in this case the error occurs before any of the print statements execute. Is there any current way to run such a model/ensemble? Where does the parameter limit bottleneck? In this case the code would be amenable to making multiple optimizers for each model, but if it's an issue in the graph assembly that wouldn't work I assume?