Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

adding ORTPipelineForxxx entrypoints #1960

Merged
merged 27 commits into from
Sep 16, 2024
Merged
Show file tree
Hide file tree
Changes from 19 commits
Commits
Show all changes
27 commits
Select commit Hold shift + click to select a range
fcb1690
created auto task mappings
IlyasMoutawwakil Jul 16, 2024
1cbb544
added correct auto classes
IlyasMoutawwakil Jul 18, 2024
cdba70e
created auto task mappings
IlyasMoutawwakil Jul 16, 2024
5bebbd5
added correct auto classes
IlyasMoutawwakil Jul 18, 2024
862e1a4
Merge branch 'auto-diffusion-pipeline' of https://github.com/huggingf…
IlyasMoutawwakil Jul 18, 2024
40b2ac0
added ort/auto diffusion classes
IlyasMoutawwakil Jul 19, 2024
29bfe57
fix ORTPipeline detection
IlyasMoutawwakil Jul 31, 2024
f6df38c
start test refactoring
IlyasMoutawwakil Jul 31, 2024
3123ea5
dynamic dtype
IlyasMoutawwakil Aug 27, 2024
7803ef3
support torch random numbers generator
IlyasMoutawwakil Aug 27, 2024
aa41f42
compact diffusion testing suite
IlyasMoutawwakil Aug 27, 2024
4837828
fix
IlyasMoutawwakil Aug 27, 2024
7504aa3
Merge branch 'main' into auto-diffusion-pipeline
IlyasMoutawwakil Sep 5, 2024
80532b3
test
IlyasMoutawwakil Sep 7, 2024
f99a058
test
IlyasMoutawwakil Sep 7, 2024
781ede7
test
IlyasMoutawwakil Sep 7, 2024
f0e3f2b
use latent-consistency architecture name instead of lcm
IlyasMoutawwakil Sep 7, 2024
80c63d0
fix
IlyasMoutawwakil Sep 7, 2024
a4518f2
add ort diffusion pipeline tests
IlyasMoutawwakil Sep 8, 2024
9f0c7b6
added dummy objects
IlyasMoutawwakil Sep 10, 2024
56d06d4
remove duplicate code
IlyasMoutawwakil Sep 10, 2024
475efdf
support testing without diffusers
IlyasMoutawwakil Sep 11, 2024
e2ad89a
remove unnecessary
IlyasMoutawwakil Sep 11, 2024
7b4b5bd
revert
IlyasMoutawwakil Sep 11, 2024
390d65d
Merge branch 'main' into auto-diffusion-pipeline
IlyasMoutawwakil Sep 11, 2024
036dc46
style
IlyasMoutawwakil Sep 12, 2024
afbb9af
remove model parts from optimum.onnxruntime
IlyasMoutawwakil Sep 12, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion optimum/exporters/tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -308,9 +308,9 @@ class TasksManager:
"image-feature-extraction": "feature-extraction",
# for backward compatibility and testing (where
# model task and model type are still the same)
"lcm": "text-to-image",
"stable-diffusion": "text-to-image",
"stable-diffusion-xl": "text-to-image",
"latent-consistency": "text-to-image",
}

_CUSTOM_CLASSES = {
Expand Down
9 changes: 6 additions & 3 deletions optimum/modeling_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,6 @@ class PreTrainedModel(ABC): # noqa: F811

class OptimizedModel(PreTrainedModel):
config_class = AutoConfig
load_tf_weights = None
IlyasMoutawwakil marked this conversation as resolved.
Show resolved Hide resolved
base_model_prefix = "optimized_model"
config_name = CONFIG_NAME

Expand Down Expand Up @@ -378,10 +377,14 @@ def from_pretrained(
)
model_id, revision = model_id.split("@")

library_name = TasksManager.infer_library_from_model(model_id, subfolder, revision, cache_dir, token=token)
library_name = TasksManager.infer_library_from_model(
model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
)

if library_name == "timm":
config = PretrainedConfig.from_pretrained(model_id, subfolder, revision)
config = PretrainedConfig.from_pretrained(
model_id, subfolder=subfolder, revision=revision, cache_dir=cache_dir, token=token
)

if config is None:
if os.path.isdir(os.path.join(model_id, subfolder)) and cls.config_name == CONFIG_NAME:
Expand Down
8 changes: 8 additions & 0 deletions optimum/onnxruntime/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,10 @@
"ORTStableDiffusionXLPipeline",
"ORTStableDiffusionXLImg2ImgPipeline",
"ORTLatentConsistencyModelPipeline",
"ORTPipelineForText2Image",
"ORTPipelineForImage2Image",
"ORTPipelineForInpainting",
"ORTDiffusionPipeline",
]


