diff --git a/README.md b/README.md index 73e4371..afe6c2d 100644 --- a/README.md +++ b/README.md @@ -70,7 +70,6 @@ Before you can use BMTrain, you need to initialize it at the beginning of your c import bmtrain as bmt bmt.init_distributed( seed=0, - zero_level=3, # support 2 and 3 now # ... ) ``` @@ -120,9 +119,9 @@ class MyModule(bmt.DistributedModule): # changed here super().__init__() self.param = bmt.DistributedParameter(torch.empty(1024)) # changed here self.module_list = torch.nn.ModuleList([ - bmt.Block(SomeTransformerBlock()), # changed here - bmt.Block(SomeTransformerBlock()), # changed here - bmt.Block(SomeTransformerBlock()) # changed here + bmt.Block(SomeTransformerBlock(), zero_level=3), # changed here, support 2 and 3 now + bmt.Block(SomeTransformerBlock(), zero_level=3), # changed here, support 2 and 3 now + bmt.Block(SomeTransformerBlock(), zero_level=3) # changed here, support 2 and 3 now ]) def forward(self):