-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 7th No.41】NO.41 为 Paddle 代码转换工具新增 API 转换规则(第 8 组) #495
Changes from all commits
16ca6c7
91de87b
e035d2d
29b37e1
492a2e9
8063b46
702f3b3
97acd4b
cc89b26
05cf0be
a972b1c
f4aa1d1
f2e7ddb
329b4b2
870d76d
411edfd
54a14ef
9d1787d
92da59f
271be5d
60b6724
3077803
c2e0a3c
6a60ce0
6db5c8e
a743e69
082eca9
6212cd9
890793e
e486474
0c741fe
eca644c
1ebbedc
d60ebc6
4af9f36
3e2e385
ecb389d
cc0062e
55139a7
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -751,6 +751,127 @@ def get_paddle_nodes(self, args, kwargs): | |
for i in range(1, len(new_args)): | ||
code = "{}({}, {})".format(self.get_paddle_api(), code, new_args[i]) | ||
return ast.parse(code).body | ||
|
||
|
||
class StudentTMatcher(BaseMatcher): | ||
def generate_aux_code(self): | ||
API_TEMPLATE = textwrap.dedent( | ||
""" | ||
import paddle | ||
class StudentT_Aux_Class: | ||
def __init__(self, df, loc, scale): | ||
self.df = paddle.to_tensor(df) | ||
self.loc = paddle.to_tensor(loc) | ||
self.scale = paddle.to_tensor(scale) | ||
self.sT = paddle.distribution.StudentT(self.df, self.loc, self.scale) | ||
def sample(self): | ||
return paddle.reshape(self.sT.sample(), self.df.shape) | ||
""" | ||
) | ||
|
||
return API_TEMPLATE | ||
def generate_code(self, kwargs): | ||
self.write_aux_code() | ||
if "validate_args" in kwargs: | ||
del kwargs["validate_args"] | ||
if "loc" not in kwargs: | ||
kwargs["loc"] = 0.1 | ||
if "scale" not in kwargs: | ||
kwargs["scale"] = 1.0 | ||
kwargs = self.kwargs_to_str(kwargs) | ||
API_TEMPLATE = textwrap.dedent( | ||
""" | ||
paddle_aux.StudentT_Aux_Class({}) | ||
""" | ||
) | ||
code = API_TEMPLATE.format(kwargs) | ||
return code | ||
|
||
|
||
class TransformsPositiveDefiniteTransformMatcher(BaseMatcher): | ||
def generate_aux_code(self): | ||
API_TEMPLATE = textwrap.dedent( | ||
""" | ||
import paddle | ||
class TransformsPositiveDefiniteTransform: | ||
def __call__(self, x): | ||
x = x.tril(-1) + x.diagonal(axis1=-2, axis2=-1).exp().diag_embed() | ||
return x @ x.T | ||
|
||
def inv(self, y): | ||
y = paddle.linalg.cholesky(y) | ||
return y.tril(-1) + y.diagonal(axis1=-2, axis2=-1).log().diag_embed() | ||
""" | ||
) | ||
|
||
return API_TEMPLATE | ||
def generate_code(self, kwargs): | ||
self.write_aux_code() | ||
API_TEMPLATE = textwrap.dedent( | ||
""" | ||
paddle_aux.TransformsPositiveDefiniteTransform() | ||
""" | ||
) | ||
return API_TEMPLATE | ||
|
||
|
||
class LKJCholeskyMatcher(BaseMatcher): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里是什么差异,需要重写Matcher。我也没有在映射文档中看到任何描述这两个API的差异,不能直接对上用GenericMatcher吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 有,输出的维度不一样,需要补上一个 |
||
def generate_aux_code(self): | ||
API_TEMPLATE = textwrap.dedent( | ||
""" | ||
import paddle | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 你用func包class,包这两层的意义呢,直接定义一个class不更简洁吗 |
||
class LKJCholesky_Aux_Class: | ||
def __init__(self, dim, concentration, sample_method='onion'): | ||
self.lkj = paddle.distribution.LKJCholesky(dim, concentration, sample_method) | ||
def sample(self): | ||
return paddle.unsqueeze(self.lkj.sample(), axis=0) | ||
""" | ||
) | ||
|
||
return API_TEMPLATE | ||
def generate_code(self, kwargs): | ||
self.write_aux_code() | ||
if "validate_args" in kwargs: | ||
del kwargs["validate_args"] | ||
kwargs = self.kwargs_to_str(kwargs) | ||
API_TEMPLATE = textwrap.dedent( | ||
""" | ||
paddle_aux.LKJCholesky_Aux_Class({}) | ||
""" | ||
) | ||
code = API_TEMPLATE.format(kwargs) | ||
return code | ||
|
||
|
||
|
||
class Is_InferenceMatcher(BaseMatcher): | ||
def generate_code(self, kwargs): | ||
if "input" not in kwargs: | ||
kwargs["input"] = self.paddleClass | ||
code = "{}.stop_gradient".format(kwargs["input"]) | ||
return code | ||
|
||
|
||
class DistributionsConstrainMatcher(BaseMatcher): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这里是什么差异,需要重写Matcher。我也没有在映射文档中看到任何描述这两个API的差异,不能直接对上用GenericMatcher吗? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 差异是一个是__call__,一个用的是call,是为了封装 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
你在文档里需要写清楚差异 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 我改改 There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 不包装一下,直接调用check会报错 |
||
def generate_aux_code(self): | ||
API_TEMPLATE = textwrap.dedent( | ||
""" | ||
import paddle | ||
class DistributionsConstrain: | ||
def check(self, value): | ||
return paddle.distribution.constraint.Constraint()(value) | ||
""" | ||
) | ||
|
||
return API_TEMPLATE | ||
def generate_code(self, kwargs): | ||
self.write_aux_code() | ||
API_TEMPLATE = textwrap.dedent( | ||
""" | ||
paddle_aux.DistributionsConstrain() | ||
""" | ||
) | ||
return API_TEMPLATE | ||
|
||
|
||
class IInfoMatcher(BaseMatcher): | ||
|
@@ -5067,6 +5188,35 @@ def generate_code(self, kwargs): | |
return code | ||
|
||
|
||
class RpcRemoteMatcher(BaseMatcher): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 因为rpc_async使用to_wait获取值,这个我改改文档 |
||
def generate_aux_code(self): | ||
CODE_TEMPLATE = textwrap.dedent( | ||
""" | ||
class rpc_remote: | ||
def __init__(self, remote_obj): | ||
self.remote = remote_obj | ||
|
||
def to_here(self): | ||
return self.remote.wait() | ||
""" | ||
) | ||
return CODE_TEMPLATE | ||
|
||
def generate_code(self, kwargs): | ||
self.write_aux_code() | ||
kwargs['fn'] = kwargs.pop('func') | ||
kwargs = self.kwargs_to_str(kwargs) | ||
API_TEMPLATE = textwrap.dedent( | ||
""" | ||
paddle_aux.rpc_remote(paddle.distributed.rpc.rpc_async({})) | ||
""" | ||
) | ||
code = API_TEMPLATE.format( | ||
kwargs | ||
) | ||
return code | ||
|
||
|
||
class GetNumThreadsMatcher(BaseMatcher): | ||
def generate_code(self, kwargs): | ||
API_TEMPLATE = textwrap.dedent( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
这里是什么差异,需要重写Matcher。我也没有在映射文档中看到任何描述这两个API的差异,不能直接对上用GenericMatcher吗?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
就是paddle输入的参数需要都是tensor或者float
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
paddle是什么样的不重要,你要先看torch是什么样,再看paddle的处理方案。
这里我还是看不出两者的差异,你的映射文档写得不清楚,用户一眼看不出应如何转写