-
Notifications
You must be signed in to change notification settings - Fork 89
/
Copy pathtest_0355-mixins.py
104 lines (94 loc) · 3.33 KB
/
test_0355-mixins.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
# BSD 3-Clause License; see https://github.com/scikit-hep/awkward-1.0/blob/main/LICENSE
import pytest # noqa: F401
import numpy as np # noqa: F401
import awkward as ak # noqa: F401
def test_make_mixins():
@ak.mixin_class(ak.behavior)
class Point:
def distance(self, other):
return np.sqrt((self.x - other.x) ** 2 + (self.y - other.y) ** 2)
@ak.mixin_class_method(np.equal, {"Point"})
def point_equal(self, other):
return np.logical_and(self.x == other.x, self.y == other.y)
@ak.mixin_class_method(np.abs)
def point_abs(self):
return np.sqrt(self.x**2 + self.y**2)
@ak.mixin_class_method(np.add, {"Point"})
def point_add(self, other):
return ak.zip(
{"x": self.x + other.x, "y": self.y + other.y},
with_name="Point",
)
@ak.mixin_class(ak.behavior)
class WeightedPoint(Point):
@ak.mixin_class_method(np.equal, {"WeightedPoint"})
def weighted_equal(self, other):
return np.logical_and(self.point_equal(other), self.weight == other.weight)
@ak.mixin_class_method(np.add, {"WeightedPoint"})
def weighted_add(self, other):
sumw = self.weight + other.weight
return ak.zip(
{
"x": (self.x * self.weight + other.x * other.weight) / sumw,
"y": (self.y * self.weight + other.y * other.weight) / sumw,
"weight": sumw,
},
with_name="WeightedPoint",
)
one = ak.Array(
[
[{"x": 1, "y": 1.1}, {"x": 2, "y": 2.2}, {"x": 3, "y": 3.3}],
[],
[{"x": 4, "y": 4.4}, {"x": 5, "y": 5.5}],
],
with_name="Point",
)
two = ak.Array(
[
[{"x": 0.9, "y": 1}, {"x": 2, "y": 2.2}, {"x": 2.9, "y": 3}],
[],
[{"x": 3.9, "y": 4}, {"x": 5, "y": 5.5}],
],
with_name="Point",
)
wone = ak.Array(ak.with_field(one, abs(one), "weight"), with_name="WeightedPoint")
wtwo = ak.Array(ak.with_field(two, abs(two), "weight"), with_name="WeightedPoint")
assert ak.to_list(one + wone) == [
[{"x": 2, "y": 2.2}, {"x": 4, "y": 4.4}, {"x": 6, "y": 6.6}],
[],
[{"x": 8, "y": 8.8}, {"x": 10, "y": 11.0}],
]
assert ak.to_list(wone + wtwo) == [
[
{
"x": 0.9524937500390619,
"y": 1.052493750039062,
"weight": 2.831969279439222,
},
{"x": 2.0, "y": 2.2, "weight": 5.946427498927402},
{
"x": 2.9516640394605282,
"y": 3.1549921183815837,
"weight": 8.632349833200564,
},
],
[],
[
{
"x": 3.9515600270076154,
"y": 4.206240108030463,
"weight": 11.533018588312771,
},
{"x": 5.0, "y": 5.5, "weight": 14.866068747318506},
],
]
assert ak.to_list(abs(one)) == [
[1.4866068747318506, 2.973213749463701, 4.459820624195552],
[],
[5.946427498927402, 7.433034373659253],
]
assert ak.to_list(one.distance(wtwo)) == [
[0.14142135623730953, 0.0, 0.31622776601683783],
[],
[0.4123105625617664, 0.0],
]