Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

LazyProxy and LazyProxyMultiton patterns #269

Merged
merged 12 commits into from
Feb 3, 2022
5 changes: 5 additions & 0 deletions HISTORY.rst
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,11 @@
History
=======

X.Y.Z (YYYY-MM-DD)
------------------
* Add Multiton, LazyProxy and LazyProxyMultiton patterns (:pr:`269`)


0.3.2 (2022-13-01)
------------------
* Support numba >= 0.54 (:pr:`264`)
Expand Down
43 changes: 40 additions & 3 deletions africanus/experimental/rime/fused/specification.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import ast
from importlib import import_module
import inspect
from itertools import groupby # noqa
import multiprocessing
from pathlib import Path
import re

Expand All @@ -11,6 +11,7 @@
from africanus.experimental.rime.fused import terms as term_mod
from africanus.experimental.rime.fused.transformers.core import Transformer
from africanus.experimental.rime.fused import transformers as transformer_mod
from africanus.util.patterns import LazyProxy


TERM_STRING_REGEX = re.compile("([A-Z])(pq|p|q)")
Expand Down Expand Up @@ -232,9 +233,17 @@ def __init__(self, specification, terms=None, transformers=None):
except KeyError as e:
raise RimeSpecificationError(f"Can't find a type for {str(e)}")

Pool = multiprocessing.get_context("spawn").Pool
pool = LazyProxy((Pool, RimeSpecification._finalise_pool), 4)

# Create the terms
terms = []
global_kw = {"corrs": corrs, "stokes": stokes, "feed_type": feed_type}
global_kw = {
"corrs": corrs,
"stokes": stokes,
"feed_type": feed_type,
"process_pool": pool
}

for cls, cfg in zip(term_types, term_cfgs):
if cfg == "pq":
Expand Down Expand Up @@ -284,8 +293,36 @@ def __init__(self, specification, terms=None, transformers=None):
raise RimeSpecificationError(
"RIME must at least contain a Brightness term")

transformers = []

for cls in transformer_types.values():
init_sig = inspect.signature(cls.__init__)
cls_kw = {}

for a, p in list(init_sig.parameters.items())[1:]:
if p.kind not in {p.POSITIONAL_ONLY,
p.POSITIONAL_OR_KEYWORD}:
raise RimeSpecification(
f"{cls}.__init__{init_sig} may not contain "
f"*args or **kwargs")

try:
cls_kw[a] = available_kw[a]
except KeyError:
raise RimeSpecificationError(
f"{cls}.__init__{init_sig} wants argument {a} "
f"but it is not available. "
f"Available args: {available_kw}")

transformer = cls(**cls_kw)
transformers.append(transformer)

self.terms = terms
self.transformers = [cls() for cls in transformer_types.values()]
self.transformers = transformers

@staticmethod
def _finalise_pool(pool):
pool.terminate()

@staticmethod
def _feed_type(corrs):
Expand Down
3 changes: 3 additions & 0 deletions africanus/experimental/rime/fused/transformers/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,9 @@ def __new__(mcls, name, bases, namespace):


class Transformer(metaclass=TransformerMetaClass):
def __init__(self):
pass

def __repr__(self):
return self.__class__.__name__

Expand Down
6 changes: 5 additions & 1 deletion africanus/experimental/rime/fused/transformers/parangle.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,9 @@
class ParallacticTransformer(Transformer):
OUTPUTS = ["feed_parangle", "beam_parangle"]

def __init__(self, process_pool):
self.pool = process_pool

def init_fields(self, typingctx,
utime, ufeed, uantenna,
antenna_position, phase_dir,
Expand All @@ -29,7 +32,8 @@ def init_fields(self, typingctx,
@njit(inline="never")
def parangle_stub(time, antenna, phase_dir):
with objmode(out=parangle_dt):
out = casa_parallactic_angles(time, antenna, phase_dir)
out = self.pool.apply(casa_parallactic_angles,
(time, antenna, phase_dir))

return out

Expand Down
9 changes: 3 additions & 6 deletions africanus/rime/jax/phase.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@


try:
import jax.numpy as np
import jax.numpy as jnp
except ImportError as e:
opt_import_error = e
else:
Expand All @@ -14,11 +14,8 @@

@requires_optional('jax', opt_import_error)
def phase_delay(lm, uvw, frequency):
out_dtype = np.result_type(lm, uvw, frequency, np.complex64)

one = lm.dtype.type(1.0)
neg_two_pi_over_c = lm.dtype.type(minus_two_pi_over_c)
complex_one = out_dtype.type(1j)

l = lm[:, 0, None, None] # noqa
m = lm[:, 1, None, None]
Expand All @@ -27,10 +24,10 @@ def phase_delay(lm, uvw, frequency):
v = uvw[None, :, 1, None]
w = uvw[None, :, 2, None]

n = np.sqrt(one - l**2 - m**2) - one
n = jnp.sqrt(one - l**2 - m**2) - one

real_phase = (neg_two_pi_over_c *
(l * u + m * v + n * w) *
frequency[None, None, :])

return np.exp(complex_one*real_phase)
return jnp.exp(jnp.complex64(1j)*real_phase)
15 changes: 7 additions & 8 deletions africanus/rime/jax/tests/test_jax_phase_delay.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,28 @@
# -*- coding: utf-8 -*-


import numpy as onp
import numpy as np
import pytest

from africanus.rime.phase import phase_delay as np_phase_delay
from africanus.rime.jax.phase import phase_delay


@pytest.mark.parametrize("dtype", [onp.float32, onp.float64])
@pytest.mark.parametrize("dtype", [np.float32, np.float64])
def test_jax_phase_delay(dtype):
jax = pytest.importorskip('jax')
np = pytest.importorskip('jax.numpy')

onp.random.seed(0)
np.random.seed(0)

uvw = onp.random.random(size=(100, 3)).astype(dtype)
lm = onp.random.random(size=(10, 2)).astype(dtype)*0.001
frequency = onp.linspace(.856e9, .856e9*2, 64).astype(dtype)
uvw = np.random.random(size=(100, 3)).astype(dtype)
lm = np.random.random(size=(10, 2)).astype(dtype)*0.001
frequency = np.linspace(.856e9, .856e9*2, 64).astype(dtype)

# Compute complex phase
np_complex_phase = np_phase_delay(lm, uvw, frequency)
complex_phase = jax.jit(phase_delay)(lm, uvw, frequency)

onp.testing.assert_array_almost_equal(complex_phase, np_complex_phase)
np.testing.assert_array_almost_equal(complex_phase, np_complex_phase)
expected_ctype = np.result_type(dtype, np.complex64)
assert np_complex_phase.dtype == expected_ctype
assert complex_phase.dtype == expected_ctype
Loading