Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Update README #196

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 8 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@
</div>

## What's New
- 2024/02/26 **BMTrain** [1.0.0](https://github.com/OpenBMB/BMTrain/releases/tag/1.0.0) released. Code refactoring and Tensor parallel support. See the detail in [update log](docs/UPDATE_1.0.0.md)
- 2023/08/17 **BMTrain** [0.2.3](https://github.com/OpenBMB/BMTrain/releases/tag/0.2.3) released. See the [update log](docs/UPDATE_0.2.3.md).
- 2022/12/15 **BMTrain** [0.2.0](https://github.com/OpenBMB/BMTrain/releases/tag/0.2.0) released. See the [update log](docs/UPDATE_0.2.0.md).
- 2022/06/14 **BMTrain** [0.1.7](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.7) released. ZeRO-2 optimization is supported!
- 2022/03/30 **BMTrain** [0.1.2](https://github.com/OpenBMB/BMTrain/releases/tag/0.1.2) released. Adapted to [OpenPrompt](https://github.com/thunlp/OpenPrompt)and [OpenDelta](https://github.com/thunlp/OpenDelta).
Expand All @@ -51,7 +53,7 @@ Our [documentation](https://bmtrain.readthedocs.io/en/latest/index.html) provide

- From pip (recommend) : ``pip install bmtrain``

- From source code: download the package and run ``python setup.py install``
- From source code: download the package and run ``pip install .``

Installing BMTrain may take a few to ten minutes, as it requires compiling the c/cuda source code at the time of installation.
We recommend compiling BMTrain directly in the training environment to avoid potential problems caused by the different environments.
Expand All @@ -68,7 +70,6 @@ Before you can use BMTrain, you need to initialize it at the beginning of your c
import bmtrain as bmt
bmt.init_distributed(
seed=0,
zero_level=3, # support 2 and 3 now
# ...
)
```
Expand Down Expand Up @@ -118,9 +119,9 @@ class MyModule(bmt.DistributedModule): # changed here
super().__init__()
self.param = bmt.DistributedParameter(torch.empty(1024)) # changed here
self.module_list = torch.nn.ModuleList([
bmt.Block(SomeTransformerBlock()), # changed here
bmt.Block(SomeTransformerBlock()), # changed here
bmt.Block(SomeTransformerBlock()) # changed here
bmt.Block(SomeTransformerBlock(), zero_level=3), # changed here, support 2 and 3 now
bmt.Block(SomeTransformerBlock(), zero_level=3), # changed here, support 2 and 3 now
bmt.Block(SomeTransformerBlock(), zero_level=3) # changed here, support 2 and 3 now
])

def forward(self):
Expand Down Expand Up @@ -181,7 +182,8 @@ class MyModule(bmt.DistributedModule):

def forward(self):
x = self.param
x = self.module_list(x, 1, 2, 3) # changed here
for module in self.module_list:
x = module(x, 1, 2, 3)
return x

```
Expand Down
26 changes: 26 additions & 0 deletions docs/UPDATE_0.2.3.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
# Update Log 0.2.3

**Full Changelog**: https://github.com/OpenBMB/BMTrain/compare/0.2.0...0.2.3


## What's New

### 1. Get rid of torch cpp extension when compiling

Before 0.2.3, the installation of BMTrain requires the torch cpp extension, which is not friendly to some users (it requires CUDA Runtime fits with torch). Now we get rid of the torch cpp extension when compiling BMTrain, which makes the source-code way installation of BMTrain more convenient.
Just run `pip install .` to install BMTrain using source code.

### 2. CICD

In 0.2.3, we bring the Github action CICD to BMTrain. Now we can run the CI/CD pipeline on Github to ensure the quality of the code. CICD will run the test cases and compile the source code into wheel packages.

### 3. Loss scale management

In 0.2.3, we add the min and max loss scale to the loss scale manager. The loss scale manager can adjust the loss scale dynamically according to the loss scale's min and max value. This feature can help users to avoid the loss scale being too large or too small.


### 3. Others

* Fix `bmt.load(model)` OOM when meets torch >= 1.12
* `AdamOffloadOptimizer` can choose avx flag automatically in runtime
* Now BMTrain is fully compatible with torch 2.0
72 changes: 72 additions & 0 deletions docs/UPDATE_1.0.0.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Update Log 1.0.0

**Full Changelog**: https://github.com/OpenBMB/BMTrain/compare/0.2.3...1.0.0

## What's New

### 1. Using pytorch's hook mechanism to refactor ZeRO, checkpoint, pipeline, communication implementation

Now user can specify zero level of each `bmt.CheckpointBlock`.

**======= Before 1.0.0 =======**

```python
import bmtrain as bmt
bmt.init_distributed(zero_level=3)

```

The zero level setting can only set globally and computation checkpointing can not be disabled.
For `bmt.TransformerBlockList`, it has to call a blocklist forward instead of a loop way

**======= After 1.0.0 =======**

```python
import bmtrain as bmt
bmt.init_distributed()
# construct block
class Transformer(bmt.DistributedModule):
def __init__(self,
num_layers : int) -> None:
super().__init__()

self.transformers = bmt.TransformerBlockList([
bmt.Block(
TransformerEncoder(
dim_model, dim_head, num_heads, dim_ff, bias, dtype
), use_checkpoint=True, zero_level=3
)
for _ in range(num_layers)
])

def forward(self):
# return self.transformers(x) v0.2.3 can only forward in this way
for block in self.transformers:
x = block(x)
return x

```

You can specify the zero level of each `bmt.CheckpointBlock` (alias of `bmt.Block`) and computation checkpointing can be disabled by setting `use_checkpoint=False` . For `bmt.TransformerBlockList`, it can be called in a loop way.


### 2. Add Bf16 support

Now BMTrain supports Bf16 training. You can simply use `dtype=torch.bfloat16' in your model construction method and BMTrain will handle the rest.

### 3. Tensor parallel implementation

For this part, BMTrain only provides a series of parallel ops for Tensor parallel implementation, including `bmt.nn.OpParallelLinear` and `bmt.nn.VPEmbedding` . We also provide a Tensor Parallel training example in our training example. You can simply use `bmt.init_distributed(tp_size=4)` to enable a 4-way tensor parallel training.

### 4. `AdamOffloadOptimizer` can save whole gathered state

Now `AdamOffloadOptimizer` can save whole gathered state. This feature can help users to save the whole gathered state of the optimizer, which can be used to resume training from the saved state. For better performance, we provide async-way save state_dict to overlap I/O and computation.
```python
import bmtrain as bmt
# you can enbale this feature in two ways: Optimmanager's or optimizer's interface
global_ckpt = bmt.optim.Optimmanager.state_dict(gather_opt=True)
global_ckpt = optimizer.state_dict(gather=True)
```
### Others

* New test for new version BMTrain
Loading