-
Notifications
You must be signed in to change notification settings - Fork 106
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add docs on implementing Pytorch Ops (and CumOp) #837
Conversation
dim = op.axis | ||
mode = op.mode | ||
|
||
def cumop(x, dim=dim, mode=mode): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not needed, the returned functions are never called by the user
def cumop(x, dim=dim, mode=mode): | |
def cumop(x): |
tests/link/pytorch/test_extra_ops.py
Outdated
# Create test value tag for a | ||
a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No need for test values and tags. We're planning to deprecate that functionality as well
tests/link/pytorch/test_extra_ops.py
Outdated
# For the second mode of CumOp | ||
out = pt.cumprod(a, axis=1) | ||
fgraph = FunctionGraph([a], [out]) | ||
compare_pytorch_and_py(fgraph, [get_test_value(i) for i in fgraph.inputs]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Here just pass the test values (instead of adding them as tags and then retrieving them)
tests/link/pytorch/test_extra_ops.py
Outdated
a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) | ||
|
||
# Create the output variable | ||
out = pt.cumsum(a, axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Test axis=None
and axis=tuple(...)
if supported by the original Op. If tuple is allowed make sure you have more dimensions (say 3) and only ask for a subset (say 2) of them in the axis. This is to make sure you test something that is different than axis=None or axis=int.
The axis can be parametrized (prod and add as well) instead of adding more conditions inside the test
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Tried this on the original Op. axis=tuple(...)
does not work and gives a TypeError
axis=None
gives the output as a 1-D array
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The Op __init__
doesn't seem to check explicitly for axes but it does assume it is either None or an int. Can we add a check and raise an explicit ValueError if it's not either?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Checked again, there is no error if we use axis=(0)
, pytorch also returns the same output.
The error only comes when there are more than 1 elements in the tuple (Even np.cumsum
gives TypeError
in this case).
We could try adding a check and raise, but would that be needed in other Op implementations?
Since this would be used as an example, it might be complicated if a check and raise is not needed for other implementations.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
(0)
is 0
, not a tuple with a 0 inside it, it would have to be (0,)
to be a tuple with a single element inside. Does it work with (0,)
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No, it gives a TypeError
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which is fine but probably gives a typeerror in an obscure place. We should raise already in the init method of the Op to save people time
Can you extend the example in the documentation page on implementing custom JAX/NUMBA Ops to mention PyTorch and include this example as well? Perhaps you can use some fancy tab to select among the different modes in the same documentation page. Is that supported @OriolAbril ? |
Not here as of now, you'd have to add an extra extension for tabs. If you'll only want tabs, then it is probably best to use https://sphinx-tabs.readthedocs.io/en/latest/, if using things like grids, dropdowns, icons... somewhere else in addition to tabs here seems a future possibility then https://sphinx-design.readthedocs.io/en/sbt-theme/ is probably best. Both should only require being added as dependencies to the doc env and adding them to the |
Thanks @OriolAbril either of those seems perfect. Any preference? |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #837 +/- ##
==========================================
+ Coverage 80.87% 80.97% +0.10%
==========================================
Files 168 170 +2
Lines 46950 47044 +94
Branches 11472 11504 +32
==========================================
+ Hits 37972 38096 +124
+ Misses 6766 6734 -32
- Partials 2212 2214 +2
|
dim = op.axis | ||
mode = op.mode | ||
|
||
def cumop(x, dim=dim): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just no need for any kwargs. The function will only ever receive the node inputs
def cumop(x, dim=dim): | |
def cumop(x): |
I use sphinx-design more because I use its other features |
tests/link/pytorch/test_extra_ops.py
Outdated
# Create a symbolic input for the first input of `CumOp` | ||
a = pt.matrix("a") | ||
|
||
# Create test value tag for a |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
# Create test value tag for a | |
# Create test value |
tests/link/pytorch/test_extra_ops.py
Outdated
a.tag.test_value = np.arange(9, dtype=config.floatX).reshape((3, 3)) | ||
|
||
# Create the output variable | ||
out = pt.cumsum(a, axis=0) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Which is fine but probably gives a typeerror in an obscure place. We should raise already in the init method of the Op to save people time
pytensor/tensor/extra_ops.py
Outdated
@@ -283,8 +283,11 @@ class CumOp(COp): | |||
def __init__(self, axis: int | None = None, mode="add"): | |||
if mode not in ("add", "mul"): | |||
raise ValueError(f'{type(self).__name__}: Unknown mode "{mode}"') | |||
self.axis = axis | |||
self.mode = mode | |||
if isinstance(axis, int) or axis is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nitpick, it's more common to just check and raise than indenting the "correct code" and raising otherwise
if isinstance(axis, int) or axis is None: | |
if not (isinstance(axis, int) or axis is None): | |
# raise error | |
# usual code |
That's how the error check above for the mode is structured as well
return res if n_outs > 1 else res[0] | ||
.. tab-set:: | ||
|
||
.. tab-item:: JAX/Numba |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not correct for Numba, can you leave it as a separate tab with [in progress] text (and open an issue) or check the source code of the Numba implementation if you want to do it correctly?
This probably applies to all the tabbed sections, no reason to combine jax and numba, and the pre-existing snippets were JAX specific
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I was adding a separate tab for numba and found this comment. Is there anything that should be changed in numba_funcify_DimShuffle
?
pytensor/pytensor/link/numba/dispatch/elemwise.py
Lines 688 to 690 in 7159215
# FIXME: Numba's `array.reshape` only accepts C arrays. | |
res_reshape = np.reshape(np.ascontiguousarray(x), new_shape) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Not in the context of this PR, but we should open an issue here to check if that's still a problem in the newer versions of numba. Could you do that?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
all tabs look rendered correctly, only left a comment so cross references to other libraries actually work
function that performs exactly the same computations as the :class:`Op`. For | ||
example, the :class:`Eye` operator has a JAX equivalent: :func:`jax.numpy.eye` | ||
(see `the documentation <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.eye.html?highlight=eye>`_). | ||
(see `the documentation <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.eye.html?highlight=eye>`_) and a Pytorch equivalent :func:`torch.eye` (see `documentation <https://pytorch.org/docs/stable/generated/torch.eye.html>`_). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This looks like this:
which is quite the weird pattern for docs, especially given jax.numpy.eye
and torch.eye
are already using the correct cross-referencing syntax. I would remove the manual links and use the cross-references. That is, leaving only this:
(see `the documentation <https://jax.readthedocs.io/en/latest/_autosummary/jax.numpy.eye.html?highlight=eye>`_) and a Pytorch equivalent :func:`torch.eye` (see `documentation <https://pytorch.org/docs/stable/generated/torch.eye.html>`_). | |
and a Pytorch equivalent :func:`torch.eye`. |
And doing two more changes to conf.py
. First add sphinx.ext.intersphinx
to the list of extensions. It is part of the main sphinx library so no need to add any extra dependency to the env file. Add
intersphinx_mapping = {
"jax": ("https://jax.readthedocs.io/en/latest", None),
"numpy": ("https://numpy.org/doc/stable", None),
"torch": ("https://pytorch.org/docs/stable", None),
}
with that, the jax.numpy.eye
and torch.eye
will still be formatted as monospaced text but no longer be pink, they'll be blue and be clickable links to their respective API pages.
Description
This PR can be used as an example for implementing
Op
s in PyTorchRelated Issue
Checklist
Type of change
cc @ricardoV94