Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Questions wrt training on TPU Pod #501

Closed
sumanthd17 opened this issue Jul 9, 2022 · 36 comments · Fixed by #1049
Closed

Questions wrt training on TPU Pod #501

sumanthd17 opened this issue Jul 9, 2022 · 36 comments · Fixed by #1049
Assignees
Labels
enhancement New feature or request feature request Request for a new feature to be added to Accelerate

Comments

@sumanthd17
Copy link

Hi Accelerate Team,

I'm looking to use run_mlm_no_trainer.py on TPU v3-128 pod. I have few questions before I want to get started with the process.

  1. Can I stream the data directly GCP bucket, or should I have to download the data to the VM where I'm training?
  2. Does accelerate library support training on TPU pod or is it limited to 8 cores? based on this Using Accelerate with TPU Pod VM like v3-32 #471
  3. Should I be using TPU node or TPU VM for better performance with accelerate library?
  4. Is there a notebook or blog to help setup environment and run small tests on GCP VM for TPU training?
  5. I want to train with dataset of order 1TB, will hf datasets be able to handle this data on a machine wth 256GB RAM? (possibly a Q for datasets repo, but just trying here as well)

Thanks in Advance

cc : @sgugger @muellerzr

@sumanthd17 sumanthd17 changed the title Question wrt training on TPU Pod Questions wrt training on TPU Pod Jul 9, 2022
@sgugger
Copy link
Collaborator

sgugger commented Jul 11, 2022

Hi there!

  1. You can do whatever you want since Accelerate will adapt to your training loop :-)
  2. This is completely untested, so I can't guarantee it will work. Let us know if you run into any issues.
  3. TPU VMs will probably have better performance (for the data pipeline)
  4. Not that I know off
  5. That's a question for the Datasets folks, you should go ask on the forums ;-)

@sumanthd17
Copy link
Author

Hi @sgugger

I started trying this out an the first thing that popped up is that
ValueError: The number of devices must be either 1 or 8, got 128 instead

How to get past this? This is probably be a lengthy thread with all the errors. I would be happy to document all the info once it's fixed

@muellerzr
Copy link
Collaborator

@sumanthd17 can you provide the full error it gives for you there? (And nothing wrong with this being lengthy, usually these make great educational material eventually as well 😄 )

@sumanthd17
Copy link
Author

Traceback (most recent call last):
  File "/home/sumanth/bert/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/home/sumanth/bert/lib/python3.8/site-packages/accelerate/commands/accelerate_cli.py", line 43, in main
    args.func(args)
  File "/home/sumanth/bert/lib/python3.8/site-packages/accelerate/commands/launch.py", line 530, in launch_command
    tpu_launcher(args)
  File "/home/sumanth/bert/lib/python3.8/site-packages/accelerate/commands/launch.py", line 360, in tpu_launcher
    xmp.spawn(PrepareForLaunch(main_function), args=(), nprocs=args.num_processes)
  File "/home/sumanth/bert/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 385, in spawn
    pf_cfg = _pre_fork_setup(nprocs)
  File "/home/sumanth/bert/lib/python3.8/site-packages/torch_xla/distributed/xla_multiprocessing.py", line 200, in _pre_fork_setup
    raise ValueError(
ValueError: The number of devices must be either 1 or 8, got 128 instead

@sumanthd17
Copy link
Author

I create a new TPU VM and installed the following libraries.

torch==1.11+cu10.2
accelerate==0.9.0
datasets==2.3.1
transformers==4.11.3

After this I tried to run accelerate config and got the following error. Even accelerate env is giving the same error

2022-07-21 16:01:30.382616: F ./tensorflow/core/tpu/tpu_executor_init_fns.inc:148] TpuCompiler_DefaultDeviceShapeRepresentation not available in this library.
https://symbolize.stripped_domain/r/?trace=7f472976803b,7f47297680bf,7f462438764c,7f46248012b3,7f4729b2ab89&map= 
*** SIGABRT received by PID 6948 (TID 6948) on cpu 47 from PID 6948; stack trace: ***
PC: @     0x7f472976803b  (unknown)  raise
    @     0x7f464811dcda        992  (unknown)
    @     0x7f47297680c0  251878112  (unknown)
    @     0x7f462438764d        416  tensorflow::tpu::(anonymous namespace)::(anonymous namespace)::SetExecutorStructFn()
    @     0x7f46248012b4        544  tensorflow::tpu::(anonymous namespace)::FindAndLoadTpuLibrary()
    @     0x7f4729b2ab8a  (unknown)  (unknown)
