Skip to content

Commit

Permalink
update README to include new features and remove outdated msg
Browse files Browse the repository at this point in the history
ghstack-source-id: 99bca8c74c7d2fc2566f20914745790abba3ffba
Pull Request resolved: pytorch#574
  • Loading branch information
tianyu-l committed Sep 13, 2024
1 parent 2a25f4d commit d2a4904
Showing 1 changed file with 15 additions and 19 deletions.
34 changes: 15 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

# torchtitan

`torchtitan` is currently in a pre-release state and under extensive development.
`torchtitan` is currently in a pre-release state and under extensive development. Currently we showcase pre-training **Llama 3.1**, **Llama 3**, and **Llama 2** LLMs of various sizes from scratch. To use the latest features of `torchtitan`, we recommend latest PyTorch nightly.

`torchtitan` is a proof-of-concept for Large-scale LLM training using native PyTorch. It is (and will continue to be) a repo to showcase PyTorch's latest distributed training features in a clean, minimal codebase. torchtitan is complementary to and not a replacement for any of the great large-scale LLM training codebases such as Megatron, Megablocks, LLM Foundry, Deepspeed, etc. Instead, we hope that the features showcased in torchtitan will be adopted by these codebases quickly. torchtitan is unlikely to ever grow a large community around it.

Expand All @@ -26,34 +26,30 @@ You may want to see how the model is defined or how parallelism techniques are a
* [torchtitan/parallelisms/pipeline_llama.py](torchtitan/parallelisms/pipeline_llama.py) - helpers for applying Pipeline Parallel to the model
* [torchtitan/checkpoint.py](torchtitan/checkpoint.py) - utils for saving/loading distributed checkpoints
* [torchtitan/float8.py](torchtitan/float8.py) - utils for applying Float8 techniques
* [torchtitan/models/llama/model.py](torchtitan/models/llama/model.py) - the Llama model definition (shared for Llama2 and Llama3 variants)

## Pre-Release Updates:
#### (4/25/2024): `torchtitan` is now public but in a pre-release state and under development.
Currently we showcase pre-training **Llama 3 and Llama 2** LLMs of various sizes from scratch. `torchtitan` is tested and verified with the PyTorch nightly version `torch-2.4.0.dev20240412`. (We recommend latest PyTorch nightly).
* [torchtitan/models/llama/model.py](torchtitan/models/llama/model.py) - the Llama model definition (shared for Llama 2 and Llama 3 variants)

### Key features available

1. [FSDP2 with per param sharding](docs/fsdp.md)
2. [Tensor Parallel](https://pytorch.org/docs/stable/distributed.tensor.parallel.html) (including async TP)
1. [FSDP2](docs/fsdp.md) with per param sharding
2. [Tensor Parallel](https://pytorch.org/docs/stable/distributed.tensor.parallel.html) (including [async TP](https://discuss.pytorch.org/t/distributed-w-torchtitan-introducing-async-tensor-parallelism-in-pytorch/209487))
3. Selective layer and operator activation checkpointing
4. Distributed checkpointing (including async checkpointing)
5. Checkpointable data-loading, with the C4 dataset pre-configured (144M entries)
6. Loss, GPU memory, tokens-per-second, and MFU displayed and logged via TensorBoard
7. Learning rate scheduler, meta-init, optional Fused RMSNorm into [`torchtune`](https://github.com/pytorch/torchtune) for fine tuning
8. [Float8 support](docs/float8.md)
6. Loss, GPU memory, tokens-per-second, and MFU displayed and logged via [TensorBoard](#tensorboard)
7. Learning rate scheduler, meta-init, optional Fused RMSNorm
8. [Float8](https://discuss.pytorch.org/t/distributed-w-torchtitan-enabling-float8-all-gather-in-fsdp2/209323) support ([how-to](docs/float8.md))
9. `torch.compile` support
10. All options easily configured via [toml files](train_configs/)
11. [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly
10. DDP and HSDP
11. All options easily configured via [toml files](train_configs/)
12. [Interoperable checkpoints](docs/checkpoint.md) which can be loaded directly into [`torchtune`](https://github.com/pytorch/torchtune) for fine-tuning

We report our [Performance](docs/performance.md) verified on 64 A100 GPUs
We report our [Performance](docs/performance.md) verified on 64/128 GPUs.


### Coming soon

1. Context Parallel
2. Pipeline Parallel (and 3D parallellism)
3. HSDP
- Pipeline Parallel (and 3D parallellism)
- Context Parallel


## Installation
Expand All @@ -74,10 +70,10 @@ Once you have confirmed access, you can run the following command to download th
```bash
# Get your HF token from https://huggingface.co/settings/tokens

# llama3 or 3.1 tokenizer.model
# Llama 3 or 3.1 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Meta-Llama-3-8B --tokenizer_path "original" --hf_token=...

# llama2 tokenizer.model
# Llama 2 tokenizer.model
python torchtitan/datasets/download_tokenizer.py --repo_id meta-llama/Llama-2-13b-hf --hf_token=...
```

Expand Down

0 comments on commit d2a4904

Please sign in to comment.