Skip to content

Commit

Permalink
comparison with Polar3DVector needs more work, save for now
Browse files Browse the repository at this point in the history
  • Loading branch information
ianna committed Dec 21, 2021
1 parent 94a6d5f commit 4894cf8
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 114 deletions.
13 changes: 8 additions & 5 deletions src/vector/_compute/spatial/dot.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@

# specialized
def xy_z_xy_z(lib, x1, y1, z1, x2, y2, z2):
return x1 * x2 + y1 * y2 + z1 * z2
return lib.nan_to_num(x1 * x2 + y1 * y2 + z1 * z2, nan=0.0)


def xy_z_xy_theta(lib, x1, y1, z1, x2, y2, theta2):
Expand Down Expand Up @@ -277,7 +277,7 @@ def rhophi_z_xy_eta(lib, rho1, phi1, z1, x2, y2, eta2):

# specialized
def rhophi_z_rhophi_z(lib, rho1, phi1, z1, rho2, phi2, z2):
return rho1 * rho2 * lib.cos(phi1 - phi2) + z1 * z2
return lib.nan_to_num(rho1 * rho2 * lib.cos(phi1 - phi2) + z1 * z2, nan=0.0)


def rhophi_z_rhophi_theta(lib, rho1, phi1, z1, rho2, phi2, theta2):
Expand Down Expand Up @@ -336,8 +336,9 @@ def rhophi_theta_rhophi_z(lib, rho1, phi1, theta1, rho2, phi2, z2):

# specialized
def rhophi_theta_rhophi_theta(lib, rho1, phi1, theta1, rho2, phi2, theta2):
return (
rho1 * rho2 * (lib.cos(phi1 - phi2) + 1 / (lib.tan(theta1) * lib.tan(theta2)))
return lib.nan_to_num(
rho1 * rho2 * (lib.cos(phi1 - phi2) + 1 / (lib.tan(theta1) * lib.tan(theta2))),
nan=0,
)


Expand Down Expand Up @@ -407,7 +408,9 @@ def rhophi_eta_rhophi_eta(lib, rho1, phi1, eta1, rho2, phi2, eta2):
expmeta2 = lib.exp(-eta2)
invtantheta1 = 0.5 * (1 - expmeta1 ** 2) / expmeta1
invtantheta2 = 0.5 * (1 - expmeta2 ** 2) / expmeta2
return rho1 * rho2 * (lib.cos(phi1 - phi2) + invtantheta1 * invtantheta2)
return lib.nan_to_num(
rho1 * rho2 * (lib.cos(phi1 - phi2) + invtantheta1 * invtantheta2), nan=0.0
)


dispatch_map = {
Expand Down
2 changes: 1 addition & 1 deletion src/vector/_compute/spatial/eta.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ def rhophi_z(lib, rho, phi, z):


def rhophi_theta(lib, rho, phi, theta):
return -lib.log(lib.tan(0.5 * theta))
return lib.nan_to_num(-lib.log(lib.tan(0.5 * theta)), nan=0.0)


def rhophi_eta(lib, rho, phi, eta):
Expand Down
77 changes: 6 additions & 71 deletions tests/root/test_Polar2DVector.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,9 @@ def test_Dot(constructor, coordinates):
vector.obj(**dict(zip(["rho", "phi"], constructor))), coordinates
)().dot(
getattr(vector.obj(**dict(zip(["rho", "phi"], constructor))), coordinates)()
)
),
1.0e-6,
1.0e-6,
)


