-
Notifications
You must be signed in to change notification settings - Fork 4
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
ReverseDiff gradients #93
Comments
I don't remember having tried reverse diff. I'll take a look at it asap. |
Here is the situation. Currently, you cannot really reverse-differentiate easily through the whole construction of the cell lists, but you can bypass all that and differentiate the computation of the objective function, if the coordinates are provided (redundantly) as a closure. For example, consider the following simple function, which sums the squared distance between particle coordinates: sum_sqr(d2, s) = s += d2 Which could be mapped to all pairs of particles (here constructed as matrices of size coordinates = rand(3,1000)
box = Box([1,1,1], 0.05)
cl = CellList(coordinates,box)
map_pairwise( (_, _, _, _, d2, s) -> sum_sqr(d2, s), 0.0, box, cl) (the This can be forward-differentiated as shown in the manual, but reverse differentiation does not work, basically because the construction of the cell lists requires mutation of arrays, and the current infrastructures do not support that (maybe Enzyme could do it, but there is a simpler alternative). The trick is to define a function that uses only the indexes of the particles, and compute the property of interest, from the particles, using the coordinates provided in a closure. That is, the above function would be defined as: sum_sqr(i, j, s, coordinates) = s += sum(abs2, @views(coordinates[:,i] - coordinates[:,j])) Note that coordinates = rand(3,1000)
box = Box([1,1,1], 0.05)
cl = CellList(coordinates, box)
map_pairwise( (_, _, i, j, _, s) -> sum_sqr(i, j, s, coordinates), 0.0, box, cl) Now we use the julia> function sum_sqr(coordinates, box, cl)
sum_sqr = map_pairwise!(
(_, _, i, j, _, sum_sqr) -> sum_sqr += sum(abs2, @views(coordinates[:,i] - coordinates[:,j])),
zero(eltype(coordinates)), box, cl,
)
return sum_sqr
end
sum_sqr (generic function with 3 methods) And this can be both forward and reverse- differentiated: julia> using ForwardDiff, ReverseDiff
julia> coordinates = rand(3,1000);
julia> box = Box([1,1,1], 0.05);
julia> cl = CellList(coordinates, box);
julia> gr = ReverseDiff.gradient( (x) -> sum_sqr(x,box,cl), coordinates)
3×1000 Matrix{Float64}:
-0.0875518 0.0 0.0 -0.0848635 0.0 0.0 0.0 … 0.0 0.0 0.00620111 -0.0335964 0.0 0.0234671
-0.0442765 0.0 0.0 -0.0575914 0.0 0.0 0.0 0.0 0.0 -0.014573 0.0649681 0.0 0.0607623
0.000192838 0.0 0.0 -0.0623673 0.0 0.0 0.0 0.0 0.0 0.0316766 0.0451708 0.0 -0.00635728
julia> gr = ForwardDiff.gradient( (x) -> sum_sqr(x,box,cl), coordinates)
3×1000 Matrix{Float64}:
-0.0875518 0.0 0.0 -0.0848635 0.0 0.0 0.0 … 0.0 0.0 0.00620111 -0.0335964 0.0 0.0234671
-0.0442765 0.0 0.0 -0.0575914 0.0 0.0 0.0 0.0 0.0 -0.014573 0.0649681 0.0 0.0607623
0.000192838 0.0 0.0 -0.0623673 0.0 0.0 0.0 0.0 0.0 0.0316766 0.0451708 0.0 -0.00635728 As expected, reverse differentiation is much faster here: julia> revg(coordiantex, box, cl) = ReverseDiff.gradient( (x) -> sum_sqr(x,box,cl), coordinates)
revg (generic function with 1 method)
julia> forg(coordinates, box, cl) = ForwardDiff.gradient( (x) -> sum_sqr(x,box,cl), coordinates)
forg (generic function with 1 method)
julia> @btime revg($coordinates, $box, $cl);
550.999 μs (11188 allocations: 567.33 KiB)
julia> @btime forg($coordinates, $box, $cl);
60.827 ms (73754 allocations: 24.44 MiB) |
Hi, thanks a lot for your answer! I've been able to replicate your example for some similar data. But sometimes I get ...
main at julia (unknown line)
__libc_start_main at /lib64/libc.so.6 (unknown line)
unknown function (ip: 0x401098)
Allocations: 282841011 (Pool: 282818657; Big: 22354); GC: 103
Segmentation fault (core dumped) My particular application is a histogram (or two-point function) that I would like to differentiate through but I keep getting the Segfault. Could it be because of the size of my data (600k coordinates)? Though that would still not explain why it segfaults with 1k coordinates. |
Segfaults usually are related to some corrupted memory access, and not because of the size of the data. When the data is too big to fit in memory you get Without further details, I can't speculate on what may be going on there. One thing that may be related is that when running CellLIstMap in parallel, there is a machinery to avoid concurrency among threads, which I don't know if the differentiation routines can handle that properly. One test is to run the calculations without parallelization. I've made a small test here, and the results in that simple example are the same. But in your case you are probably updating a shared julia> function sum_sqr(coordinates, box, cl; parallel=true)
sum_sqr = map_pairwise!(
(_, _, i, j, _, sum_sqr) -> sum_sqr += sum(abs2, @views(coordinates[:,i] - coordinates[:,j])),
zero(eltype(coordinates)), box, cl; parallel=parallel
)
return sum_sqr
end
sum_sqr (generic function with 1 method)
julia> coordinates = rand(3,5000);
julia> box = Box([1,1,1], 0.05);
julia> cl = CellList(coordinates, box);
julia> ReverseDiff.gradient(x -> sum_sqr(x, box, cl; parallel=true), coordinates)
3×5000 Matrix{Float64}:
-0.0344465 0.0 -0.00210207 -0.0551787 -0.0152113 0.0472379 … 0.0680135 -0.0428575 0.0 0.0715126 -2.16583
0.0527802 0.0 -0.0280015 -0.0101965 -0.0427886 -0.0660361 0.10064 0.00204247 0.0 -0.0447712 0.0216365
0.0416164 0.0 0.103765 -0.0231282 -0.0654189 -0.00256454 -0.0578568 -3.82477 0.0 0.10979 0.0366878
julia> ReverseDiff.gradient(x -> sum_sqr(x, box, cl; parallel=false), coordinates)
3×5000 Matrix{Float64}:
-0.0344465 0.0 -0.00210207 -0.0551787 -0.0152113 0.0472379 … 0.0680135 -0.0428575 0.0 0.0715126 -2.16583
0.0527802 0.0 -0.0280015 -0.0101965 -0.0427886 -0.0660361 0.10064 0.00204247 0.0 -0.0447712 0.0216365
0.0416164 0.0 0.103765 -0.0231282 -0.0654189 -0.00256454 -0.0578568 -3.82477 0.0 0.10979 0.0366878
julia> ReverseDiff.gradient(x -> sum_sqr(x, box, cl; parallel=false), coordinates) ≈
ReverseDiff.gradient(x -> sum_sqr(x, box, cl; parallel=true), coordinates)
true |
In fact, in a simple Histogram-like function, ReverseDiff fails, even serially: julia> function hist(coordinates, box, cl; parallel=true)
h = map_pairwise!(
(_, _, i, j, d2, h) -> begin
if sqrt(d2) < box.cutoff / 2
h[1] += sum(abs2, @views(coordinates[:,i] - coordinates[:,j]))
else
h[2] += sum(abs2, @views(coordinates[:,i] - coordinates[:,j]))
end
return h
end,
zeros(eltype(coordinates), 2), box, cl; parallel=parallel
)
return h
end
hist (generic function with 1 method)
julia> hist(coordinates, box, cl)
2-element Vector{Float64}:
406.8648388535888
3223.091284798098
julia> ReverseDiff.gradient(x -> hist(x, box, cl; parallel=false), coordinates)
ERROR: DimensionMismatch: new dimensions (2, 10000) must be consistent with array size 10000 If you change that to compute the histogram by passing the Maybe one alternative is to compute each bin of the histogram independently, as that would provide scalar returns to the function. I'm not an specialist in autodiff to be more precise about what to suggest there. |
I see, however it seems you get a different kind of error as I do (you get |
Can you share something about your code, such to at least we can localize the issue? |
Sure! here are my "core" functions. I believe the issue can be reproduced with randomly distributed particles in a box since my dataset is a bit large (though the "before" dataset I shared in a past issue may work too): bin_edges = 10 .^range(-2, stop = log10(50), length=11)
positions = 2000. .* rand(3, 600000)
box_size = [2e3 for _ = 1:3]
box = Box(box_size, 5.)
cl = CellList(positions, box)
function coordinate_separation(a, b, box_size)
delta = abs(a - b)
return (delta > 0.5*box_size ? delta - box_size : delta)*sign(a-b)
end
function diff_build_histogram!(i, j ,hist, coordinates, bin_edges, box_size)
d2 = sum(abs2, coordinate_separation.(view(coordinates, : , i), view(coordinates, :, j), box_size))
ibin = searchsortedlast(bin_edges, sqrt(d2))
if (ibin > 0) && ibin <= length(bin_edges)
hist[ibin] += 1
end #if
return hist
end
function loss(positions, box, cl, bin_edges, box_size)
hist = zeros(Int,size(bin_edges,1)-1);
println("Counting pairs...")
# Run calculation
map_pairwise!(
(_, _, i, j, _, hist) -> diff_build_histogram!(i, j, hist, positions, bin_edges, box_size),
hist, box, cl; show_progress = true
)
println("Done")
N = size(positions,2)
hist = hist / (N * (N - 1))
norm = @. (4/3) * π * (bin_edges[2:end]^3 -bin_edges[1:end-1]^3) / (box_size[1] * box_size[2] * box_size[3])
hist ./= norm
mean(abs.(hist - xi_ref)) # I think for testing purposes xi_ref can be 0.
end #func
ReverseDiff.gradient((x) -> loss(x, box, cl, bin_edges, box_size), positions) |
Sorry, actually in my last test it seems to work (I did deactivate parallelization). All gradients seem to be are 0 but that may be because the histogram is just not differentiable. |
I would try to compute a single count (of one bin) in a regular scalar variable to see how that works. Then, if that works, maybe it is possible to create the histogram with an immutable structure (a Svector, for example). |
Just to add, if I compute a single bin of the histogram, the differentiation apparently works, but returns, all zeros, as you observed. I'm not sure if this is correct: julia> using CellListMap, LinearAlgebra
julia> function hist(coordinates, box, cl; parallel=true)
h = map_pairwise!(
(_, _, i, j, _, h) -> begin
d = norm(@views(coordinates[:,i] - coordinates[:,j]))
if d < box.cutoff / 2
h += 1
#else
# h[2] += 1
end
return h
end,
0, box, cl; parallel=parallel
)
return h
end
hist (generic function with 1 method)
julia> coordinates = rand(3,1000);
julia> box = Box([1,1,1], 0.05);
julia> cl = CellList(coordinates, box);
julia> hist(coordinates, box, cl)
24
julia> all(==(0), ReverseDiff.gradient(x -> hist(x, box, cl; parallel=false), coordinates))
true
julia> all(==(0), ForwardDiff.gradient(x -> hist(x, box, cl; parallel=false), coordinates))
true (example fixed @dforero0896) |
Indeed it seems to work. The issue of the zeros is just that this "exact" way of histogramming is not differentiable. An approximate histogram could be built with something like function diff_build_histogram!(i, j ,hist, coordinates, bin_widths, box_size, bin_centers)
d2 = sum(abs2, coordinate_separation.(view(coordinates, : , i), view(coordinates, :, j), box_size))
hist .+= exp(-((sqrt(d2) .- bin_centers) ./ bin_widths).^2)
return hist
end So it is clear how the end product depends on the coordinates. |
Yes, cool, I was thinking about that problem. Exactly, the histogram has a zero gradient because no infinitesimal move of of the particles will cause a particle to change from one bin to the other. No only the derivative is discontinuous, but mostly it is zero. The problem of obtaining a differentiable distribution is indeed interesting. Thanks for posting. I will update the docs with some examples that came out of this discussion, and will close the issue when I do that, thank you very much for the feedback. It will be useful for others to know to apply ReverseDiff here. |
Glad my question was helpful. There are some other packages that have implemented differentiable histogramming in other ways. May be useful for someone looking into this too. |
Hi, I wanted to try getting some gradients from a function involving
map_pairwise
as I saw on the docs that automatic differentiation was available. The issue is my input consists in hundreds of thousands of variables (a 3x~1e5 Matrix) and my output is a loss score, so ForwardDiff is quite inefficient. I tried just replacing it with ReverseDiff but I got thisand with Zygote
My question would be first, if it is even possible to use reverse-mode differentiation with CellListMap? If so, is it possible to add some examples to the docs on how to do so? The type-conversion trick used for ForwardDiff does not work.
Thanks in advance for your help.
The text was updated successfully, but these errors were encountered: