Skip to content

Commit

Permalink
Fix training artifacts for 2GB+ models and MSELoss (#22414)
Browse files Browse the repository at this point in the history
  • Loading branch information
jkbeavers authored Oct 15, 2024
1 parent 6407d81 commit a5e85a9
Show file tree
Hide file tree
Showing 2 changed files with 6 additions and 3 deletions.
4 changes: 3 additions & 1 deletion orttraining/orttraining/python/training/onnxblock/blocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,10 @@ def __call__(self, *args, **kwargs):
output = self.build(*args, **kwargs)

if accessor._GLOBAL_ACCESSOR.has_path:
# `save` will destructively access any external data
copied_model = copy.deepcopy(accessor._GLOBAL_ACCESSOR.model)
onnx.save(
accessor._GLOBAL_ACCESSOR.model,
copied_model,
self.temp_onnx_file_path,
save_as_external_data=True,
all_tensors_to_one_file=True,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1159,7 +1159,8 @@ def test_generate_artifacts_external_data_one_file():
assert os.path.exists(os.path.join(temp_dir, "checkpoint"))


def test_generate_artifacts_external_data_separate_files():
@pytest.mark.parametrize("loss", [loss_t for loss_t in artifacts.LossType])
def test_generate_artifacts_external_data_separate_files(loss):
with tempfile.TemporaryDirectory() as temp_dir:
_, simple_net = _get_models("cpu", 32, 28, 10, 10)

Expand All @@ -1176,7 +1177,7 @@ def test_generate_artifacts_external_data_separate_files():
artifacts.generate_artifacts(
os.path.join(temp_dir, "simple_net.onnx"),
requires_grad=requires_grad_params,
loss=artifacts.LossType.CrossEntropyLoss,
loss=loss,
optimizer=artifacts.OptimType.AdamW,
artifact_directory=temp_dir,
)
Expand Down

0 comments on commit a5e85a9

Please sign in to comment.