https://symbolize.stripped_domain/r/?trace=7f472976803b,7f464811dcd9,7f47297680bf,7f462438764c,7f46248012b3,7f4729b2ab89&map=50c831e765011c7eb7163b7f3cb5c4b6:7f4639973000-7f464848bf00 
E0721 16:01:30.590137    6948 coredump_hook.cc:365] RAW: Remote crash data gathering hook invoked.
E0721 16:01:30.590152    6948 coredump_hook.cc:411] RAW: Skipping coredump since rlimit was 0 at process start.
E0721 16:01:30.590165    6948 client.cc:222] RAW: Coroner client retries enabled (b/136286901), will retry for up to 30 sec.
E0721 16:01:30.590169    6948 coredump_hook.cc:473] RAW: Sending fingerprint to remote end.
E0721 16:01:30.590178    6948 coredump_socket.cc:124] RAW: Stat failed errno=2 on socket /var/google/services/logmanagerd/remote_coredump.socket
E0721 16:01:30.590195    6948 coredump_hook.cc:477] RAW: Cannot send fingerprint to Coroner: [NOT_FOUND] Missing crash reporting socket. Is the listener running?
E0721 16:01:30.590200    6948 coredump_hook.cc:550] RAW: Discarding core.
E0721 16:01:30.594149    6948 process_state.cc:771] RAW: Raising signal 6 with default behavior
Aborted (core dumped)

@sgugger @muellerzr

@muellerzr muellerzr self-assigned this Jul 22, 2022
@muellerzr muellerzr added enhancement New feature or request feature request Request for a new feature to be added to Accelerate labels Jul 22, 2022
@sumanthd17
Copy link
Author

sumanthd17 commented Jul 22, 2022

I was able to get past the above issue and start training on 8 cores (changed versions, attached below). But more cores are still an issue. Any workaround for this?

- `Accelerate` version: 0.11.0
- Platform: Linux-5.13.0-1019-gcp-x86_64-with-glibc2.29
- Python version: 3.8.10
- Numpy version: 1.23.1
- PyTorch version (GPU?): 1.11.0+cpu (False)
- `Accelerate` default config:
        - compute_environment: LOCAL_MACHINE
        - distributed_type: TPU
        - mixed_precision: no
        - use_cpu: False
        - num_processes: 8
        - machine_rank: 0
        - num_machines: 1
        - main_process_ip: None
        - main_process_port: None
        - main_training_function: main
        - deepspeed_config: {}
        - fsdp_config: {}

@muellerzr
Copy link
Collaborator

We're actively working on this, give us a bit please 😃

@muellerzr
Copy link
Collaborator

For now though consider this unsupported

@Ontopic
Copy link

Ontopic commented Jul 22, 2022

Is for now more in days, weeks or months ^^?

@Ontopic
Copy link

Ontopic commented Aug 4, 2022

I was able to get past the above issue and start training on 8 cores (changed versions, attached below). But more cores are still an issue. Any workaround for this?

- `Accelerate` version: 0.11.0
- Platform: Linux-5.13.0-1019-gcp-x86_64-with-glibc2.29
- Python version: 3.8.10
- Numpy version: 1.23.1
- PyTorch version (GPU?): 1.11.0+cpu (False)
- `Accelerate` default config:
        - compute_environment: LOCAL_MACHINE
        - distributed_type: TPU
        - mixed_precision: no
        - use_cpu: False
        - num_processes: 8
        - machine_rank: 0
        - num_machines: 1
        - main_process_ip: None
        - main_process_port: None
        - main_training_function: main
        - deepspeed_config: {}
        - fsdp_config: {}

Since we were gonna make this a lengthy, educational thread... ;)

I'm facing a similar issue as described in your post before and huunguyen10's post.

Now, I've been seeing this error quite a bit while working with pods and code that was originally meant to run on a TPU-VM by itself. Which of the settings do you think is the important one? I'm thinking xla/pytorch might need to be at at least 1.11 to fix this issue, but not completely sure yet.

Also would love to hear where the HuggingFace team is at, and whether they consider this as a target, or more of a nice-to-have if things happen to fall into place, so I can adjust my expectations and own efforts accordingly.

Thanks for all the insights offered here so far! Hoping for more lengthy posts (or copy-pasteable-solutions will work as well 😇)

