diff --git a/float8_experimental/float8_tensor.py b/float8_experimental/float8_tensor.py index 7287c11..27117d2 100644 --- a/float8_experimental/float8_tensor.py +++ b/float8_experimental/float8_tensor.py @@ -325,6 +325,8 @@ def allowed_subclasses(type): if func in FLOAT8_OPS_TABLE: return FLOAT8_OPS_TABLE[func](func, args, kwargs) + else: + return func.decompose(*args, *kwargs) raise NotImplementedError(f"attempting to run {func}, this is not supported") # Do not force the Float8Tensor type on the returned tensor