Skip to content

Commit

Permalink
fix: profiling cli
Browse files Browse the repository at this point in the history
  • Loading branch information
steven-murray committed Dec 13, 2023
1 parent 1dfb360 commit 92ee8f7
Show file tree
Hide file tree
Showing 2 changed files with 119 additions and 47 deletions.
159 changes: 114 additions & 45 deletions src/matvis/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,51 +47,7 @@
main = click.Group()


@main.command()
@click.option("-A/-I", "--analytic-beam/--interpolated-beam", default=True)
@click.option("-f", "--nfreq", default=1)
@click.option(
"-t",
"--ntimes",
default=1,
)
@click.option(
"-a",
"--nants",
default=1,
)
@click.option(
"-b",
"--nbeams",
default=1,
)
@click.option(
"-s",
"--nsource",
default=1,
)
@click.option(
"-g/-c",
"--gpu/--cpu",
default=False,
)
@click.option(
"--method",
default="MatMul",
type=click.Choice(["MatMul", "VectorDot"]),
)
@click.option(
"-v/-V", "--verbose/--not-verbose", default=False, help="Print verbose output"
)
@click.option(
"-l",
"--log-level",
default="INFO",
type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]),
)
@click.option("-o", "--outdir", default=".")
@click.option("--double-precision/--single-precision", default=True)
def profile(
def run_profile(
analytic_beam,
nfreq,
ntimes,
Expand All @@ -104,6 +60,7 @@ def profile(
verbose,
log_level,
method,
pairs=None,
):
"""Run the script."""
if not HAVE_GPU and gpu:
Expand Down Expand Up @@ -134,6 +91,7 @@ def profile(
print(f" DOUBLE-PRECISION: {double_precision:>7}")
print(f" ANALYTIC-BEAM: {analytic_beam:>7}")
print(f" METHOD: {method:>7}")
print(f" NPAIRS: {len(pairs) if pairs is not None else nants**2:>7}")
print("---------------------------------")

if gpu:
Expand All @@ -156,6 +114,7 @@ def profile(
use_gpu=gpu,
beam_idx=beam_idx,
matprod_method=f"{'GPU' if gpu else 'CPU'}{method}",
antpairs=pairs,
)

outdir = Path(outdir).expanduser().absolute()
Expand Down Expand Up @@ -183,6 +142,116 @@ def profile(
pickle.dump(thing_stats, fl)


common_profile_options = [
click.option("-A/-I", "--analytic-beam/--interpolated-beam", default=True),
click.option("-f", "--nfreq", default=1),
click.option(
"-t",
"--ntimes",
default=1,
),
click.option(
"-b",
"--nbeams",
default=1,
),
click.option(
"-g/-c",
"--gpu/--cpu",
default=False,
),
click.option(
"--method",
default="MatMul",
type=click.Choice(["MatMul", "VectorDot"]),
),
click.option(
"-v/-V", "--verbose/--not-verbose", default=False, help="Print verbose output"
),
click.option(
"-l",
"--log-level",
default="INFO",
type=click.Choice(["DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]),
),
click.option("-o", "--outdir", default="."),
click.option("--double-precision/--single-precision", default=True),
]


def add_common_options(func):
"""Add common profiling options to a function."""
for option in reversed(common_profile_options):
func = option(func)
return func


@main.command()
@click.option(
"-s",
"--nsource",
default=1,
)
@click.option(
"-a",
"--nants",
default=1,
)
@add_common_options
def profile(**kwargs):
"""Run the script."""
run_profile(**kwargs)


def get_redundancies(bls, ndecimals: int = 2):
"""Find redundant baselines."""
uvbins = set()
pairs = []

# Everything here is in wavelengths
bls = np.round(bls, decimals=ndecimals)
nant = bls.shape[0]

# group redundant baselines
for i in range(nant):
for j in range(i + 1, nant):
u, v = bls[i, j]
if (u, v) not in uvbins and (-u, -v) not in uvbins:
uvbins.add((u, v))
pairs.append([i, j])

return pairs


@main.command()
@click.option(
"-a",
"--hex-num",
default=11,
)
@click.option(
"-s",
"--nside",
default=64,
)
@click.option("-k", "--keep-ants", type=str, default="")
@click.option("--outriggers/--no-outriggers", default=False)
@add_common_options
def hera_profile(hex_num, nside, keep_ants, outriggers, **kwargs):
"""Run profiling of matvis with a HERA-like array."""
from py21cmsense.antpos import hera

antpos = hera(hex_num=hex_num, split_core=True, outriggers=2 if outriggers else 0)
if keep_ants:
keep_ants = [int(i) for i in keep_ants.split(",")]
antpos = antpos[keep_ants]

bls = antpos[np.newaxis, :, :2] - antpos[:, np.newaxis, :2]
pairs = np.array(get_redundancies(bls.value))

run_profile(nsource=12 * nside**2, nants=antpos.shape[0], pairs=pairs, **kwargs)


def get_line_based_stats(lstats) -> tuple[dict, float]:
"""Convert the line-number based stats into line-based stats."""
time_unit = lstats.unit
Expand Down
7 changes: 5 additions & 2 deletions src/matvis/wrapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ def simulate_vis(
use_gpu: bool = False,
beam_spline_opts: dict | None = None,
beam_idx: np.ndarray | None = None,
antpairs: np.ndarray | list[tuple[int, int]] | None = None,
**backend_kwargs,
):
"""
Expand Down Expand Up @@ -126,12 +127,13 @@ def simulate_vis(
for beam in beams
]

npairs = len(antpairs) if antpairs is not None else nants * nants
if polarized:
vis = np.zeros(
(freqs.size, lsts.size, nfeeds, nfeeds, nants * nants), dtype=complex_dtype
(freqs.size, lsts.size, nfeeds, nfeeds, npairs), dtype=complex_dtype
)
else:
vis = np.zeros((freqs.size, lsts.size, nants * nants), dtype=complex_dtype)
vis = np.zeros((freqs.size, lsts.size, npairs), dtype=complex_dtype)

# Loop over frequencies and call matvis_cpu/gpu
for i, freq in enumerate(freqs):
Expand All @@ -146,6 +148,7 @@ def simulate_vis(
polarized=polarized,
beam_spline_opts=beam_spline_opts,
beam_idx=beam_idx,
antpairs=antpairs,
**backend_kwargs,
)
return vis

0 comments on commit 92ee8f7

Please sign in to comment.