Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: new convenience methods on BeamInterface #1496

Merged
merged 8 commits into from
Nov 17, 2024
Merged
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ All notable changes to this project will be documented in this file.
## [Unreleased]

### Added
- New convenience methods on `BeamInterface` to simplify the handling of analytic vs UVBeam objects.
- Added support for partial read for MWA correlator FITS files.
- Added `antenna_names`, `time_range`, `lsts` and `lst_range` parameters to
`UVFlag.select` to match UVData and UVCal select methods.
Expand Down
6 changes: 5 additions & 1 deletion src/pyuvdata/analytic_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import dataclasses
import importlib
import warnings
from dataclasses import InitVar, astuple, dataclass, field
from dataclasses import InitVar, astuple, dataclass, field, replace
from typing import ClassVar, Literal

import numpy as np
Expand Down Expand Up @@ -189,6 +189,10 @@ def Npols(self): # noqa N802
"""The number of polarizations."""
return self.polarization_array.size

def clone(self, **kw):
"""Create a new instance of the object with updated parameters."""
return replace(self, **kw)

@property
def east_ind(self):
"""The index of the east feed in the feed array."""
Expand Down
132 changes: 129 additions & 3 deletions src/pyuvdata/beam_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@

import copy
import warnings
from dataclasses import InitVar, dataclass
from dataclasses import InitVar, asdict, dataclass, replace
from itertools import product
from typing import Literal

import numpy as np
import numpy.typing as npt

from .analytic_beam import AnalyticBeam
from .utils import pol as upol
from .uvbeam import UVBeam

# Other methods we may want to include:
Expand All @@ -32,8 +34,9 @@ class BeamInterface:

Attributes
----------
beam : pyuvdata.UVBeam or pyuvdata.AnalyticBeam
Beam object to use for computations
beam : pyuvdata.UVBeam or pyuvdata.AnalyticBeam or BeamInterface
Beam object to use for computations. If a BeamInterface is passed, a new
view of the same object is created.
beam_type : str
The beam type, either "efield" or "power".
include_cross_pols : bool
Expand All @@ -59,6 +62,11 @@ def __post_init__(self, include_cross_pols: bool):
for the power beam.

"""
if isinstance(self.beam, BeamInterface):
self.beam = self.beam.beam
self.__post_init__(include_cross_pols=include_cross_pols)
return

if not isinstance(self.beam, UVBeam) and not issubclass(
type(self.beam), AnalyticBeam
):
Expand All @@ -82,6 +90,124 @@ def __post_init__(self, include_cross_pols: bool):
"efield beam, either provide an efield UVBeam or do not "
"specify `beam_type`."
)
elif self.beam_type is None:
self.beam_type = "efield"

@property
def Npols(self): # noqa N802
"""The number of polarizations in the beam."""
return self.beam.Npols or len(self.polarization_array)

@property
def polarization_array(self):
"""The polarizations defined on the beam."""
return self.beam.polarization_array

@property
def feed_array(self):
"""The feeds for which the beam is defined."""
return self.beam.feed_array

@property
def Nfeeds(self): # noqa N802
"""The number of feeds defined on the beam."""
return self.beam.Nfeeds or len(self.feed_array)

def clone(self, **kw):
"""Return a new instance with updated parameters."""
return replace(self, **kw)

def as_power_beam(
self, include_cross_pols: bool | None = None, allow_beam_mutation: bool = False
):
"""Return a new interface instance that is in the power-beam mode.

