-
Notifications
You must be signed in to change notification settings - Fork 181
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Empty second-order derivative (= Hessian) for the segment_* reductions #299
Comments
As far as I can tell, the problem comes from the fact that In any case, here is a quick workaround: import torch
import torch_scatter
# For the CSR operator:
class SumCSR(torch.autograd.Function):
@staticmethod
def forward(ctx, values, groups):
ctx.save_for_backward(groups)
return torch_scatter.segment_csr(values, groups, reduce="sum")
@staticmethod
def backward(ctx, grad_output):
(groups,) = ctx.saved_tensors
return GatherCSR.apply(grad_output, groups), None
class GatherCSR(torch.autograd.Function):
@staticmethod
def forward(ctx, values, groups):
ctx.save_for_backward(groups)
return torch_scatter.gather_csr(values, groups)
@staticmethod
def backward(ctx, grad_output):
(groups,) = ctx.saved_tensors
return SumCSR.apply(grad_output, groups), None
# For the COO operator:
class SumCOO(torch.autograd.Function):
@staticmethod
def forward(ctx, values, groups, dim_size):
ctx.save_for_backward(groups)
ctx.dim_size = dim_size
return torch_scatter.segment_coo(
values, groups, dim_size=dim_size, reduce="sum"
)
@staticmethod
def backward(ctx, grad_output):
(groups,) = ctx.saved_tensors
return GatherCOO.apply(grad_output, groups, ctx.dim_size), None, None
class GatherCOO(torch.autograd.Function):
@staticmethod
def forward(ctx, values, groups, dim_size):
ctx.save_for_backward(groups)
ctx.dim_size = dim_size
return torch_scatter.gather_coo(values, groups)
@staticmethod
def backward(ctx, grad_output):
(groups,) = ctx.saved_tensors
return SumCOO.apply(grad_output, groups, ctx.dim_size), None, None Then, the code below runs just fine: # Values:
val = torch.FloatTensor([[0, 1, 2]])
# Groups:
gr_coo = torch.LongTensor([[0, 0, 1]])
gr_csr = torch.LongTensor([[0, 2, 3]])
val.requires_grad = True
B, D = val.shape
def group_reduce(*, values, groups, reduction, output_size, backend):
if backend == "torch":
# Compatibility switch for PyTorch.scatter_reduce:
if reduction == "max":
reduction = "amax"
return torch.scatter_reduce(
values, 1, groups, reduction, output_size=output_size
)
elif backend == "pyg":
return torch_scatter.scatter(
values, groups, dim=1, dim_size=output_size, reduce=reduction
)
elif backend == "coo":
return torch_scatter.segment_coo(
values, groups, dim_size=output_size, reduce=reduction
)
elif backend == "my_coo":
if reduction == "sum":
return SumCOO.apply(values, groups, output_size)
else:
return torch_scatter.segment_coo(
values, groups, dim_size=output_size, reduce=reduction
)
elif backend == "csr":
return torch_scatter.segment_csr(values, groups, reduce=reduction)
elif backend == "my_csr":
if reduction == "sum":
return SumCSR.apply(values, groups)
else:
return torch_scatter.segment_csr(values, groups, reduce=reduction)
else:
raise ValueError(
f"Invalid value for the scatter backend ({backend}), "
"should be one of 'torch', 'pyg', 'coo' or 'csr'."
)
for backend in ["torch", "pyg", "coo", "my_coo", "csr", "my_csr"]:
red = group_reduce(
values=val,
groups=gr_csr if "csr" in backend else gr_coo,
reduction="sum",
output_size=2,
backend=backend,
)
# Compute an arbitrary scalar value out of our reduction:
v = (red ** 2).sum(-1) + 0.0 * (val ** 2).sum(-1)
# Gradient:
g = torch.autograd.grad(v.sum(), [val], create_graph=True)[0]
# Hessian:
h = torch.zeros(B, D, D).type_as(val)
for d in range(D):
h[:, d, :] = torch.autograd.grad(g[:, d].sum(), [val], retain_graph=True)[0]
print(backend, ":")
print("Value:", v.detach().numpy())
print("Grad :", g.detach().numpy())
print("Hessian:")
print(h.detach().numpy())
print("--------------") With the following output (notice that torch :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[2. 2. 0.]
[2. 2. 0.]
[0. 0. 2.]]]
--------------
pyg :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[2. 2. 0.]
[2. 2. 0.]
[0. 0. 2.]]]
--------------
coo :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]]
--------------
my_coo :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[2. 2. 0.]
[2. 2. 0.]
[0. 0. 2.]]]
--------------
csr :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[0. 0. 0.]
[0. 0. 0.]
[0. 0. 0.]]]
--------------
my_csr :
Value: [5.]
Grad : [[2. 2. 4.]]
Hessian:
[[[2. 2. 0.]
[2. 2. 0.]
[0. 0. 2.]]]
-------------- Best regards, |
Thanks @jeanfeydy for this insightful thread. I will need to take a closer look at this. It looks like PyTorch C++ backward linkage is indeed different from Python backward linkage, not really sure why. Glad that you already found a way to fix this on your end. |
This issue had no activity for 6 months. It will be closed in 2 weeks unless there is some new activity. Is this issue already resolved? |
Hi @rusty1s,
Thanks again for your great work on this library!
I am currently experimenting with computing second-order derivatives that involve torch_scatter operations, and noticed that the
segment_coo
andsegment_csr
operators are not twice differentiable with thesum
reduction. To reproduce this behavior, see e.g.:The output shows that
torch_scatter.scatter
andtorch.scatter_reduce
coincide on all derivatives, while the twosegment_*
implementations have a Null derivative at order 2:Is this expected behavior on your side?
Support for order-two derivatives would be especially useful to perform e.g. Newton optimization.
I'm sure that I could hack something with a
torch.autograd.Function
wrapper for my own use-case, but a proper fix would certainly be useful to other people. Unfortunately, I am not familiar enough with the PyTorch C++ API to fix e.g. segment_csr.cpp myself and write a Pull Request for this :-(What do you think?
The text was updated successfully, but these errors were encountered: