Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

【Hackathon 7th No.41】NO.41 为 Paddle 代码转换工具新增 API 转换规则(第 8 组) #495

Merged
merged 39 commits into from
Oct 22, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
39 commits
Select commit Hold shift + click to select a range
16ca6c7
add is_inference
decade-afk Oct 14, 2024
91de87b
update
decade-afk Oct 14, 2024
e035d2d
update
decade-afk Oct 14, 2024
29b37e1
update
decade-afk Oct 15, 2024
492a2e9
add geometric_
decade-afk Oct 15, 2024
8063b46
add cauchy_
decade-afk Oct 15, 2024
702f3b3
add random_
decade-afk Oct 15, 2024
97acd4b
update
decade-afk Oct 15, 2024
cc89b26
update
decade-afk Oct 15, 2024
05cf0be
add chi2
decade-afk Oct 15, 2024
a972b1c
add Constraint
decade-afk Oct 15, 2024
f4aa1d1
add Gamma
decade-afk Oct 15, 2024
f2e7ddb
update
decade-afk Oct 15, 2024
329b4b2
add Poisson LKJCholesky
decade-afk Oct 15, 2024
870d76d
update
decade-afk Oct 15, 2024
411edfd
add StudentT PositiveDefiniteTransform
decade-afk Oct 15, 2024
54a14ef
update
decade-afk Oct 16, 2024
9d1787d
update
decade-afk Oct 16, 2024
92da59f
add remote
decade-afk Oct 16, 2024
271be5d
add remote
decade-afk Oct 16, 2024
60b6724
add remote
decade-afk Oct 16, 2024
3077803
update
decade-afk Oct 16, 2024
c2e0a3c
update
decade-afk Oct 16, 2024
6a60ce0
update
decade-afk Oct 16, 2024
6db5c8e
update
decade-afk Oct 16, 2024
a743e69
update
decade-afk Oct 16, 2024
082eca9
update
decade-afk Oct 16, 2024
6212cd9
update
decade-afk Oct 16, 2024
890793e
update
decade-afk Oct 16, 2024
e486474
update
decade-afk Oct 17, 2024
0c741fe
update
decade-afk Oct 17, 2024
eca644c
update
decade-afk Oct 17, 2024
1ebbedc
add DistributedOptimizer
decade-afk Oct 17, 2024
d60ebc6
update
decade-afk Oct 18, 2024
4af9f36
update
decade-afk Oct 21, 2024
3e2e385
update
decade-afk Oct 21, 2024
ecb389d
update
decade-afk Oct 21, 2024
cc0062e
upadte api_matcher
decade-afk Oct 22, 2024
55139a7
update
decade-afk Oct 22, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
149 changes: 145 additions & 4 deletions paconvert/api_mapping.json
Original file line number Diff line number Diff line change
Expand Up @@ -844,7 +844,22 @@
"memory_format"
]
},
"torch.Tensor.cauchy_": {},
"torch.Tensor.cauchy_": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.cauchy_",
"min_input_args": 0,
"args_list": [
"median",
"sigma",
"*",
"generator"
],
"kwargs_change": {
"median": "loc",
"sigma":"scale",
"generator":""
}
},
"torch.Tensor.cdouble": {
"Matcher": "TensorCdoubleMatcher",
"paddle_api": "paddle.Tensor.astype",
Expand Down Expand Up @@ -1711,7 +1726,20 @@
"other": "y"
}
},
"torch.Tensor.geometric_": {},
"torch.Tensor.geometric_": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.geometric_",
"min_input_args": 1,
"args_list": [
"p",
"*",
"generator"
],
"kwargs_change": {
"p": "probs",
"generator":""
}
},
"torch.Tensor.geqrf": {},
"torch.Tensor.ger": {
"Matcher": "GenericMatcher",
Expand Down Expand Up @@ -2088,7 +2116,17 @@
"paddle_api": "paddle.Tensor.is_floating_point",
"min_input_args": 0
},
"torch.Tensor.is_inference": {},
"torch.Tensor.is_inference": {
"Matcher": "Is_InferenceMatcher",
"min_input_args": 0
},
"torch.is_inference": {
"Matcher": "Is_InferenceMatcher",
"min_input_args": 1,
"args_list":[
"input"
]
},
"torch.Tensor.is_pinned": {
"Matcher": "Is_PinnedMatcher",
"min_input_args": 0
Expand Down Expand Up @@ -3166,7 +3204,22 @@
"Matcher": "UnchangeMatcher",
"min_input_args": 0
},
"torch.Tensor.random_": {},
"torch.Tensor.random_": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.uniform_",
"min_input_args": 0,
"args_list": [
"from",
"to",
"*",
"generator"
],
"kwargs_change": {
"from": "min",
"to": "max",
"generator": ""
}
},
"torch.Tensor.ravel": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.Tensor.flatten",
Expand Down Expand Up @@ -6298,6 +6351,18 @@
],
"min_input_args": 1
},
"torch.distributed.rpc.remote":{
"Matcher": "RpcRemoteMatcher",
"paddle_api": "paddle.distributed.rpc.rpc_async",
"min_input_args": 2,
"args_list": [
"to",
"func",
"args",
"kwargs",
"timeout"
]
},
"torch.distributed.rpc.shutdown": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distributed.rpc.shutdown",
Expand Down Expand Up @@ -6381,6 +6446,52 @@
},
"min_input_args": 2
},
"torch.distributions.lkj_cholesky.LKJCholesky":{
"Matcher": "LKJCholeskyMatcher",
"paddle_api": "paddle.distribution.LKJCholesky",
"min_input_args": 1,
"args_list": [
"dim",
"concentration",
"validate_args"
]
},
"torch.distributions.studentT.StudentT":{
"Matcher": "StudentTMatcher",
"paddle_api": "paddle.distribution.StudentT",
"min_input_args": 1,
"args_list": [
"df",
"loc",
"scale",
"validate_args"
],
"kwargs_change": {
"validate_args": ""
}
},
"torch.distributions.transforms.PositiveDefiniteTransform":{
"Matcher": "TransformsPositiveDefiniteTransformMatcher",
"min_input_args": 0,
"args_list": [
"cache_size"
],
"kwargs_change": {
"cache_size": ""
}
},
"torch.distributions.poisson.Poisson":{
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Poisson",
"min_input_args": 1,
"args_list": [
"rate",
"validate_args"
],
"kwargs_change": {
"validate_args": ""
}
},
"torch.distributions.Bernoulli": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Bernoulli",
Expand Down Expand Up @@ -6431,6 +6542,18 @@
"total_count": "1"
}
},
"torch.distributions.chi2.Chi2":{
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Chi2",
"min_input_args": 1,
"args_list": [
"df",
"validate_args"
],
"kwargs_change": {
"validate_args": ""
}
},
"torch.distributions.Categorical": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Categorical",
Expand Down Expand Up @@ -6473,6 +6596,24 @@
"cache_size": ""
}
},
"torch.distributions.constraints.Constraint" : {
"Matcher": "DistributionsConstrainMatcher",
"paddle_api": "paddle.distribution.constraint.Constraint",
"abstract": true
},
"torch.distributions.gamma.Gamma":{
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.Gamma",
"min_input_args": 2,
"args_list": [
"concentration",
"rate",
"validate_args"
],
"kwargs_change": {
"validate_args": ""
}
},
"torch.distributions.ContinuousBernoulli": {
"Matcher": "GenericMatcher",
"paddle_api": "paddle.distribution.ContinuousBernoulli",
Expand Down
150 changes: 150 additions & 0 deletions paconvert/api_matcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -751,6 +751,127 @@ def get_paddle_nodes(self, args, kwargs):
for i in range(1, len(new_args)):
code = "{}({}, {})".format(self.get_paddle_api(), code, new_args[i])
return ast.parse(code).body


class StudentTMatcher(BaseMatcher):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是什么差异,需要重写Matcher。我也没有在映射文档中看到任何描述这两个API的差异,不能直接对上用GenericMatcher吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image
就是paddle输入的参数需要都是tensor或者float

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

image 就是paddle输入的参数需要都是tensor或者float

paddle是什么样的不重要,你要先看torch是什么样,再看paddle的处理方案。

这里我还是看不出两者的差异,你的映射文档写得不清楚,用户一眼看不出应如何转写

def generate_aux_code(self):
API_TEMPLATE = textwrap.dedent(
"""
import paddle
class StudentT_Aux_Class:
def __init__(self, df, loc, scale):
self.df = paddle.to_tensor(df)
self.loc = paddle.to_tensor(loc)
self.scale = paddle.to_tensor(scale)
self.sT = paddle.distribution.StudentT(self.df, self.loc, self.scale)
def sample(self):
return paddle.reshape(self.sT.sample(), self.df.shape)
"""
)

return API_TEMPLATE
def generate_code(self, kwargs):
self.write_aux_code()
if "validate_args" in kwargs:
del kwargs["validate_args"]
if "loc" not in kwargs:
kwargs["loc"] = 0.1
if "scale" not in kwargs:
kwargs["scale"] = 1.0
kwargs = self.kwargs_to_str(kwargs)
API_TEMPLATE = textwrap.dedent(
"""
paddle_aux.StudentT_Aux_Class({})
"""
)
code = API_TEMPLATE.format(kwargs)
return code


class TransformsPositiveDefiniteTransformMatcher(BaseMatcher):
def generate_aux_code(self):
API_TEMPLATE = textwrap.dedent(
"""
import paddle
class TransformsPositiveDefiniteTransform:
def __call__(self, x):
x = x.tril(-1) + x.diagonal(axis1=-2, axis2=-1).exp().diag_embed()
return x @ x.T

def inv(self, y):
y = paddle.linalg.cholesky(y)
return y.tril(-1) + y.diagonal(axis1=-2, axis2=-1).log().diag_embed()
"""
)

return API_TEMPLATE
def generate_code(self, kwargs):
self.write_aux_code()
API_TEMPLATE = textwrap.dedent(
"""
paddle_aux.TransformsPositiveDefiniteTransform()
"""
)
return API_TEMPLATE


class LKJCholeskyMatcher(BaseMatcher):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是什么差异,需要重写Matcher。我也没有在映射文档中看到任何描述这两个API的差异,不能直接对上用GenericMatcher吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

有,输出的维度不一样,需要补上一个

def generate_aux_code(self):
API_TEMPLATE = textwrap.dedent(
"""
import paddle
Copy link
Collaborator

@zhwesky2010 zhwesky2010 Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

你用func包class,包这两层的意义呢,直接定义一个class不更简洁吗

class LKJCholesky_Aux_Class:
def __init__(self, dim, concentration, sample_method='onion'):
self.lkj = paddle.distribution.LKJCholesky(dim, concentration, sample_method)
def sample(self):
return paddle.unsqueeze(self.lkj.sample(), axis=0)
"""
)

return API_TEMPLATE
def generate_code(self, kwargs):
self.write_aux_code()
if "validate_args" in kwargs:
del kwargs["validate_args"]
kwargs = self.kwargs_to_str(kwargs)
API_TEMPLATE = textwrap.dedent(
"""
paddle_aux.LKJCholesky_Aux_Class({})
"""
)
code = API_TEMPLATE.format(kwargs)
return code



class Is_InferenceMatcher(BaseMatcher):
def generate_code(self, kwargs):
if "input" not in kwargs:
kwargs["input"] = self.paddleClass
code = "{}.stop_gradient".format(kwargs["input"])
return code


class DistributionsConstrainMatcher(BaseMatcher):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

这里是什么差异,需要重写Matcher。我也没有在映射文档中看到任何描述这两个API的差异,不能直接对上用GenericMatcher吗?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

差异是一个是__call__,一个用的是call,是为了封装

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

差异是一个是__call__,一个用的是call,是为了封装

你在文档里需要写清楚差异

Copy link
Collaborator

@zhwesky2010 zhwesky2010 Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

infoflow 2024-10-22 15-34-33

我没有看出来为什么需要重写Matcher?

即使重写,为何需要先定义一个func,在func里再包一个class,这不是多此一举吗?你包这两层的意义呢,直接定义一个class不更简洁吗

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

我改改

Copy link
Contributor Author

@decade-afk decade-afk Oct 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

不包装一下,直接调用check会报错

def generate_aux_code(self):
API_TEMPLATE = textwrap.dedent(
"""
import paddle
class DistributionsConstrain:
def check(self, value):
return paddle.distribution.constraint.Constraint()(value)
"""
)

return API_TEMPLATE
def generate_code(self, kwargs):
self.write_aux_code()
API_TEMPLATE = textwrap.dedent(
"""
paddle_aux.DistributionsConstrain()
"""
)
return API_TEMPLATE


class IInfoMatcher(BaseMatcher):
Expand Down Expand Up @@ -5067,6 +5188,35 @@ def generate_code(self, kwargs):
return code


class RpcRemoteMatcher(BaseMatcher):
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

infoflow 2024-10-22 15-33-27

我没有看出来为什么需要重写Matcher。

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

因为rpc_async使用to_wait获取值,这个我改改文档

def generate_aux_code(self):
CODE_TEMPLATE = textwrap.dedent(
"""
class rpc_remote:
def __init__(self, remote_obj):
self.remote = remote_obj

def to_here(self):
return self.remote.wait()
"""
)
return CODE_TEMPLATE

def generate_code(self, kwargs):
self.write_aux_code()
kwargs['fn'] = kwargs.pop('func')
kwargs = self.kwargs_to_str(kwargs)
API_TEMPLATE = textwrap.dedent(
"""
paddle_aux.rpc_remote(paddle.distributed.rpc.rpc_async({}))
"""
)
code = API_TEMPLATE.format(
kwargs
)
return code


class GetNumThreadsMatcher(BaseMatcher):
def generate_code(self, kwargs):
API_TEMPLATE = textwrap.dedent(
Expand Down
Loading