If already in the power-beam mode, this is a no-op. Note that this might be
slighty unexpected, because the effect of `include_cross_pols` is not accounted
for in this case.
bhazelton marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
include_cross_pols : bool, optional
Whether to include cross-pols in the power beam. By default, this is True
for E-field beams, and takes the same value as the existing beam if the
existing beam is already a power beam.
allow_beam_mutation : bool, optional
Whether to allow the underlying beam to be updated in-place.
"""
if self.beam_type == "power":
if include_cross_pols is None:
# By default, keep the value of include_cross_pols the same.
include_cross_pols = self.Npols > 2

if self.Npols > 1 and (
(include_cross_pols and self.Npols != 4)
or (not include_cross_pols and self.Npols != 2)
):
warnings.warn(
"as_power_beam does not modify cross pols when the beam is"
f"already in power mode! You have polarizations: "
f"{self.polarization_array} but asked to "
f"*{'include' if include_cross_pols else 'not include'}* "
"cross-pols."
)
return self

if include_cross_pols is None:
include_cross_pols = True

beam = self.beam if allow_beam_mutation else copy.deepcopy(self.beam)

# We cannot simply use .clone() here, because we need to be able to pass
# include_cross_pols, which can only be passed to the constructor proper.
this = asdict(self)
this["beam"] = beam
this["beam_type"] = "power"
this["include_cross_pols"] = include_cross_pols
with warnings.catch_warnings():
# Don't emit the warning that we're converting to power, because that is
# explicitly desired.
warnings.simplefilter("ignore", UserWarning)
return BeamInterface(**this)

def with_feeds(self, feeds, *, maintain_ordering: bool = True):
"""Return a new interface instance with updated feed_array.

