From d3e74a56648615d0090017910a71dff7c39e498e Mon Sep 17 00:00:00 2001 From: pwwang Date: Fri, 7 Oct 2022 23:22:13 -0700 Subject: [PATCH 1/5] =?UTF-8?q?=F0=9F=90=9B=20Patch=20classes=20if=20they?= =?UTF-8?q?=20have=20piping=20operator=20method?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pipda/__init__.py | 2 +- pipda/piping.py | 176 +++++++++++++++++++++++++++++++++++++++++++ pipda/utils.py | 3 +- pipda/verb.py | 32 -------- poetry.lock | 97 ++++++++++++++++++++++-- pyproject.toml | 7 +- tests/test_piping.py | 153 +++++++++++++++++++++++++++++++++++++ tests/test_verb.py | 23 ------ 8 files changed, 430 insertions(+), 63 deletions(-) create mode 100644 pipda/piping.py create mode 100644 tests/test_piping.py diff --git a/pipda/__init__.py b/pipda/__init__.py index 20acbe5..e4909dd 100644 --- a/pipda/__init__.py +++ b/pipda/__init__.py @@ -8,8 +8,8 @@ from .verb import ( Verb, VerbCall, - register_piping, register_verb, ) +from .piping import register_piping __version__ = "0.7.6" diff --git a/pipda/piping.py b/pipda/piping.py new file mode 100644 index 0000000..6b597c8 --- /dev/null +++ b/pipda/piping.py @@ -0,0 +1,176 @@ +import ast +import functools +from typing import Type + +from .verb import VerbCall + +PIPING_OPS = { + # op: (method, ast node) + ">>": ("__rrshift__", ast.RShift), + "|": ("__ror__", ast.BitOr), + "//": ("__rfloordiv__", ast.FloorDiv), + "@": ("__rmatmul__", ast.MatMult), + "%": ("__rmod__", ast.Mod), + "&": ("__rand__", ast.BitAnd), + "^": ("__rxor__", ast.BitXor), +} + +PATCHED_CLASSES = { + # kls: + # {} # registered but not patched + # {"method": , "imethod": } # patched +} + + +def _patch_cls_method(kls: Type, method: str) -> None: + """Borrowed from https://github.com/sspipe/sspipe""" + try: + original = getattr(kls, method) + except AttributeError: + return + + PATCHED_CLASSES[kls][method] = original + + @functools.wraps(original) + def wrapper(self, x, *args, **kwargs): + if isinstance(x, VerbCall): + return NotImplemented + return original(self, x, *args, **kwargs) + + setattr(kls, method, wrapper) + + +def _unpatch_cls_method(kls: Type, method: str) -> None: + if method in PATCHED_CLASSES[kls]: + setattr(kls, method, PATCHED_CLASSES[kls].pop(method)) + + +def _patch_cls_operator(kls: Type, op: str) -> None: + method = PIPING_OPS[op][0].replace("__r", "__") + imethod = PIPING_OPS[op][0].replace("__r", "__i") + _patch_cls_method(kls, method) + _patch_cls_method(kls, imethod) + + +def _unpatch_cls_operator(kls: Type, op: str) -> None: + method = PIPING_OPS[op][0].replace("__r", "__") + imethod = PIPING_OPS[op][0].replace("__r", "__i") + _unpatch_cls_method(kls, method) + _unpatch_cls_method(kls, imethod) + + +def patch_classes(*classes: Type) -> None: + """Patch the classes in case it has piping operator defined + + For example, DataFrame.__or__ has already been defined, so we need to + patch it to force it to use __ror__ of VerbCall if `|` is registered + for piping. + + Args: + classes: The classes to patch + """ + for kls in classes: + if kls not in PATCHED_CLASSES: + PATCHED_CLASSES[kls] = {} + + if not PATCHED_CLASSES[kls]: + _patch_cls_operator(kls, VerbCall.PIPING) + + +def unpatch_classes(*classes: Type) -> None: + """Unpatch the classes + + Args: + classes: The classes to unpatch + """ + for kls in classes: + if PATCHED_CLASSES[kls]: + _unpatch_cls_operator(kls, VerbCall.PIPING) + # Don't patch it in the future + del PATCHED_CLASSES[kls] + + +def _patch_all(op: str) -> None: + """Patch all registered classes that has the operator defined + + Args: + op: The operator used for piping + Avaiable: ">>", "|", "//", "@", "%", "&" and "^" + un: Unpatch the classes + """ + for kls in PATCHED_CLASSES: + _patch_cls_operator(kls, op) + + +def _unpatch_all(op: str) -> None: + """Unpatch all registered classes + + Args: + op: The operator used for piping + Avaiable: ">>", "|", "//", "@", "%", "&" and "^" + """ + for kls in PATCHED_CLASSES: + _unpatch_cls_operator(kls, op) + + +def _patch_default_classes() -> None: + """Patch the default/commonly used classes""" + try: + import pandas + patch_classes( + pandas.DataFrame, + pandas.Series, + pandas.Index, + pandas.Categorical, + ) + except ImportError: + pass + + try: # pragma: no cover + from modin import pandas + patch_classes( + pandas.DataFrame, + pandas.Series, + pandas.Index, + pandas.Categorical, + ) + except ImportError: + pass + + try: # pragma: no cover + import torch + patch_classes(torch.Tensor) + except ImportError: + pass + + try: # pragma: no cover + from django.db.models import query + patch_classes(query.QuerySet) + except ImportError: + pass + + +def register_piping(op: str) -> None: + """Register the piping operator for verbs + + Args: + op: The operator used for piping + Avaiable: ">>", "|", "//", "@", "%", "&" and "^" + """ + if op not in PIPING_OPS: + raise ValueError(f"Unsupported piping operator: {op}") + + if VerbCall.PIPING: + orig_method = VerbCall.__orig_opmethod__ + curr_method = PIPING_OPS[VerbCall.PIPING][0] + setattr(VerbCall, curr_method, orig_method) + _unpatch_all(VerbCall.PIPING) + + VerbCall.PIPING = op + VerbCall.__orig_opmethod__ = getattr(VerbCall, PIPING_OPS[op][0]) + setattr(VerbCall, PIPING_OPS[op][0], VerbCall._pipda_eval) + _patch_all(op) + + +register_piping(">>") +_patch_default_classes() diff --git a/pipda/utils.py b/pipda/utils.py index 13cb794..aad471c 100644 --- a/pipda/utils.py +++ b/pipda/utils.py @@ -53,7 +53,8 @@ def is_piping_verbcall(verb: str, fallback: str) -> bool: True if it is a piping verb call, otherwise False """ from executing import Source - from .verb import PIPING_OPS, VerbCall + from .verb import VerbCall + from .piping import PIPING_OPS frame = sys._getframe(2) node = Source.executing(frame).node diff --git a/pipda/verb.py b/pipda/verb.py index f9719e4..516d5b5 100644 --- a/pipda/verb.py +++ b/pipda/verb.py @@ -1,7 +1,6 @@ """Provide verb definition""" from __future__ import annotations -import ast from enum import Enum from typing import ( TYPE_CHECKING, @@ -29,16 +28,6 @@ from inspect import Signature from .context import ContextType -PIPING_OPS = { - ">>": ("__rrshift__", ast.RShift), - "|": ("__ror__", ast.BitOr), - "//": ("__rfloordiv__", ast.FloorDiv), - "@": ("__rmatmul__", ast.MatMult), - "%": ("__rmod__", ast.Mod), - "&": ("__rand__", ast.BitAnd), - "^": ("__rxor__", ast.BitXor), -} - class VerbCall(Expression): """A verb call @@ -298,24 +287,3 @@ def register_verb( dep=dep, ast_fallback=ast_fallback, ) - - -def register_piping(op: str) -> None: - """Register the piping operator for verbs - - Args: - op: The operator used for piping - Avaiable: ">>", "|", "//", "@", "%", "&" and "^" - """ - if op not in PIPING_OPS: - raise ValueError(f"Unsupported piping operator: {op}") - - if VerbCall.PIPING: - curr_method = PIPING_OPS[VerbCall.PIPING][0] - delattr(VerbCall, curr_method) - - VerbCall.PIPING = op - setattr(VerbCall, PIPING_OPS[op][0], VerbCall._pipda_eval) - - -register_piping(">>") diff --git a/poetry.lock b/poetry.lock index cfa9ee0..2b59a1c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -89,6 +89,27 @@ python-versions = ">=3.6" [package.dependencies] pyparsing = ">=2.0.2,<3.0.5 || >3.0.5" +[[package]] +name = "pandas" +version = "1.3.5" +description = "Powerful data structures for data analysis, time series, and statistics" +category = "dev" +optional = false +python-versions = ">=3.7.1" + +[package.dependencies] +numpy = [ + {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, + {version = ">=1.17.3", markers = "platform_machine != \"aarch64\" and platform_machine != \"arm64\" and python_version < \"3.10\""}, + {version = ">=1.19.2", markers = "platform_machine == \"aarch64\" and python_version < \"3.10\""}, + {version = ">=1.20.0", markers = "platform_machine == \"arm64\" and python_version < \"3.10\""}, +] +python-dateutil = ">=2.7.3" +pytz = ">=2017.3" + +[package.extras] +test = ["hypothesis (>=3.58)", "pytest (>=6.0)", "pytest-xdist"] + [[package]] name = "pluggy" version = "1.0.0" @@ -159,6 +180,33 @@ pytest = ">=4.6" [package.extras] testing = ["fields", "hunter", "process-tests", "pytest-xdist", "six", "virtualenv"] +[[package]] +name = "python-dateutil" +version = "2.8.2" +description = "Extensions to the standard Python datetime module" +category = "dev" +optional = false +python-versions = "!=3.0.*,!=3.1.*,!=3.2.*,>=2.7" + +[package.dependencies] +six = ">=1.5" + +[[package]] +name = "pytz" +version = "2022.4" +description = "World timezone definitions, modern and historical" +category = "dev" +optional = false +python-versions = "*" + +[[package]] +name = "six" +version = "1.16.0" +description = "Python 2 and 3 compatibility utilities" +category = "dev" +optional = false +python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*" + [[package]] name = "tomli" version = "2.0.1" @@ -169,7 +217,7 @@ python-versions = ">=3.7" [[package]] name = "typing-extensions" -version = "4.3.0" +version = "4.4.0" description = "Backported and Experimental Type Hints for Python 3.7+" category = "dev" optional = false @@ -189,8 +237,8 @@ testing = ["func-timeout", "jaraco.itertools", "pytest (>=6)", "pytest-black (>= [metadata] lock-version = "1.1" -python-versions = "^3.7" -content-hash = "514724306a432a2ff1fd593dd24aa5cf239e933a7f383d090a870a607fc8bccf" +python-versions = "^3.7.1" +content-hash = "2c249d7329310a2795b598cbb2e8dced0a71bd66bd561ca72228c966e4cf0bbf" [metadata.files] attrs = [ @@ -299,6 +347,33 @@ packaging = [ {file = "packaging-21.3-py3-none-any.whl", hash = "sha256:ef103e05f519cdc783ae24ea4e2e0f508a9c99b2d4969652eed6a2e1ea5bd522"}, {file = "packaging-21.3.tar.gz", hash = "sha256:dd47c42927d89ab911e606518907cc2d3a1f38bbd026385970643f9c5b8ecfeb"}, ] +pandas = [ + {file = "pandas-1.3.5-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:62d5b5ce965bae78f12c1c0df0d387899dd4211ec0bdc52822373f13a3a022b9"}, + {file = "pandas-1.3.5-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:adfeb11be2d54f275142c8ba9bf67acee771b7186a5745249c7d5a06c670136b"}, + {file = "pandas-1.3.5-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:60a8c055d58873ad81cae290d974d13dd479b82cbb975c3e1fa2cf1920715296"}, + {file = "pandas-1.3.5-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:fd541ab09e1f80a2a1760032d665f6e032d8e44055d602d65eeea6e6e85498cb"}, + {file = "pandas-1.3.5-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2651d75b9a167cc8cc572cf787ab512d16e316ae00ba81874b560586fa1325e0"}, + {file = "pandas-1.3.5-cp310-cp310-win_amd64.whl", hash = "sha256:aaf183a615ad790801fa3cf2fa450e5b6d23a54684fe386f7e3208f8b9bfbef6"}, + {file = "pandas-1.3.5-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:344295811e67f8200de2390093aeb3c8309f5648951b684d8db7eee7d1c81fb7"}, + {file = "pandas-1.3.5-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:552020bf83b7f9033b57cbae65589c01e7ef1544416122da0c79140c93288f56"}, + {file = "pandas-1.3.5-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:5cce0c6bbeb266b0e39e35176ee615ce3585233092f685b6a82362523e59e5b4"}, + {file = "pandas-1.3.5-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:7d28a3c65463fd0d0ba8bbb7696b23073efee0510783340a44b08f5e96ffce0c"}, + {file = "pandas-1.3.5-cp37-cp37m-win32.whl", hash = "sha256:a62949c626dd0ef7de11de34b44c6475db76995c2064e2d99c6498c3dba7fe58"}, + {file = "pandas-1.3.5-cp37-cp37m-win_amd64.whl", hash = "sha256:8025750767e138320b15ca16d70d5cdc1886e8f9cc56652d89735c016cd8aea6"}, + {file = "pandas-1.3.5-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:fe95bae4e2d579812865db2212bb733144e34d0c6785c0685329e5b60fcb85dd"}, + {file = "pandas-1.3.5-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5f261553a1e9c65b7a310302b9dbac31cf0049a51695c14ebe04e4bfd4a96f02"}, + {file = "pandas-1.3.5-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8b6dbec5f3e6d5dc80dcfee250e0a2a652b3f28663492f7dab9a24416a48ac39"}, + {file = "pandas-1.3.5-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d3bc49af96cd6285030a64779de5b3688633a07eb75c124b0747134a63f4c05f"}, + {file = "pandas-1.3.5-cp38-cp38-win32.whl", hash = "sha256:b6b87b2fb39e6383ca28e2829cddef1d9fc9e27e55ad91ca9c435572cdba51bf"}, + {file = "pandas-1.3.5-cp38-cp38-win_amd64.whl", hash = "sha256:a395692046fd8ce1edb4c6295c35184ae0c2bbe787ecbe384251da609e27edcb"}, + {file = "pandas-1.3.5-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:bd971a3f08b745a75a86c00b97f3007c2ea175951286cdda6abe543e687e5f2f"}, + {file = "pandas-1.3.5-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:37f06b59e5bc05711a518aa10beaec10942188dccb48918bb5ae602ccbc9f1a0"}, + {file = "pandas-1.3.5-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2c21778a688d3712d35710501f8001cdbf96eb70a7c587a3d5613573299fdca6"}, + {file = "pandas-1.3.5-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:3345343206546545bc26a05b4602b6a24385b5ec7c75cb6059599e3d56831da2"}, + {file = "pandas-1.3.5-cp39-cp39-win32.whl", hash = "sha256:c69406a2808ba6cf580c2255bcf260b3f214d2664a3a4197d0e640f573b46fd3"}, + {file = "pandas-1.3.5-cp39-cp39-win_amd64.whl", hash = "sha256:32e1a26d5ade11b547721a72f9bfc4bd113396947606e00d5b4a5b79b3dcb006"}, + {file = "pandas-1.3.5.tar.gz", hash = "sha256:1e4285f5de1012de20ca46b188ccf33521bff61ba5c5ebd78b4fb28e5416a9f1"}, +] pluggy = [ {file = "pluggy-1.0.0-py2.py3-none-any.whl", hash = "sha256:74134bbf457f031a36d68416e1509f34bd5ccc019f0bcc952c7b909d06b37bd3"}, {file = "pluggy-1.0.0.tar.gz", hash = "sha256:4224373bacce55f955a878bf9cfa763c1e360858e330072059e10bad68531159"}, @@ -319,13 +394,25 @@ pytest-cov = [ {file = "pytest-cov-3.0.0.tar.gz", hash = "sha256:e7f0f5b1617d2210a2cabc266dfe2f4c75a8d32fb89eafb7ad9d06f6d076d470"}, {file = "pytest_cov-3.0.0-py3-none-any.whl", hash = "sha256:578d5d15ac4a25e5f961c938b85a05b09fdaae9deef3bb6de9a6e766622ca7a6"}, ] +python-dateutil = [ + {file = "python-dateutil-2.8.2.tar.gz", hash = "sha256:0123cacc1627ae19ddf3c27a5de5bd67ee4586fbdd6440d9748f8abb483d3e86"}, + {file = "python_dateutil-2.8.2-py2.py3-none-any.whl", hash = "sha256:961d03dc3453ebbc59dbdea9e4e11c5651520a876d0f4db161e8674aae935da9"}, +] +pytz = [ + {file = "pytz-2022.4-py2.py3-none-any.whl", hash = "sha256:2c0784747071402c6e99f0bafdb7da0fa22645f06554c7ae06bf6358897e9c91"}, + {file = "pytz-2022.4.tar.gz", hash = "sha256:48ce799d83b6f8aab2020e369b627446696619e79645419610b9facd909b3174"}, +] +six = [ + {file = "six-1.16.0-py2.py3-none-any.whl", hash = "sha256:8abb2f1d86890a2dfb989f9a77cfcfd3e47c2a354b01111771326f8aa26e0254"}, + {file = "six-1.16.0.tar.gz", hash = "sha256:1e61c37477a1626458e36f7b1d82aa5c9b094fa4802892072e49de9c60c4c926"}, +] tomli = [ {file = "tomli-2.0.1-py3-none-any.whl", hash = "sha256:939de3e7a6161af0c887ef91b7d41a53e7c5a1ca976325f429cb46ea9bc30ecc"}, {file = "tomli-2.0.1.tar.gz", hash = "sha256:de526c12914f0c550d15924c62d72abc48d6fe7364aa87328337a31007fe8a4f"}, ] typing-extensions = [ - {file = "typing_extensions-4.3.0-py3-none-any.whl", hash = "sha256:25642c956049920a5aa49edcdd6ab1e06d7e5d467fc00e0506c44ac86fbfca02"}, - {file = "typing_extensions-4.3.0.tar.gz", hash = "sha256:e6d2677a32f47fc7eb2795db1dd15c1f34eff616bcaf2cfb5e997f854fa1c4a6"}, + {file = "typing_extensions-4.4.0-py3-none-any.whl", hash = "sha256:16fa4864408f655d35ec496218b85f79b3437c829e93320c7c9215ccfd92489e"}, + {file = "typing_extensions-4.4.0.tar.gz", hash = "sha256:1511434bb92bf8dd198c12b1cc812e800d4181cfcb867674e0f8279cc93087aa"}, ] zipp = [ {file = "zipp-3.8.1-py3-none-any.whl", hash = "sha256:47c40d7fe183a6f21403a199b3e4192cca5774656965b0a4988ad2f8feb5f009"}, diff --git a/pyproject.toml b/pyproject.toml index 924b8d3..d7639d3 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,13 +7,14 @@ authors = ["pwwang "] license = "MIT" [tool.poetry.dependencies] -python = "^3.7" +python = "^3.7.1" executing = "^1.0" [tool.poetry.dev-dependencies] pytest = "^7" pytest-cov = "^3" numpy = "^1.20" +pandas = "^1.3" [build-system] requires = ["poetry>=0.12"] @@ -30,6 +31,10 @@ strict_optional = false addopts = "-vv -W error::UserWarning -p no:asyncio --cov-config=.coveragerc --cov=pipda --cov-report xml:.coverage.xml --cov-report term-missing" console_output_style = "progress" junit_family = "xunit1" +filterwarnings = [ + # "error", + "ignore::DeprecationWarning", +] [tool.black] line-length = 80 diff --git a/tests/test_piping.py b/tests/test_piping.py new file mode 100644 index 0000000..6f11d6d --- /dev/null +++ b/tests/test_piping.py @@ -0,0 +1,153 @@ +import pytest +from pipda.operator import OperatorCall +from pipda.verb import register_verb +from pipda.piping import ( + register_piping, + patch_classes, + unpatch_classes, + PATCHED_CLASSES, +) + + +def test_register_piping(): + + @register_verb(int) + def incre(x): + return x + 1 + + out = 1 >> incre() + assert out == 2 and isinstance(out, int) + + register_piping("|") + with pytest.raises(TypeError): + 1 >> incre() + out = 1 | incre() + assert out == 2 and isinstance(out, int) + + register_piping(">>") + out = 1 >> incre() + assert out == 2 and isinstance(out, int) + + with pytest.raises(ValueError): + register_piping("123") + + +def test_patching(): + + class Data: + def __init__(self, x): + self.x = x + + def __rshift__(self, other): + return self.x + other + + @register_verb(Data) + def incre(x): + return x.x + 1 + + assert Data(1) >> 2 == 3 + out = Data(1) >> incre() + # __rshift__ is not patched + assert isinstance(out, OperatorCall) + + rshift = Data.__rshift__ + patch_classes(Data) + assert PATCHED_CLASSES[Data]["__rshift__"] is rshift + + out = Data(1) >> incre() + assert out == 2 and isinstance(out, int) + + assert Data.__rshift__ is not rshift + # But original __rshift__ still works + assert Data(1) >> 2 == 3 + + unpatch_classes(Data) + assert Data.__rshift__ is rshift + + # And the original __rshift__ still works + assert Data(1) >> 2 == 3 + # back to original + out = Data(1) >> incre() + assert isinstance(out, OperatorCall) + + register_piping("|") + # works without patching class as Data has no __or__ + out = Data(1) | incre() + assert out == 2 and isinstance(out, int) + + register_piping(">>") + # Since Data is unregistered + out = Data(1) >> incre() + assert isinstance(out, OperatorCall) + + +def test_patching_pandas(): + + import pandas as pd + + @register_verb(pd.DataFrame) + def incre(x): + return x + 1 + + df = pd.DataFrame({"a": [1, 2, 3]}) + out = df >> incre() + assert out.equals(pd.DataFrame({"a": [2, 3, 4]})) + + out = df | 1 + assert out.equals(pd.DataFrame({"a": [1, 3, 3]})) + + with pytest.raises(TypeError): + df | incre() + + register_piping("|") + out = df | incre() + assert out.equals(pd.DataFrame({"a": [2, 3, 4]})) + # Original still works + out = df | 1 + assert out.equals(pd.DataFrame({"a": [1, 3, 3]})) + + # Restore it for other tests + register_piping(">>") + + +def test_imethod(): + + @register_verb(int) + def incre(x): + return x + 1 + + a = 1 + a >>= incre() + assert a == 2 + + register_piping("|") + a = 1 + a |= incre() + assert a == 2 + + register_piping(">>") + + +def test_patch_imethod(): + + class Data: + def __init__(self, x): + self.x = x + + def __irshift__(self, other): + return self.x * other + + @register_verb(Data) + def incre(x): + return x.x + 1 + + a = Data(1) + a >>= incre() + assert a == 2 + + register_piping("|") + a = Data(1) + a |= incre() + assert a == 2 + + register_piping(">>") diff --git a/tests/test_verb.py b/tests/test_verb.py index 6e6ec70..ec50f5b 100644 --- a/tests/test_verb.py +++ b/tests/test_verb.py @@ -202,29 +202,6 @@ def length(data): length(1, 2) -def test_register_piping(): - - @register_verb(int) - def incre(x): - return x + 1 - - out = 1 >> incre() - assert out == 2 and isinstance(out, int) - - register_piping("|") - with pytest.raises(TypeError): - 1 >> incre() - out = 1 | incre() - assert out == 2 and isinstance(out, int) - - register_piping(">>") - out = 1 >> incre() - assert out == 2 and isinstance(out, int) - - with pytest.raises(ValueError): - register_piping("123") - - def test_registered(): @register_verb(int) From 6894a604265c093045a024edfa862f42a2cafc40 Mon Sep 17 00:00:00 2001 From: pwwang Date: Sat, 8 Oct 2022 00:58:11 -0700 Subject: [PATCH 2/5] =?UTF-8?q?=E2=9C=A8=20Auto-register=20numpy=20ufuncs?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pipda/expression.py | 33 +++++++++++++++++++++++++++++++-- pipda/function.py | 26 +++++++++++++++++++------- pipda/piping.py | 17 +++++++++-------- pipda/verb.py | 29 ++++++++++++++++++++--------- tests/test_expression.py | 19 +++++++++++++++++++ tests/test_function.py | 14 ++++++++++++++ tests/test_verb.py | 19 +++++++++++++++++-- 7 files changed, 129 insertions(+), 28 deletions(-) diff --git a/pipda/expression.py b/pipda/expression.py index db1ef9f..74478e3 100644 --- a/pipda/expression.py +++ b/pipda/expression.py @@ -3,12 +3,13 @@ from abc import ABC, abstractmethod from functools import partialmethod -from typing import Any, TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Callable from .context import ContextBase if TYPE_CHECKING: from .operator import OperatorCall + from .function import FunctionCall from .reference import ReferenceAttr, ReferenceItem OPERATORS = { @@ -55,7 +56,35 @@ class Expression(ABC): """The abstract Expression class""" _pipda_operator = None - __array_ufunc__ = None + + def __array_ufunc__( + self, + ufunc: Callable, + method: str, + *inputs: Any, + **kwargs: Any, + ) -> FunctionCall: + """Allow numpy ufunc to work on Expression objects""" + + from .piping import PIPING_OPS + from .verb import VerbCall + + if ( + ufunc.__name__ == PIPING_OPS[VerbCall.PIPING][2] + and isinstance(inputs[1], VerbCall) + and len(inputs) == 2 + and method == "__call__" + ): + # We can't patch numpy.ndarray + return inputs[1]._pipda_eval(inputs[0]) + + from .function import Function, FunctionCall + + if method == "reduce": + ufunc = ufunc.reduce + + fun = Function(ufunc, None, {}) + return FunctionCall(fun, *inputs, **kwargs) def __hash__(self) -> int: """Make it hashable""" diff --git a/pipda/function.py b/pipda/function.py index c6ffceb..0885105 100644 --- a/pipda/function.py +++ b/pipda/function.py @@ -65,15 +65,27 @@ def _pipda_eval(self, data: Any, context: ContextType = None) -> Any: }, ) - bound = func.bind_arguments(*self._pipda_args, **self._pipda_kwargs) context = func.contexts["_"] or context - extra_contexts = func.extra_contexts - for key, val in bound.arguments.items(): - ctx = extra_contexts["_"].get(key, context) - val = evaluate_expr(val, data, ctx) - bound.arguments[key] = val + extra_contexts = func.extra_contexts["_"] - return func.func(*bound.args, **bound.kwargs) + if extra_contexts: + bound = func.bind_arguments(*self._pipda_args, **self._pipda_kwargs) + + for key, val in bound.arguments.items(): + ctx = extra_contexts.get(key, context) + val = evaluate_expr(val, data, ctx) + bound.arguments[key] = val + + return func.func(*bound.args, **bound.kwargs) + + # we don't need signature if there is no extra context + return func.func( + *(evaluate_expr(arg, data, context) for arg in self._pipda_args), + **{ + key: evaluate_expr(val, data, context) + for key, val in self._pipda_kwargs.items() + }, + ) class Registered(ABC): diff --git a/pipda/piping.py b/pipda/piping.py index 6b597c8..04cf27c 100644 --- a/pipda/piping.py +++ b/pipda/piping.py @@ -5,14 +5,14 @@ from .verb import VerbCall PIPING_OPS = { - # op: (method, ast node) - ">>": ("__rrshift__", ast.RShift), - "|": ("__ror__", ast.BitOr), - "//": ("__rfloordiv__", ast.FloorDiv), - "@": ("__rmatmul__", ast.MatMult), - "%": ("__rmod__", ast.Mod), - "&": ("__rand__", ast.BitAnd), - "^": ("__rxor__", ast.BitXor), + # op: (method, ast node, numpy ufunc name) + ">>": ("__rrshift__", ast.RShift, "right_shift"), + "|": ("__ror__", ast.BitOr, "bitwise_or"), + "//": ("__rfloordiv__", ast.FloorDiv, "floor_divide"), + "@": ("__rmatmul__", ast.MatMult, "matmul"), + "%": ("__rmod__", ast.Mod, "remainder"), + "&": ("__rand__", ast.BitAnd, "bitwise_and"), + "^": ("__rxor__", ast.BitXor, "bitwise_xor"), } PATCHED_CLASSES = { @@ -115,6 +115,7 @@ def _unpatch_all(op: str) -> None: def _patch_default_classes() -> None: """Patch the default/commonly used classes""" + try: import pandas patch_classes( diff --git a/pipda/verb.py b/pipda/verb.py index 516d5b5..87a1562 100644 --- a/pipda/verb.py +++ b/pipda/verb.py @@ -80,17 +80,28 @@ def _pipda_eval(self, data: Any, context: ContextType = None) -> Any: self._pipda_func.extra_contexts.get(func, None) or self._pipda_func.extra_contexts["_"] ) - bound = self._pipda_func.bind_arguments( + if extra_contexts: + bound = self._pipda_func.bind_arguments( + data, + *self._pipda_args, + **self._pipda_kwargs, + ) + for key, val in bound.arguments.items(): + ctx = extra_contexts.get(key, context) + val = evaluate_expr(val, data, ctx) + bound.arguments[key] = val + + return func(*bound.args, **bound.kwargs) + + # we don't need signature if there is no extra context + return func( data, - *self._pipda_args, - **self._pipda_kwargs, + *(evaluate_expr(arg, data, context) for arg in self._pipda_args), + **{ + key: evaluate_expr(val, data, context) + for key, val in self._pipda_kwargs.items() + }, ) - for key, val in bound.arguments.items(): - ctx = extra_contexts.get(key, context) - val = evaluate_expr(val, data, ctx) - bound.arguments[key] = val - - return func(*bound.args, **bound.kwargs) class Verb(Registered): diff --git a/tests/test_expression.py b/tests/test_expression.py index 4525ca3..f15ab1b 100644 --- a/tests/test_expression.py +++ b/tests/test_expression.py @@ -90,3 +90,22 @@ def test_test_pipda_attr(): f = Expr() assert not hasattr(f, "_pipda_xyz") + + +def test_ufunc(): + f = Symbolic() + x = np.sqrt(f) + assert isinstance(x, FunctionCall) + + out = x._pipda_eval(4, Context.EVAL) + assert out == 2 + + out = x._pipda_eval([1, 4], Context.EVAL) + assert out[0] == 1 + assert out[1] == 2 + + x = np.multiply.reduce(f) + assert isinstance(x, FunctionCall) + + out = x._pipda_eval([1, 2, 3], Context.EVAL) + assert out == 6 diff --git a/tests/test_function.py b/tests/test_function.py index f82ad13..f89a226 100644 --- a/tests/test_function.py +++ b/tests/test_function.py @@ -63,6 +63,20 @@ def add(x, y): assert out == 3 and isinstance(out, int) +def test_extra_contexts(): + + @register_func( + context=Context.EVAL, + extra_contexts={"plus": Context.SELECT}, + ) + def add(x, plus): + return f"{x} + {plus}" + + f = Symbolic() + expr = add(f["a"], f["b"]) + assert expr._pipda_eval({"a": 1, "b": 2}) == "1 + b" + + def test_meta(): name = "myfun" qualname = "mypackage.myfun" diff --git a/tests/test_verb.py b/tests/test_verb.py index ec50f5b..af73c0f 100644 --- a/tests/test_verb.py +++ b/tests/test_verb.py @@ -2,6 +2,7 @@ import pytest import numpy as np +from pipda.piping import register_piping from pipda.symbolic import Symbolic from pipda.context import Context from pipda.verb import * @@ -199,7 +200,7 @@ def length(data): return len(data) with pytest.raises(TypeError): - length(1, 2) + length([1], 2) def test_registered(): @@ -275,4 +276,18 @@ def sum_(data): return data.sum() s = np.array([1, 2, 3]) >> sum_() - assert s == 6 + assert s == 6 and isinstance(s, np.integer) + + register_piping("|") + + s = np.array([1, 2, 3]) | sum_() + assert s == 6 and isinstance(s, np.integer) + + register_piping(">>") + + @register_verb(np.ndarray) + def sum2(data, n): + return data.sum() + n + + s = np.array([1, 2, 3]) >> sum2(1) + assert s == 7 and isinstance(s, np.integer) From a37b2f0124ddc55e96b0efc0f2c53b1cf13d82e1 Mon Sep 17 00:00:00 2001 From: pwwang <1188067+pwwang@users.noreply.github.com> Date: Sat, 8 Oct 2022 09:34:27 -0700 Subject: [PATCH 3/5] Fix linting --- pipda/piping.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pipda/piping.py b/pipda/piping.py index 04cf27c..3a9b9de 100644 --- a/pipda/piping.py +++ b/pipda/piping.py @@ -1,6 +1,6 @@ import ast import functools -from typing import Type +from typing import Type, Dict, Callable from .verb import VerbCall @@ -15,7 +15,7 @@ "^": ("__rxor__", ast.BitXor, "bitwise_xor"), } -PATCHED_CLASSES = { +PATCHED_CLASSES: Dict[Type, Dict[str, Callable]] = { # kls: # {} # registered but not patched # {"method": , "imethod": } # patched From 4a523ca71c5d846d2893cb29e9eca61cc69f82e6 Mon Sep 17 00:00:00 2001 From: pwwang <1188067+pwwang@users.noreply.github.com> Date: Sat, 8 Oct 2022 16:55:12 +0000 Subject: [PATCH 4/5] pump executing to 1.1.1 to fix pwwang/datar#149 --- poetry.lock | 8 ++++---- pyproject.toml | 2 +- 2 files changed, 5 insertions(+), 5 deletions(-) diff --git a/poetry.lock b/poetry.lock index 2b59a1c..a755b22 100644 --- a/poetry.lock +++ b/poetry.lock @@ -36,7 +36,7 @@ toml = ["tomli"] [[package]] name = "executing" -version = "1.1.0" +version = "1.1.1" description = "Get the currently executing AST node of a frame, and other information" category = "main" optional = false @@ -238,7 +238,7 @@ testing = ["func-timeout", "jaraco.itertools", "pytest (>=6)", "pytest-black (>= [metadata] lock-version = "1.1" python-versions = "^3.7.1" -content-hash = "2c249d7329310a2795b598cbb2e8dced0a71bd66bd561ca72228c966e4cf0bbf" +content-hash = "ec66e5e020f4d213b88b9a1845878bdbc0a2a34f43c89050fbf7cfe2c4b9d328" [metadata.files] attrs = [ @@ -302,8 +302,8 @@ coverage = [ {file = "coverage-6.5.0.tar.gz", hash = "sha256:f642e90754ee3e06b0e7e51bce3379590e76b7f76b708e1a71ff043f87025c84"}, ] executing = [ - {file = "executing-1.1.0-py2.py3-none-any.whl", hash = "sha256:4a6d96ba89eb3dcc11483471061b42b9006d8c9f81c584dd04246944cd022530"}, - {file = "executing-1.1.0.tar.gz", hash = "sha256:2c2c07d1ec4b2d8f9676b25170f1d8445c0ee2eb78901afb075a4b8d83608c6a"}, + {file = "executing-1.1.1-py2.py3-none-any.whl", hash = "sha256:236ea5f059a38781714a8bfba46a70fad3479c2f552abee3bbafadc57ed111b8"}, + {file = "executing-1.1.1.tar.gz", hash = "sha256:b0d7f8dcc2bac47ce6e39374397e7acecea6fdc380a6d5323e26185d70f38ea8"}, ] importlib-metadata = [ {file = "importlib_metadata-5.0.0-py3-none-any.whl", hash = "sha256:ddb0e35065e8938f867ed4928d0ae5bf2a53b7773871bfe6bcc7e4fcdc7dea43"}, diff --git a/pyproject.toml b/pyproject.toml index d7639d3..75cb8ae 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -8,7 +8,7 @@ license = "MIT" [tool.poetry.dependencies] python = "^3.7.1" -executing = "^1.0" +executing = "^1.1.1" [tool.poetry.dev-dependencies] pytest = "^7" From d12414f580f58dfdb8c4d192a199b8b3997e3494 Mon Sep 17 00:00:00 2001 From: pwwang <1188067+pwwang@users.noreply.github.com> Date: Sat, 8 Oct 2022 17:58:05 +0000 Subject: [PATCH 5/5] 0.8.0 --- docs/CHANGELOG.md | 6 ++++++ pipda/__init__.py | 2 +- pyproject.toml | 2 +- 3 files changed, 8 insertions(+), 2 deletions(-) diff --git a/docs/CHANGELOG.md b/docs/CHANGELOG.md index 8613f13..daa21f7 100644 --- a/docs/CHANGELOG.md +++ b/docs/CHANGELOG.md @@ -1,5 +1,11 @@ # Change Log +## 0.8.0 + +- patch classes if they have piping operator method +- auto register numpy ufuncs +- pump executing to 1.1.1 to fix pwwang/datar#149 + ## 0.7.6 - 🐛 Fix `numpy.ndarray` as data argument for verbs diff --git a/pipda/__init__.py b/pipda/__init__.py index e4909dd..6d75d80 100644 --- a/pipda/__init__.py +++ b/pipda/__init__.py @@ -12,4 +12,4 @@ ) from .piping import register_piping -__version__ = "0.7.6" +__version__ = "0.8.0" diff --git a/pyproject.toml b/pyproject.toml index 75cb8ae..ba38a54 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "pipda" -version = "0.7.6" +version = "0.8.0" readme = "README.md" description = "A framework for data piping in python" authors = ["pwwang "]