Expand Down Expand Up @@ -146,7 +150,11 @@
)
else:
from .modeling_diffusion import (
ORTDiffusionPipeline,
ORTLatentConsistencyModelPipeline,
ORTPipelineForImage2Image,
IlyasMoutawwakil marked this conversation as resolved.
Show resolved Hide resolved
ORTPipelineForInpainting,
ORTPipelineForText2Image,
ORTStableDiffusionImg2ImgPipeline,
ORTStableDiffusionInpaintPipeline,
ORTStableDiffusionPipeline,
Expand Down
50 changes: 23 additions & 27 deletions optimum/onnxruntime/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,17 +41,11 @@ class ORTModelPart:
_prepare_onnx_inputs = ORTModel._prepare_onnx_inputs
_prepare_onnx_outputs = ORTModel._prepare_onnx_outputs

def __init__(
self,
session: InferenceSession,
parent_model: "ORTModel",
):
def __init__(self, session: InferenceSession, parent_model: "ORTModel"):
self.session = session
self.parent_model = parent_model
self.normalized_config = NormalizedConfigManager.get_normalized_config_class(
self.parent_model.config.model_type
)(self.parent_model.config)
IlyasMoutawwakil marked this conversation as resolved.
Show resolved Hide resolved
self.main_input_name = self.parent_model.main_input_name

self.input_names = {input_key.name: idx for idx, input_key in enumerate(self.session.get_inputs())}
self.output_names = {output_key.name: idx for idx, output_key in enumerate(self.session.get_outputs())}
self.input_dtypes = {input_key.name: input_key.type for input_key in session.get_inputs()}
Expand Down Expand Up @@ -90,12 +84,18 @@ class ORTEncoder(ORTModelPart):
Encoder part of the encoder-decoder model for ONNX Runtime inference.
"""

def forward(
self,
input_ids: torch.LongTensor,
attention_mask: torch.LongTensor,
**kwargs,
) -> BaseModelOutput:
def __init__(self, session: InferenceSession, parent_model: "ORTModel"):
super().__init__(session, parent_model)

config = (
self.parent_model.config.encoder
if hasattr(self.parent_model.config, "encoder")
else self.parent_model.config
)

self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)

def forward(self, input_ids: torch.LongTensor, attention_mask: torch.LongTensor, **kwargs) -> BaseModelOutput:
use_torch = isinstance(input_ids, torch.Tensor)
self.parent_model.raise_on_numpy_input_io_binding(use_torch)

Expand Down Expand Up @@ -138,6 +138,14 @@ def __init__(
):
super().__init__(session, parent_model)

config = (
self.parent_model.config.decoder
if hasattr(self.parent_model.config, "decoder")
else self.parent_model.config
)

self.normalized_config = NormalizedConfigManager.get_normalized_config_class(config.model_type)(config)

# TODO: make this less hacky.
self.key_value_input_names = [key for key in self.input_names if (".key" in key) or (".value" in key)]
self.key_value_output_names = [key for key in self.output_names if (".key" in key) or (".value" in key)]
Expand All @@ -153,11 +161,7 @@ def __init__(

self.use_past_in_outputs = len(self.key_value_output_names) > 0
self.use_past_in_inputs = len(self.key_value_input_names) > 0
self.use_fp16 = False
for inp in session.get_inputs():
if "past_key_values" in inp.name and inp.type == "tensor(float16)":
self.use_fp16 = True
break
self.use_fp16 = self.dtype == torch.float16
IlyasMoutawwakil marked this conversation as resolved.
Show resolved Hide resolved

# We may use ORTDecoderForSeq2Seq for vision-encoder-decoder models, where models as gpt2
# can be used but do not support KV caching for the cross-attention key/values, see:
Expand Down Expand Up @@ -461,11 +465,3 @@ def prepare_inputs_for_merged(
cache_position = cache_position.to(self.device)

return use_cache_branch_tensor, past_key_values, cache_position


class ORTDecoder(ORTDecoderForSeq2Seq):
def __init__(self, *args, **kwargs):
logger.warning(
"The class `ORTDecoder` is deprecated and will be removed in optimum v1.15.0, please use `ORTDecoderForSeq2Seq` instead."
)
super().__init__(*args, **kwargs)
Loading
Loading