Parameters
----------
feeds : array_like of str
The feeds to keep in the beam. Each value should be a string, e.g. 'n', 'x'.
maintain_ordering : bool, optional
If True, maintain the same polarization ordering as in the beam currently.
If False, change ordering to match the input feeds, which are turned into
pols (if a power beam) by using product(feeds, feeds).
"""
if not self._isuvbeam:
if maintain_ordering:
feeds = [fd for fd in self.feed_array if fd in feeds]
return self.clone(beam=self.beam.clone(feed_array=feeds))
if self.beam_type == "power":
# Down-select polarizations based on the feeds input.
possible_pols = [f1 + f2 for f1, f2 in product(feeds, feeds)]
possible_pol_ints = upol.polstr2num(
possible_pols, x_orientation=self.beam.x_orientation
)
if maintain_ordering:
use_pols = [
p for p in self.beam.polarization_array if p in possible_pol_ints
]
else:
use_pols = [
p for p in possible_pol_ints if p in self.beam.polarization_array
]

new_beam = self.beam.select(polarizations=use_pols, inplace=False)
else:
if maintain_ordering:
feeds = [fd for fd in self.feed_array if fd in feeds]

new_beam = self.beam.select(feeds=feeds, inplace=False)
return self.clone(beam=new_beam)

@property
def _isuvbeam(self):
Expand Down
6 changes: 6 additions & 0 deletions tests/test_analytic_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,3 +568,9 @@ def test_single_feed():
beam = GaussianBeam(diameter=14.0, feed_array=["x"], include_cross_pols=True)
assert beam.feed_array == ["x"]
assert beam.polarization_array == [-5]


def test_clone():
beam = GaussianBeam(diameter=14.0, feed_array=["x", "y"])
new_beam = beam.clone(feed_array=["x"])
assert new_beam.feed_array == ["x"]
109 changes: 109 additions & 0 deletions tests/test_beam_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,22 @@
GaussianBeam,
ShortDipoleBeam,
UniformBeam,
UVBeam,
utils,
)
from pyuvdata.testing import check_warnings


@pytest.fixture(scope="function")
def airy() -> AiryBeam:
return AiryBeam(diameter=14.0)


@pytest.fixture()
def gaussian() -> GaussianBeam:
return GaussianBeam(diameter=14.0)


@pytest.fixture()
def xy_grid_coarse():
nfreqs = 5
Expand All @@ -34,6 +45,14 @@ def xy_grid_coarse():
return az_array, za_array, freqs


@pytest.fixture()
def gaussian_uv(gaussian, az_za_coords) -> UVBeam:
az, za = az_za_coords
return gaussian.to_uvbeam(
axis1_array=az, axis2_array=za, freq_array=np.array([1e8])
)


@pytest.mark.parametrize(
["beam_obj", "kwargs"],
[
Expand Down Expand Up @@ -243,3 +262,93 @@ def test_compute_response_errors(param, value):
else:
# this shouldn't error
bi_uvb.compute_response(**compute_kwargs)


@pytest.mark.parametrize("beam_obj", ["airy", "gaussian", "gaussian_uv"])
def test_idempotent_instantiation(beam_obj, request):
beam = BeamInterface(request.getfixturevalue(beam_obj))
beam2 = BeamInterface(beam)
assert beam == beam2


def test_properties(airy: AiryBeam):
intf = BeamInterface(airy)
assert airy.Npols == intf.Npols
assert airy.Nfeeds == intf.Nfeeds
assert np.all(airy.polarization_array == intf.polarization_array)
assert np.all(airy.feed_array == intf.feed_array)


def test_clone(airy):
intf = BeamInterface(airy)
intf_clone = intf.clone(beam_type="power")
assert intf != intf_clone


@pytest.mark.parametrize("uvbeam", [True, False], ids=["uvbeam", "analytic"])
@pytest.mark.parametrize("allow_mutation", [True, False], ids=["mutate", "nomutate"])
@pytest.mark.parametrize(
"include_cross_pols", [True, False, None], ids=["incx", "nox", "xpolnone"]
)
def test_as_power(
uvbeam: bool, allow_mutation: bool, include_cross_pols: bool, gaussian, gaussian_uv
):
beam = gaussian_uv if uvbeam else gaussian
intf = BeamInterface(beam)
intf_power = intf.as_power_beam(
allow_beam_mutation=allow_mutation, include_cross_pols=include_cross_pols
)
if include_cross_pols is None:
include_cross_pols = True

assert intf_power.beam_type == "power"
assert intf_power.Npols == 4 if include_cross_pols else 2

if uvbeam:
if allow_mutation:
assert intf.beam.beam_type == "power"
else:
assert intf.beam.beam_type == "efield"


def test_as_power_noop(airy):
"""Ensure that calling as_power_beam on a power beam is a no-op."""
intf = BeamInterface(airy, beam_type="power")
intf2 = intf.as_power_beam()
assert intf is intf2

with pytest.warns(UserWarning, match="as_power_beam does not modify cross pols"):
intf2 = intf.as_power_beam(include_cross_pols=False)
assert intf is intf2


@pytest.mark.parametrize("uvbeam", [True, False])
def test_with_feeds(uvbeam: bool, gaussian, gaussian_uv):
beam = gaussian_uv if uvbeam else gaussian

intf = BeamInterface(beam)

intf_feedx = intf.with_feeds(["x"])
assert intf_feedx.feed_array == ["x"]


def test_with_feeds_ordering(airy):
intf = BeamInterface(airy)

intf_feedx = intf.with_feeds(["y", "x"], maintain_ordering=True)
assert np.all(intf_feedx.feed_array == ["x", "y"])

intf_feedyx = intf.with_feeds(["y", "x"], maintain_ordering=False)
assert np.all(intf_feedyx.feed_array == ["y", "x"])


@pytest.mark.filterwarnings("ignore:Input beam is an efield UVBeam")
@pytest.mark.filterwarnings("ignore:Selected polarizations are not evenly spaced")
def test_with_feeds_ordering_power(gaussian_uv):
# beam = AiryBeam(diameter=14.0).to_uvbeam(freq_array=np.array([1e8]), nside=16)
intf = BeamInterface(gaussian_uv, beam_type="power")
intf_feedx = intf.with_feeds(["y", "x"], maintain_ordering=True)
assert np.all(intf_feedx.polarization_array == [-5, -6, -7, -8])

intf_feedyx = intf.with_feeds(["y", "x"], maintain_ordering=False)
assert np.all(intf_feedyx.polarization_array == [-6, -8, -7, -5])
Loading