Skip to content

Commit

Permalink
Fix parsing of pp split points from toml file
Browse files Browse the repository at this point in the history
Previously, string_list formatting was only applied to pp split points
arg when it came from the cmd line.  The string_list formatting needs to
be applied when loading via toml too.

ghstack-source-id: 1adf6716382884efbdeedae8601a2f210c4fb860
Pull Request resolved: #450
  • Loading branch information
wconstab committed Oct 30, 2024
1 parent dbb0520 commit 53d0f69
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions torchtitan/config_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,19 @@ def parse_args(self, args_list: list = sys.argv[1:]):
logger.exception(f"Error details: {str(e)}")
raise e

# if split-points came from 'args' (from cmd line) it would have already been parsed into a list by that parser
if (
"experimental" in args_dict
and "pipeline_parallel_split_points" in args_dict["experimental"]
and isinstance(
args_dict["experimental"]["pipeline_parallel_split_points"], str
)
):
exp = args_dict["experimental"]
exp["pipeline_parallel_split_points"] = string_list(
exp["pipeline_parallel_split_points"]
)

# override args dict with cmd_args
cmd_args_dict = self._args_to_two_level_dict(cmd_args)
for section, section_args in cmd_args_dict.items():
Expand Down

0 comments on commit 53d0f69

Please sign in to comment.