Skip to content

Commit

Permalink
cleanup custom operator functionality
Browse files Browse the repository at this point in the history
  • Loading branch information
4imothy committed Oct 20, 2024
1 parent 924e424 commit 535bb52
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 4 deletions.
2 changes: 1 addition & 1 deletion example/custom_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ def conv2d_selector(orig: torch.nn.Conv2d) -> str:
torch_out = vgg16(input_data)

ai3.swap_conv2d(
vgg16, conv2d_selector, None, SpecialConv)
vgg16, conv2d_selector, None, swap_with=SpecialConv)
sb_out = vgg16(input_data)
assert torch.allclose(
torch_out, sb_out, atol=1e-4)
7 changes: 4 additions & 3 deletions src/ai3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,7 @@
existing *DNN*.
"""

import inspect
from typing import Mapping, Optional, Sequence, Type, Union, TypeVar
from typing import Mapping, Optional, Sequence, Type, Union
from . import _core, utils, layers, _version
from .tensor import Tensor

Expand Down Expand Up @@ -110,6 +109,7 @@ def swap_operation(
module,
algos: Optional[AlgorithmicSelector] = None,
sample_input_shape: Optional[Sequence[int]] = None,
*,
swap_with = None):
"""
Swaps operations in-place out of the existing *DNN* for an implementation of
Expand Down Expand Up @@ -220,12 +220,13 @@ def swap_backend(module,
def swap_conv2d(module,
algos: Optional['AlgorithmicSelector'] = None,
sample_input_shape: Optional[Sequence[int]] = None,
*,
swap_with = None):
"""
Calls
>>> swap_operation('conv2d', module, algos, sample_input_shape) # doctest: +SKIP
"""
swap_operation('conv2d', module, algos, sample_input_shape, swap_with)
swap_operation('conv2d', module, algos, sample_input_shape, swap_with=swap_with)


def using_mps_and_metal() -> bool:
Expand Down

0 comments on commit 535bb52

Please sign in to comment.