Skip to content

Commit

Permalink
reduce parent model usage in model parts
Browse files Browse the repository at this point in the history
  • Loading branch information
IlyasMoutawwakil committed Sep 18, 2024
1 parent 7d50df3 commit cce0ee8
Showing 1 changed file with 10 additions and 5 deletions.
15 changes: 10 additions & 5 deletions optimum/onnxruntime/modeling_diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -468,13 +468,17 @@ def __call__(self, *args, **kwargs):
return self.auto_model_class.__call__(self, *args, **kwargs)


class ORTPipelinePart(ORTModelPart):
class ORTPipelinePart(ORTModelPart, ConfigMixin):
config_name: str = "config.json"

def __init__(self, session: ort.InferenceSession, parent_model: ORTPipeline):
super().__init__(session, parent_model)

config_path = Path(session._model_path).parent / "config.json"
config_dict = parent_model._dict_from_json_file(config_path) if config_path.is_file() else {}
self.config = FrozenDict(config_dict)
config_path = Path(session._model_path).parent / self.config_name
config_dict = self.load_config(config_path) if config_path.is_file() else {}
config_dict = config_dict[0] if isinstance(config_dict, tuple) else config_dict

self._internal_dict = FrozenDict(config_dict)

@property
def input_dtype(self):
Expand Down Expand Up @@ -605,10 +609,11 @@ def forward(

class ORTVaeWrapper(ORTPipelinePart):
def __init__(self, vae_encoder: ORTModelVaeEncoder, vae_decoder: ORTModelVaeDecoder, parent_model: ORTPipeline):
super().__init__(vae_decoder.session, parent_model)
self.vae_encoder = vae_encoder
self.vae_decoder = vae_decoder

super().__init__(vae_decoder.session, parent_model)

def encode(
self,
sample: Union[np.ndarray, torch.Tensor],
Expand Down

0 comments on commit cce0ee8

Please sign in to comment.