Expand All @@ -91,18 +93,10 @@ def test_Dot(constructor, coordinates):
constructor1=st.tuples(
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
)
| st.tuples(
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
),
constructor2=st.tuples(
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
)
| st.tuples(
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
),
)
def test_fuzz_Dot(constructor1, constructor2, coordinates):
Expand Down Expand Up @@ -137,10 +131,6 @@ def test_Mag2(constructor, coordinates):
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
)
| st.tuples(
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
)
)
def test_fuzz_Mag2(constructor, coordinates):
assert ROOT.Math.Polar2DVector(*constructor).Mag2() == pytest.approx(
Expand All @@ -166,10 +156,6 @@ def test_Mag(constructor, coordinates):
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
)
| st.tuples(
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
)
)
def test_fuzz_Mag(constructor, coordinates):
assert ROOT.Math.sqrt(
Expand All @@ -193,10 +179,6 @@ def test_Phi(constructor, coordinates):
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
)
| st.tuples(
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
)
)
def test_fuzz_Phi(constructor, coordinates):
assert ROOT.Math.Polar2DVector(*constructor).Phi() == pytest.approx(
Expand Down Expand Up @@ -230,13 +212,8 @@ def test_Rotate(constructor, angle, coordinates):
constructor=st.tuples(
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
)
| st.tuples(
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
),
angle=st.floats(min_value=-10e7, max_value=10e7)
| st.integers(min_value=-10e7, max_value=10e7),
angle=st.floats(min_value=-10e7, max_value=10e7),
)
def test_fuzz_Rotate(constructor, angle, coordinates):
ref_vec = ROOT.Math.Polar2DVector(*constructor)
Expand Down Expand Up @@ -271,10 +248,6 @@ def test_Unit(constructor, coordinates):
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
)
| st.tuples(
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
)
)
def test_fuzz_Unit(constructor, coordinates):
ref_vec = ROOT.Math.Polar2DVector(*constructor).Unit()
Expand All @@ -298,10 +271,6 @@ def test_X_and_Y(constructor, coordinates):
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
)
| st.tuples(
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
)
)
def test_fuzz_X_and_Y(constructor, coordinates):
ref_vec = ROOT.Math.Polar2DVector(*constructor)
Expand Down Expand Up @@ -337,18 +306,10 @@ def test_add(constructor, coordinates):
constructor1=st.tuples(
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
)
| st.tuples(
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
),
constructor2=st.tuples(
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
)
| st.tuples(
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
),
)
def test_fuzz_add(constructor1, constructor2, coordinates):
Expand Down Expand Up @@ -398,18 +359,10 @@ def test_sub(constructor, coordinates):
constructor1=st.tuples(
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
)
| st.tuples(
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
),
constructor2=st.tuples(
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
)
| st.tuples(
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
),
)
def test_fuzz_sub(constructor1, constructor2, coordinates):
Expand Down Expand Up @@ -448,10 +401,6 @@ def test_neg(constructor, coordinates):
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
)
| st.tuples(
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
)
)
def test_fuzz_neg(constructor, coordinates):
ref_vec = ROOT.Math.Polar2DVector(*constructor).__neg__()
Expand All @@ -478,13 +427,8 @@ def test_mul(constructor, scalar, coordinates):
constructor=st.tuples(
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
)
| st.tuples(
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
),
scalar=st.floats(min_value=-10e7, max_value=10e7)
| st.integers(min_value=-10e7, max_value=10e7),
scalar=st.floats(min_value=-10e7, max_value=10e7),
)
def test_fuzz_mul(constructor, scalar, coordinates):
ref_vec = ROOT.Math.Polar2DVector(*constructor).__mul__(scalar)
Expand Down Expand Up @@ -513,13 +457,8 @@ def test_truediv(constructor, scalar, coordinates):
constructor=st.tuples(
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
)
| st.tuples(
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
),
scalar=st.floats(min_value=-10e7, max_value=10e7)
| st.integers(min_value=-10e7, max_value=10e7),
scalar=st.floats(min_value=-10e7, max_value=10e7),
)
def test_fuzz_truediv(constructor, scalar, coordinates):
# FIXME:
Expand Down Expand Up @@ -552,10 +491,6 @@ def test_eq(constructor, coordinates):
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
)
| st.tuples(
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
)
)
def test_fuzz_eq(constructor, coordinates):
ref_vec = ROOT.Math.Polar2DVector(*constructor).__eq__(
Expand Down
84 changes: 47 additions & 37 deletions tests/root/test_Polar3DVector.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,18 @@
ROOT = pytest.importorskip("ROOT")

# ROOT.Math.Polar3DVector constructor arguments to get all the weird cases.
# "rho", "theta", "phi"
# Phi is restricted to be in the range [-PI,PI)
constructor = [
(0, 0, 0),
(0, 10, 0),
(0, -10, 0),
(1, 0, 0),
(1, 10, 0),
(1, -10, 0),
(1.0, 2.5, 2.0),
(1, 2.5, 2.0),
(1, -2.5, 2.0),
(0.0, 0.0, 0.0),
# (0.0, 10.0, 0.0),
# (0.0, -10.0, 0.0),
(1.0, 0.0, 0.0),
# (1.0, 10.0, 0.0),
# (1.0, -10.0, 0.0),
# (1.0, 2.5, 2.0),
# (1.0, 2.5, 2.0),
# (1.0, -2.5, 2.0),
]

# Coordinate conversion methods to apply to the VectorObject2D.
Expand Down Expand Up @@ -97,22 +99,12 @@ def test_Dot(constructor, coordinates):
constructor1=st.tuples(
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
)
| st.tuples(
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
st.floats(min_value=-ROOT.Math.Pi(), max_value=ROOT.Math.Pi()),
),
constructor2=st.tuples(
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
st.floats(min_value=-10e7, max_value=10e7),
)
| st.tuples(
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
st.integers(min_value=-10e7, max_value=10e7),
st.floats(min_value=-ROOT.Math.Pi(), max_value=ROOT.Math.Pi()),
),
)
def test_fuzz_Dot(constructor1, constructor2, coordinates):
Expand Down Expand Up @@ -144,21 +136,21 @@ def test_Cross(constructor, coordinates):
)()
)
assert (
ref_vec.Rho()
ref_vec.X()
== pytest.approx(
vec.rho,
vec.x,
1.0e-6,
1.0e-6,
)
and ref_vec.Theta()
and ref_vec.Y()
== pytest.approx(
vec.theta,
vec.y,
1.0e-6,
1.0e-6,
)
and ref_vec.Phi()
and ref_vec.Z()
== pytest.approx(
vec.phi,
vec.z,
1.0e-6,
1.0e-6,
)
Expand Down Expand Up @@ -200,21 +192,21 @@ def test_fuzz_Cross(constructor1, constructor2, coordinates):
)()
)
assert (
ref_vec.Rho()
ref_vec.X()
== pytest.approx(
vec.rho,
vec.x,
1.0e-6,
1.0e-6,
)
and ref_vec.Theta()
and ref_vec.Y()
== pytest.approx(
vec.theta,
vec.y,
1.0e-6,
1.0e-6,
)
and ref_vec.Phi()
and ref_vec.Z()
== pytest.approx(
vec.phi,
vec.z,
1.0e-6,
1.0e-6,
)
Expand All @@ -234,9 +226,7 @@ def test_Mag2(constructor, coordinates):
# Run a test that compares ROOT's 'Mag()' with vector's 'mag' for all cases.
@pytest.mark.parametrize("constructor", constructor)
def test_R(constructor, coordinates):
assert ROOT.Math.sqrt(
ROOT.Math.Polar3DVector(*constructor).Mag2()
) == pytest.approx(
assert ROOT.Math.Polar3DVector(*constructor).R() == pytest.approx(
getattr(
vector.obj(**dict(zip(["rho", "theta", "phi"], constructor))),
coordinates,
Expand Down Expand Up @@ -329,6 +319,26 @@ def test_RotateZ(constructor, angle, coordinates):
vector.obj(**dict(zip(["rho", "theta", "phi"], constructor))), coordinates
)()
res_vec = vec.rotateZ(angle)
assert (
ref_vec.R()
== pytest.approx(
vec.rho,
1.0e-6,
1.0e-6,
)
and ref_vec.Theta()
== pytest.approx(
vec.theta,
1.0e-6,
1.0e-6,
)
and ref_vec.Phi()
== pytest.approx(
vec.phi,
1.0e-6,
1.0e-6,
)
)
assert ref_vec.X() == pytest.approx(res_vec.x)
assert ref_vec.Y() == pytest.approx(res_vec.y)
assert ref_vec.Z() == pytest.approx(res_vec.z)
Expand All @@ -342,7 +352,7 @@ def test_RotateAxes(constructor, angle, coordinates):
)()
# FIXME: rotate_axis
assert (
ref_vec.Rho()
ref_vec.R()
== pytest.approx(
vec.rho,
1.0e-6,
Expand Down

0 comments on commit 4894cf8

Please sign in to comment.