-
Notifications
You must be signed in to change notification settings - Fork 99
Implement a ONNX to ONNX Script code generator based on libcst #873
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,60 @@ | ||
| # ------------------------------------------------------------------------- | ||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||
| # Licensed under the MIT License. | ||
| # -------------------------------------------------------------------------- | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import argparse | ||
| from pathlib import Path | ||
| from typing import BinaryIO, Protocol | ||
|
|
||
| from onnxscript.codeanalysis import onnx_to_onnxscript | ||
|
|
||
|
|
||
| class ConvertCommandArgs(Protocol): | ||
| onnx_model_reader: BinaryIO | ||
| onnxscript_writer: BinaryIO | ||
|
|
||
|
|
||
| def convert_command(args: ConvertCommandArgs): | ||
| args.onnxscript_writer.write( | ||
| onnx_to_onnxscript.Driver(args.onnx_model_reader).to_python_code( | ||
| None | ||
| if args.onnxscript_writer.name == "<stdout>" | ||
| else Path(args.onnxscript_writer.name) | ||
| ) | ||
| ) | ||
|
|
||
|
|
||
| def main(): | ||
| parser = argparse.ArgumentParser(prog="onnxscript") | ||
| subparsers = parser.add_subparsers(required=True) | ||
|
|
||
| parser_convert = subparsers.add_parser( | ||
| "convert", | ||
| help="Convert an ONNX model to ONNX Script Python code", | ||
| description="Convert an ONNX model to ONNX Script Python code", | ||
| ) | ||
| parser_convert.set_defaults(func=convert_command) | ||
| parser_convert.add_argument( | ||
| "onnx_model_reader", | ||
| metavar="ONNX_MODEL_FILE", | ||
| type=argparse.FileType("rb"), | ||
| ) | ||
| parser_convert.add_argument( | ||
| "--output", | ||
| dest="onnxscript_writer", | ||
| metavar="OUTPUT_FILE", | ||
| type=argparse.FileType("wb"), | ||
| help="file path for writing generated ONNX Script code", | ||
| default="-", | ||
| required=False, | ||
| ) | ||
|
|
||
| args = parser.parse_args() | ||
| args.func(args) | ||
|
|
||
|
|
||
| if __name__ == "__main__": | ||
| main() | ||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,215 @@ | ||
| # ------------------------------------------------------------------------- | ||
| # Copyright (c) Microsoft Corporation. All rights reserved. | ||
| # Licensed under the MIT License. | ||
| # -------------------------------------------------------------------------- | ||
| # pylint: disable=import-outside-toplevel | ||
| # pylint: disable=too-many-ancestors | ||
| # -------------------------------------------------------------------------- | ||
|
|
||
| from __future__ import annotations | ||
|
|
||
| import os | ||
| from collections import defaultdict | ||
| from dataclasses import dataclass | ||
| from pathlib import Path | ||
Check warningCode scanning / lintrunner RUFF/TID251 Warning
pathlib is banned: Using pathlib can impact performance. Use os.path instead.
See https://docs.astral.sh/ruff/rules/banned-api |
||
| from typing import Final, Protocol, Sequence, runtime_checkable | ||
|
|
||
| import libcst as cst | ||
Check failureCode scanning / lintrunner MYPY/import-not-found Error
Cannot find implementation or library stub for module named "libcst"
To disable, use # type: ignore[import-not-found]
|
||
| import libcst.matchers as cstm | ||
Check failureCode scanning / lintrunner MYPY/import-not-found Error
Cannot find implementation or library stub for module named "libcst.matchers"
To disable, use # type: ignore[import-not-found]
|
||
| import libcst.metadata as cstmeta | ||
Check failureCode scanning / lintrunner MYPY/import-not-found Error
Cannot find implementation or library stub for module named "libcst.metadata"
To disable, use # type: ignore[import-not-found]
|
||
|
|
||
| __all__ = [ | ||
| "format_code", | ||
| "make_name", | ||
| "make_import_alias", | ||
| "make_const_expr", | ||
| "RemoveUnusedImportsTransformer", | ||
| "CstCodeGenerator", | ||
| ] | ||
|
|
||
|
|
||
| def format_code(path: Path | None, code: bytes) -> bytes: | ||
| try: | ||
| import ufmt | ||
Check failureCode scanning / lintrunner MYPY/import-not-found Error
Cannot find implementation or library stub for module named "ufmt"
To disable, use # type: ignore[import-not-found]
|
||
|
|
||
| if path is None: | ||
| path = Path(os.curdir) | ||
|
|
||
| return ufmt.ufmt_bytes( | ||
| path, | ||
| code, | ||
| black_config=ufmt.util.make_black_config(path), | ||
| usort_config=ufmt.UsortConfig.find(path), | ||
| ) | ||
| except ImportError: | ||
| return code | ||
|
|
||
|
|
||
| def make_name(name: str) -> cst.Attribute | cst.Name: | ||
| tokens = name.split(".") | ||
| expr: cst.Name | cst.Attribute = cst.Name(tokens[0]) | ||
| for attr in tokens[1:]: | ||
| expr = cst.Attribute(expr, cst.Name(attr)) | ||
| return expr | ||
|
|
||
|
|
||
| def make_import_alias(name: str, asname: str | None = None) -> cst.ImportAlias: | ||
| return cst.ImportAlias( | ||
| name=make_name(name), | ||
| asname=cst.AsName(cst.Name(asname)) if asname else None, | ||
| ) | ||
|
|
||
|
|
||
| def make_const_expr(const: str | int | float) -> cst.BaseExpression: | ||
| negate = False | ||
| val: cst.Float | cst.Integer | ||
|
|
||
| if isinstance(const, str): | ||
| return cst.SimpleString('"' + const.replace('"', '\\"') + '"') | ||
| elif isinstance(const, int): | ||
| val = cst.Integer(str(abs(const))) | ||
| negate = const < 0 | ||
| elif isinstance(const, float): | ||
| val = cst.Float(str(abs(const))) | ||
| negate = const < 0 | ||
| else: | ||
| raise NotImplementedError(repr(const)) | ||
|
|
||
| if negate: | ||
| return cst.UnaryOperation( | ||
| operator=cst.Minus(), | ||
| expression=val, | ||
| ) | ||
|
|
||
| return val | ||
|
|
||
|
|
||
| @dataclass | ||
| class ImportAlias: | ||
| name: str | ||
| alias: str | None = None | ||
|
|
||
| def to_cst(self) -> cst.ImportAlias: | ||
| return cst.ImportAlias( | ||
| make_name(self.name), cst.AsName(cst.Name(self.alias)) if self.alias else None | ||
| ) | ||
|
|
||
|
|
||
| @dataclass | ||
| class Import: | ||
| module: ImportAlias | ||
|
|
||
| def to_cst(self) -> cst.Import: | ||
| return cst.Import(names=[self.module.to_cst()]) | ||
|
|
||
|
|
||
| @dataclass | ||
| class ImportFrom: | ||
| module: str | ||
| names: list[ImportAlias] | ||
|
|
||
| def to_cst(self) -> cst.ImportFrom: | ||
| return cst.ImportFrom( | ||
| module=make_name(self.module), | ||
| names=[name.to_cst() for name in self.names], | ||
| ) | ||
|
|
||
|
|
||
| @runtime_checkable | ||
| class ScopeAnalyzer(Protocol): | ||
| def analyze_scopes(self, scopes: set[cstmeta.Scope]): | ||
| pass | ||
|
|
||
|
|
||
| class RemoveUnusedImportsTransformer(cst.CSTTransformer, ScopeAnalyzer): | ||
| def __init__(self): | ||
| self.__unused_imports: dict[cst.Import | cst.ImportFrom, set[str]] = defaultdict(set) | ||
|
|
||
| def is_unused_allowed(self, node: cst.Import | cst.ImportFrom, name: str): | ||
| return name == "annotations" and cstm.matches( | ||
| node, cstm.ImportFrom(module=cstm.Name("__future__")) | ||
| ) | ||
|
|
||
| def analyze_scopes(self, scopes: set[cstmeta.Scope]): | ||
| for scope in scopes: | ||
| for assignment in scope.assignments: | ||
| if ( | ||
| isinstance(assignment, cstmeta.Assignment) | ||
| and isinstance(node := assignment.node, (cst.Import, cst.ImportFrom)) | ||
| and len(assignment.references) == 0 | ||
| and not self.is_unused_allowed(node, assignment.name) | ||
| ): | ||
| self.__unused_imports[node].add(assignment.name) | ||
|
|
||
| def __leave_import_alike( | ||
| self, | ||
| original_node: cst.Import | cst.ImportFrom, | ||
| updated_node: cst.Import | cst.ImportFrom, | ||
| ) -> cst.Import | cst.ImportFrom | cst.RemovalSentinel: | ||
| if original_node not in self.__unused_imports or isinstance( | ||
| updated_node.names, cst.ImportStar | ||
| ): | ||
| return updated_node | ||
|
|
||
| names_to_keep: list[cst.ImportAlias] = [] | ||
|
|
||
| for name in updated_node.names: | ||
| if name.asname is not None: | ||
| if not isinstance(name.asname, cst.Name): | ||
| continue | ||
| name_value = name.asname.name.value | ||
| else: | ||
| name_value = name.name.value | ||
| if name_value not in self.__unused_imports[original_node]: | ||
| names_to_keep.append(name.with_changes(comma=cst.MaybeSentinel.DEFAULT)) | ||
|
|
||
| if len(names_to_keep) == 0: | ||
| return cst.RemoveFromParent() | ||
|
|
||
| return updated_node.with_changes(names=names_to_keep) | ||
|
|
||
| def leave_Import(self, original_node: cst.Import, updated_node: cst.Import): | ||
| return self.__leave_import_alike(original_node, updated_node) | ||
|
|
||
| def leave_ImportFrom(self, original_node: cst.ImportFrom, updated_node: cst.ImportFrom): | ||
| return self.__leave_import_alike(original_node, updated_node) | ||
|
|
||
|
|
||
| class CstCodeGenerator: | ||
| def __init__(self): | ||
| self.__imports: Final[list[Import | ImportFrom]] = [] | ||
|
|
||
| def add_import(self, module: str, alias: str | None = None): | ||
| if not any( | ||
| isinstance(imp, Import) and imp.module.name == module and imp.module.alias == alias | ||
| for imp in self.__imports | ||
| ): | ||
| self.__imports.append(Import(ImportAlias(module, alias))) | ||
|
|
||
| def add_import_from(self, module: str, name: str, alias: str | None = None): | ||
| for imp in self.__imports: | ||
| if isinstance(imp, ImportFrom) and imp.module == module: | ||
| for existing in imp.names: | ||
| if existing.name == name and existing.alias == alias: | ||
| return | ||
| imp.names.append(ImportAlias(name, alias)) | ||
| return | ||
| self.__imports.append(ImportFrom(module, [ImportAlias(name, alias)])) | ||
|
|
||
| def make_import_statements(self) -> Sequence[cst.SimpleStatementLine]: | ||
| return [cst.SimpleStatementLine(body=[imp.to_cst()]) for imp in self.__imports] | ||
|
|
||
| def apply_transformers( | ||
| self, module: cst.Module, transformers: Sequence[cst.CSTTransformer] | ||
| ) -> cst.Module: | ||
| for transformer in transformers: | ||
| wrapper = cstmeta.MetadataWrapper(module) | ||
| if isinstance(transformer, ScopeAnalyzer): | ||
| scopes = { | ||
| scope | ||
| for scope in wrapper.resolve(cstmeta.ScopeProvider).values() | ||
| if scope is not None | ||
| } | ||
| transformer.analyze_scopes(scopes) | ||
| module = wrapper.visit(transformer) | ||
| return module | ||
Check warning
Code scanning / lintrunner
RUFF/TID251 Warning