Skip to content

Commit

Permalink
Added image-to-image task for ORT Pipeline (#2031)
Browse files Browse the repository at this point in the history
* Add ORTModelForImageToImage for image-to-image task SwinSR

* Added image-to-image task to optimum pipeline

* Add Tests fpr ORTModelForImageToImage for image-to-image task SwinSR

* Use export=True for models from transformers, self._setup and more

* Code Refactor

* Refactor ORTModelForImageToImageIntegrationTest
  • Loading branch information
h3110Fr13nd authored Sep 26, 2024
1 parent 2fb5ea5 commit fd638d2
Show file tree
Hide file tree
Showing 5 changed files with 219 additions and 1 deletion.
2 changes: 2 additions & 0 deletions optimum/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
"ORTModelForSemanticSegmentation",
"ORTModelForSequenceClassification",
"ORTModelForTokenClassification",
"ORTModelForImageToImage",
],
"modeling_seq2seq": [
"ORTModelForSeq2SeqLM",
Expand Down Expand Up @@ -112,6 +113,7 @@
ORTModelForCustomTasks,
ORTModelForFeatureExtraction,
ORTModelForImageClassification,
ORTModelForImageToImage,
ORTModelForMaskedLM,
ORTModelForMultipleChoice,
ORTModelForQuestionAnswering,
Expand Down
73 changes: 73 additions & 0 deletions optimum/onnxruntime/modeling_ort.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
AutoModelForAudioXVector,
AutoModelForCTC,
AutoModelForImageClassification,
AutoModelForImageToImage,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
Expand All @@ -47,6 +48,7 @@
BaseModelOutput,
CausalLMOutput,
ImageClassifierOutput,
ImageSuperResolutionOutput,
MaskedLMOutput,
ModelOutput,
MultipleChoiceModelOutput,
Expand Down Expand Up @@ -2183,6 +2185,77 @@ def forward(
return TokenClassifierOutput(logits=logits)


IMAGE_TO_IMAGE_EXAMPLE = r"""
Example of image-to-image (Super Resolution):
```python
>>> from transformers import {processor_class}
>>> from optimum.onnxruntime import {model_class}
>>> from PIL import Image
>>> image = Image.open("path/to/image.jpg")
>>> image_processor = {processor_class}.from_pretrained("{checkpoint}")
>>> model = {model_class}.from_pretrained("{checkpoint}")
>>> inputs = image_processor(images=image, return_tensors="pt")
>>> with torch.no_grad():
... logits = model(**inputs).logits
```
"""


@add_end_docstrings(ONNX_MODEL_END_DOCSTRING)
class ORTModelForImageToImage(ORTModel):
"""
ONNX Model for image-to-image tasks. This class officially supports pix2pix, cyclegan, wav2vec2, wav2vec2-conformer.
"""

auto_model_class = AutoModelForImageToImage

@add_start_docstrings_to_model_forward(
ONNX_IMAGE_INPUTS_DOCSTRING.format("batch_size, num_channels, height, width")
+ IMAGE_TO_IMAGE_EXAMPLE.format(
processor_class=_PROCESSOR_FOR_DOC,
model_class="ORTModelForImgageToImage",
checkpoint="caidas/swin2SR-realworld-sr-x4-64-bsrgan-psnr",
)
)
def forward(
self,
pixel_values: Union[torch.Tensor, np.ndarray],
**kwargs,
):
use_torch = isinstance(pixel_values, torch.Tensor)
self.raise_on_numpy_input_io_binding(use_torch)
if self.device.type == "cuda" and self.use_io_binding:
input_shapes = pixel_values.shape
io_binding, output_shapes, output_buffers = self.prepare_io_binding(
pixel_values,
ordered_input_names=self._ordered_input_names,
known_output_shapes={
"reconstruction": [
input_shapes[0],
input_shapes[1],
input_shapes[2] * self.config.upscale,
input_shapes[3] * self.config.upscale,
]
},
)
io_binding.synchronize_inputs()
self.model.run_with_iobinding(io_binding)
io_binding.synchronize_outputs()
reconstruction = output_buffers["reconstruction"].view(output_shapes["reconstruction"])
else:
model_inputs = {"pixel_values": pixel_values}
onnx_inputs = self._prepare_onnx_inputs(use_torch, **model_inputs)
onnx_outputs = self.model.run(None, onnx_inputs)
model_outputs = self._prepare_onnx_outputs(use_torch, *onnx_outputs)
reconstruction = model_outputs["reconstruction"]
return ImageSuperResolutionOutput(reconstruction=reconstruction)


CUSTOM_TASKS_EXAMPLE = r"""
Example of custom tasks(e.g. a sentence transformers taking `pooler_output` as output):
Expand Down
8 changes: 8 additions & 0 deletions optimum/pipelines/pipelines_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
FillMaskPipeline,
ImageClassificationPipeline,
ImageSegmentationPipeline,
ImageToImagePipeline,
ImageToTextPipeline,
Pipeline,
PreTrainedTokenizer,
Expand Down Expand Up @@ -55,6 +56,7 @@
ORTModelForCausalLM,
ORTModelForFeatureExtraction,
ORTModelForImageClassification,
ORTModelForImageToImage,
ORTModelForMaskedLM,
ORTModelForQuestionAnswering,
ORTModelForSemanticSegmentation,
Expand Down Expand Up @@ -157,6 +159,12 @@
"default": "superb/hubert-base-superb-ks",
"type": "audio",
},
"image-to-image": {
"impl": ImageToImagePipeline,
"class": (ORTModelForImageToImage,),
"default": "caidas/swin2SR-classical-sr-x2-64",
"type": "image",
},
}
else:
ORT_SUPPORTED_TASKS = {}
Expand Down
136 changes: 135 additions & 1 deletion tests/onnxruntime/test_modeling.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
AutoModelForCausalLM,
AutoModelForCTC,
AutoModelForImageClassification,
AutoModelForImageToImage,
AutoModelForMaskedLM,
AutoModelForMultipleChoice,
AutoModelForQuestionAnswering,
Expand All @@ -57,7 +58,9 @@
PretrainedConfig,
set_seed,
)
from transformers.modeling_outputs import ImageSuperResolutionOutput
from transformers.modeling_utils import no_init_weights
from transformers.models.swin2sr.configuration_swin2sr import Swin2SRConfig
from transformers.onnx.utils import get_preprocessor
from transformers.testing_utils import get_gpu_count, require_torch_gpu, slow
from utils_onnxruntime_tests import MODEL_NAMES, SEED, ORTModelTestMixin
Expand All @@ -79,6 +82,7 @@
ORTModelForCustomTasks,
ORTModelForFeatureExtraction,
ORTModelForImageClassification,
ORTModelForImageToImage,
ORTModelForMaskedLM,
ORTModelForMultipleChoice,
ORTModelForPix2Struct,
Expand Down Expand Up @@ -4704,6 +4708,136 @@ def test_compare_generation_to_io_binding(
gc.collect()


class ORTModelForImageToImageIntegrationTest(ORTModelTestMixin):
SUPPORTED_ARCHITECTURES = ["swin2sr"]

ORTMODEL_CLASS = ORTModelForImageToImage

TASK = "image-to-image"

def _get_sample_image(self):
url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
return image

def _get_preprocessors(self, model_id):
image_processor = AutoImageProcessor.from_pretrained(model_id)

return image_processor

def test_load_vanilla_transformers_which_is_not_supported(self):
with self.assertRaises(Exception) as context:
_ = ORTModelForImageToImage.from_pretrained(MODEL_NAMES["bert"], export=True)

self.assertIn("only supports the tasks", str(context.exception))

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_compare_to_transformers(self, model_arch: str):
model_args = {"test_name": model_arch, "model_arch": model_arch}
self._setup(model_args)
model_id = MODEL_NAMES[model_arch]
onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch])
self.assertIsInstance(onnx_model.config, Swin2SRConfig)
set_seed(SEED)

transformers_model = AutoModelForImageToImage.from_pretrained(model_id)
image_processor = self._get_preprocessors(model_id)

data = self._get_sample_image()
features = image_processor(data, return_tensors="pt")

with torch.no_grad():
transformers_outputs = transformers_model(**features)

onnx_outputs = onnx_model(**features)
self.assertIsInstance(onnx_outputs, ImageSuperResolutionOutput)
self.assertTrue("reconstruction" in onnx_outputs)
self.assertIsInstance(onnx_outputs.reconstruction, torch.Tensor)
self.assertTrue(torch.allclose(onnx_outputs.reconstruction, transformers_outputs.reconstruction, atol=1e-4))

gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_generate_utils(self, model_arch: str):
model_args = {"test_name": model_arch, "model_arch": model_arch}
self._setup(model_args)
model_id = MODEL_NAMES[model_arch]
onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch])
image_processor = self._get_preprocessors(model_id)

data = self._get_sample_image()
features = image_processor(data, return_tensors="pt")

outputs = onnx_model(**features)
self.assertIsInstance(outputs, ImageSuperResolutionOutput)

gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
def test_pipeline_image_to_image(self, model_arch: str):
model_args = {"test_name": model_arch, "model_arch": model_arch}
self._setup(model_args)
model_id = MODEL_NAMES[model_arch]
onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch])
image_processor = self._get_preprocessors(model_id)
pipe = pipeline(
"image-to-image",
model=onnx_model,
feature_extractor=image_processor,
)
data = self._get_sample_image()
outputs = pipe(data)
self.assertEqual(pipe.device, onnx_model.device)
self.assertIsInstance(outputs, Image.Image)

gc.collect()

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@require_torch_gpu
@pytest.mark.cuda_ep_test
def test_pipeline_on_gpu(self, model_arch: str):
model_args = {"test_name": model_arch, "model_arch": model_arch}
self._setup(model_args)
model_id = MODEL_NAMES[model_arch]
onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch])
image_processor = self._get_preprocessors(model_id)
pipe = pipeline(
"image-to-image",
model=onnx_model,
feature_extractor=image_processor,
device=0,
)

data = self._get_sample_image()
outputs = pipe(data)

self.assertEqual(pipe.model.device.type.lower(), "cuda")
self.assertIsInstance(outputs, Image.Image)

@parameterized.expand(SUPPORTED_ARCHITECTURES)
@require_torch_gpu
@require_ort_rocm
@pytest.mark.rocm_ep_test
def test_pipeline_on_rocm(self, model_arch: str):
model_args = {"test_name": model_arch, "model_arch": model_arch}
self._setup(model_args)
model_id = MODEL_NAMES[model_arch]
onnx_model = ORTModelForImageToImage.from_pretrained(self.onnx_model_dirs[model_arch])
image_processor = self._get_preprocessors(model_id)
pipe = pipeline(
"image-to-image",
model=onnx_model,
feature_extractor=image_processor,
device=0,
)

data = self._get_sample_image()
outputs = pipe(data)

self.assertEqual(pipe.model.device.type.lower(), "cuda")
self.assertIsInstance(outputs, Image.Image)


class ORTModelForVision2SeqIntegrationTest(ORTModelTestMixin):
SUPPORTED_ARCHITECTURES = ["vision-encoder-decoder", "trocr", "donut"]

Expand Down Expand Up @@ -4831,7 +4965,6 @@ def test_compare_to_transformers(self, test_name: str, model_arch: str, use_cach
len(onnx_outputs["past_key_values"][0]), len(transformers_outputs["past_key_values"][0])
)
for i in range(len(onnx_outputs["past_key_values"])):
print(onnx_outputs["past_key_values"][i])
for ort_pkv, trfs_pkv in zip(
onnx_outputs["past_key_values"][i], transformers_outputs["past_key_values"][i]
):
Expand Down Expand Up @@ -5517,6 +5650,7 @@ class TestBothExportersORTModel(unittest.TestCase):
["automatic-speech-recognition", ORTModelForCTCIntegrationTest],
["audio-xvector", ORTModelForAudioXVectorIntegrationTest],
["audio-frame-classification", ORTModelForAudioFrameClassificationIntegrationTest],
["image-to-image", ORTModelForImageToImageIntegrationTest],
]
)
def test_find_untested_architectures(self, task: str, test_class):
Expand Down
1 change: 1 addition & 0 deletions tests/onnxruntime/utils_onnxruntime_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
"stable-diffusion-xl": "echarlaix/tiny-random-stable-diffusion-xl",
"swin": "hf-internal-testing/tiny-random-SwinModel",
"swin-window": "yujiepan/tiny-random-swin-patch4-window7-224",
"swin2sr": "hf-internal-testing/tiny-random-Swin2SRForImageSuperResolution",
"t5": "hf-internal-testing/tiny-random-t5",
"table-transformer": "hf-internal-testing/tiny-random-TableTransformerModel",
"trocr": "microsoft/trocr-small-handwritten",
Expand Down

0 comments on commit fd638d2

Please sign in to comment.