Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: use a very large coil dimension for 2D stacked NUFFT #39

Merged
merged 15 commits into from
Oct 17, 2023
Merged
1 change: 1 addition & 0 deletions .github/workflows/test-ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,7 @@ jobs:
cmake -DFINUFFT_USE_CUDA=1 ../ && cmake --build . && cp libcufinufft.so ../python/cufinufft/.
# enter venv
source $RUNNER_WORKSPACE/venv/bin/activate
pip install cupy-cuda11x
cd $RUNNER_WORKSPACE/finufft/python/cufinufft
python setup.py develop
# FIXME: This is hardcoded
Expand Down
56 changes: 24 additions & 32 deletions src/mrinufft/operators/interfaces/cufinufft.py
Original file line number Diff line number Diff line change
Expand Up @@ -368,7 +368,8 @@ def _op_sense_host(self, data, ksp=None):
coil_img_d = cp.empty((T, *XYZ), dtype=self.cpx_dtype)
dataf = data.reshape((B, *XYZ))
data_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype)
ksp = ksp or np.empty((B, C, K), dtype=self.cpx_dtype)
if ksp is None:
ksp = np.empty((B, C, K), dtype=self.cpx_dtype)
ksp = ksp.reshape((B * C, K))
ksp_batched = cp.empty((T, K), dtype=self.cpx_dtype)

Expand Down Expand Up @@ -408,8 +409,9 @@ def _op_calibless_host(self, data, ksp=None):

coil_img_d = cp.empty(np.prod(XYZ) * T, dtype=self.cpx_dtype)
ksp_d = cp.empty((T, K), dtype=self.cpx_dtype)

ksp = np.zeros((B * C, K), dtype=self.cpx_dtype)
if ksp is None:
ksp = np.zeros((B * C, K), dtype=self.cpx_dtype)
ksp = ksp.reshape((B * C, K))
# TODO: Add concurrency compute batch n while copying batch n+1 to device
# and batch n-1 to host
dataf = data.flatten()
Expand Down Expand Up @@ -504,13 +506,13 @@ def _adj_op_sense_host(self, coeffs, img_d=None):
# Define short name
T, B, C = self.n_trans, self.n_batchs, self.n_coils
K, XYZ = self.n_samples, self.shape

coeffs_f = coeffs.flatten()
# Allocate memory
coil_img_d = cp.empty((T, *XYZ), dtype=self.cpx_dtype)
if img_d is None:
img_d = cp.zeros((B, *XYZ), dtype=self.cpx_dtype)

smaps_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype)
coeffs_f = coeffs.flatten()
ksp_batched = cp.empty((T, K), dtype=self.cpx_dtype)
if self.uses_density:
density_batched = cp.repeat(self.density[None, :], T, axis=0)
Expand All @@ -533,23 +535,16 @@ def _adj_op_sense_host(self, coeffs, img_d=None):
return img

def _adj_op_calibless_device(self, coeffs, img_d=None):
T, B, C = self.n_trans, self.n_batchs, self.n_coils
K, XYZ = self.n_samples, self.shape
coeffs_f = coeffs.flatten()
n_trans_samples = self.n_trans * self.n_samples
ksp_batched = cp.empty(n_trans_samples, dtype=self.cpx_dtype)
ksp_batched = cp.empty(T * K, dtype=self.cpx_dtype)
if self.uses_density:
density_batched = cp.repeat(
self.density[None, :], self.n_trans, axis=0
).flatten()
img_d = img_d or cp.empty(
(self.n_batchs, self.n_coils, *self.shape),
dtype=self.cpx_dtype,
)
for i in range((self.n_coils * self.n_batchs) // self.n_trans):
density_batched = cp.repeat(self.density[None, :], T, axis=0).flatten()
img_d = img_d or cp.empty((B, C, *XYZ), dtype=self.cpx_dtype)
for i in range((B * C) // T):
if self.uses_density:
cp.copyto(
ksp_batched,
coeffs_f[i * n_trans_samples : (i + 1) * n_trans_samples],
)
cp.copyto(ksp_batched, coeffs_f[i * T * K : (i + 1) * T * K])
ksp_batched *= density_batched
self.__adj_op(get_ptr(ksp_batched), get_ptr(img_d) + i * self.bsize_img)
else:
Expand All @@ -560,28 +555,25 @@ def _adj_op_calibless_device(self, coeffs, img_d=None):
return img_d

def _adj_op_calibless_host(self, coeffs, img_batched=None):
T, B, C = self.n_trans, self.n_batchs, self.n_coils
K, XYZ = self.n_samples, self.shape
coeffs_f = coeffs.flatten()
n_trans_samples = self.n_trans * self.n_samples
ksp_batched = cp.empty(n_trans_samples, dtype=self.cpx_dtype)
ksp_batched = cp.empty(T * K, dtype=self.cpx_dtype)
if self.uses_density:
density_batched = cp.repeat(
self.density[None, :], self.n_trans, axis=0
).flatten()
density_batched = cp.repeat(self.density[None, :], T, axis=0).flatten()

img = np.zeros(
(self.n_batchs * self.n_coils, *self.shape), dtype=self.cpx_dtype
)
img = np.zeros((B * C, *XYZ), dtype=self.cpx_dtype)
if img_batched is None:
img_batched = cp.empty((self.n_trans, *self.shape), dtype=self.cpx_dtype)
img_batched = cp.empty((T, *XYZ), dtype=self.cpx_dtype)
# TODO: Add concurrency compute batch n while copying batch n+1 to device
# and batch n-1 to host
for i in range((self.n_batchs * self.n_coils) // self.n_trans):
ksp_batched.set(coeffs_f[i * n_trans_samples : (i + 1) * n_trans_samples])
for i in range((B * C) // T):
ksp_batched.set(coeffs_f[i * T * K : (i + 1) * T * K])
if self.uses_density:
ksp_batched *= density_batched
self.__adj_op(get_ptr(ksp_batched), get_ptr(img_batched))
img[i * self.n_trans : (i + 1) * self.n_trans] = img_batched.get()
img = img.reshape((self.n_batchs, self.n_coils, *self.shape))
img[i * T : (i + 1) * T] = img_batched.get()
img = img.reshape((B, C, *XYZ))
return img

@nvtx_mark()
Expand Down
Loading