Skip to content

Commit b51cd9a

Browse files
authored
Merge pull request #6 from pytorch-labs/fmassa/fix_compute_cost
Fix computation cost to take number of arguments into account
2 parents 56bd39c + f5632cc commit b51cd9a

File tree

1 file changed

+1
-1
lines changed

1 file changed

+1
-1
lines changed

autoparallel/optimize_sharding.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -100,7 +100,7 @@ def build_ds(self):
100100
)
101101
ds[(s_i, argi, ss, ii)] = {
102102
"va": va,
103-
"cost": comm_cost + compute_cost,
103+
"cost": comm_cost + compute_cost / num_args[s_i],
104104
"full_strat": ssi,
105105
"out_strat": ssi.output_specs,
106106
"inp_strat": ssi.input_specs[argi],

0 commit comments

Comments
 (0)