Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Compute Loss inside the training step. #686

Merged

Conversation

AdamLouly
Copy link
Contributor

This PR is the suggested solution for the feature requested in the issues section : #671

and it is a completion of this PR: https://github.com/pengwa/optimum/pull/1

Motivation:
Modifying The optimum Trainer to compute the loss inside the training step will benefit from saving the memory allocated during training.

The first solution was to create a Module wrapper, named ModuleWithLoss() which will override the forward function to compute the loss inside the training step, but we faced some limitations when running eval and predict, this solution breaks because on the forward function we're expecting to get input as a dictionary, but there are some cases in eval and predict where the input is passed with **, thus it will be sent as unpacked parameters instead of one dictionary.

to solve this Issue, we have to maintain 2 models, _training_model which will be the wrapped ModuleWithLoss instance, and _inference_model which will be the default model.
and we will switch between the models in train or eval/predict step.

This is the logs for running 2 experiments.
Specs: Tesla V100 - 4 GPU - 16 GB Memory Each.
python -m torch.distributed.launch --nproc_per_node=4 examples/onnxruntime/training/language-modeling/run_clm.py --model_name_or_path bigscience/bloom-560m --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --label_smoothing 0.1 --do_eval --output_dir /tmp/ --overwrite_output_dir --max_steps 400 --logging_steps 20 --fp16 --block_size 256 --per_device_train_batch_size 8 --deepspeed aml_ds_config_zero_1.json

Using default optimum trainer:

image

Using This Wrapper Solution:
image

We see a difference of almost 2GB between the 2 runs using the same recipe.

@AdamLouly AdamLouly changed the title Adamlouly/compute loss training Compute Loss inside the training step. Jan 11, 2023
@AdamLouly
Copy link
Contributor Author

@JingyaHuang , is there a way to run CI pipelines to make sure this change does not break anything?
is it expected to run do_train and do_predict in the same run?

@JingyaHuang
Copy link
Collaborator

Hi @AdamLouly, the CI for ort training is scheduled for every morning. To run the test directly, you can use the following:

python -m unittest tests/onnxruntime/nightly_test_trainer.py

do_train, do_predict and do_eval are intended for training/evaluation scripts. You can definitely run them at the same time if they are in your script (e.g. all three are added in examples/onnxruntime/training/question-answering/run_qa.py, but do_predict was not added in previous contribution for clm #248)

@JingyaHuang
Copy link
Collaborator

Hi @AdamLouly, could you rebase your branch? I would like to test the compatibility of the wrapper with other components in ORTTrainer. Thanks.

@HuggingFaceDocBuilderDev
Copy link

HuggingFaceDocBuilderDev commented Jan 17, 2023

The documentation is not available anymore as the PR was closed or merged.

@AdamLouly
Copy link
Contributor Author

@JingyaHuang Done, any updates on the compatibility tests?

@@ -443,7 +544,7 @@ def _inner_training_loop(
RuntimeWarning,
)

self.model = model
self.model = self._training_model
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if deepspeed is not used, should we override also?

optimum/onnxruntime/trainer.py Outdated Show resolved Hide resolved
optimum/onnxruntime/trainer.py Outdated Show resolved Hide resolved
optimum/onnxruntime/trainer.py Show resolved Hide resolved
optimum/onnxruntime/trainer.py Outdated Show resolved Hide resolved
optimum/onnxruntime/trainer.py Outdated Show resolved Hide resolved
@pengwa
Copy link

pengwa commented Feb 14, 2023

As we synced offline, for this PR, something we are curious: 1). memory saving, 2). possible batch size increment 3). any perf impact.

@JingyaHuang JingyaHuang self-assigned this Feb 14, 2023
@JingyaHuang JingyaHuang added the onnxruntime Related to ONNX Runtime label Feb 14, 2023
@AdamLouly
Copy link
Contributor Author

AdamLouly commented Feb 14, 2023

adding some extra testing results for reference here:
Specs: Tesla V100 - 4 GPU - 16 GB Memory Each.
python -m torch.distributed.launch --nproc_per_node=4 examples/onnxruntime/training/language-modeling/run_clm.py --model_name_or_path bigscience/bloom-560m --dataset_name wikitext --dataset_config_name wikitext-2-raw-v1 --do_train --output_dir /tmp/ --overwrite_output_dir --max_steps 400 --logging_steps 2 --bloc_size 2560 --fp16

