Skip to content

Commit 50dcaac

Browse files
committed
Update base for Update on "introduce batch sharding strategy"
(Split out the large PR from #46) Introduce the batch sharding strategy: ```python from torch.distributed.tensor._op_schema import RuntimeSchemaInfo from autoparallel.dtensor_util.utils import batch_shard_strategy from autoparallel.dtensor_util import strategy_pool # create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0: custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0]) # register the strategy: strategy_pool.register_op_strategy(new_op)(custom_shard_strategy) ``` For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py. [ghstack-poisoned]
2 parents 65250eb + 22e663f commit 50dcaac

13 files changed

+1053
-681
lines changed

.pre-commit-config.yaml

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -34,7 +34,12 @@ repos:
3434
additional_dependencies: [toml]
3535
args: ["--profile", "black"]
3636

37-
- repo: https://github.com/pre-commit/mirrors-mypy
38-
rev: 'v1.10.0'
37+
- repo: local
3938
hooks:
4039
- id: mypy
40+
name: mypy
41+
entry: mypy
42+
language: system
43+
types: [python]
44+
exclude: (docs|examples)
45+
args: ["--ignore-missing-imports", "--scripts-are-modules", "--pretty"]

0 commit comments

Comments
 (0)