Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

workaround a bug with the ast unparse package and Python 3.8 #738

Merged
merged 1 commit into from
Jun 18, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 23 additions & 1 deletion lale/expressions.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from typing import Any, Dict, Optional, Union

import astunparse
from six.moves import cStringIO

AstLits = (ast.Num, ast.Str, ast.List, ast.Tuple, ast.Set, ast.Dict)
AstLit = Union[ast.Num, ast.Str, ast.List, ast.Tuple, ast.Set, ast.Dict]
Expand Down Expand Up @@ -48,6 +49,27 @@
]


# !! WORKAROUND !!
# There is a bug with astunparse and Python 3.8.
# https://github.com/simonpercivall/astunparse/issues/43
# Until it is fixed (which may be never), here is a workaround,
# based on the workaround found in https://github.com/juanlao7/codeclose
class FixUnparser(astunparse.Unparser):
def _Constant(self, t):
if not hasattr(t, "kind"):
setattr(t, "kind", None)

super()._Constant(t)


# !! WORKAROUND !!
# This method should be called instead of astunparse.unparse
def fixedUnparse(tree):
v = cStringIO()
FixUnparser(tree, file=v)
return v.getvalue()


class Expr:
_expr: AstExpr

Expand Down Expand Up @@ -104,7 +126,7 @@ def __getitem__(self, key: Union[int, str, slice]) -> "Expr":
return Expr(subscript)

def __str__(self) -> str:
result = astunparse.unparse(self._expr).strip()
result = fixedUnparse(self._expr).strip()
if isinstance(self._expr, (ast.UnaryOp, ast.BinOp, ast.Compare, ast.BoolOp)):
if result.startswith("(") and result.endswith(")"):
result = result[1:-1]
Expand Down
13 changes: 13 additions & 0 deletions test/test_core_misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,19 @@ def test_transformers(self):
self.assertNotIn("MLPClassifier", ops_names)


class TestUnparseExpr(unittest.TestCase):
def test_unparse_const38(self):
import lale.expressions
from lale.expressions import it

test_expr = it.hello["hi"]
# This fails on 3.8 with some versions of the library
# which is why we use the fixed version
# import astunparse
# astunparse.unparse(he._expr)
str(lale.expressions.fixedUnparse(test_expr._expr))


class TestOperatorWithoutSchema(unittest.TestCase):
def test_trainable_pipe_left(self):
from sklearn.decomposition import PCA
Expand Down