diff --git a/torch2jax/__init__.py b/torch2jax/__init__.py index 905a8ba..02ba559 100644 --- a/torch2jax/__init__.py +++ b/torch2jax/__init__.py @@ -136,11 +136,11 @@ def div_(self, other): coerce = lambda x: Torchish(x).value -def implements(torch_function, JAXishify_output=True): +def implements(torch_function, Torchishify_output=True): """Register a torch function override for Torchish""" def decorator(func): - func1 = (lambda *args, **kwargs: Torchish(func(*args, **kwargs))) if JAXishify_output else func + func1 = (lambda *args, **kwargs: Torchish(func(*args, **kwargs))) if Torchishify_output else func functools.update_wrapper(func1, torch_function) HANDLED_FUNCTIONS[torch_function] = func1 return func1 @@ -571,7 +571,7 @@ def scaled_dot_product_attention(query, key, value, attn_mask=None, dropout_p=0. # NOTE: the "torch.Tensor" type annotations here are a lie, or at least an approximation: In reality, they can be # anything coerce-able. -@implements(torch.nn.functional.multi_head_attention_forward, JAXishify_output=False) +@implements(torch.nn.functional.multi_head_attention_forward, Torchishify_output=False) def multi_head_attention_forward( query: torch.Tensor, key: torch.Tensor,