From 976653e59c87c371d5bcb79f4f5a423e4240e634 Mon Sep 17 00:00:00 2001 From: Francisco Massa Date: Fri, 4 Jul 2025 13:27:17 +0000 Subject: [PATCH] Pass kwargs as well to function This should be merged directly in main --- autoparallel/optimize_sharding.py | 5 ++++- autoparallel/utils.py | 8 ++++---- 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/autoparallel/optimize_sharding.py b/autoparallel/optimize_sharding.py index bac9d6c0..56051360 100644 --- a/autoparallel/optimize_sharding.py +++ b/autoparallel/optimize_sharding.py @@ -63,8 +63,11 @@ def build_sharding_metadata(self): user_args = tree_map_only( torch.fx.Node, lambda x: x.meta["val"], node.args ) + user_kwargs = tree_map_only( + torch.fx.Node, lambda x: x.meta["val"], node.kwargs + ) strat = get_placement_options( - self.mesh, node.target, user_strats, user_args + self.mesh, node.target, user_strats, user_args, user_kwargs ) strats[node] = strat elif node.op == "output": diff --git a/autoparallel/utils.py b/autoparallel/utils.py index 83ef918f..15fca4c0 100644 --- a/autoparallel/utils.py +++ b/autoparallel/utils.py @@ -13,8 +13,8 @@ from .propagation_rules import _op_partial_rules, _op_rules, remove_invalid_configs -def propagate_tensor_meta(op, user_args, out_strat): - out_t = op(*user_args) +def propagate_tensor_meta(op, user_args, user_kwargs, out_strat): + out_t = op(*user_args, **user_kwargs) if isinstance(out_t, torch.Tensor): new_tensor_meta = TensorMeta(out_t.shape, out_t.stride(), out_t.dtype) @@ -85,7 +85,7 @@ def fill_missing_redistribute_cost(op, specs, out_strat): strat.redistribute_cost = redistribute_costs -def get_placement_options(mesh, op, specs, user_args): +def get_placement_options(mesh, op, specs, user_args, user_kwargs): # print(op) if op in _op_rules: @@ -118,7 +118,7 @@ def get_placement_options(mesh, op, specs, user_args): op_schema ) - propagate_tensor_meta(op, user_args, out_strat) + propagate_tensor_meta(op, user_args, user_kwargs, out_strat) fill_missing_redistribute_cost(op, specs, out_strat) out_strat = remove_invalid_configs(out_strat, mesh)