Skip to content

Commit

Permalink
Add autograd support for sampled_op (#159)
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s authored Dec 6, 2022
1 parent 08d02e9 commit 8f5b0b9
Show file tree
Hide file tree
Showing 3 changed files with 156 additions and 10 deletions.
2 changes: 1 addition & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/).

## [0.2.0] - 2023-MM-DD
### Added
- Added `sampled_op` impementation ([#156](https://github.com/pyg-team/pyg-lib/pull/156))
- Added `sampled_op` impementation ([#156](https://github.com/pyg-team/pyg-lib/pull/156), [#159](https://github.com/pyg-team/pyg-lib/pull/159))
### Changed
- Improved `[segment|grouped]_matmul` CPU implementation via `at::matmul_out` and MKL BLAS `gemm_batch` ([#146](https://github.com/pyg-team/pyg-lib/pull/146))
### Removed
Expand Down
114 changes: 114 additions & 0 deletions pyg_lib/csrc/ops/autograd/sampled_kernel.cpp
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
#include "../sampled.h"

#include <torch/autograd.h>

namespace pyg {
namespace ops {

namespace {

using torch::autograd::Variable;
using torch::autograd::variable_list;

class SampledOp : public torch::autograd::Function<SampledOp> {
public:
static variable_list forward(torch::autograd::AutogradContext* ctx,
const Variable& left,
const Variable& right,
const at::optional<at::Tensor> left_index,
const at::optional<at::Tensor> right_index,
const std::string fn) {
at::AutoDispatchBelowADInplaceOrView g;
Variable out = sampled_op(left, right, left_index, right_index, fn);
ctx->saved_data["has_left_index"] = left_index.has_value();
ctx->saved_data["has_right_index"] = right_index.has_value();
ctx->saved_data["fn"] = fn;
ctx->save_for_backward({
left, right,
left_index.has_value() ? left_index.value() : left, // dummy
right_index.has_value() ? right_index.value() : right, // dummy
});
return {out};
}

static variable_list backward(torch::autograd::AutogradContext* ctx,
variable_list grad_outs) {
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();

auto left = saved[0];
auto right = saved[1];
at::optional<at::Tensor> left_index = at::nullopt;
if (ctx->saved_data["has_left_index"].toBool()) {
left_index = saved[2];
}
at::optional<at::Tensor> right_index = at::nullopt;
if (ctx->saved_data["has_right_index"].toBool()) {
right_index = saved[3];
}
auto fn = ctx->saved_data["fn"].toStringRef();

auto grad_left = Variable();
if (torch::autograd::any_variable_requires_grad({left})) {
grad_left = grad_out;

if (fn == "mul") {
grad_left =
sampled_op(grad_left, right, at::nullopt, right_index, "mul");
} else if (fn == "div") {
grad_left =
sampled_op(grad_left, right, at::nullopt, right_index, "div");
}

if (left_index.has_value()) {
grad_left = at::index_select_backward(grad_left, left.sizes(), 0,
left_index.value());
}
}

auto grad_right = Variable();
if (torch::autograd::any_variable_requires_grad({right})) {
grad_right = grad_out;

if (fn == "sub" && grad_out.size(0) <= right.size(0)) {
grad_right = -grad_right;
} else if (fn == "mul") {
grad_right =
sampled_op(grad_right, left, at::nullopt, left_index, "mul");
} else if (fn == "div") {
auto tmp = sampled_op(left, right, left_index, right_index, "div");
tmp = sampled_op(tmp, right, at::nullopt, right_index, "div");
grad_right = -grad_right * tmp;
}

if (right_index.has_value()) {
grad_right = at::index_select_backward(grad_right, right.sizes(), 0,
right_index.value());
}

if (fn == "sub" && grad_out.size(0) > right.size(0)) {
grad_right = -grad_right;
}
}

return {grad_left, grad_right, Variable(), Variable(), Variable()};
}
};

at::Tensor sampled_op_autograd(const at::Tensor& left,
const at::Tensor& right,
const at::optional<at::Tensor> left_index,
const at::optional<at::Tensor> right_index,
const std::string fn) {
return SampledOp::apply(left, right, left_index, right_index, fn)[0];
}

} // namespace

TORCH_LIBRARY_IMPL(pyg, Autograd, m) {
m.impl(TORCH_SELECTIVE_NAME("pyg::sampled_op"),
TORCH_FN(sampled_op_autograd));
}

} // namespace ops
} // namespace pyg
50 changes: 41 additions & 9 deletions test/csrc/ops/test_sampled.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,30 +10,62 @@ TEST_P(MultipleDeviceTest, SampledOpTest) {
const auto param = ::testing::TestWithParam<c10::DeviceType>::GetParam();
auto options = at::TensorOptions().device(param);

at::Tensor a, b, out;
at::Tensor a, b, out, grad_a, grad_b, grad_out;
auto a_index = at::tensor({0, 1, 3}, options.dtype(at::kLong));
auto b_index = at::tensor({3, 4, 5}, options.dtype(at::kLong));

a = at::randn({3, 8}, options);
b = at::randn({3, 8}, options);
a = at::randn({3, 8}, options).requires_grad_();
b = at::randn({3, 8}, options).requires_grad_();
grad_out = at::randn({3, 8}, options);
out = pyg::ops::sampled_op(a, b, c10::nullopt, c10::nullopt, "add");
out.backward(grad_out);
grad_a = a.grad().clone(), grad_b = b.grad().clone();
EXPECT_TRUE(at::allclose(out, a + b));
a.grad().fill_(0);
b.grad().fill_(0);
(a + b).backward(grad_out);
EXPECT_TRUE(at::allclose(grad_a, a.grad()));
EXPECT_TRUE(at::allclose(grad_b, b.grad()));

a = at::randn({6, 8}, options);
b = at::randn({3, 8}, options);
a = at::randn({6, 8}, options).requires_grad_();
b = at::randn({3, 8}, options).requires_grad_();
grad_out = at::randn({3, 8}, options);
out = pyg::ops::sampled_op(a, b, a_index, c10::nullopt, "sub");
out.backward(grad_out);
grad_a = a.grad().clone(), grad_b = b.grad().clone();
EXPECT_TRUE(at::allclose(out, a.index_select(0, a_index) - b));
a.grad().fill_(0);
b.grad().fill_(0);
(a.index_select(0, a_index) - b).backward(grad_out);
EXPECT_TRUE(at::allclose(grad_a, a.grad()));
EXPECT_TRUE(at::allclose(grad_b, b.grad()));

a = at::randn({3, 8}, options);
b = at::randn({6, 8}, options);
a = at::randn({3, 8}, options).requires_grad_();
b = at::randn({6, 8}, options).requires_grad_();
grad_out = at::randn({3, 8}, options);
out = pyg::ops::sampled_op(a, b, c10::nullopt, b_index, "mul");
out.backward(grad_out);
grad_a = a.grad().clone(), grad_b = b.grad().clone();
EXPECT_TRUE(at::allclose(out, a * b.index_select(0, b_index)));
a.grad().fill_(0);
b.grad().fill_(0);
(a * b.index_select(0, b_index)).backward(grad_out);
EXPECT_TRUE(at::allclose(grad_a, a.grad()));
EXPECT_TRUE(at::allclose(grad_b, b.grad()));

a = at::randn({6, 8}, options);
b = at::randn({6, 8}, options);
a = at::randn({6, 8}, options).requires_grad_();
b = at::randn({8, 8}, options).requires_grad_();
grad_out = at::randn({3, 8}, options);
out = pyg::ops::sampled_op(a, b, a_index, b_index, "div");
out.backward(grad_out);
grad_a = a.grad().clone(), grad_b = b.grad().clone();
EXPECT_TRUE(at::allclose(
out, a.index_select(0, a_index) / b.index_select(0, b_index)));
a.grad().fill_(0);
b.grad().fill_(0);
(a.index_select(0, a_index) / b.index_select(0, b_index)).backward(grad_out);
EXPECT_TRUE(at::allclose(grad_a, a.grad()));
EXPECT_TRUE(at::allclose(grad_b, b.grad()));
}

INSTANTIATE_TEST_SUITE_P(OpsTest,
Expand Down

0 comments on commit 8f5b0b9

Please sign in to comment.