Edit, things to consider:

  • Should weights be sharded perhaps?
  • XLA/Torch version 1.11 at least?
  • tensorlow-cpu instead of tf-nightly
  • pinning jax==??? with matching jaxlib==???
  • TPU version of popular "timm" repo, with Pytorch/XLA, TPU-pods setup that works. Also see this response with extra instructions for making all of them play nice
  • The latest bloom-jax-inference repo by HF is also definitely worth a look for those interested in making this work [and post lengthy posts about it ;)]. Check the branches.

@jianguoz
Copy link

jianguoz commented Sep 23, 2022

@muellerzr @sgugger Thanks for the great accelerate library, It is super reliable on 8 cores! Can I know when will accelerate support trainings on TPU vm with more than 8 cores? We are eager to try accelerate on more TPUs:)

@sgugger
Copy link
Collaborator

sgugger commented Sep 23, 2022

@jianguoz It's not a priority for now, as we have no mean of testing the solution (our request to get access to a free small TPU pod to maintain Accelerate was denied). Of course if lots of users show interest, we'll reconsider!

@jianguoz
Copy link

@sgugger Thanks very much for your quick update:). We have several colleagues interested in deploying Accelerator on more cores. Looking forward to the future release:)

@muellerzr
Copy link
Collaborator

muellerzr commented Sep 23, 2022

To help us properly gauge the need for this feature, if you are actively trying to train on a TPU pod with PyTorch could you react with a 👍 to this message? 😄

Thanks!

@jianguoz
Copy link

jianguoz commented Sep 23, 2022

@Ontopic @sumanthd17 Hi there, please react above message with a 👍🏻 if you want to train models on more than 8 TPU cores in future

@dhruvrnaik
Copy link

@sherlock42 check this out for training on TPU VMs with accelerate

@sherlock42
Copy link

sherlock42 commented Oct 3, 2022

@jianguoz what modifications did you make to run accelerate on 8 Google Cloud TPUs? If you have any code that you could share

@jianguoz
Copy link

jianguoz commented Oct 4, 2022

@sherlock42

You can take a look at https://huggingface.co/docs/accelerate/index and especially examples inside https://github.com/huggingface/transformers/tree/main/examples/pytorch. It is pretty simple to run accelerate on TPUs.

@muellerzr
Copy link
Collaborator

We also have a doc specifically on tpu best practices: https://huggingface.co/docs/accelerate/concept_guides/training_tpu

@jianguoz
Copy link

jianguoz commented Oct 5, 2022

@muellerzr There are several likes with interests in the training of Accelerate on TPU vm with more than 8 cores, and I think many people may have the same requests to scale up their training with Accelerator but have not yet noticed this GitHub issue. Do you have any plans to prioritize our request? We could provide potential feedbacks on TPU vm 32. Thanks:)

@muellerzr
Copy link
Collaborator

@jianguoz in about two weeks or so I'll be able to look at this, as yes this seems to be a quite desired feature 😄 Thanks for your patience everyone!

@muellerzr
Copy link
Collaborator

Hi @jianguoz et all, I am happy to say we're at a place where you can beta-test the new pod launcher!

Currently it only supports GCP-based TPU pods, as this is what we can test on currently.

Here are the steps to try this out:

