Skip to content

Commit

Permalink
allow customization of the forwarding type
Browse files Browse the repository at this point in the history
  • Loading branch information
4imothy committed Oct 16, 2024
1 parent 89cb5bb commit 924e424
Show file tree
Hide file tree
Showing 3 changed files with 55 additions and 17 deletions.
34 changes: 34 additions & 0 deletions example/custom_operator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@
import torch
import torchvision
import ai3


class SpecialConv(torch.nn.Module):
def __init__(self, orig: torch.nn.Conv2d, algorithm: str):
super(SpecialConv, self).__init__()
self.orig = orig
self.algorithm = algorithm

def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.orig(x)


def conv2d_selector(orig: torch.nn.Conv2d) -> str:
in_channels = orig.weight.shape[1]
if in_channels > 200:
return 'smm'
return 'direct'


input_data = torch.randn(1, 3, 224, 224)
vgg16 = torchvision.models.vgg16(
weights=torchvision.models.VGG16_Weights.DEFAULT)
vgg16.eval()
with torch.inference_mode():
torch_out = vgg16(input_data)

ai3.swap_conv2d(
vgg16, conv2d_selector, None, SpecialConv)
sb_out = vgg16(input_data)
assert torch.allclose(
torch_out, sb_out, atol=1e-4)
16 changes: 10 additions & 6 deletions src/ai3/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
"""

import inspect
from typing import Mapping, Optional, Sequence, Type, Union
from typing import Mapping, Optional, Sequence, Type, Union, TypeVar
from . import _core, utils, layers, _version
from .tensor import Tensor

Expand Down Expand Up @@ -105,12 +105,12 @@ def predict(self, input, out_type=None):
out = Tensor(out)
return out.to(out_type)


def swap_operation(
op_type: Union[Type, str],
module,
algos: Optional[AlgorithmicSelector] = None,
sample_input_shape: Optional[Sequence[int]] = None):
sample_input_shape: Optional[Sequence[int]] = None,
swap_with = None):
"""
Swaps operations in-place out of the existing *DNN* for an implementation of
the user specified algorithm. After swapping, the same *DNN* can still be trained
Expand All @@ -124,6 +124,9 @@ def swap_operation(
algorithmic selector for the *conv2d* operations
sample_input_shape : the input shape to the *DNN* which is passed to the
function algorithmic selector if present
swap_with : the new type which performs the operation, the original operator
and the algorithm are passed for initialization, if not present,
the operator provided by |name| will be used
Example:
Swaps the first *conv2d* operation for an implementation of direct convolution
Expand All @@ -148,7 +151,7 @@ def swap_operation(
utils.check_callable_params_with_shape(
{op_str: algos}, sample_input_shape)
swapper.swap_operation(op_type,
module, algos, sample_input_shape)
module, algos, sample_input_shape, swap_with)


def swap_backend(module,
Expand Down Expand Up @@ -216,12 +219,13 @@ def swap_backend(module,

def swap_conv2d(module,
algos: Optional['AlgorithmicSelector'] = None,
sample_input_shape: Optional[Sequence[int]] = None):
sample_input_shape: Optional[Sequence[int]] = None,
swap_with = None):
"""
Calls
>>> swap_operation('conv2d', module, algos, sample_input_shape) # doctest: +SKIP
"""
swap_operation('conv2d', module, algos, sample_input_shape)
swap_operation('conv2d', module, algos, sample_input_shape, swap_with)


def using_mps_and_metal() -> bool:
Expand Down
22 changes: 11 additions & 11 deletions src/ai3/swap_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,9 +16,8 @@


class Conv2D(nn.Module):
def __init__(self, orig: nn.Conv2d, algorithm: str, target: str):
def __init__(self, orig: nn.Conv2d, algorithm: str):
super(Conv2D, self).__init__()
self.target = target
self.algorithm = algorithm

self.stride = utils.make_2d(orig.stride)
Expand Down Expand Up @@ -353,13 +352,14 @@ def swapped_type(op_type) -> Optional[Type]:


def swap_operation(
op_type: Type, module: nn.Module, selector: utils.AlgorithmicSelector,
sample_input_shape: Optional[Sequence[int]] = None):
swapped_op_type = swapped_type(op_type)
orig_op_type: Type, module: nn.Module, selector: utils.AlgorithmicSelector,
sample_input_shape: Optional[Sequence[int]], swap_with):
if not swap_with:
swap_with = swapped_type(orig_op_type)
errors.bail_if(
swapped_op_type is None,
f'cannot perform inplace algorithmic selection for {op_type}')
assert swapped_op_type is not None
swap_with is None,
f'cannot perform inplace algorithmic selection for {orig_op_type}')
assert swap_with is not None
graph, with_shapes = trace_module(
module, sample_input_shape)

Expand All @@ -371,16 +371,16 @@ def swap_operation(
node_input_shape = node.meta['tensor_meta'].shape
if node.op == 'call_module':
mod = getmodule(module, node.target)
if isinstance(mod, (op_type, swapped_op_type)):
if isinstance(mod, (orig_op_type, swap_with)):
algo = get_algo_inc_counter(
mod, selector,
layer_counters, node_input_shape)
if algo == 'torch':
continue
if isinstance(mod, op_type):
if isinstance(mod, orig_op_type):
module = setmodule(
module, node.target,
swapped_op_type(mod, algo, str(node.target)))
swap_with(mod, algo))
else:
setattr(mod, 'algorithm', algo)

Expand Down

0 comments on commit 924e424

Please sign in to comment.