-
Notifications
You must be signed in to change notification settings - Fork 5.6k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add new api paddle.Tensor.fill_diagonal_ #34460
add new api paddle.Tensor.fill_diagonal_ #34460
Conversation
Thanks for your contribution! |
"offset of diagonal, zero means no offset, positive means " | ||
"offset to up-right corner; negtive means offset to " | ||
"bottom-left corner") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
说明改一下?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
)DOC"); | ||
AddInput("X", "(Tensor) The input tensor."); | ||
AddOutput("Out", | ||
"Tensor, the clipped tensor, with the same shape and data type " |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
“the clipped tensor”的说明在这里有什么特殊意义不,是否去掉?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done,已修改
framework::TensorCopy(*xin, ctx.GetPlace(), out); | ||
|
||
T *out_data = out->mutable_data<T>(ctx.GetPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
常规做法是先 out->mutable_data 分配内存后再使用 TensorCopy,这里反过来是否有特殊考虑
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, TensorCopy中也有out->mutable_data,所以其实都可以。不过也已经修改。
framework::TensorCopy(*xin, ctx.GetPlace(), out); | ||
|
||
T* out_data = out->mutable_data<T>(ctx.GetPlace()); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
常规做法是先 out->mutable_data 分配内存后再使用 TensorCopy,这里反过来是否有特殊考虑
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done,同上已修改。
if ((i - offset) % strides == 0 && i < wrapsize) { | ||
data[i] = T(0); | ||
} else { | ||
data[i] = static_cast<T>(doutdata[i]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
反向的逻辑中是否也可以使用TensorCopy 批量实现数据cp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done,已修改。
if ((idx - offset) % strides == 0 && idx < wrapsize) { | ||
dst_data[idx] = T(0); | ||
} else { | ||
dst_data[idx] = const_cast<T&>(src_data[idx]); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
反向的逻辑中是否也可以使用TensorCopy 批量实现数据cp
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done,已修改。
AddOutput("Out", | ||
"Tensor, the output tensor, with the same shape and data type " | ||
"as input(x)"); | ||
AddAttr<float>( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
value只能是float吗
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
以float类型读取python数值,然后填入时转化为实际类型填入。
protected: | ||
void Apply(GradOpPtr<T> retv) const override { | ||
retv->SetType("fill_diagonal_grad"); | ||
// retv->SetInput("Out", this->Output("Out")); |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
注释的代码如无需要就直接删除吧
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
expected_grad = np.array( | ||
[[0, 1, 1], [1, 0, 1], [1, 1, 0]]).astype('float32') | ||
|
||
typelist = ['float32', 'float64', 'int32', 'int64'] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Kernel注册了Bool类型,也加上bool类型的测试
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG API
x(Tensor): ``x`` is the original Tensor | ||
value(Scale): ``value`` is the value to filled in x | ||
offset(int,optional): the offset to the main diagonal. Default: 0 (main diagonal). | ||
name(str,optional): Name for the operation (optional, default is None) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
wrap参数缺少说明
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
done, 已添加
a7376d0
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LG API
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
PR types
Others
PR changes
APIs
Describe
补充缺失API,新增 api paddle.Tensor.fill_diagonal_
功能:实现对Tensor对角线向量值的修改。
使用示例:
x = paddle.ones((3,3))
x.fill_diagonal_(5)
tensor([[5., 1., 1.],
[1., 5., 1.],
[1., 1., 5.]])