(Assume all commands are ran solely from the main ssh instance/worker you are working off of unless specified otherwise)

  1. Either install torch_xla from their latest nightly or where torch_xla is installed on the main instance (do pip show torch_xla to find it) put this file in there to replace torch_xla.distributed.xla_dist. (e.g. could look like: wget https://raw.githubusercontent.com/pytorch/xla/master/torch_xla/distributed/xla_dist.py -O /usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_dist.py)
  2. Install accelerate with sudo pip3 install git+https://github.com/huggingface/accelerate@tpu-pod-launch (for the commands to be available, sudo is scary I know!)
  3. Run accelerate config and answer the new prompts, or modify your existing default_config.yaml (which can be found in .cache/huggingface/accelerate/) to include:
  • tpu_name: SOME_TPU_NAME this TPU name should align with how it is registered in GCP, so what you pass in when calling gcloud compute tpus tpu-vm ssh {SOME_TPU_NAME}.
  • tpu_zone: SOME_TPU_ZONE this is the zone your TPU pod lives in, such as europe-west4-a
  • tpu_cluster: true this will make sure you're enabling the cluster launcher
  1. Make sure your pod is configured to ssh into each other, by performing gcloud compute config-ssh. If you can't do this you may need to login with gcloud auth login first
  2. Using the new accelerate tpu-config command, download a script you wish to run and store it in /usr/share/. For example I did: accelerate tpu-config --command "sudo wget https://gist.githubusercontent.com/muellerzr/a85c9692101d47a9264a27fb5478225a/raw/bbdfff6868cbf61fcc0dcff8b76fe64b06fe43ab/xla_script.py" (I'll look into if a better way to upload a file to all of them without using wget is possible, but for now this is what the API is :) ). This command will also start in /usr/share, hence why there is no need to cd. Alternatively make sure that the script you wish to run is available in every pod how you see fit, just make sure the file is there and available!
  3. Accelerate needs to be installed on each pod, so do accelerate tpu-config --command "sudo pip3 install git+https://github.com/huggingface/accelerate@tpu-pod-launch. In the future this will just be accelerate tpu-config --install_accelerate
  4. From there, just run accelerate launch /usr/share/{script_name} and it should launch on the pod for you! E.g. accelerate launch /usr/share/xla_script.py if you are following the above.

Please let me know how this experience works for you, and what feedback you may have on it!

@jianguoz
Copy link

jianguoz commented Nov 16, 2022 via email

@jianguoz
Copy link

jianguoz commented Nov 22, 2022

@muellerzr Hi Zachary, sorry for the late reply (Just restore access to TPUs). When I run accelerate tpu-config on Step 5, it returns errors: Failed to execute command on multiple workers. This may have happened if you have not added your SSH key to your ssh-agent using "ssh-add ~/.ssh/google_compute_engine". While accelerate (without @tpu-pod-launch) can connect and work on pods of V3-32 vm. Should I have extra settings to use accelerate@tpu-pod-launch? Thanks:)

@muellerzr
Copy link
Collaborator

muellerzr commented Nov 28, 2022

Hi @jianguoz! Sorry you responded just as I was off for vacation that week. As the directions state you should run ssh-add ~/.ssh/google_compute_engine on the machine to get accelerate tpu-config working, it wraps around a separate set of gcp commands

Though I would have thought gcloud compute config-ssh would have enabled this.

@jianguoz
Copy link

@muellerzr Thanks for your instructions! We have tried the above steps, and most commands, such as python3 -m torch_xla.distributed.xla_dist work across pods in V3-32. Only accelerate tpu-config does not work and shows connection errors to different IPs and fail to exec on multiple workers. Is it possible for you to go through the steps again to check if something is missing? Thanks:)

@muellerzr
Copy link
Collaborator

Hey @jianguoz, glad to hear accelerate launch is doing its job and setting that up right and starting training! I'll look into accelerate tpu-config tommorow and see if I missed a step I need to write or make it clearer, thanks for the feedback!! :)

@muellerzr
Copy link
Collaborator

muellerzr commented Dec 14, 2022

Hi @jianguoz, apologize this has taken a month to get to I understand that can be quite frustrating 😢 I was indeed able to recreate your issue on a new TPU pod, so let's try out some different instructions that worked for me:

  1. For simplicity, put the following code into a file called startup.sh:
#!/bin/bash
wget https://raw.githubusercontent.com/pytorch/xla/master/torch_xla/distributed/xla_dist.py -O /usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_dist.py
sudo pip3 install git+https://github.com/huggingface/accelerate@tpu-pod-launch
  1. Setup a new instance through a setup such as the following:
gcloud compute tpus tpu-vm create {{INSERT_NAME_HERE}} --zone {{INSERT_ZONE_HERE}} --version tpu-vm-pt-1.12 --project={{INSERT_PROJECT_HERE}} --accelerator-type v3-32 --metadata startup-script="$(cat startup.sh)"
  1. Afterwards, ssh into the machine and run gcloud auth application-default login. This should solve your SSH issues a moment ago
  2. Next perform accelerate config as before
  3. Run accelerate tpu-config --command "sudo wget https://gist.githubusercontent.com/muellerzr/a85c9692101d47a9264a27fb5478225a/raw/5e120fae8290f30bf1eeca086286cf5e40c66bd8/xla_script.py"
  4. Do accelerate launch /usr/share/xla_script.py

Can you confirm if this works for you and you can recreate it? I'm also looking into a better way to send the script across the pods today, since each accelerate launch needs access to the training script.

