You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Update base for Update on "introduce batch sharding strategy"
(Split out the large PR from #46)
Introduce the batch sharding strategy:
```python
from torch.distributed.tensor._op_schema import RuntimeSchemaInfo
from autoparallel.dtensor_util.utils import batch_shard_strategy
from autoparallel.dtensor_util import strategy_pool
# create strategy with input tensor 1 replicated, input tensor 2 shard on dim 0. Output tensor shard on dim 0:
custom_shard_strategy = functools.partial(batch_shard_strategy, input_shard_dim=[None, 0], output_shard_dim=[0])
# register the strategy:
strategy_pool.register_op_strategy(new_op)(custom_shard_strategy)
```
For details, check func description in autoparallel/dtensor_util/utils.py and example usage in tests/test_dtensor.py.
[ghstack-poisoned]
0 commit comments