diff --git a/example/custom_operator.py b/example/custom_operator.py new file mode 100644 index 0000000..92a38e1 --- /dev/null +++ b/example/custom_operator.py @@ -0,0 +1,34 @@ +import torch +import torchvision +import ai3 + + +class SpecialConv(torch.nn.Module): + def __init__(self, orig: torch.nn.Conv2d, algorithm: str): + super(SpecialConv, self).__init__() + self.orig = orig + self.algorithm = algorithm + + def forward(self, x: torch.Tensor) -> torch.Tensor: + return self.orig(x) + + +def conv2d_selector(orig: torch.nn.Conv2d) -> str: + in_channels = orig.weight.shape[1] + if in_channels > 200: + return 'smm' + return 'direct' + + +input_data = torch.randn(1, 3, 224, 224) +vgg16 = torchvision.models.vgg16( + weights=torchvision.models.VGG16_Weights.DEFAULT) +vgg16.eval() +with torch.inference_mode(): + torch_out = vgg16(input_data) + + ai3.swap_conv2d( + vgg16, conv2d_selector, None, SpecialConv) + sb_out = vgg16(input_data) + assert torch.allclose( + torch_out, sb_out, atol=1e-4) diff --git a/src/ai3/__init__.py b/src/ai3/__init__.py index 7fa02e1..c447924 100644 --- a/src/ai3/__init__.py +++ b/src/ai3/__init__.py @@ -9,7 +9,7 @@ """ import inspect -from typing import Mapping, Optional, Sequence, Type, Union +from typing import Mapping, Optional, Sequence, Type, Union, TypeVar from . import _core, utils, layers, _version from .tensor import Tensor @@ -105,12 +105,12 @@ def predict(self, input, out_type=None): out = Tensor(out) return out.to(out_type) - def swap_operation( op_type: Union[Type, str], module, algos: Optional[AlgorithmicSelector] = None, - sample_input_shape: Optional[Sequence[int]] = None): + sample_input_shape: Optional[Sequence[int]] = None, + swap_with = None): """ Swaps operations in-place out of the existing *DNN* for an implementation of the user specified algorithm. After swapping, the same *DNN* can still be trained @@ -124,6 +124,9 @@ def swap_operation( algorithmic selector for the *conv2d* operations sample_input_shape : the input shape to the *DNN* which is passed to the function algorithmic selector if present + swap_with : the new type which performs the operation, the original operator + and the algorithm are passed for initialization, if not present, + the operator provided by |name| will be used Example: Swaps the first *conv2d* operation for an implementation of direct convolution @@ -148,7 +151,7 @@ def swap_operation( utils.check_callable_params_with_shape( {op_str: algos}, sample_input_shape) swapper.swap_operation(op_type, - module, algos, sample_input_shape) + module, algos, sample_input_shape, swap_with) def swap_backend(module, @@ -216,12 +219,13 @@ def swap_backend(module, def swap_conv2d(module, algos: Optional['AlgorithmicSelector'] = None, - sample_input_shape: Optional[Sequence[int]] = None): + sample_input_shape: Optional[Sequence[int]] = None, + swap_with = None): """ Calls >>> swap_operation('conv2d', module, algos, sample_input_shape) # doctest: +SKIP """ - swap_operation('conv2d', module, algos, sample_input_shape) + swap_operation('conv2d', module, algos, sample_input_shape, swap_with) def using_mps_and_metal() -> bool: diff --git a/src/ai3/swap_torch.py b/src/ai3/swap_torch.py index 0c851a9..984336d 100644 --- a/src/ai3/swap_torch.py +++ b/src/ai3/swap_torch.py @@ -16,9 +16,8 @@ class Conv2D(nn.Module): - def __init__(self, orig: nn.Conv2d, algorithm: str, target: str): + def __init__(self, orig: nn.Conv2d, algorithm: str): super(Conv2D, self).__init__() - self.target = target self.algorithm = algorithm self.stride = utils.make_2d(orig.stride) @@ -353,13 +352,14 @@ def swapped_type(op_type) -> Optional[Type]: def swap_operation( - op_type: Type, module: nn.Module, selector: utils.AlgorithmicSelector, - sample_input_shape: Optional[Sequence[int]] = None): - swapped_op_type = swapped_type(op_type) + orig_op_type: Type, module: nn.Module, selector: utils.AlgorithmicSelector, + sample_input_shape: Optional[Sequence[int]], swap_with): + if not swap_with: + swap_with = swapped_type(orig_op_type) errors.bail_if( - swapped_op_type is None, - f'cannot perform inplace algorithmic selection for {op_type}') - assert swapped_op_type is not None + swap_with is None, + f'cannot perform inplace algorithmic selection for {orig_op_type}') + assert swap_with is not None graph, with_shapes = trace_module( module, sample_input_shape) @@ -371,16 +371,16 @@ def swap_operation( node_input_shape = node.meta['tensor_meta'].shape if node.op == 'call_module': mod = getmodule(module, node.target) - if isinstance(mod, (op_type, swapped_op_type)): + if isinstance(mod, (orig_op_type, swap_with)): algo = get_algo_inc_counter( mod, selector, layer_counters, node_input_shape) if algo == 'torch': continue - if isinstance(mod, op_type): + if isinstance(mod, orig_op_type): module = setmodule( module, node.target, - swapped_op_type(mod, algo, str(node.target))) + swap_with(mod, algo)) else: setattr(mod, 'algorithm', algo)