Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

supplement the function of slice. #34172

Merged
merged 12 commits into from
Aug 4, 2021

Conversation

hbwx24
Copy link
Contributor

@hbwx24 hbwx24 commented Jul 15, 2021

PR types

Function optimization

PR changes

APIs

Describe

  • 补全slice功能:在静态图中当start=None, end=None时,支持step>0。例如 a[::2]
  • strided_slice_op:支持TensorArray类型的输入。

@paddle-bot-old
Copy link

Thanks for your contribution!
Please wait for the result of CI firstly. See Paddle CI Manual for details.

chenwhql
chenwhql previously approved these changes Jul 20, 2021
Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@@ -154,6 +160,13 @@ class StridedSliceOp : public framework::OperatorWithKernel {
protected:
framework::OpKernelType GetExpectedKernelType(
const framework::ExecutionContext &ctx) const override {
auto *in_var = ctx.InputVar("Input");
auto is_in_var_array = in_var->IsType<framework::LoDTensorArray>();
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

结合下面的code,会不会有这种情况,lodtensorarray里面tensor的place是cuda_pinned

Copy link
Contributor Author

@hbwx24 hbwx24 Jul 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

TensorCopy(in_tensor, context.GetPlace(), out_tensor);
}

return;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

是不是用else分支管理代码好一些

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

// StridedSliceGrad
// cannot be calculated by `framework::GradVarName("Output")`,
// the dim of "Input" is used to calculate the output shape.
// when set it to inplace OP, there may be some problems.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里注释说的可能存在的问题是已解决的还是TODO的

Copy link
Contributor Author

@hbwx24 hbwx24 Jul 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这个注释说的是可能存在的问题:因为这个反向op使用Input(前向op的输入)计算输出shape,所以这个op不能是inplace op。

改成了NOTE(xx):

set_zero(dev_ctx, d_out_tensor, static_cast<T>(0));
}
}
return;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

同上

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

@@ -176,6 +177,45 @@ def test_set_value_with_save(self):
output_spec=None)


class TestSliceSupplementCase(unittest.TestCase):
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

命名建议准确一些,带一些功能的特征?后面可能还会追加case吧

Copy link
Contributor Author

@hbwx24 hbwx24 Jul 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

修改了名字:TestSliceSupplementCase -> TestSliceSupplementSpecialCase
添加了一行注释:# unittest for slice index which abs(step)>0. eg: x[::2]


self.create_case(Net(input_size=112, array_size=13))

# TODO(weixin):Currently, the case that the start index is
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这种case现在的报错提示或者说warning是怎样的

Copy link
Contributor Author

@hbwx24 hbwx24 Jul 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

strided_slice_op.h的130行处理index,现在可支持这种情况了,但是用到这个op的其他api例如,varbase.getitem、paddle.strided_slice等op也做了类似的简单处理,这些处理是不冲突的。

Comment on lines +130 to +132
if (ends[axis_index] < 0) {
ends[axis_index] = 0;
}
Copy link
Contributor Author

@hbwx24 hbwx24 Jul 29, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

[start : end : step]:处理end<-axis_size的情况。例如:len(a)=10, a[:-100:-1]

platform::is_same_place(tensor.place(),
ctx.device_context().GetPlace()),
true, platform::errors::InvalidArgument(
"Place of context is %s. Place of context is %s. They "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有一个place是tensor的?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Done, thx.

Copy link
Contributor

@chenwhql chenwhql left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@hbwx24 hbwx24 merged commit 1f0f5d3 into PaddlePaddle:develop Aug 4, 2021
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants