Skip to content
This repository has been archived by the owner on Aug 7, 2024. It is now read-only.

Try dispatching to the decomposed OpOverload to account for inference mode #251

Open
wants to merge 1 commit into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions float8_experimental/float8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -325,6 +325,8 @@ def allowed_subclasses(type):

if func in FLOAT8_OPS_TABLE:
return FLOAT8_OPS_TABLE[func](func, args, kwargs)
else:
return func.decompose(*args, *kwargs)
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure this is what you want, it directly invokes the CompositeImplicitAutograd implementation

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is exactly what they want :D
What happens if there is no CompositeImplicitAutograd impl?

Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it returns NotImplemented

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah I figure that it we can noop ( execept not implemented) slide our way down by trying to decompose the op

  • Either it wont be decomposable, and return not implemented, cool same place as after this
  • or it is decomposable:
    • Run supported op, Great!
    • Not supported and ultimately end back up with a not implemented

raise NotImplementedError(f"attempting to run {func}, this is not supported")

# Do not force the Float8Tensor type on the returned tensor
Expand Down
Loading