From bf1c8746812d55cb9cec23e954188044b9a7fcfc Mon Sep 17 00:00:00 2001 From: Saransh Chopra Date: Sat, 20 Jul 2024 01:51:10 +0530 Subject: [PATCH] fix: refine the implementation of copy_behaviors (#3177) --- src/awkward/_util.py | 15 +++++++-------- tests/test_2433_copy_behaviors.py | 22 ++++++++-------------- 2 files changed, 15 insertions(+), 22 deletions(-) diff --git a/src/awkward/_util.py b/src/awkward/_util.py index 89ea936ae8..110da2e22d 100644 --- a/src/awkward/_util.py +++ b/src/awkward/_util.py @@ -6,7 +6,6 @@ import os import struct import sys -import typing from collections.abc import Collection import numpy as np # noqa: TID251 @@ -105,16 +104,16 @@ def unique_list(items: Collection[T]) -> list[T]: return result -def copy_behaviors(existing_class: typing.Any, new_class: typing.Any, behavior: dict): +def copy_behaviors(from_name: str, to_name: str, behavior: dict): output = {} - oldname = existing_class.__name__ - newname = new_class.__name__ - for key, value in behavior.items(): - if oldname in key: - if not isinstance(key, str) and "*" not in key: - new_tuple = tuple(newname if k == oldname else k for k in key) + if isinstance(key, str): + if key == from_name: + output[to_name] = value + else: + if from_name in key: + new_tuple = tuple(to_name if k == from_name else k for k in key) output[new_tuple] = value return output diff --git a/tests/test_2433_copy_behaviors.py b/tests/test_2433_copy_behaviors.py index c52decc09e..30efb7f1d3 100644 --- a/tests/test_2433_copy_behaviors.py +++ b/tests/test_2433_copy_behaviors.py @@ -62,6 +62,12 @@ def __eq__(self, other): ak.behavior[numpy.add, "VectorTwoD", "VectorTwoD"] = lambda v1, v2: v1.add(v2) assert v + v == v_added + # instead of registering every operator again, just copy the behaviors of + # another class to this class + ak.behavior.update( + ak._util.copy_behaviors("VectorTwoD", "VectorTwoDAgain", ak.behavior) + ) + # second sub-class @ak.mixin_class(ak.behavior) class VectorTwoDAgain(VectorTwoD): @@ -81,17 +87,14 @@ class VectorTwoDAgain(VectorTwoD): with_name="VectorTwoDAgain", behavior=ak.behavior, ) - # add method works but the binary operator does not assert v.add(v) == v_added - with pytest.raises(TypeError): - v + v + assert v + v == v_added # instead of registering every operator again, just copy the behaviors of # another class to this class ak.behavior.update( - ak._util.copy_behaviors(VectorTwoD, VectorTwoDAgain, ak.behavior) + ak._util.copy_behaviors("VectorTwoDAgain", "VectorTwoDAgainAgain", ak.behavior) ) - assert v + v == v_added # third sub-class @ak.mixin_class(ak.behavior) @@ -112,14 +115,5 @@ class VectorTwoDAgainAgain(VectorTwoDAgain): with_name="VectorTwoDAgainAgain", behavior=ak.behavior, ) - # add method works but the binary operator does not assert v.add(v) == v_added - with pytest.raises(TypeError): - v + v - - # instead of registering every operator again, just copy the behaviors of - # another class to this class - ak.behavior.update( - ak._util.copy_behaviors(VectorTwoDAgain, VectorTwoDAgainAgain, ak.behavior) - ) assert v + v == v_added