Skip to content

Commit 88b5db9

Browse files
committed
Update base for Update on "introduce batch sharding strategy"
(Split out the large PR from #46) Introduce the batch sharding strategy: ```python from torch.distributed.tensor._op_schema import RuntimeSchemaInfo from autoparallel.dtensor_util.utils import batch_shard_strategy from autoparallel.dtensor_util import strategy_pool # create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0: custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0]) # register the strategy: strategy_pool.register_op_strategy(new_op)(custom_shard_strategy) ``` For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py. [ghstack-poisoned]
1 parent 17aaea5 commit 88b5db9

File tree

0 file changed

+0
-0
lines changed

    0 file changed

    +0
    -0
    lines changed

    0 commit comments

    Comments
     (0)