Skip to content
This repository has been archived by the owner on Nov 3, 2023. It is now read-only.

[FSDP] wrap modules #4911

Merged
merged 1 commit into from
Dec 6, 2022
Merged

[FSDP] wrap modules #4911

merged 1 commit into from
Dec 6, 2022

Conversation

jxmsML
Copy link
Contributor

@jxmsML jxmsML commented Dec 6, 2022

Patch description
When running distributed_train one might run into this error

Traceback (most recent call last):
  File "/private/home/jingxu23/.conda/envs/parlai_py39_pyt113/lib/python3.9/runpy.py", line 197, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/private/home/jingxu23/.conda/envs/parlai_py39_pyt113/lib/python3.9/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/private/home/jingxu23/ParlAI/parlai/scripts/distributed_train.py", line 61, in <module>
    DistributedTrain.main()
  File "/private/home/jingxu23/ParlAI/parlai/core/script.py", line 129, in main
    return cls._run_args(None)
  File "/private/home/jingxu23/ParlAI/parlai/core/script.py", line 101, in _run_args
    return cls._run_from_parser_and_opt(opt, parser)
  File "/private/home/jingxu23/ParlAI/parlai/core/script.py", line 108, in _run_from_parser_and_opt
    return script.run()
  File "/private/home/jingxu23/ParlAI/parlai/scripts/distributed_train.py", line 57, in run
    return self.train_loop.train()
  File "/private/home/jingxu23/ParlAI/parlai/scripts/train_model.py", line 1010, in train
    for _train_log in self.train_steps():
  File "/private/home/jingxu23/ParlAI/parlai/scripts/train_model.py", line 917, in train_steps
    world.parley()
  File "/private/home/jingxu23/ParlAI/parlai/core/worlds.py", line 700, in parley
    self.worlds[self.world_idx].parley()
  File "/private/home/jingxu23/ParlAI/parlai/core/worlds.py", line 370, in parley
    acts[1] = agents[1].act()
  File "/private/home/jingxu23/ParlAI/parlai/core/torch_agent.py", line 2157, in act
    response = self.batch_act([self.observation])[0]
  File "/private/home/jingxu23/ParlAI/parlai/agents/fid/fid.py", line 389, in batch_act
    batch_reply = super().batch_act(observations)
  File "/private/home/jingxu23/ParlAI/parlai/core/torch_agent.py", line 2248, in batch_act
    output = self.train_step(batch)
  File "/private/home/jingxu23/ParlAI/parlai/core/torch_generator_agent.py", line 791, in train_step
    raise e
  File "/private/home/jingxu23/ParlAI/parlai/core/torch_generator_agent.py", line 777, in train_step
    self.backward(loss)
  File "/private/home/jingxu23/ParlAI/parlai/core/torch_agent.py", line 2334, in backward
    self.optimizer.backward(loss, update_main_grads=False, **kwargs)
  File "/private/home/jingxu23/ParlAI/parlai/utils/fp16.py", line 194, in backward
    loss.backward(retain_graph=retain_graph)
  File "/private/home/jingxu23/.conda/envs/parlai_py39_pyt113/lib/python3.9/site-packages/torch/_tensor.py", line 487, in backward
    torch.autograd.backward(
  File "/private/home/jingxu23/.conda/envs/parlai_py39_pyt113/lib/python3.9/site-packages/torch/autograd/__init__.py", line 197, in backward
    Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
  File "/private/home/jingxu23/.conda/envs/parlai_py39_pyt113/lib/python3.9/site-packages/torch/autograd/function.py", line 267, in apply
    return user_fn(self, *args)
  File "/private/home/jingxu23/.conda/envs/parlai_py39_pyt113/lib/python3.9/site-packages/fairscale/nn/checkpoint/checkpoint_activations.py", line 332, in backward
    outputs = ctx.run_function(*unpacked_args, **unpacked_kwargs)
  File "/private/home/jingxu23/ParlAI/parlai/agents/transformer/modules/decoder.py", line 529, in forward
    x = self.norm1(x)
  File "/private/home/jingxu23/.conda/envs/parlai_py39_pyt113/lib/python3.9/site-packages/torch/nn/modules/module.py", line 1190, in _call_impl
    return forward_call(*input, **kwargs)
  File "/private/home/jingxu23/.conda/envs/parlai_py39_pyt113/lib/python3.9/site-packages/torch/nn/modules/normalization.py", line 190, in forward
    return F.layer_norm(
  File "/private/home/jingxu23/.conda/envs/parlai_py39_pyt113/lib/python3.9/site-packages/torch/nn/functional.py", line 2515, in layer_norm
    return torch.layer_norm(input, normalized_shape, weight, bias, eps, torch.backends.cudnn.enabled)
RuntimeError: The tensor has a non-zero number of elements, but its data is not allocated yet. Caffe2 uses a lazy allocation, so you will need to call mutable_data() or raw_mutable_data() to actually allocate memory.

This PR will fix it.

Testing steps

Other information

@jxmsML jxmsML merged commit 804b10b into main Dec 6, 2022
@jxmsML jxmsML deleted the conda_parlai_torch12 branch December 6, 2022 22:57
Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants