From 172ce41a05615f83d654c2f0faf4e9125c4390d4 Mon Sep 17 00:00:00 2001 From: apasarkar Date: Sat, 16 Sep 2023 07:20:26 +0800 Subject: [PATCH] Fixes split calculation bug and updates setup and workflow --- .github/workflows/dev.yml | 2 +- jnormcorre/motion_correction.py | 5 +++-- setup.py | 2 +- 3 files changed, 5 insertions(+), 4 deletions(-) diff --git a/.github/workflows/dev.yml b/.github/workflows/dev.yml index 7030418..fefc0cf 100644 --- a/.github/workflows/dev.yml +++ b/.github/workflows/dev.yml @@ -5,7 +5,7 @@ name: Python package on: push: - branches: [ "dev*"] + branches: [ "main"] jobs: build: diff --git a/jnormcorre/motion_correction.py b/jnormcorre/motion_correction.py index 9f57196..6ef1315 100644 --- a/jnormcorre/motion_correction.py +++ b/jnormcorre/motion_correction.py @@ -83,6 +83,7 @@ from functools import partial import time +import random import multiprocessing @@ -2720,7 +2721,7 @@ def motion_correction_piecewise(fname, splits, strides, overlaps, add_to_movie=0 shape_mov = (np.prod(dims), T) if num_splits is not None: num_splits = min(num_splits, len(idxs)) - idxs = np.array(idxs)[np.random.randint(0, len(idxs), num_splits)] + idxs = random.sample(idxs, num_splits) save_movie = False if save_movie: @@ -2745,7 +2746,7 @@ def motion_correction_piecewise(fname, splits, strides, overlaps, add_to_movie=0 pars = [] for idx in idxs: logging.debug('Processing: frames: {}'.format(idx)) - pars.append([fname, fname_tot, idx, shape_mov, template, strides, overlaps, max_shifts, np.array( + pars.append([fname, fname_tot, np.array(idx), shape_mov, template, strides, overlaps, max_shifts, np.array( add_to_movie, dtype=np.float32), max_deviation_rigid, upsample_factor_grid, newoverlaps, newstrides, nonneg_movie, is_fiji, var_name_hdf5, indices, filter_kernel]) diff --git a/setup.py b/setup.py index 529fa43..1edb00a 100644 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ version="0.0.6", description="Jax-accelerated implementation of normcorre", packages=setuptools.find_packages(), - install_requires=["future","numpy", "scipy", "h5py", "tqdm", "matplotlib", "opencv-python", "tifffile", "typing", "torch", "pynwb", "pillow", "scikit-image", "jax", "jaxlib"], + install_requires=["future","numpy", "scipy", "h5py", "tqdm", "matplotlib", "opencv-python", "tifffile", "typing", "torch", "pynwb", "pillow", "scikit-image", "jax", "jaxlib", "pytest"], classifiers=( "Programming Language :: Python :: 3", ),