Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
rusty1s committed Jul 13, 2022
1 parent 045ecc7 commit 365682b
Showing 1 changed file with 37 additions and 59 deletions.
96 changes: 37 additions & 59 deletions pyg_lib/csrc/ops/matmul.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@ using torch::autograd::AutogradContext;
using torch::autograd::Variable;
using torch::autograd::variable_list;

// Performs matrix multiplication across list of elements.
std::vector<at::Tensor> grouped_matmul(const std::vector<at::Tensor>& input,
const std::vector<at::Tensor>& other) {
namespace {

std::vector<at::Tensor> _grouped_matmul(const std::vector<at::Tensor>& input,
const std::vector<at::Tensor>& other) {
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::grouped_matmul", "")
.typed<decltype(grouped_matmul)>();
Expand All @@ -23,84 +24,61 @@ std::vector<at::Tensor> grouped_matmul(const std::vector<at::Tensor>& input,
return op.call(input, other);
}

// static auto group_op = c10::Dispatcher::singleton()
// .findSchemaOrThrow("pyg::grouped_matmul", "")
// .typed<decltype(grouped_matmul)>();
// class GroupedMatmul : public torch::autograd::Function<GroupedMatmul> {
// // TODO (matthias) Add TensorArg definitions.
// public:
// static std::vector<variable_list> forward(AutogradContext* ctx,
// std::vector<Variable> input,
// std::vector<Variable> other) {
// auto out = group_op.call(input, other);
// ctx->save_for_backward({input, other});
// return {out};
// }

//
// static std::vector<variable_list> backward(AutogradContext* ctx,
// variable_list grad_outs) {
// auto saved = ctx->get_saved_variables();
// variable_list input = saved[0];
// variable_list other = saved[1];
// for (size_t i = 0; i < input.size(); ++i)
// other[i] = other[i].transpose(-2, -1).contiguous();
// auto other_grad = group_op.call(grad_outs, other);
// if (torch::autograd::any_variable_requires_grad(input)) {
// for (size_t i = 0; i < input.size(); ++i)
// input[i] = input[i].transpose(-2, -1).contiguous();
// auto input_grad = group_op.call(input, grad_outs);
// return {input_grad, other_grad};
// } else {
// return {variable_list(), other_grad};
// }
// }
// };

// std::vector<at::Tensor> grouped_matmul(const std::vector<at::Tensor>& input,
// const std::vector<at::Tensor>& other)
// {
// return GroupedMatmul::apply(input, other)[0];
// }
at::Tensor segment_op(const at::Tensor& input,
const at::Tensor& ptr,
const at::Tensor& other) {
at::Tensor _segment_matmul(const at::Tensor& input,
const at::Tensor& ptr,
const at::Tensor& other) {
// TODO (matthias) Add TensorArg definitions.
// TODO (matthias) Add autograd support.
static auto op = c10::Dispatcher::singleton()
.findSchemaOrThrow("pyg::segment_matmul", "")
.typed<decltype(segment_op)>();
.typed<decltype(segment_matmul)>();
return op.call(input, ptr, other);
}

// Performs matrix multiplication according to segments.
class SegmentMatmul : public torch::autograd::Function<SegmentMatmul> {
// TODO (matthias) Add TensorArg definitions.
public:
static variable_list forward(AutogradContext* ctx,
Variable input,
const at::Tensor& ptr,
Variable other) {
Variable out = segment_op(input, ptr, other);
Variable out = _segment_matmul(input, ptr, other);
ctx->save_for_backward({input, ptr, other});
return {out};
}

static variable_list backward(AutogradContext* ctx, variable_list grad_outs) {
variable_list saved = ctx->get_saved_variables();
Variable input = saved[0];
Variable ptr = saved[1];
Variable other = saved[2].transpose(-2, -1).contiguous();
Variable grad_out = grad_outs[0];
Variable other_grad = segment_op(grad_out, ptr, other);
auto grad_out = grad_outs[0];
auto saved = ctx->get_saved_variables();
auto input = saved[0], ptr = saved[1], other = saved[2];

auto input_grad = Variable(), other_grad = Variable();

if (torch::autograd::any_variable_requires_grad({input})) {
input = input.transpose(-2, -1).contiguous();
Variable input_grad = segment_op(input, ptr, grad_out);
return {input_grad, other_grad};
} else {
return {Variable(), other_grad};
// TODO (matthias) get rid of unnecessary `contiguous` here.
auto input_t = input.transpose(-2, -1).contiguous();
input_grad = _segment_matmul(input_t, ptr, grad_out);
}

if (torch::autograd::any_variable_requires_grad({other})) {
// TODO (matthias) get rid of unnecessary `contiguous` here.
auto other_t = other.transpose(-2, -1).contiguous();
other_grad = _segment_matmul(grad_out, ptr, other_t);
}

return {input_grad, other_grad};
}
};

} // namespace

// Performs matrix multiplication across list of elements.
std::vector<at::Tensor> grouped_matmul(const std::vector<at::Tensor>& input,
const std::vector<at::Tensor>& other) {
return _grouped_matmul(input, other);
}

// Performs matrix multiplication according to segments.
at::Tensor segment_matmul(const at::Tensor& input,
const at::Tensor& ptr,
const at::Tensor& other) {
Expand Down

0 comments on commit 365682b

Please sign in to comment.