-
Notifications
You must be signed in to change notification settings - Fork 53
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
【Hackathon 7th No.42】NO.42 为 Paddle 代码转换工具新增 API 转换规则(第 9 组) #490
Changes from 9 commits
5d46f40
5fed93f
587e426
6389a93
c2ddba4
4041675
bb52abd
286b66b
b780348
e402151
cfb2a60
d1bfb6a
1f096fb
09be172
47bd025
e6de29b
f6d64fd
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -436,6 +436,72 @@ def generate_code(self, kwargs): | |
return super().generate_code(kwargs) | ||
|
||
|
||
class IsSparseCsrMatcher(BaseMatcher): | ||
def get_paddle_class_attribute_nodes(self, node): | ||
self.parse_func(node) | ||
code = "{}.is_sparse_csr()".format(self.paddleClass) | ||
return ast.parse(code).body | ||
|
||
|
||
class TensorStrideMatcher(BaseMatcher): | ||
def get_paddle_nodes(self, args, kwargs): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这个应该不用重写更底层的get_paddle_nodes,不涉及到可变参数这些,重写generate_code就行 |
||
kwargs = self.parse_kwargs(kwargs) | ||
args = self.parse_args(args) | ||
if "dim" in kwargs.keys() and kwargs["dim"] != "None": | ||
code = "{}.get_strides()[{}]".format(self.paddleClass, kwargs["dim"]) | ||
elif len(args) > 0 and args[0] != "None": | ||
code = "{}.get_strides()[{}]".format(self.paddleClass, args[0]) | ||
else: | ||
code = "{}.get_strides()".format(self.paddleClass) | ||
|
||
return ast.parse(code).body | ||
|
||
|
||
class TensorToSparseCooMatcher(BaseMatcher): | ||
def generate_code(self, kwargs): | ||
API_TEMPLATE = textwrap.dedent( | ||
""" | ||
{}.to_sparse_coo(len({}.shape)) | ||
""" | ||
) | ||
code = API_TEMPLATE.format(self.paddleClass, self.paddleClass) | ||
return code | ||
|
||
|
||
class TensorNbytesMatcher(BaseMatcher): | ||
def get_paddle_class_attribute_nodes(self, node): | ||
self.parse_func(node) | ||
code = "int(paddle.numel({}) * {}.element_size())".format( | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 按文档改组合方式 |
||
self.paddleClass, self.paddleClass | ||
) | ||
return ast.parse(code).body | ||
|
||
|
||
class DimOrderMatcher(BaseMatcher): | ||
def generate_code(self, kwargs): | ||
API_TEMPLATE = textwrap.dedent( | ||
""" | ||
tuple([i for i in range(len({}.shape))]) | ||
""" | ||
) | ||
code = API_TEMPLATE.format(self.paddleClass) | ||
return code | ||
|
||
|
||
class TensorItemsizeMatcher(BaseMatcher): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. 这么一个简单的示例,无需重写Matcher吧,用GenericMatcher就行 |
||
def get_paddle_class_attribute_nodes(self, node): | ||
self.parse_func(node) | ||
paddle_class = self.paddleClass | ||
|
||
API_TEMPLATE = textwrap.dedent( | ||
""" | ||
{}.element_size() | ||
""" | ||
) | ||
code = API_TEMPLATE.format(paddle_class) | ||
return ast.parse(code).body | ||
|
||
|
||
class TRFMPreTrainedTokenizerMatcher(BaseMatcher): | ||
def generate_aux_code(self): | ||
CODE_TEMPLATE = textwrap.dedent( | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
用 Attribute2Func 这个Matcher就行