Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add-support-for-examples-scripts-to-run-on-sagemaker #9367

Closed
wants to merge 6 commits into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 54 additions & 6 deletions examples/text-classification/run_glue.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,8 +67,7 @@ class DataTrainingArguments:
"""

task_name: Optional[str] = field(
default=None,
metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
default=None, metadata={"help": "The name of the task to train on: " + ", ".join(task_to_keys.keys())},
)
max_seq_length: int = field(
default=128,
Expand Down Expand Up @@ -133,12 +132,50 @@ class ModelArguments:
)


# Should be moved to transformers/src/transformers/file_utils.py
def is_run_on_sagemaker():
if "SM_OUTPUT_DATA_DIR" in os.environ and "SM_MODEL_DIR" in os.environ:
return True
else:
return False


# Should be moved to transformers/src/transformers/file_utils.py or transformers/src/transformers/hf_argparser.py
def parse_sagemaker_env_into_args(argv):
## add output_dir
argv.extend(["--output_dir", os.environ["SM_OUTPUT_DATA_DIR"]])
# if datafiles add them as args
for key, value in os.environ.items():
# get all passed in s3 data_paths
if key.startswith("SM_CHANNEL_"):
# extract channel
key_type = key.split("_")[-1].lower()
# check if sm_channel is a real file with *.csv or *.json or if it is a path
if value.endswith(".csv") or value.endswith(".json"):
# if true add args --{channel}_file with value
argv.extend([f"--{channel}_file", value])
else: # if train file is passed hyperparameter and channel is a directory
# get index of {channel}_file and add +1 to get the file path
index_of_file_from_key_type = argv.index(f"--{key_type}_file") + 1
# create new path from channel and file name
new_path = os.path.join(value, argv[index_of_file_from_key_type])
# overwrite existing argument for {channel}_file
argv[index_of_file_from_key_type] = new_path
## remove True for Argsparser to work
argv = [args for args in argv if args != "True"]
return argv


def main():
# See all possible arguments in src/transformers/training_args.py
# or by passing the --help flag to this script.
# We now keep distinct sets of args, for a cleaner separation of concerns.

# ##### $custom ####
if is_run_on_sagemaker():
sys.argv = parse_sagemaker_env_into_args(sys.argv)
#### $custom end ####
parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))

if len(sys.argv) == 2 and sys.argv[1].endswith(".json"):
# If we pass only one argument to the script and it's the path to a json file,
# let's parse it to get our arguments.
Expand Down Expand Up @@ -355,7 +392,12 @@ def compute_metrics(p: EvalPrediction):
)
metrics = train_result.metrics
philschmid marked this conversation as resolved.
Show resolved Hide resolved

trainer.save_model() # Saves the tokenizer too for easy upload
##### $custom ####
if is_run_on_sagemaker():
trainer.save_model(os.environ["SM_MODEL_DIR"]) # Saves the tokenizer too for easy upload
else:
trainer.save_model() # Saves the tokenizer too for easy upload
#### $custom end ####
Comment on lines +395 to +400
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

same here we shouldn't need this if, no?


output_train_file = os.path.join(training_args.output_dir, "train_results.txt")
if trainer.is_world_process_zero():
Expand All @@ -365,8 +407,14 @@ def compute_metrics(p: EvalPrediction):
logger.info(f" {key} = {value}")
writer.write(f"{key} = {value}\n")

# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
##### $custom ####
if is_run_on_sagemaker():
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
trainer.state.save_to_json(os.path.join(os.environ["SM_MODEL_DIR"], "trainer_state.json"))
else:
# Need to save the state, since Trainer.save_model saves only the tokenizer with the model
trainer.state.save_to_json(os.path.join(training_args.output_dir, "trainer_state.json"))
#### $custom end ####
Comment on lines +411 to +417
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this if clause should not be needed if we tweak .output_dir above, right?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes. But we have to be careful since we have in Sagemaker 2 Output directories. One for logs and files (e.g. eval_result.txt) and the other is for the model and associated files (config etc).

SM_MODEL_DIR represents the model directory and SM_OUTPUT_DATA_DIR represents the directory for logs and files.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, that we can probably add "upstream" in TrainingArguments if we don't already have it (the possibility of having different output dirs for model and logs)


# Evaluation
eval_results = {}
Expand Down