Skip to content

Commit

Permalink
enforce no register overwrite
Browse files Browse the repository at this point in the history
  • Loading branch information
pmeier committed Jul 27, 2023
1 parent bbaa35c commit ca4ad32
Showing 1 changed file with 12 additions and 2 deletions.
14 changes: 12 additions & 2 deletions torchvision/transforms/v2/functional/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,11 @@ def wrapper(inpt, *args, **kwargs):

return wrapper

registry = _KERNEL_REGISTRY.get(dispatcher, {})
registry = _KERNEL_REGISTRY.setdefault(dispatcher, {})
if datapoint_cls in registry:
raise TypeError(
f"Dispatcher '{dispatcher.__name__}' already has a kernel registered for type '{datapoint_cls.__name__}'."
)

def decorator(kernel):
registry[datapoint_cls] = datapoint_wrapper(kernel) if datapoint_wrapping else kernel
Expand All @@ -34,9 +38,15 @@ def _noop(inpt, *args, **kwargs):


def _get_kernel(dispatcher, datapoint_cls):
registry = _KERNEL_REGISTRY.get(dispatcher, {})
registry = _KERNEL_REGISTRY.get(dispatcher)
if not registry:
raise ValueError(f"No kernel registered for dispatcher '{dispatcher.__name__}'.")

if datapoint_cls in registry:
return registry[datapoint_cls]

for registered_cls, kernel in registry.items():
if issubclass(datapoint_cls, registered_cls):
return kernel

return _noop

0 comments on commit ca4ad32

Please sign in to comment.