Skip to content

Commit

Permalink
feat: new convenience methods on BeamInterface
Browse files Browse the repository at this point in the history
  • Loading branch information
steven-murray committed Nov 8, 2024
1 parent 7f633a2 commit d124ef1
Show file tree
Hide file tree
Showing 5 changed files with 140 additions and 4 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,7 @@ select = [
]
ignore = [
"N806", # non-lowercase variable (we use N* for axes lengths)
"N802", # non-lowercase function name (we use N* for axes lengths)
"B028", # no-explicit-stacklevel for warnings
"SIM108", # prefer ternary opperators. I find them difficult to read.
"D203", # one-blank-line-before-class. we use two.
Expand Down
6 changes: 5 additions & 1 deletion src/pyuvdata/analytic_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
import dataclasses
import importlib
import warnings
from dataclasses import InitVar, astuple, dataclass, field
from dataclasses import InitVar, astuple, dataclass, field, replace
from typing import ClassVar, Literal

import numpy as np
Expand Down Expand Up @@ -189,6 +189,10 @@ def Npols(self): # noqa N802
"""The number of polarizations."""
return self.polarization_array.size

def clone(self, **kw):
"""Create a new instance of the object with updated parameters."""
return replace(self, **kw)

@property
def east_ind(self):
"""The index of the east feed in the feed array."""
Expand Down
65 changes: 62 additions & 3 deletions src/pyuvdata/beam_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

import copy
import warnings
from dataclasses import InitVar, dataclass
from dataclasses import InitVar, asdict, dataclass, replace
from typing import Literal

import numpy as np
Expand All @@ -32,8 +32,9 @@ class BeamInterface:
Attributes
----------
beam : pyuvdata.UVBeam or pyuvdata.AnalyticBeam
Beam object to use for computations
beam : pyuvdata.UVBeam or pyuvdata.AnalyticBeam or BeamInterface
Beam object to use for computations. If a BeamInterface is passed, a new
view of the same object is created.
beam_type : str
The beam type, either "efield" or "power".
include_cross_pols : bool
Expand All @@ -59,6 +60,11 @@ def __post_init__(self, include_cross_pols: bool):
for the power beam.
"""
if isinstance(self.beam, BeamInterface):
self.beam = self.beam.beam
self.__post_init__(include_cross_pols=include_cross_pols)
return

if not isinstance(self.beam, UVBeam) and not issubclass(
type(self.beam), AnalyticBeam
):
Expand All @@ -83,6 +89,59 @@ def __post_init__(self, include_cross_pols: bool):
"specify `beam_type`."
)

@property
def Npols(self):
"""The number of polarizations in the beam."""
return self.beam.Npols

@property
def polarization_array(self):
"""The polarizations defined on the beam."""
return self.beam.polarization_array

@property
def feed_array(self):
"""The feeds for which the beam is defined."""
return self.beam.feed_array

@property
def Nfeeds(self):
"""The number of feeds defined on the beam."""
return self.beam.Nfeeds

def clone(self, **kw):
"""Return a new instance with updated parameters."""
return replace(self, **kw)

def as_power_beam(
self, include_cross_pols: bool = True, allow_beam_mutation: bool = False
):
"""Return a new interface instance that is in the power-beam mode.
Parameters
----------
include_cross_pols : bool, optional
Whether to include cross-pols in the power beam.
allow_beam_mutation : bool, optional
Whether to allow the underlying beam to be updated in-place.
"""
beam = self.beam if allow_beam_mutation else copy.deepcopy(self.beam)

# We cannot simply use .clone() here, because we need to be able to pass
# include_cross_pols, which can only be passed to the constructor proper.
this = asdict(self)
this["beam"] = beam
this["beam_type"] = "power"
this["include_cross_pols"] = include_cross_pols
return BeamInterface(**this)

def with_feeds(self, feeds):
"""Return a new interface instance with updated feed_array."""
if not self._isuvbeam:
return self.clone(beam=self.beam.clone(feed_array=feeds))
new_beam = self.beam.select(feeds=feeds, inplace=False)
return self.clone(beam=new_beam)

@property
def _isuvbeam(self):
return isinstance(self.beam, UVBeam)
Expand Down
6 changes: 6 additions & 0 deletions tests/test_analytic_beam.py
Original file line number Diff line number Diff line change
Expand Up @@ -568,3 +568,9 @@ def test_single_feed():
beam = GaussianBeam(diameter=14.0, feed_array=["x"], include_cross_pols=True)
assert beam.feed_array == ["x"]
assert beam.polarization_array == [-5]


def test_clone():
beam = GaussianBeam(diameter=14.0, feed_array=["x", "y"])
new_beam = beam.clone(feed_array=["x"])
assert new_beam.feed_array == ["x"]
66 changes: 66 additions & 0 deletions tests/test_beam_interface.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,3 +243,69 @@ def test_compute_response_errors(param, value):
else:
# this shouldn't error
bi_uvb.compute_response(**compute_kwargs)


@pytest.mark.parametrize(
"beam_obj",
[
AiryBeam(diameter=14.0),
GaussianBeam(diameter=14.0),
AiryBeam(diameter=14.0).to_uvbeam(
freq_array=np.array([1e8]), pixel_coordinate_system="healpix", nside=32
),
],
)
def test_idempotent_instantiation(beam_obj):
beam = BeamInterface(beam_obj)
beam2 = BeamInterface(beam)
assert beam == beam2


def test_properties():
beam = AiryBeam(diameter=14.0)
intf = BeamInterface(beam)
assert beam.Npols == intf.Npols
assert beam.Nfeeds == intf.Nfeeds
assert np.all(beam.polarization_array == intf.polarization_array)
assert np.all(beam.feed_array == intf.feed_array)


def test_clone():
beam = AiryBeam(diameter=14.0)
intf = BeamInterface(beam)
intf_clone = intf.clone(beam_type="power")
assert intf != intf_clone


@pytest.mark.parametrize("uvbeam", [True, False], ids=["uvbeam", "analytic"])
@pytest.mark.parametrize("allow_mutation", [True, False], ids=["mutate", "nomutate"])
@pytest.mark.parametrize("include_cross_pols", [True, False], ids=["incx", "nox"])
def test_astype(uvbeam: bool, allow_mutation: bool, include_cross_pols: bool):
beam = AiryBeam(diameter=14.0)
if uvbeam:
beam = beam.to_uvbeam(freq_array=np.array([1e8]), nside=32)

intf = BeamInterface(beam)
intf_power = intf.as_power_beam(
allow_beam_mutation=allow_mutation, include_cross_pols=include_cross_pols
)
assert intf_power.beam_type == "power"
assert intf_power.Npols == 4 if include_cross_pols else 2

# Ensure that the original beam is not mutated unless we say it can be.
if uvbeam:
if allow_mutation:
assert intf.beam.beam_type == "power"
else:
assert intf.beam.beam_type == "efield"


@pytest.mark.parametrize("uvbeam", [True, False])
def test_with_feeds(uvbeam: bool):
beam = AiryBeam(diameter=14.0)
if uvbeam:
beam = beam.to_uvbeam(freq_array=np.array([1e8]), nside=32)

intf = BeamInterface(beam)
intf_feedx = intf.with_feeds(["x"])
assert intf_feedx.feed_array == ["x"]

0 comments on commit d124ef1

Please sign in to comment.