Thanks for your patience 🙏

If you don't want to setup a new instance and want to follow the old directions, replace gcloud ssh config with the new one mentioned in step 3

@jianguoz
Copy link

jianguoz commented Dec 16, 2022

Hi @muellerzr, thanks for your detailed instructions. While we create a new pod following above process or we start from Step 3. It seems that Step 5 still does not work as it cannot connect to the pod workers, i.e., Step 3 does not help. Can you try it again to see if there are any potential issues? Thanks so much:). If you have time we can also schedule a quick meeting to accelerate the process.

@muellerzr
Copy link
Collaborator

muellerzr commented Dec 16, 2022

@jianguoz I did it from a fresh instance when I posted those instructions and did not face any issues. Is it still that "fail to execute" issue?

@muellerzr
Copy link
Collaborator

Hi @jianguoz thanks for the wonderful debugging session!

Let's try running through this all again, please follow these steps:

  1. For simplicity, put the following code into a file called startup.sh:
#!/bin/bash
wget https://raw.githubusercontent.com/pytorch/xla/master/torch_xla/distributed/xla_dist.py -O /usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_dist.py
sudo pip3 install git+https://github.com/huggingface/accelerate@tpu-pod-launch
  1. Setup a new instance through a setup such as the following:
gcloud compute tpus tpu-vm create {{INSERT_NAME_HERE}} --zone {{INSERT_ZONE_HERE}} --version tpu-vm-pt-1.12 --project={{INSERT_PROJECT_HERE}} --accelerator-type v3-32 --metadata startup-script="$(cat startup.sh)"
  1. Afterwards, ssh into the machine and run gcloud auth application-default login. This should solve your SSH issues a moment ago
  2. Next perform accelerate config and this time when prompted if commands on the TPU should be ran with sudo select yes
  3. Run gcloud alpha compute tpus tpu-vm ssh {TPU_NAME} --zone {TPU_ZONE} --command "sudo wget https://gist.githubusercontent.com/muellerzr/a85c9692101d47a9264a27fb5478225a/raw/5e120fae8290f30bf1eeca086286cf5e40c66bd8/xla_script.py -O {WHERE YOU WOULD LIKE TO SAVE THE FILE}/xla_script.py" --worker all
  4. Do sudo accelerate launch {WHERE YOU WOULD LIKE TO SAVE THE FILE}/xla_script.py

Let me know if that works for you or if you face trouble!

@jianguoz
Copy link

jianguoz commented Dec 19, 2022

Hi @muellerzr, thanks for the debugging session! When I run the Step 6 on new instructions. I face below errors:

 File "/usr/local/bin/accelerate", line 8, in <module>
    sys.exit(main())
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/accelerate_cli.py", line 45, in main
    args.func(args)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/launch.py", line 1182, in launch_command
    tpu_pod_launcher(args)
  File "/usr/local/lib/python3.8/dist-packages/accelerate/commands/launch.py", line 884, in tpu_pod_launcher
    new_args.positional = new_args
AttributeError: 'list' object has no attribute 'positional'

Did you forget to modify the launch file accordingly or I missed something? Thanks:)

@muellerzr
Copy link
Collaborator

muellerzr commented Dec 19, 2022

Thanks @jianguoz, please try again by downloading the latest commit I pushed (just wget that launch.py file like we did before).

Wound up needing one more thing I forgot to do, name duplication 😆

@jianguoz
Copy link

jianguoz commented Dec 19, 2022

Thanks @muellerzr!. It raises another issue:
Below is the original pods (I only replace the launch.py, other files are okay)

