diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 4a7f608f1..964bf00d3 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -1575,13 +1575,12 @@ def slime_validate_args(args): ) args.eval_max_context_len = args.rollout_max_context_len - if args.rollout_max_prompt_len is None and args.rollout_max_context_len is not None: - logger.info( - f"args.rollout_max_prompt_len is not set. Use args.rollout_max_context_len - 1 ({args.rollout_max_context_len} - 1) as default value so that there is at least one generated token to compute loss." - ) - args.rollout_max_prompt_len = args.rollout_max_context_len - 1 - - if args.rollout_max_prompt_len is not None and args.rollout_max_context_len is not None: + if args.rollout_max_context_len is not None: + if args.rollout_max_prompt_len is None: + args.rollout_max_prompt_len = args.rollout_max_context_len - 1 + logger.info( + f"args.rollout_max_prompt_len is not set. Use args.rollout_max_context_len - 1 ({args.rollout_max_context_len} - 1) as default value so that there is at least one generated token to compute loss." + ) assert ( args.rollout_max_prompt_len <= args.rollout_max_context_len - 1 ), f"args.rollout_max_prompt_len ({args.rollout_max_prompt_len}) must be smaller than args.rollout_max_context_len ({args.rollout_max_context_len}) so that there is at least one generated token to compute loss."