image

the max batch we could run using the current trainer is 11, while with this change, we can bring the max batch size to 14.

The percentage of the memory saved is different from and experiment to another based on batch size, block size..

Perf Report :

This shows the runs for batch size of 2, with and without label smoother.
image

Also using this wrapper will give the ability to use a bigger batch size which means perf improvement, in this example we ran max batch size for both versions, and the Wrapper was 7.31% faster than the default version.

image

@AdamLouly
Copy link
Contributor Author

@JingyaHuang we added a flag "--loss_in_train" to keep these modifications optional,
if users want to use the ModuleWithLoss Wrapper, they can run the recipes using --loss_in_train flag, otherwise it will run the same as it is right now.

we are planning to remove this in the future.

is there any documentation where we can add this flag so users can be aware of it?
Thank you

@JingyaHuang JingyaHuang added gpu-test trigger GPU tests and removed gpu-test trigger GPU tests labels Feb 22, 2023


# Integrations must be imported before ML frameworks:
from transformers.integrations import hp_params, is_fairscale_available # isort: split
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

useless references?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this was coming from main when I resolved conflict

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

but ideally after merge, this two lines should not appear in the diff.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will remove this to pass code quality check.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I don't have push access. Can you use ruff styling with the command make style and remove the redundant dependencies? Thx!

Copy link
Contributor Author

@AdamLouly AdamLouly Mar 2, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok, I don't have push access. Can you use ruff styling with the command make style and remove the redundant dependencies? Thx!

I did make style and it formatted the trainer.py, but the CI still says it should be formatted.
is there a specific configuration for this?

Seems like CI is using the latest black version every time, so we should always upgrade black before formatting,
it will always format other files that were previously formatted using a different version.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @AdamLouly, yes the CI always uses the latest formatting tools. And whenever the team observe a failure of the check code quality CI, we would fix it.

Your previous formatting issue could come from the fact that we recently switched from isort to ruff #760. If you want to be more cautious, you can update your formatter with pip install -U .[quality] before make style.

