Replies: 2 comments 1 reply
-
The reason I didn't take this seriously 4 months ago was because the "holomorphic" requirement for the complex function seemed like a tight constraint that few functions would satisfy. What changed my mind is that regular autodiff blithely ignores not-actually-differentiable points such as In particular, logarithms have a branch cut where they fail to be holomorphic. What does this method do there? >>> import numpy as np
>>> from eager_forward import *
>>> x = np.linspace(0, 100, 10000)
>>> da_x = diffarray(x)
>>> da_y = np.log(da_x)
>>> truth = 1/x
<stdin>:1: RuntimeWarning: divide by zero encountered in divide
>>> da_y.tangent
array([1.57079633e+04, 9.99866679e+01, 4.99945835e+01, ...,
1.00020006e-02, 1.00010002e-02, 1.00000000e-02])
>>> truth
array([ inf, 9.99900000e+01, 4.99950000e+01, ...,
1.00020006e-02, 1.00010002e-02, 1.00000000e-02])
>>> abs(da_y.tangent - truth)[1:].max()
0.0033321335475875458
>>> abs(da_y.tangent - truth)[1:].argmax()
0 Instead of reporting How about using the logarithm with chain rule? >>> import numpy as np
>>> import matplotlib.pyplot as plt
>>> from eager_forward import *
>>> x = np.linspace(-10, 10, 10000)
>>> da_x = diffarray(x)
>>> da_y = np.log(np.sin(da_x)**2)
>>> truth = 2*np.sin(x)*np.cos(x) / np.sin(x)**2
>>> da_y.tangent
array([-3.08470206, -3.09826064, -3.11190351, ..., 3.11190351,
3.09826064, 3.08470206])
>>> truth
array([-3.08470209, -3.09826068, -3.11190355, ..., 3.11190355,
3.09826068, 3.08470209])
>>> plt.plot(x, da_y.tangent)
[<matplotlib.lines.Line2D object at 0x7d67681b0a50>]
>>> plt.plot(x, truth, ls="--")
[<matplotlib.lines.Line2D object at 0x7d67681b1650>]
>>> plt.show() |
Beta Was this translation helpful? Give feedback.
-
About ReLU'(0), @pfackeldey pointed me to this: https://dl.acm.org/doi/10.5555/3540261.3540297 But from my perspective, the ReLU'(0) case is just what changed my mind about considering this technique that has known errors for functions that aren't perfectly smooth. If we can get away with saying that ReLU'(0) has a value, what else can we get away with? |
Beta Was this translation helpful? Give feedback.
-
As presented at https://indico.cern.ch/event/1387764/
and discussed in https://iris-hep.slack.com/archives/C0155BGPGE4/p1734127343384079
The idea is to replace Awkward Array's JAX backend with
This is based on the observation that autodiff is not hard to implement, and all of the troubles we're going through to partially implement autodiff through JAX might be more easily solved by a custom autodiff with a friendlier API.
Here's a complete demonstrator of eager, forward-mode autodiff using the complex-step technique:
awkward/studies/autodiff/eager_forward.py
Lines 11 to 123 in c59a49c
which can be called like
to get
However, we'll probably need backpropagation, which will require some sort of DAG to reverse the steps in the calculation, and maybe we'd use our typetracers or dask-awkward to make that DAG. (Our typetracers can be used to make a low-level DAG, with one node per buffer operation, whereas a dask-awkward DAG is high-level. Thinking about this now, I think it has to be a low-level DAG.)
I want to ask everyone (@matthewfeickert, @kratsg, @pfackeldey, @alexander-held, @gordonwatts) whether this is sufficient, and generally what your thoughts are on requirements. I'll try answering a few questions that came up in the talk in subsequent comments.
Beta Was this translation helpful? Give feedback.
All reactions