Skip to content

Commit

Permalink
dev_optimize_mock_torch (#410)
Browse files Browse the repository at this point in the history
关联 oneflow:  siliconflow/oneflow#55


删除 与**文件拷贝,文件创建相关**的所有 操作。

---------

Co-authored-by: Xiaoyu Xu <xuxiaoyu2048@foxmail.com>
  • Loading branch information
ccssu and strint authored Dec 19, 2023
1 parent 516ea1f commit 2cf4a2f
Show file tree
Hide file tree
Showing 9 changed files with 48 additions and 346 deletions.
152 changes: 0 additions & 152 deletions src/onediff/infer_compiler/import_tools/copier.py

This file was deleted.

62 changes: 0 additions & 62 deletions src/onediff/infer_compiler/import_tools/copy_utils.py

This file was deleted.

1 change: 0 additions & 1 deletion src/onediff/infer_compiler/import_tools/format_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ def format(self, entity: Union[str, type, FunctionType]) -> str:
return self._format_full_class_name(entity)

def unformat(self, mock_entity_name: str) -> str:

if "." in mock_entity_name:
pkg_name, cls_name = mock_entity_name.split(".", 1)
return f"{self._reverse_pkg_name(pkg_name)}.{cls_name}"
Expand Down
94 changes: 26 additions & 68 deletions src/onediff/infer_compiler/import_tools/importer.py
Original file line number Diff line number Diff line change
@@ -1,52 +1,30 @@
import inspect
import os
import sys
import importlib
import shutil
from typing import Any, Optional, Union
from typing import Optional, Union
from types import FunctionType, ModuleType
from oneflow.mock_torch import DynamicMockModule
from pathlib import Path
from ..utils.log_utils import logger
from .copier import PackageCopier
from .mock_torch_context import onediff_mock_torch
from importlib.metadata import requires
from .format_utils import MockEntityNameFormatter

__all__ = ["import_module_from_path", "LazyMocker"]
__all__ = ["import_module_from_path", "LazyMocker", "is_need_mock"]


class MockEntity:
def __init__(self, obj_entity: ModuleType = None):
self._obj_entity = obj_entity # ModuleType or _LazyModule

@classmethod
def from_package(cls, package: str):
with onediff_mock_torch():
return cls(importlib.import_module(package))

def _get_module(self, _name: str):
# Fix Lazy import
# https://github.com/huggingface/diffusers/blob/main/src/diffusers/__init__.py#L728-L734
module_name = f"{self._obj_entity.__name__}.{_name}"
try:
return importlib.import_module(module_name)
except Exception as e:
raise RuntimeError(
f"Failed to import {module_name} because of the following error (look up to see its"
f" traceback):\n{e}"
) from e

def __getattr__(self, name: str):
with onediff_mock_torch():
obj_entity = getattr(self._obj_entity, name, None)
if obj_entity is None:
obj_entity = self._get_module(name)

if inspect.ismodule(obj_entity):
return MockEntity(obj_entity)
return obj_entity

def entity(self):
return self._obj_entity
def is_need_mock(cls) -> bool:
assert isinstance(cls, (type, str))
main_pkg = cls.__module__.split(".")[0]
try:
pkgs = requires(main_pkg)
except Exception as e:
return True
if pkgs:
for pkg in pkgs:
pkg = pkg.split(" ")[0]
if pkg == "torch":
return True
return False
return True


def import_module_from_path(module_path: Union[str, Path]) -> ModuleType:
Expand Down Expand Up @@ -81,22 +59,10 @@ def __init__(self, prefix: str, suffix: str, tmp_dir: Optional[Union[str, Path]]
self.cleanup_list = []

def mock_package(self, package: str):
# TODO Mock the package in memory
with onediff_mock_torch():
copier = PackageCopier(
package,
prefix=self.prefix,
suffix=self.suffix,
target_directory=self.tmp_dir,
)
copier.mock()
self.mocked_packages.add(copier.new_pkg_name)
self.cleanup_list.append(copier.new_pkg_path)
pass

def cleanup(self):
for path in self.cleanup_list:
logger.debug(f"Removing {path=}")
shutil.rmtree(path, ignore_errors=True)
pass

def get_mock_entity_name(self, entity: Union[str, type, FunctionType]):
formatter = MockEntityNameFormatter(prefix=self.prefix, suffix=self.suffix)
Expand All @@ -105,7 +71,7 @@ def get_mock_entity_name(self, entity: Union[str, type, FunctionType]):

def mock_entity(self, entity: Union[str, type, FunctionType]):
"""Mock the entity and return the mocked entity
Example:
>>> mocker = LazyMocker(prefix="mock_", suffix="_of", tmp_dir="tmp")
>>> mocker.mock_entity("models.DemoModel")
Expand All @@ -120,16 +86,8 @@ def load_entity_with_mock(self, entity: Union[str, type, FunctionType]):
formatter = MockEntityNameFormatter(prefix=self.prefix, suffix=self.suffix)
full_obj_name = formatter.format(entity)
attrs = full_obj_name.split(".")
if attrs[0] in self.mocked_packages:
obj_entity = MockEntity.from_package(attrs[0])
for name in attrs[1:]:
obj_entity = getattr(obj_entity, name)
return obj_entity
else:
pkg_name = formatter.unformat(attrs[0])
pkg = importlib.import_module(pkg_name)
if pkg is None:
RuntimeError(f'Importing package "{pkg_name}" failed')
# https://docs.python.org/3/reference/import.html#path__
self.mock_package(pkg.__path__[0])
return self.load_entity_with_mock(entity)
self.mocked_packages.add(attrs[0])
mock_pkg = DynamicMockModule.from_package(attrs[0], verbose=False)
for name in attrs[1:]:
mock_pkg = getattr(mock_pkg, name)
return mock_pkg
22 changes: 0 additions & 22 deletions src/onediff/infer_compiler/import_tools/mock_torch_context.py

This file was deleted.

5 changes: 4 additions & 1 deletion src/onediff/infer_compiler/transform/builtin_transform.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@
from .manager import transform_mgr
from ..utils.log_utils import logger
from ..utils.patch_for_diffusers import diffusers_checker
from ..import_tools.importer import is_need_mock
from functools import singledispatch

__all__ = [
"proxy_class",
Expand All @@ -24,7 +26,6 @@
"torch2oflow",
"default_converter",
]
from functools import singledispatch


def singledispatch_proxy(func):
Expand Down Expand Up @@ -150,6 +151,8 @@ def torch2oflow(mod, *args, **kwargs):


def default_converter(obj, verbose=False, *, proxy_cls=None):
if not is_need_mock(type(obj)):
return obj
try:
new_obj_cls = proxy_class(type(obj)) if proxy_cls is None else proxy_cls

Expand Down
Loading

0 comments on commit 2cf4a2f

Please sign in to comment.