optimum/onnxruntime/trainer.py Outdated Show resolved Hide resolved
optimum/onnxruntime/trainer.py Outdated Show resolved Hide resolved
@@ -801,6 +867,8 @@ def evaluate(
dictionary also contains the epoch number which comes from the training state.
"""
# memory metrics - must set up as early as possible
# TODO: We need to enable evaluation using ORT backend.
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this TODO is only meant for --loss_in_train flag, right?

optimum/onnxruntime/training_args.py Outdated Show resolved Hide resolved
Copy link

@pengwa pengwa left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM, thanks a lot!

self.label_smoother = None

def forward(self, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs):
return self.hf_trainer.compute_loss(self._original_model, inputs, return_outputs=False)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is not correct. doing this way, self.label_smoother won't be used by the compute_loss. What I suggested earlier is to bind the compute_loss of hf Trainer to ModuleWithLoss. Here is the example:


class A:
    def __init__(self) -> None:
        self.prop = "A's prop"

    def f(self, x: int) -> int:
        print("A>>f is called, the prop used is: ", self.prop)
        return x + 1

class B:
    def __init__(self) -> None:
        self.prop = "B's prop"

    def main(self, x: int) -> int:
        print("B>>main is called")
        self.f(x)
        return x + 2

    def f(self, x: int) -> int:
        raise NotImplementedError()


b_instance = B()
import types

b_instance.f = types.MethodType(getattr(A,'f'), b_instance)

b_instance.main(1)

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

B>>main is called
A>>f is called, the prop used is: B's prop

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the compute_loss() will use self.hf_trainer.label_smoother. Although by doing this the compute of loss with label_smoother will be under a certain forward pass and intercepted by onnxruntime. The self.label_smoother defined in the init will not be used.

It's good that we can reuse the compute_loss function, but in terms of code clarity I would prefer to override the forward pass of pretrained model, instead of having Trainer involved.

(As discussed internally with transformers team, It would be nice to have a wrapper directly in transformers package to include the compute of loss in forward pass when using label_smoother. But let's do that for optimum first, have this PR merged, test it and then when it is mature migrate it to transformers. After that, it would be easier for maintaining ORTTrainer

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

FYI @JingyaHuang , In this case the code in the Trainer should be maintained if the compute_loss in hf trainer got changed then it should be changed in the forward pass of ModuleWithLoss as well.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @AdamLouly, sorry for the back and forth. I proposed to rewrite the codes as I was considering opening a PR in Transformers to put the label smoother inside forward. If so we don't need a wrapper in Optimum. But as @pengwa explained, a PR in Transformers won't be enough (we can't limit unnecessary outputs in Transformers for the flexibility reason), so we will always need this wrapper in Optimum.

If so, I agree that we should inherit the compute_loss()(as you did before) to ease the maintenance.


def compute_loss(self, model_with_loss, inputs, return_outputs=False):
# Run model forward + loss compute.
if self.args.loss_in_train and self.model == self._training_model:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

have you ever run PyTorch for loss_in_train is enabled, is that working?


# Only Wrap the model if we pass --loss_in_train flag.
if args.loss_in_train:
self._training_model = ModuleWithLoss(model, args)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: maybe we can pass self into ModuleWithLoss int its constructor. Then you can get self.model, self.args, and even self.label_smoothers inside the ModuleWithLoss class.

Copy link

@pengwa pengwa Feb 26, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

here is an example:

import types

def __init__(...):
    ....
    if args.loss_in_train:
       self._training_model = self.create_model_with_loss()
    ...

def create_model_with_loss(self):
    class ModuleWithLoss(nn.Module):
        def __init__(self, model, args, label_smoother):
            super().__init__()
            self._original_model = model
            self.args = args
            self.label_smoother = label_smoother

        def forward(self, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs):
            # The compute_model_plus_loss_internal is assigned once the class is instantiated.
            # It should have same signature as Trainer.compute_loss().
            # We do this to avoid potential un-synced states if we duplicated compute loss codes .
            return self.compute_model_plus_loss_internal(self._original_model, inputs, return_outputs)

        @property
        def config(self):
            return self._original_model.config

    model_with_loss = ModuleWithLoss(self.model, self.args, self.label_smoother)
    model_with_loss.compute_model_plus_loss_internal = types.MethodType(self.super().compute_loss, model_with_loss)

    return model_with_loss

Copy link
Collaborator

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @AdamLouly and @pengwa, thanks a lot for the contribution, and sorry for the delay on reviewing.

I just left some comments and questions I have for the PR.

self.label_smoother = None

def forward(self, inputs: Dict[str, Union[torch.Tensor, Any]], return_outputs):
return self.hf_trainer.compute_loss(self._original_model, inputs, return_outputs=False)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the compute_loss() will use self.hf_trainer.label_smoother. Although by doing this the compute of loss with label_smoother will be under a certain forward pass and intercepted by onnxruntime. The self.label_smoother defined in the init will not be used.

It's good that we can reuse the compute_loss function, but in terms of code clarity I would prefer to override the forward pass of pretrained model, instead of having Trainer involved.

(As discussed internally with transformers team, It would be nice to have a wrapper directly in transformers package to include the compute of loss in forward pass when using label_smoother. But let's do that for optimum first, have this PR merged, test it and then when it is mature migrate it to transformers. After that, it would be easier for maintaining ORTTrainer

@@ -134,6 +139,30 @@
SCALER_NAME = "scaler.pt"


class ModuleWithLoss(nn.Module):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

ModuleWithLoss as a wrapper for torch.nn.module subclass, can you add a module property so that the unwrap_model could be compatible? I believe ORTModule did the same


self.label_smoother = LabelSmoother(epsilon=self.args.label_smoothing_factor)
else:
self.label_smoother = None
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why there is a case for not using label smoother? In transformers, unless using label smoother, the loss should be already calculated in the forward pass. C.f. gpt2

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And with the wrapper, model not using label smoother in the first place shall not have any benefit on memory right?

Copy link

@pengwa pengwa Mar 1, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

And with the wrapper, model not using label smoother in the first place shall not have any benefit on memory right?

This is a good question; I can help answer it. In short, whenever label smoothing factor argument is given or not, we see improvement on memory. The key reason is the target model ORTModule wraps will have one single loss output after the change.

If label smoothing factor argument is not given, then the CrossEntropyLoss will be done inside the model forward pass. This is true. While there is a minor tricky here: loss along with lm_logits and other intermediate states are returned in the results. If ORTModule wraps and operates on this model, during model exporting, there will be few outputs besides loss, those outputs in training phases are not used later, but exporter don't know it will not be used. In ORT training implementation, though those outputs are not used, but we still fill them with zero and use them during the whole backward propagation phase.

For this case, if we wrap model+loss together, the final output of model+loss (ORTModule wraps) is just loss.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the explanation @pengwa, that's very clear.

if args.loss_in_train:
self._training_model = ModuleWithLoss(model, args)
else:
self._training_model = model
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we still need to distinguish training model and inference model, if we have the module property with the wrapper, given that we unwrap self.model for inference here:

self.model = unwrap_model(deepspeed_engine)

@JingyaHuang
Copy link
Collaborator

is there any documentation where we can add this flag so users can be aware of it?

@AdamLouly , sure! It's import for the visibility.

We can stress the flag both in trainer's doc and trainer examples' readmes.

@AdamLouly
Copy link
Contributor Author

The flag got changed to --use_module_with_loss

outputs = model_with_loss(inputs, return_outputs)
return outputs
else:
if self.label_smoother is not None and "labels" in inputs:
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why we revert this super().compute_loss() in else branch?

use_module_with_loss: Optional[bool] = field(
default=False,
metadata={
"help": "Use ModuleWithLoss Wrapper to compute loss inside the training loop, when label smoother is NOT none having this will help save memory for ORTMOdule Runs."
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shallwe remove this "when label smoother is NOT none"

use_module_with_loss: Optional[bool] = field(
default=False,
metadata={
"help": "Use ModuleWithLoss Wrapper to compute loss inside the training loop, having this will help save memory for ORTMOdule Runs."
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: ORTMOdule -> "ORTModule "

@AdamLouly
Copy link
Contributor Author

@JingyaHuang any updates on reviewing this ?

@pengwa
Copy link

pengwa commented Mar 22, 2023

@AdamLouly can you fix the CI errors:

"optimum/onnxruntime/trainer.py:17:1: I001 [*] Import block is un-sorted or un-formatted"

Copy link
Collaborator

@JingyaHuang JingyaHuang left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

All trainer tests passed
====================================== test session starts =======================================
platform linux -- Python 3.8.10, pytest-7.2.2, pluggy-1.0.0
rootdir: /workspace, configfile: pyproject.toml
collected 32 items                                                                               

tests/onnxruntime/nightly_test_trainer.py s...s....................sssss..                 [100%]

======================================== warnings summary ========================================
../usr/lib/python3/dist-packages/requests/__init__.py:89
  /usr/lib/python3/dist-packages/requests/__init__.py:89: RequestsDependencyWarning: urllib3 (1.26.15) or chardet (3.0.4) doesn't match a supported version!
    warnings.warn("urllib3 ({}) or chardet ({}) doesn't match a supported "

../usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_validation.py:114
  /usr/local/lib/python3.8/dist-packages/onnxruntime/capi/onnxruntime_validation.py:114: UserWarning: WARNING: failed to get cudart_version from onnxruntime build info.
    warnings.warn("WARNING: failed to get cudart_version from onnxruntime build info.")

tests/onnxruntime/nightly_test_trainer.py: 25 warnings
  /usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py:1794: FutureWarning: The first argument to symbolic functions is deprecated in 1.13 and will be removed in the future. Please annotate treat the first argument (g) as GraphContext and use context information from the object instead.
    warnings.warn(

tests/onnxruntime/nightly_test_trainer.py: 25 warnings
  /usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_utils.py:250: UserWarning: User Module's attribute name float collides with ORTModule's attribute name. User Module's method may not be called upon invocation through ORTModule.
    warnings.warn(

tests/onnxruntime/nightly_test_trainer.py: 25 warnings
  /usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_utils.py:250: UserWarning: User Module's attribute name half collides with ORTModule's attribute name. User Module's method may not be called upon invocation through ORTModule.
    warnings.warn(

tests/onnxruntime/nightly_test_trainer.py: 25 warnings
  /usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_utils.py:250: UserWarning: User Module's attribute name to collides with ORTModule's attribute name. User Module's method may not be called upon invocation through ORTModule.
    warnings.warn(

tests/onnxruntime/nightly_test_trainer.py: 25 warnings
  /usr/local/lib/python3.8/dist-packages/transformers/optimization.py:391: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning
    warnings.warn(

tests/onnxruntime/nightly_test_trainer.py: 25 warnings
  /usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_training_manager.py:192: UserWarning: Fast path enabled - skipping checks. Rebuild graph: True, Execution agent: True, Device check: True
    warnings.warn(

tests/onnxruntime/nightly_test_trainer.py: 12 warnings
  /usr/local/lib/python3.8/dist-packages/transformers/models/distilbert/modeling_distilbert.py:223: TracerWarning: torch.tensor results are registered as constants in the trace. You can safely ignore this warning if you use this function to create tensors out of constant variables that would be the same every time you call this function. In any other case, this might cause the trace to be incorrect.
    mask, torch.tensor(torch.finfo(scores.dtype).min)

tests/onnxruntime/nightly_test_trainer.py: 25 warnings
  /usr/local/lib/python3.8/dist-packages/torch/onnx/_internal/jit_utils.py:258: UserWarning: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
    _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)

tests/onnxruntime/nightly_test_trainer.py: 15 warnings
  /usr/local/lib/python3.8/dist-packages/torch/onnx/_internal/jit_utils.py:258: UserWarning: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
    _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)

tests/onnxruntime/nightly_test_trainer.py: 25 warnings
  /usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_io.py:66: FutureWarning: 'torch.onnx._patch_torch._graph_op' is deprecated in version 1.13 and will be removed in version 1.14. Please note 'g.op()' is to be removed from torch.Graph. Please open a GitHub issue if you need this functionality..
    return g.op("Identity", self)

tests/onnxruntime/nightly_test_trainer.py: 25 warnings
  /usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py:687: UserWarning: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
    _C._jit_pass_onnx_graph_shape_type_inference(

tests/onnxruntime/nightly_test_trainer.py: 15 warnings
  /usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py:687: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
    _C._jit_pass_onnx_graph_shape_type_inference(

tests/onnxruntime/nightly_test_trainer.py: 15 warnings
  /usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py:687: UserWarning: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
    _C._jit_pass_onnx_graph_shape_type_inference(

tests/onnxruntime/nightly_test_trainer.py: 25 warnings
  /usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py:1178: UserWarning: The shape inference of org.pytorch.aten::ATen type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
    _C._jit_pass_onnx_graph_shape_type_inference(

tests/onnxruntime/nightly_test_trainer.py: 15 warnings
  /usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py:1178: UserWarning: The shape inference of prim::Constant type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
    _C._jit_pass_onnx_graph_shape_type_inference(

tests/onnxruntime/nightly_test_trainer.py: 15 warnings
  /usr/local/lib/python3.8/dist-packages/torch/onnx/utils.py:1178: UserWarning: The shape inference of com.microsoft::SoftmaxCrossEntropyLossInternal type is missing, so it may result in wrong shape inference for the exported graph. Please consider adding it in symbolic function. (Triggered internally at ../torch/csrc/jit/passes/onnx/shape_type_inference.cpp:1884.)
    _C._jit_pass_onnx_graph_shape_type_inference(

tests/onnxruntime/nightly_test_trainer.py: 13 warnings
  /usr/local/lib/python3.8/dist-packages/transformers/models/t5/tokenization_t5_fast.py:155: FutureWarning: This tokenizer was incorrectly instantiated with a model max length of 512 which will be corrected in Transformers v5.
  For now, this behavior is kept to avoid breaking backwards compatibility when padding/encoding with `truncation is True`.
  - Be aware that you SHOULD NOT rely on t5-small automatically truncating your input to 512 when padding/encoding.
  - If you want to encode/pad to sequences longer than 512 you can either instantiate this tokenizer with `model_max_length` or pass `max_length` when encoding/padding.
  - To avoid this warning, please instantiate this tokenizer with `model_max_length` set to your preferred value.
    warnings.warn(

tests/onnxruntime/nightly_test_trainer.py::ORTTrainerIntegrationTest::test_trainer_fp16_ort_inference_2_t5_seq2seq_lm
tests/onnxruntime/nightly_test_trainer.py::ORTTrainerIntegrationTest::test_trainer_fp16_pt_inference_2_t5_seq2seq_lm
tests/onnxruntime/nightly_test_trainer.py::ORTTrainerIntegrationTest::test_trainer_fp16_pt_inference_3_t5_seq2seq_lm_with_past
  /usr/local/lib/python3.8/dist-packages/transformers/tokenization_utils_base.py:3587: UserWarning: `as_target_tokenizer` is deprecated and will be removed in v5 of Transformers. You can tokenize your labels by using the argument `text_target` of the regular `__call__` method (either in the same call as your input texts if you use the same keyword arguments, or in a separate call.
    warnings.warn(

tests/onnxruntime/nightly_test_trainer.py: 13 warnings
  /usr/local/lib/python3.8/dist-packages/transformers/modeling_utils.py:793: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
    if causal_mask.shape[1] < attention_mask.shape[1]:

tests/onnxruntime/nightly_test_trainer.py: 55 warnings
  /usr/local/lib/python3.8/dist-packages/torch/onnx/symbolic_opset9.py:3674: DeprecationWarning: an integer is required (got type float).  Implicit conversion to integers using __int__ is deprecated, and may be removed in a future version of Python.
    value_t=torch.tensor([fill_value], dtype=scalar_type.dtype()),

tests/onnxruntime/nightly_test_trainer.py::ORTTrainerIntegrationTest::test_trainer_fp16_ort_inference_2_t5_seq2seq_lm
tests/onnxruntime/nightly_test_trainer.py::ORTTrainerIntegrationTest::test_trainer_fp32_5_t5_seq2seq_lm
tests/onnxruntime/nightly_test_trainer.py::ORTTrainerIntegrationTest::test_trainer_fp32_7_t5_seq2seq_lm_with_past
tests/onnxruntime/nightly_test_trainer.py::ORTTrainerIntegrationTest::test_trainer_fp32_with_label_smoothing_5_t5_seq2seq_lm
tests/onnxruntime/nightly_test_trainer.py::ORTTrainerIntegrationTest::test_trainer_fp32_with_label_smoothing_7_t5_seq2seq_lm_with_past
tests/onnxruntime/nightly_test_trainer.py::ORTSeq2SeqTrainerSpecificIntegrationTest::test_predict_with_generate_ort_0_t5_seq2seq_lm
tests/onnxruntime/nightly_test_trainer.py::ORTSeq2SeqTrainerSpecificIntegrationTest::test_predict_with_generate_ort_1_t5_seq2seq_lm_with_past
  /usr/local/lib/python3.8/dist-packages/transformers/models/t5/modeling_t5.py:507: TracerWarning: Converting a tensor to a Python boolean might cause the trace to be incorrect. We can't record the data flow of Python values, so this value will be treated as a constant in the future. This means that the trace might not generalize to other inputs!
    elif past_key_value.shape[2] != key_value_states.shape[1]:

tests/onnxruntime/nightly_test_trainer.py::ORTTrainerIntegrationTest::test_trainer_fp32_0_distilbert_sequence_classification
tests/onnxruntime/nightly_test_trainer.py::ORTTrainerIntegrationTest::test_trainer_fp32_1_distilbert_sequence_classification
tests/onnxruntime/nightly_test_trainer.py::ORTTrainerIntegrationTest::test_trainer_fp32_2_distilbert_token_classification
tests/onnxruntime/nightly_test_trainer.py::ORTTrainerIntegrationTest::test_trainer_fp32_3_distilbert_token_classification
tests/onnxruntime/nightly_test_trainer.py::ORTTrainerIntegrationTest::test_trainer_fp32_4_t5_seq2seq_lm
tests/onnxruntime/nightly_test_trainer.py::ORTTrainerIntegrationTest::test_trainer_fp32_5_t5_seq2seq_lm
tests/onnxruntime/nightly_test_trainer.py::ORTTrainerIntegrationTest::test_trainer_fp32_6_t5_seq2seq_lm_with_past
tests/onnxruntime/nightly_test_trainer.py::ORTTrainerIntegrationTest::test_trainer_fp32_7_t5_seq2seq_lm_with_past
  /usr/local/lib/python3.8/dist-packages/onnxruntime/training/ortmodule/_custom_op_symbolic_registry.py:125: UserWarning: Unsupported diverged input and output types for logits when export cross_entropy_loss.logits type: Float, loss type: Float
    warnings.warn(

-- Docs: https://docs.pytest.org/en/stable/how-to/capture-warnings.html
==================== 25 passed, 7 skipped, 438 warnings in 589.63s (0:09:49) =====================

LGTM, let's get it merged. Thank you both for the awesome contribution @AdamLouly @pengwa , and helping me understand the memory issue!

@JingyaHuang JingyaHuang merged commit 4f79ebd into huggingface:main Mar 24, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
onnxruntime Related to ONNX Runtime
Projects
None yet
Development

Successfully merging this pull request may close these issues.

5 participants