2022-12-19 23:26:54 172.16.96.185 [2] Traceback (most recent call last):
2022-12-19 23:26:54 172.16.96.185 [2]   File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 330, in _mp_start_fn
2022-12-19 23:26:54 172.16.96.185 [2]     _start_fn(index, pf_cfg, fn, args)
2022-12-19 23:26:54 172.16.96.185 [2]   File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_multiprocessing.py", line 324, in _start_fn
2022-12-19 23:26:54 172.16.96.185 [2]     fn(gindex, *args)
2022-12-19 23:26:54 172.16.96.185 [2]   File "/usr/local/lib/python3.8/dist-packages/accelerate/utils/launch.py", line 115, in __call__
2022-12-19 23:26:54 172.16.96.185 [2]     self.launcher(*args)
2022-12-19 23:26:54 172.16.96.185 [2]   File "/export/home/xla_script.py", line 8, in main
2022-12-19 23:26:54 172.16.96.185 [2]     xm.rendezvous("checking_out")
2022-12-19 23:26:54 172.16.96.185 [2]   File "/usr/local/lib/python3.8/dist-packages/torch_xla/core/xla_model.py", line 1058, in rendezvous
2022-12-19 23:26:54 172.16.96.185 [2]     return torch_xla._XLAC._xla_rendezvous(get_ordinal(), tag, payload, replicas)
2022-12-19 23:26:54 172.16.96.185 [2] RuntimeError: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'checking_out': Duplicate ordinal: 19 (3)
2022-12-19 23:26:54 172.16.96.185 [2] Exception in device=TPU:17: tensorflow/compiler/xla/xla_client/mesh_service.cc:377 : Failed to meet rendezvous 'checking_out': Duplicate ordinal: 19 (3)

Below is on the new pod

2022-12-19 23:18:54 172.16.80.51 [3] Terminated
2022-12-19 23:18:54 172.16.80.48 [2] Terminated
2022-12-19 23:18:54 172.16.80.49 [1] Terminated
Process Process-1:
Traceback (most recent call last):
  File "/usr/lib/python3.8/multiprocessing/process.py", line 315, in _bootstrap
    self.run()
  File "/usr/lib/python3.8/multiprocessing/process.py", line 108, in run
    self._target(*self._args, **self._kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_dist.py", line 608, in _run_cmd
    self._start_run(script_map)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_dist.py", line 602, in _start_run
    xu.parallel_work(
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/utils/utils.py", line 293, in parallel_work
    return [res for res in results]  # Iterating to re-raise any exceptions
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/utils/utils.py", line 293, in <listcomp>
    return [res for res in results]  # Iterating to re-raise any exceptions
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 619, in result_iterator
    yield fs.pop().result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 444, in result
    return self.__get_result()
  File "/usr/lib/python3.8/concurrent/futures/_base.py", line 389, in __get_result
    raise self._exception
  File "/usr/lib/python3.8/concurrent/futures/thread.py", line 57, in run
    result = self.fn(*self.args, **self.kwargs)
  File "/usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_dist.py", line 588, in _run_script
    raise RuntimeError(
RuntimeError: Remote command exitted with code: 143

I tested above commands, and found that step 1 is okay. While after I replace the launch.py., even my previous successful pod (We debugged together this morning) does not work with Step 9. Can you check again that file?

@muellerzr
Copy link
Collaborator

Thanks @jianguoz, will give it a peek tommorow and see if I can solve it!

@muellerzr
Copy link
Collaborator

This has now been introduced in #1049. Please follow the new accelerate config command to set this up. Below are some directions:

  1. Install accelerate via `pip install git+https://github.com/huggingface/accelerate (and ensure each node has this installed as well)
  2. Very Important: Either torch_xla needs to be installed via git, or run wget https://raw.githubusercontent.com/pytorch/xla/master/torch_xla/distributed/xla_dist.py -O /usr/local/lib/python3.8/dist-packages/torch_xla/distributed/xla_dist.py on the host node only is all that should be needed I believe. If not use the tpu-config option or add it to the startup command (as we rely on that refactor of xla_dist to launch)
  3. Run accelerate config on the host node and configure it accordingly
  4. Based on the setup of the system, it may require to do sudo pip install. If so, the prompt in accelerate config should be set to True when asked about this, and accelerate config should be sudo accelerate config. (I hit some permissions issues, this has been my workaround for now)
  5. Download the script you wish to run into /usr/share/some_script
  6. Run accelerate launch /usr/share/some_script.py

The example script I use is located here:
https://gist.githubusercontent.com/muellerzr/a85c9692101d47a9264a27fb5478225a/raw/bbdfff6868cbf61fcc0dcff8b76fe64b06fe43ab/xla_script.py

We have also introduced a tpu-config command which will run commands across the pods, so you could instead of having a startup script to install everything perform:
accelerate tpu-config --command "sudo wget https://gist.githubusercontent.com/muellerzr/a85c9692101d47a9264a27fb5478225a/raw/bbdfff6868cbf61fcc0dcff8b76fe64b06fe43ab/xla_script.py -O /usr/share/xla_script.py"

I did not notice your issue @jianguoz, so do let me know if it is still present after this

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request feature request Request for a new feature to be added to Accelerate
Projects
None yet
7 participants