Skip to content
This repository has been archived by the owner on Nov 17, 2023. It is now read-only.

Commit

Permalink
Bugfix: kAddTo support for swap_axes.
Browse files Browse the repository at this point in the history
  • Loading branch information
Tobias Domhan committed Jan 24, 2018
1 parent 20253d5 commit 758e4f8
Showing 1 changed file with 11 additions and 5 deletions.
16 changes: 11 additions & 5 deletions src/operator/swapaxis-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -101,15 +101,17 @@ class SwapAxisOp : public Operator {
}

void SwapAxis(mshadow::Stream<xpu> *s,
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data) {
const std::vector<TBlob> &in_data,
const std::vector<TBlob> &out_data,
const std::vector<OpReqType> &req) {
using namespace mshadow;
using namespace mshadow::expr;
uint32_t dim1 = param_.dim1;
uint32_t dim2 = param_.dim2;

TBlob data_in = in_data[swapaxisenum::kData];
TBlob data_out = out_data[swapaxisenum::kData];
OpReqType out_req = req[swapaxisenum::kData];

TShape shape_in = data_in.shape_;
TShape shape_out = data_out.shape_;
Expand All @@ -125,7 +127,11 @@ class SwapAxisOp : public Operator {

Tensor<xpu, 5, DType> inter_data_out = data_out.get_with_shape<xpu, 5, DType>(inter_shape2, s);

inter_data_out = swapaxis<3, 1>(inter_data_in);
if (out_req == kAddTo) {
inter_data_out += swapaxis<3, 1>(inter_data_in);
} else {
inter_data_out = swapaxis<3, 1>(inter_data_in);
}
}

virtual void Forward(const OpContext &ctx,
Expand All @@ -136,7 +142,7 @@ class SwapAxisOp : public Operator {
using namespace mshadow;
Stream<xpu> *s = ctx.get_stream<xpu>();

SwapAxis(s, in_data, out_data);
SwapAxis(s, in_data, out_data, req);
}

virtual void Backward(const OpContext &ctx,
Expand All @@ -149,7 +155,7 @@ class SwapAxisOp : public Operator {
using namespace mshadow;
Stream<xpu> *s = ctx.get_stream<xpu>();

SwapAxis(s, out_grad, in_grad);
SwapAxis(s, out_grad, in_grad, req);
}

SwapAxisParam param_;
Expand Down

0 comments on commit 758e4f8

Please sign in to comment.