Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: new convenience methods on BeamInterface #1496

Merged
merged 8 commits into from
Nov 17, 2024
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ All notable changes to this project will be documented in this file.
- Added support for partial read for MWA correlator FITS files.
- Added `antenna_names`, `time_range`, `lsts` and `lst_range` parameters to
`UVFlag.select` to match UVData and UVCal select methods.
- New convenience methods on `BeamInterface` to simplify the handling of analytic vs
UVBeam objects.
bhazelton marked this conversation as resolved.
Show resolved Hide resolved

### Changed
- Made it possible to *not* return the `interp_basis_vector` array from beam
Expand Down
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ select = [
]
ignore = [
"N806", # non-lowercase variable (we use N* for axes lengths)
"N802", # non-lowercase function name (we use N* for axes lengths)
bhazelton marked this conversation as resolved.
Show resolved Hide resolved
"B028", # no-explicit-stacklevel for warnings
"SIM108", # prefer ternary opperators. I find them difficult to read.
"D203", # one-blank-line-before-class. we use two.
Expand Down
6 changes: 5 additions & 1 deletion src/pyuvdata/analytic_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import dataclasses
import importlib
import warnings
from dataclasses import InitVar, astuple, dataclass, field
from dataclasses import InitVar, astuple, dataclass, field, replace
from typing import ClassVar, Literal

import numpy as np
Expand Down Expand Up @@ -189,6 +189,10 @@ def Npols(self): # noqa N802
"""The number of polarizations."""
return self.polarization_array.size

def clone(self, **kw):
"""Create a new instance of the object with updated parameters."""
return replace(self, **kw)

@property
def east_ind(self):
"""The index of the east feed in the feed array."""
Expand Down
115 changes: 111 additions & 4 deletions src/pyuvdata/beam_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,13 +7,15 @@

import copy
import warnings
from dataclasses import InitVar, dataclass
from dataclasses import InitVar, asdict, dataclass, replace
from itertools import product
from typing import Literal

import numpy as np
import numpy.typing as npt

from .analytic_beam import AnalyticBeam
from .utils import pol as upol
from .uvbeam import UVBeam

# Other methods we may want to include:
Expand All @@ -32,8 +34,9 @@ class BeamInterface:

Attributes
----------
beam : pyuvdata.UVBeam or pyuvdata.AnalyticBeam
Beam object to use for computations
beam : pyuvdata.UVBeam or pyuvdata.AnalyticBeam or BeamInterface
Beam object to use for computations. If a BeamInterface is passed, a new
view of the same object is created.
beam_type : str
The beam type, either "efield" or "power".
include_cross_pols : bool
Expand All @@ -59,6 +62,11 @@ def __post_init__(self, include_cross_pols: bool):
for the power beam.

"""
if isinstance(self.beam, BeamInterface):
self.beam = self.beam.beam
self.__post_init__(include_cross_pols=include_cross_pols)
return

if not isinstance(self.beam, UVBeam) and not issubclass(
type(self.beam), AnalyticBeam
):
Expand All @@ -69,7 +77,7 @@ def __post_init__(self, include_cross_pols: bool):
if isinstance(self.beam, UVBeam):
if self.beam_type is None or self.beam_type == self.beam.beam_type:
self.beam_type = self.beam.beam_type
elif self.beam_type == "power":
elif self.beam_type == "power" and self.beam.beam_type != "power":
bhazelton marked this conversation as resolved.
Show resolved Hide resolved
warnings.warn(
"Input beam is an efield UVBeam but beam_type is specified as "
"'power'. Converting efield beam to power."
Expand All @@ -82,6 +90,105 @@ def __post_init__(self, include_cross_pols: bool):
"efield beam, either provide an efield UVBeam or do not "
"specify `beam_type`."
)
elif self.beam_type is None:
self.beam_type = "efield"

@property
def Npols(self):
"""The number of polarizations in the beam."""
return self.beam.Npols or len(self.polarization_array)

@property
def polarization_array(self):
"""The polarizations defined on the beam."""
return self.beam.polarization_array

@property
def feed_array(self):
"""The feeds for which the beam is defined."""
return self.beam.feed_array

@property
def Nfeeds(self):
"""The number of feeds defined on the beam."""
return self.beam.Nfeeds or len(self.feed_array)

def clone(self, **kw):
"""Return a new instance with updated parameters."""
return replace(self, **kw)

def as_power_beam(
self, include_cross_pols: bool = True, allow_beam_mutation: bool = False
):
"""Return a new interface instance that is in the power-beam mode.

If already in the power-beam mode, this is a no-op. Note that this might be
slighty unexpected, because the effect of `include_cross_pols` is not accounted
for in this case.
bhazelton marked this conversation as resolved.
Show resolved Hide resolved

Parameters
----------
include_cross_pols : bool, optional
Whether to include cross-pols in the power beam.
allow_beam_mutation : bool, optional
Whether to allow the underlying beam to be updated in-place.
"""
if self.beam_type == "power":
return self

beam = self.beam if allow_beam_mutation else copy.deepcopy(self.beam)

# We cannot simply use .clone() here, because we need to be able to pass
# include_cross_pols, which can only be passed to the constructor proper.
this = asdict(self)
this["beam"] = beam
this["beam_type"] = "power"
this["include_cross_pols"] = include_cross_pols
with warnings.catch_warnings():
# Don't emit the warning that we're converting to power, because that is
# explicitly desired.
warnings.simplefilter("ignore", UserWarning)
return BeamInterface(**this)

def with_feeds(self, feeds, *, maintain_ordering: bool = True):
"""Return a new interface instance with updated feed_array.

Parameters
----------
feeds : array_like of str
The feeds to keep in the beam. Each value should be a string, e.g. 'n', 'x'.
maintain_ordering : bool, optional
If True, maintain the same polarization ordering as in the beam currently.
If False, change ordering to match the input feeds, which are turned into
pols (if a power beam) by using product(feeds, feeds).
"""
if not self._isuvbeam:
if maintain_ordering:
feeds = [fd for fd in self.feed_array if fd in feeds]
return self.clone(beam=self.beam.clone(feed_array=feeds))
if self.beam_type == "power":
# Down-select polarizations based on the feeds input.
possible_pols = [f1 + f2 for f1, f2 in product(feeds, feeds)]
possible_pol_ints = upol.polstr2num(
possible_pols, x_orientation=self.beam.x_orientation
)
if maintain_ordering:
use_pols = [
p for p in self.beam.polarization_array if p in possible_pol_ints
]
print("use_pols: ", use_pols)
bhazelton marked this conversation as resolved.
Show resolved Hide resolved
else:
use_pols = [
p for p in possible_pol_ints if p in self.beam.polarization_array
]

new_beam = self.beam.select(polarizations=use_pols, inplace=False)
else:
if maintain_ordering:
feeds = [fd for fd in self.feed_array if fd in feeds]

new_beam = self.beam.select(feeds=feeds, inplace=False)
return self.clone(beam=new_beam)

@property
def _isuvbeam(self):
Expand Down
6 changes: 6 additions & 0 deletions tests/test_analytic_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,3 +568,9 @@ def test_single_feed():
beam = GaussianBeam(diameter=14.0, feed_array=["x"], include_cross_pols=True)
assert beam.feed_array == ["x"]
assert beam.polarization_array == [-5]


def test_clone():
beam = GaussianBeam(diameter=14.0, feed_array=["x", "y"])
new_beam = beam.clone(feed_array=["x"])
assert new_beam.feed_array == ["x"]
100 changes: 100 additions & 0 deletions tests/test_beam_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,11 +11,22 @@
GaussianBeam,
ShortDipoleBeam,
UniformBeam,
UVBeam,
utils,
)
from pyuvdata.testing import check_warnings


@pytest.fixture(scope="function")
def airy() -> AiryBeam:
return AiryBeam(diameter=14.0)


@pytest.fixture()
def gaussian() -> GaussianBeam:
return GaussianBeam(diameter=14.0)


@pytest.fixture()
def xy_grid_coarse():
nfreqs = 5
Expand All @@ -34,6 +45,14 @@ def xy_grid_coarse():
return az_array, za_array, freqs


@pytest.fixture()
def gaussian_uv(gaussian, az_za_coords) -> UVBeam:
az, za = az_za_coords
return gaussian.to_uvbeam(
axis1_array=az, axis2_array=za, freq_array=np.array([1e8])
)


@pytest.mark.parametrize(
["beam_obj", "kwargs"],
[
Expand Down Expand Up @@ -243,3 +262,84 @@ def test_compute_response_errors(param, value):
else:
# this shouldn't error
bi_uvb.compute_response(**compute_kwargs)


@pytest.mark.parametrize("beam_obj", ["airy", "gaussian", "gaussian_uv"])
def test_idempotent_instantiation(beam_obj, request):
beam = BeamInterface(request.getfixturevalue(beam_obj))
beam2 = BeamInterface(beam)
assert beam == beam2


def test_properties(airy: AiryBeam):
intf = BeamInterface(airy)
assert airy.Npols == intf.Npols
assert airy.Nfeeds == intf.Nfeeds
assert np.all(airy.polarization_array == intf.polarization_array)
assert np.all(airy.feed_array == intf.feed_array)


def test_clone(airy):
intf = BeamInterface(airy)
intf_clone = intf.clone(beam_type="power")
assert intf != intf_clone


@pytest.mark.parametrize("uvbeam", [True, False], ids=["uvbeam", "analytic"])
@pytest.mark.parametrize("allow_mutation", [True, False], ids=["mutate", "nomutate"])
@pytest.mark.parametrize("include_cross_pols", [True, False], ids=["incx", "nox"])
def test_as_power(
uvbeam: bool, allow_mutation: bool, include_cross_pols: bool, gaussian, gaussian_uv
):
beam = gaussian_uv if uvbeam else gaussian
intf = BeamInterface(beam)
intf_power = intf.as_power_beam(
allow_beam_mutation=allow_mutation, include_cross_pols=include_cross_pols
)
assert intf_power.beam_type == "power"
assert intf_power.Npols == 4 if include_cross_pols else 2

if uvbeam:
if allow_mutation:
assert intf.beam.beam_type == "power"
else:
assert intf.beam.beam_type == "efield"


def test_as_power_noop(airy):
"""Ensure that calling as_power_beam on a power beam is a no-op."""
intf = BeamInterface(airy, beam_type="power")
intf2 = intf.as_power_beam()
assert intf is intf2


@pytest.mark.parametrize("uvbeam", [True, False])
def test_with_feeds(uvbeam: bool, gaussian, gaussian_uv):
beam = gaussian_uv if uvbeam else gaussian

intf = BeamInterface(beam)

intf_feedx = intf.with_feeds(["x"])
assert intf_feedx.feed_array == ["x"]


def test_with_feeds_ordering(airy):
intf = BeamInterface(airy)

intf_feedx = intf.with_feeds(["y", "x"], maintain_ordering=True)
assert np.all(intf_feedx.feed_array == ["x", "y"])

intf_feedyx = intf.with_feeds(["y", "x"], maintain_ordering=False)
assert np.all(intf_feedyx.feed_array == ["y", "x"])


@pytest.mark.filterwarnings("ignore:Input beam is an efield UVBeam")
@pytest.mark.filterwarnings("ignore:Selected polarizations are not evenly spaced")
def test_with_feeds_ordering_power(gaussian_uv):
# beam = AiryBeam(diameter=14.0).to_uvbeam(freq_array=np.array([1e8]), nside=16)
intf = BeamInterface(gaussian_uv, beam_type="power")
intf_feedx = intf.with_feeds(["y", "x"], maintain_ordering=True)
assert np.all(intf_feedx.polarization_array == [-5, -6, -7, -8])

intf_feedyx = intf.with_feeds(["y", "x"], maintain_ordering=False)
assert np.all(intf_feedyx.polarization_array == [-6, -8, -7, -5])
Loading