Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Pyx datatypes #1357

Merged
merged 8 commits into from
Dec 7, 2023
10 changes: 5 additions & 5 deletions pyuvdata/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -791,7 +791,7 @@ def baseline_to_antnums(baseline, Nants_telescope):

return_array = isinstance(baseline, (np.ndarray, list, tuple))
ant1, ant2 = _utils.baseline_to_antnums(
np.ascontiguousarray(baseline, dtype=np.int64)
np.ascontiguousarray(baseline, dtype=np.uint64)
)
if return_array:
return ant1, ant2
Expand Down Expand Up @@ -846,8 +846,8 @@ def antnums_to_baseline(ant1, ant2, Nants_telescope, attempt256=False):

return_array = isinstance(ant1, (np.ndarray, list, tuple))
baseline = _utils.antnums_to_baseline(
np.ascontiguousarray(ant1, dtype=np.int64),
np.ascontiguousarray(ant2, dtype=np.int64),
np.ascontiguousarray(ant1, dtype=np.uint64),
np.ascontiguousarray(ant2, dtype=np.uint64),
attempt256=attempt256,
nants_less2048=nants_less2048,
)
Expand Down Expand Up @@ -6200,14 +6200,14 @@ def determine_rectangularity(
baseline_array = baseline_array.reshape((nbls, ntimes))
if np.sum(np.abs(np.diff(time_array, axis=0))) != 0:
return False, False
if np.sum(np.abs(np.diff(baseline_array, axis=1))) != 0:
if (np.diff(baseline_array, axis=1) != 0).any():
return False, False
return True, True
elif bl_first:
time_array = time_array.reshape((ntimes, nbls))
baseline_array = baseline_array.reshape((ntimes, nbls))
if np.sum(np.abs(np.diff(time_array, axis=1))) != 0:
return False, False
if np.sum(np.abs(np.diff(baseline_array, axis=0))) != 0:
if (np.diff(baseline_array, axis=0) != 0).any():
return False, False
return True, False
70 changes: 38 additions & 32 deletions pyuvdata/utils.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -36,8 +36,10 @@ e_squared = _e2
e_prime_squared = _ep2

ctypedef fused int_or_float:
numpy.uint64_t
numpy.int64_t
numpy.int32_t
numpy.uint32_t
numpy.float64_t
numpy.float32_t

Expand Down Expand Up @@ -68,8 +70,8 @@ cdef int_or_float arraymax(int_or_float[::1] array) nogil:
@cython.boundscheck(False)
@cython.wraparound(False)
cdef inline void _bl_to_ant_256(
numpy.int64_t[::1] _bl,
numpy.int64_t[:, ::1] _ants,
numpy.uint64_t[::1] _bl,
numpy.uint64_t[:, ::1] _ants,
long nbls,
):
cdef Py_ssize_t i
Expand All @@ -82,8 +84,8 @@ cdef inline void _bl_to_ant_256(
@cython.boundscheck(False)
@cython.wraparound(False)
cdef inline void _bl_to_ant_2048(
numpy.int64_t[::1] _bl,
numpy.int64_t[:, ::1] _ants,
numpy.uint64_t[::1] _bl,
numpy.uint64_t[:, ::1] _ants,
int nbls
):
cdef Py_ssize_t i
Expand All @@ -92,37 +94,41 @@ cdef inline void _bl_to_ant_2048(
_ants[0, i] = (_bl[i] - 2 ** 16 - (_ants[1, i])) // 2048
return

# defining these constants helps cython not cast the large
# numbers as python ints
cdef numpy.uint64_t bl_large = 2 ** 16 + 2 ** 22
cdef numpy.uint64_t large_mod = 2147483648

@cython.boundscheck(False)
@cython.wraparound(False)
@cython.cdivision(True)
cdef inline void _bl_to_ant_2147483648(
numpy.int64_t[::1] _bl,
numpy.int64_t[:, ::1] _ants,
numpy.uint64_t[::1] _bl,
numpy.uint64_t[:, ::1] _ants,
int nbls
):
cdef Py_ssize_t i
for i in range(nbls):
_ants[1, i] = (_bl[i] - 2 ** 16 - 2 ** 22) % 2147483648
_ants[0, i] = (_bl[i] - 2 ** 16 - 2 ** 22 - (_ants[1, i])) // 2147483648
_ants[1, i] = (_bl[i] - bl_large) % large_mod
_ants[0, i] = (_bl[i] - bl_large - (_ants[1, i])) // large_mod
return


@cython.boundscheck(False)
@cython.wraparound(False)
cpdef numpy.ndarray[dtype=numpy.int64_t, ndim=2] baseline_to_antnums(
numpy.int64_t[::1] _bl
cpdef numpy.ndarray[dtype=numpy.uint64_t, ndim=2] baseline_to_antnums(
numpy.uint64_t[::1] _bl
):
cdef numpy.int64_t _min = arraymin(_bl)
cdef bint use2147483648 = _min >= (2 ** 16 + 2 ** 22)
cdef bint use2048 = _min >= 2 ** 16
cdef numpy.uint64_t _min = arraymin(_bl)
cdef long nbls = _bl.shape[0]
cdef int ndim = 2
cdef numpy.npy_intp * dims = [2, <numpy.npy_intp> nbls]
cdef numpy.ndarray[ndim=2, dtype=numpy.int64_t] ants = numpy.PyArray_EMPTY(ndim, dims, numpy.NPY_INT64, 0)
cdef numpy.int64_t[:, ::1] _ants = ants
cdef numpy.ndarray[ndim=2, dtype=numpy.uint64_t] ants = numpy.PyArray_EMPTY(ndim, dims, numpy.NPY_UINT64, 0)
cdef numpy.uint64_t[:, ::1] _ants = ants

if use2147483648:
if _min >= (2 ** 16 + 2 ** 22):
_bl_to_ant_2147483648(_bl, _ants, nbls)
elif use2048:
elif _min >= 2 ** 16:
_bl_to_ant_2048(_bl, _ants, nbls)
else:
_bl_to_ant_256(_bl, _ants, nbls)
Expand All @@ -131,24 +137,24 @@ cpdef numpy.ndarray[dtype=numpy.int64_t, ndim=2] baseline_to_antnums(
@cython.boundscheck(False)
@cython.wraparound(False)
cdef inline void _antnum_to_bl_2147483648(
numpy.int64_t[::1] ant1,
numpy.int64_t[::1] ant2,
numpy.int64_t[::1] baselines,
numpy.uint64_t[::1] ant1,
numpy.uint64_t[::1] ant2,
numpy.uint64_t[::1] baselines,
int nbls,
):
cdef Py_ssize_t i

for i in range(nbls):
baselines[i] = 2147483648 * (ant1[i]) + (ant2[i]) + 2 ** 16 + 2 ** 22
baselines[i] = large_mod * (ant1[i]) + (ant2[i]) + bl_large
return


@cython.boundscheck(False)
@cython.wraparound(False)
cdef inline void _antnum_to_bl_2048(
numpy.int64_t[::1] ant1,
numpy.int64_t[::1] ant2,
numpy.int64_t[::1] baselines,
numpy.uint64_t[::1] ant1,
numpy.uint64_t[::1] ant2,
numpy.uint64_t[::1] baselines,
int nbls,
):
cdef Py_ssize_t i
Expand All @@ -160,9 +166,9 @@ cdef inline void _antnum_to_bl_2048(
@cython.boundscheck(False)
@cython.wraparound(False)
cdef inline void _antnum_to_bl_256(
numpy.int64_t[::1] ant1,
numpy.int64_t[::1] ant2,
numpy.int64_t[::1] baselines,
numpy.uint64_t[::1] ant1,
numpy.uint64_t[::1] ant2,
numpy.uint64_t[::1] baselines,
int nbls,
):
cdef Py_ssize_t i
Expand All @@ -172,17 +178,17 @@ cdef inline void _antnum_to_bl_256(
baselines[i] = 256 * (ant1[i]) + (ant2[i])
return

cpdef numpy.ndarray[dtype=numpy.int64_t] antnums_to_baseline(
numpy.int64_t[::1] ant1,
numpy.int64_t[::1] ant2,
cpdef numpy.ndarray[dtype=numpy.uint64_t] antnums_to_baseline(
numpy.uint64_t[::1] ant1,
numpy.uint64_t[::1] ant2,
bint attempt256=False,
bint nants_less2048=True
):
cdef int ndim = 1
cdef int nbls = ant1.shape[0]
cdef numpy.npy_intp * dims = [<numpy.npy_intp>nbls]
cdef numpy.ndarray[ndim=1, dtype=numpy.int64_t] baseline = numpy.PyArray_EMPTY(ndim, dims, numpy.NPY_INT64, 0)
cdef numpy.int64_t[::1] _bl = baseline
cdef numpy.ndarray[ndim=1, dtype=numpy.uint64_t] baseline = numpy.PyArray_EMPTY(ndim, dims, numpy.NPY_UINT64, 0)
cdef numpy.uint64_t[::1] _bl = baseline
cdef bint less255
cdef bint ants_less2048
# to ensure baseline numbers are unambiguous,
Expand Down
13 changes: 11 additions & 2 deletions pyuvdata/uvcal/initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -478,10 +478,19 @@ def new_uvcal_from_uvdata(
"ant_array", np.union1d(uvdata.ant_1_array, uvdata.ant_2_array)
)

# Just in case a user inputs their own ant_array kwarg
# make sure this is a numpy array for the following interactions
if not isinstance(ant_array, np.ndarray):
ant_array = np.asarray(ant_array)

if antenna_numbers is not None:
ant_array = np.intersect1d(ant_array, antenna_numbers)
ant_array = np.intersect1d(
ant_array, np.asarray(antenna_numbers, dtype=ant_array.dtype)
)
elif isinstance(antenna_positions, dict):
ant_array = np.intersect1d(ant_array, list(antenna_positions.keys()))
ant_array = np.intersect1d(
ant_array, np.asarray(list(antenna_positions.keys()), dtype=ant_array.dtype)
)

if jones_array is None:
if np.all(uvdata.polarization_array < -4):
Expand Down
6 changes: 6 additions & 0 deletions pyuvdata/uvcal/tests/test_initializers.py
Original file line number Diff line number Diff line change
Expand Up @@ -234,3 +234,9 @@ def test_new_uvcal_with_history(uvd_kw, uvc_only_kw):
uvd = new_uvdata(**uvd_kw)
uvc = new_uvcal_from_uvdata(uvd, history="my substring", **uvc_only_kw)
assert "my substring" in uvc.history


def test_new_uvcal_ant_array_list(uvd_kw, uvc_only_kw):
uvd = new_uvdata(**uvd_kw)
uvc = new_uvcal_from_uvdata(uvd, ant_array=[1, 2, 3], **uvc_only_kw)
assert np.array_equal(np.array([1, 2], dtype=np.uint64), uvc.ant_array)
4 changes: 2 additions & 2 deletions pyuvdata/uvdata/uvdata.py
Original file line number Diff line number Diff line change
Expand Up @@ -8854,7 +8854,7 @@ def upsample_in_time(

temp_Nblts = np.sum(n_new_samples)

temp_baseline = np.zeros((temp_Nblts,), dtype=np.int64)
temp_baseline = np.zeros((temp_Nblts,), dtype=np.uint64)
temp_id_array = np.zeros((temp_Nblts,), dtype=int)
if initial_nphase_ids > 1 and initial_driftscan:
temp_initial_ids = np.zeros((temp_Nblts,), dtype=int)
Expand Down Expand Up @@ -9240,7 +9240,7 @@ def downsample_in_time(
self.phase_to_time(phase_time)

# make temporary arrays
temp_baseline = np.zeros((temp_Nblts,), dtype=np.int64)
temp_baseline = np.zeros((temp_Nblts,), dtype=np.uint64)
temp_id_array = np.zeros((temp_Nblts,), dtype=int)
temp_time = np.zeros((temp_Nblts,))
temp_int_time = np.zeros((temp_Nblts,))
Expand Down
Loading