Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
hjchen2 committed Apr 15, 2024
1 parent c974d3a commit e025cb0
Show file tree
Hide file tree
Showing 4 changed files with 3 additions and 10 deletions.
2 changes: 1 addition & 1 deletion src/onediff/infer_compiler/backends/nexfort.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import dataclasses
import torch
from .registry import register_backend
from ..options import CompileOptions


def make_inductor_options(options):
Expand All @@ -19,6 +18,7 @@ def compile(torch_module: torch.nn.Module, *, options=None):
from nexfort.utils.memory_format import apply_memory_format
from nexfort.compilers import nexfort_compile
from ..nexfort.deployable_module import NexfortDeployableModule
from ..utils import CompileOptions

options = options if options is not None else CompileOptions()
nexfort_options = options.nexfort
Expand Down
2 changes: 1 addition & 1 deletion src/onediff/infer_compiler/core/with_onediff_compile.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
def compile(
torch_module: torch.nn.Module, *, backend="nexfort", options=None
) -> DeployableModule:
from .backends.registry import lookup_backend
from ..backends.registry import lookup_backend

backend = lookup_backend(backend)
model = backend(torch_module, options=options)
Expand Down
2 changes: 1 addition & 1 deletion src/onediff/infer_compiler/nexfort/deployable_module.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import torch
from ..deployable_module import DeployableModule
from ..core.deployable_module import DeployableModule


class NexfortDeployableModule(DeployableModule):
Expand Down
7 changes: 0 additions & 7 deletions src/onediff/infer_compiler/utils/env_var.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,6 @@ def set_integer_env_var(env_var: str, val: Optional[int]):


def _set_env_vars(field2env_var, options):
from .utils import (
parse_boolean_from_env,
set_boolean_env_var,
parse_integer_from_env,
set_integer_env_var,
)

for field in dataclasses.fields(options):
field_name = field.name
field_value = getattr(options, field_name)
Expand Down

0 comments on commit e025cb0

Please sign in to comment.