Authors | Hameer Abbasi, Edward Z. Yang and Ralf Gommers |
Status | Accepted |
Type | Proposal |
Created | 2020-01-24 |
Resolution | TBD |
This RFC describes changes necessary to allow __torch_function__
to be used
by methods of torch.Tensor
in an attempt to make subclassing more accessible
to the users of the class. This entails making an API for subclass views
public, and a change in the signature of __torch_function__
.
Quoting [1], [2] and [3], the goals of this proposal are:
- Support subclassing
torch.Tensor
in Python - Preserve
torch.Tensor
subclasses when callingtorch
functions on them - Use the PyTorch API with
torch.Tensor
-like objects that are nottorch.Tensor
subclasses - Preserve
torch.Tensor
subclasses when callingtorch.Tensor
methods. - Propagating subclass instances correctly also with operators, using views/slices/indexing/etc.
- Preserve subclass attributes when using methods or views/slices/indexing.
- A way to insert code that operates on both functions and methods uniformly (so we can write a single function that overrides all operators).
- The ability to give external libraries a way to also define
functions/methods that follow the
__torch_function__
protocol.
Goals 1‒6 are explicitly about subclassing, goal 7 is already partially achieved via the __torch_function__
protocol (which we're proposing to extend to methods), and goal 8 is a by-product required to make overridden torch.Tensor
subclass methods behave similar to torch.Tensor
methods.
Achieving interoperability with NumPy and adopting its array protocols is out of scope for this proposal and we propose to defer it to a later proposal.
We propose to solve this problem with the following changes to PyTorch:
- Make methods, operators and properties of
torch.Tensor
go through the__torch_function__
machinery. - Add a
types
argument to__torch_function__
, to make it match NumPy's__array_function__
. - Add a new method to
torch.Tensor
,as_subclass
, which creates a subtype view into the original object. - Make
torch.Tensor
gain a generic implementation of__torch_function__
.
Once this proposal is merged, users of subclasses of torch.Tensor
will have
a much more streamlined experience. Namely, the following code example will
work as-is, without the need for any further modification:
class SubTensor(torch.Tensor):
a = 1
t = SubTensor([1])
s = t.sum()
isinstance(s, SubTensor) # True
s.a # 1
i = t[0]
isinstance(i, SubTensor) # True
i.a # 1
s2 = t + torch.Tensor(1)
isinstance(s2, SubTensor) # True
s2.a # 1
s3 = torch.Tensor(1) + t
isinstance(s3, SubTensor) # True
s3.a # 1
Additionally, it will provide subclass authors the ability to also modify the
results of methods, operators and properties in __torch_function__
, along with
regular function calls, and to modify the result to their specific use-case,
perform logging, or otherwise change the result or the action of the method.
For example:
import logging
class LoggingTensor(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs):
logging.info(f"func: {func.__name__}, args: {args!r}, kwargs: {kwargs!r}")
return super().__torch_function__(
func,
types,
args,
kwargs
)
Assuming minimum logging level is set to logging.INFO
, the following
indicates the code run, with the logging output in the comments.
t = LoggingTensor([1])
t.sum() # Tensor.sum, (LoggingTensor([1]),), {}
t[0] # Tensor.__getitem__, (LoggingTensor([1]), 0,), {}
# This is already possible
torch.sum(t) # sum, (LoggingTensor([1]),), {}
To make the protocol operate only on functions rather than methods, one can
check for func not in type(self).__dict__.values()
. To check for operators
and/or indexing, one can check func.__name__.endswith("__")
.
There are a few requirements for the performance of this proposal, when implemented:
- No deterioration for function/method calls on
torch.Tensor
objects. - No deterioration of current
__torch_function__
overhead - Sub-µs impact on the performance of subclasses not implementing
__torch_function__
.
Requirement 1 seems unachievable due to the structure of the code at this point, as:
- In methods defined in C++,
self
is excluded from the argument processing that gathersTensor
-likes in C++. - Similar to point 1, C++ methods that take only
self
as aTensor
-like don't pass through this processing, and they will be required to. - For methods defined in Python, the processing for handling
__torch_function__
will need to be added, similar to the original__torch_function__
PR [5].
We think an overhead of sub-100 ns per method call is feasible.
PyTorch master
pointed to commit hash
957a07ffbd13d8a805f4d718e0282efc5d2bff85
at the time of writing. Any classes
implementing __torch_function__
based on the usage in this commit hash will
break completely, due to the differing signature of the protocol. However, as a
release hasn't been made with __torch_function__
in it, this is a minor-
impact issue. This brings the design of __torch_function__
more in line with
NumPy's __array_function__
, and one familiar with NumPy's protocol could
transition to PyTorch's take on it without too many surprises, with the caveat
that it could also receive methods rather than functions. The release that
__torch_function__
will make it into PyTorch is expected to be 1.5.0.
The implementation of this proposal will have no effect on how things interact with NumPy.
Subclasses are an important way to override functionality of classes. Given the
popularity of PyTorch, a number of subclasses have sprung up, both within and
outside PyTorch. It is important that functions operating on torch.Tensor
, as
well as methods on it, support passing through the appropriate subclasses,
otherwise information about which type was passed into the function is lost.
The same applies equally, if not more so, to operators and indexing.
In addition, there has been interest in adding a "universal hook" that operated
on both functions and methods, perhaps modifying the control flow before
returning the result. Such a hook already exists today in the form of
__torch_function__
, however, it only operates on functions and not on
methods, and support for subclassed torch.Tensor
objects in this protocol is
limited.
We propose the following signature change to __torch_function__
, to make it
match NumPy, other than the @classmethod
decorator: [4]
class SubTensor(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs):
# Implementation here
The reason for adding types
to the signature is necessitated so we can
check for support of the types if Tensor
-likes coming in and we do not
mix unrelated class trees.
The process followed during a function/method call would be equivalent to:
- The dispatcher is called to extract the
Tensor
-likes. - All
Tensor
-likes are checked for__torch_function__
. If none exist, the internal implementation is called, and the final result is returned. - A collection of types that implement
__torch_function__
is created, with no guaranteed order other than that subclasses come before superclasses. - For one instance of each type in
types
,__torch_function__
is called. The first such function or method to return something other thanNotImplemented
will be the final result. All exceptions will be propagated upward. - If all
__torch_function__
implementations returnNotImplemented
, aTypeError
is raised with an appropriate error message.
In practice, for most PyTorch functions, the list of tensor-likes is already
available and the dispatcher doesn't need to be called. Additionally, while
equivalent to the code above, if the Tensor
-likes are all Tensor
or don't have
an __torch_function__
implementation, the internal implementation is called
immediately. This is done as a performance optimisation to avoid overhead for
concrete Tensor
objects.
It will be the job of the dispatcher to extract Tensor
-like objects from the
argument list, however, arguments of type Optional[Tensor]
will be considered
Tensor
-like. If one gets a compound or dependent type such as List[Tensor]
or Tuple[Tensor, ...]
or Tuple[Tensor, int]
, the dispatcher will have the job
of extracting an iterable of objects that could be Tensor
-like.
torch.Tensor
will gain a generic __torch_function__
of the following form:
class Tensor:
@classmethod
def __torch_function__(cls, func, types, args, kwargs):
if not all(issubclass(cls, t) for t in types):
return NotImplemented
# Defer to internal implementation
ret = func._implementation(*args, **kwargs)
if cls is not Tensor and isinstance(ret, Tensor):
ret = ret.as_subclass(cls)
return ret
This method has the effect of passing through subclasses through all functions/methods as intended.
This corresponds roughly to the implementation numpy.ndarray
gains in [4],
except for the fact that subclasses are passed through via another internal
mechanism (namely the __array_finalize__
protocol) there, as well as the fact
that we are checking subclassing against cls
instead of Tensor
. This
has the side-effect of ensuring unrelated class trees are not merged, which is
an inconsistency in NumPy's own design. Specifically, consider the example of
two direct subclasses of torch.Tensor
. Both will return NotImplemented
, and
therefore, the check will fail and TypeError
will be raised.
Since subclasses are checked before superclasses in __torch_function__
, it is
guaranteed that the subclass implementation will be called first. In this
instance, since cls
is a subclass of all types, the code will
continue. Since cls
is not torch.Tensor
, a view into the original
data is created and returned.
This also works for all operators: __add__
, __getitem__
and so on since in
Python these operators are just dunder methods of the corresponding class.
One can check for compatibility with supported classes in the following manner:
class MyTensor:
HANDLED_CLASSES = (MyTensor, Tensor, ...)
@classmethod
def __torch_function__(cls, func, types, args, kwargs):
if not issubclass(t, HANDLED_CLASSES) for t in types:
return NotImplemented
# Do further processing here.
One can directly follow the following procedure to implement a subset of the API by using a hashmap to your own implementations of a function:
_TORCH_IMPLEMENTATIONS = {}
def implements(torch_function):
def inner(f):
_TORCH_IMPLEMENTATIONS[torch_function] = f
return f
return inner
@implements(torch.add)
def my_add(self, other):
# Implementation here
class MyTensor:
@classmethod
def __torch_function__(cls, func, types, args, kwargs):
compatible = ...
if not compatible:
return NotImplemented
if func not in _TORCH_IMPLEMENTATIONS:
return NotImplemented
return _TORCH_IMPLEMENTATIONS[func](*args, **kwargs)
To access super, one would do the following:
class SubTensor(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs):
# Pre-processing here
val = super().__torch_function__(
func,
types
args,
kwargs
)
# Post processing here
To make the need for super()
to be available concrete, let's consider the
following scenario:
class SubTensor(torch.Tensor):
@classmethod
def __torch_function__(...):
# Pre-processing
ret = super().__torch_function__(
func,
types
args,
kwargs
)
# Post processing
return ret
class SubSubTensor(SubTensor):
def __add__(self, other):
# Pre-processing
ret = super().__add__(other)
# Post-processing
return ret
In this instance, processing would follow the __torch_function__
protocol.
This means that control would end up in SubSubTensor.__add__
, go to Tensor._add__
,
SubTensor.__torch_function__
from there and and then come to
Tensor.__torch_function__
, from where it would go to Tensor.__add__
, and
then back up the stack in the reverse order. This means that great care needs
to be taken when writing SubTensor.__torch_function__
to take into account the fact that it has to handle subclass methods.
In general, control flow will follow this pattern:
The reason we use super().__torch_function__
instead of func
directly is
- We do not know if there are other
Tensor
-likes that may need to be handled. - Calling
func
directly would dispatch back to__torch_function__
, leading to an infinite recursion.
We will also recommend that all Tensor
subclasses make their own methods that
do not exist on torch.Tensor
go through __torch_function__
via a decorator
@torch_function_dispatch
. This decorator was added and then removed for
performance reasons, however it will be added back to allow external libraries
to interface with the protocol. It will take a single argument: a dispatcher,
i.e. a callable that returns an iterable of all the "duck-Tensor
s", or
possible candidates for classes that may implement __torch_function__
.
If a library forgets to add the aforementioned decorator, then the method will
no longer dispatch at all to any form of __torch_function__
. In other words,
it will lose support for the protocol. This can lead to confusion, as some
methods of the subclass will pass through __torch_function__
(the ones
inherited from torch.Tensor
), and some won't.
Note that subclasses will still be passed through due to the default
implementation of __torch_function__
, but any __torch_function__
defined on
the class itself (or any of its subclasses) won't have an effect on its
methods.
This is a design choice that a subclass author will have to make, whether they
prefer their own functions/methods to pass through __torch_function__
like
PyTorch's implementations, or whether they'd like ultimately to not support the
protocol and accept having a mix of overridable and non-overridable methods.
We do not propose automatic marking of functions with this decorator due to the potential backwards-compatibility break it could cause, as well as the parameters that are needed in order to allow this to happen (namely the dispatcher, which isn't in our control).
To construct the function given its __name__
and __module__
, one can do
the following, as an example:
def get_function(name, module):
func = __import__(module)
for n in name.split('.'):
func = getattr(func, n)
return func
The torch.Tensor.as_subclass
method will be added, taking a single non-self
argument: cls
, the class for which an instance will be created with a view
into the data of the original Tensor
. It will become public API. This method
will create an object that has the same data pointer as the original object,
which means that modifications to this will be reflected in the original object.
More or less, it will have the same effect as modifying an object's __class__
attribute in Python.
This method is already used in external libraries, and they may need it as a
way to e.g. bypass the processing of torch.Tensor.__torch_function__
entirely, while still creating torch.Tensor
subclasses in their own code.
To implement this proposal requires three main steps:
- Add a
types
argument to__torch_function__
and make sure that only arguments that are instances of a type intypes
are processed. - Making sure that all
Tensor
methods except__new__
and__init__
go through__torch_function__
. - Add
Tensor.as_subclass
and@torch_function_dispatch
as public API.
One can use the dictionary idiom to only implement some methods but not others. A code example follows:
HANDLED_FUNCTIONS = {}
def implements(func):
def inner(implementation):
HANDLED_FUNCTIONS[func] = implementation
return implementation
@implements(torch.add)
def my_add(self, other):
...
class TensorLike:
@classmethod
def __torch_function__(cls, func, types, args, kwargs):
implementation = HANDLED_FUNCTIONS.get(func, None)
if implementation is None:
return NotImplemented
return implementation(*args, **kwargs)
For subclasses, one can also choose to use the fallback implementation if
a specialized implementation isn't available using super
, as shown below.
class SubTensor(torch.Tensor):
@classmethod
def __torch_function__(cls, func, types, args, kwargs):
implementation = HANDLED_FUNCTIONS.get(func, None)
if implementation is None:
return super().__torch_function__(
func, types, args, kwargs
)
return implementation(*args, **kwargs)
A call to super().__torch_function__
can also be used to call the fallback
implementation within any other function.
The examples we have seen here actually specify what we anticipate will be two
common patterns of using __torch_function__
: LoggingTensor
is an example
of a global hook, and the two examples above show a way to achieve specialised
implementations of particular functions.
Sometimes it's useful to wrap torch.Tensor
rather than have a subclass.
The following class shows how this is possible in practice:
def wrap(f):
@functools.wraps(f)
def inner(self, *a, **kw):
# Call `f` with all-unwrapped args
# Possibly wrap back result before returning
class WrappedTensor:
def __init__(self, towrap: Tensor):
self._wrapped = towrap
def __getattr__(self, name):
base = getattr(torch.Tensor, name)
if not callable(base):
return property(wrap(base.__get__))
return wrap(base)
@classmethod
def __torch_function__(cls, func, types, args, kwargs):
return wrap(func)(*args, **kwargs)
One alternative that has been proposed is to automatically pass through
subclasses a-la NumPy and provide a __torch_finalize__
method that allows for
any post-processing of the result. While this would achieve most goals, it
would miss out on the one to provide a hook for methods and operators.
Both functions and methods/properties on torch.Tensor
will be possible arguments to
__torch_function__
. These are different in subtle but important ways, and
in some cases it is required to handle them differently. For instance,
torch.Tensor
methods/properties have the following properties:
- They can only accept
torch.Tensor
instances as the first argument. - They may or may not have a
__module__
defined.
Even classes implementing __torch_function__
that aren't subclasses
can have methods passed in. It is required to treat this case with care.
Consider the following code:
class TensorLike:
@classmethod
def __torch_function__(cls, func, types, args, kwargs):
print(func.__name__)
torch.tensor([5]) + TensorLike() # prints "add"
If, in this case, we are using the default implementation, of func
, and a
torch.Tensor
instance is not passed in, an error will be raised. To handle
this case, we have provided a utility method,
torch.overrides.is_tensor_method_or_property
, to determine whether something
is a torch.Tensor
method/property.
For properties, their __get__
method is passed in. For example,for
torch.Tensor.grad
, torch.Tensor.grad.__get__
is passed in as func
.