Skip to content

Commit

Permalink
changed shift to be single value and added memory management features
Browse files Browse the repository at this point in the history
  • Loading branch information
brunomsaraiva committed Oct 12, 2023
1 parent 95a275b commit 6823ab9
Show file tree
Hide file tree
Showing 2 changed files with 89 additions and 62 deletions.
137 changes: 82 additions & 55 deletions src/nanopyx/core/transform/_le_interpolation_nearest_neighbor.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ class ShiftAndMagnify(LiquidEngine):

# tag-start: _le_interpolation_nearest_neighbor.ShiftAndMagnify._run_opencl
#@LiquidEngine._logger(logger)
def _run_opencl(self, image, shift_row, shift_col, float magnification_row, float magnification_col, dict device) -> np.ndarray:
def _run_opencl(self, image, shift_row, shift_col, float magnification_row, float magnification_col, dict device, int mem_div=1) -> np.ndarray:

# QUEUE AND CONTEXT
cl_ctx = cl.Context([device['device']])
Expand All @@ -95,14 +95,13 @@ class ShiftAndMagnify(LiquidEngine):
output_shape = (image.shape[0], int(image.shape[1]*magnification_row), int(image.shape[2]*magnification_col))
image_out = np.zeros(output_shape, dtype=np.float32)

# TODO 3 is a magic number
max_slices = int((dc.global_mem_size // (image_out[0,:,:].nbytes + image[0,:,:].nbytes))/3)
# TODO add exception if max_slices < 1
max_slices = int((dc.global_mem_size // (image_out[0,:,:].nbytes + image[0,:,:].nbytes))/mem_div)
max_slices = self._check_max_slices(image, max_slices)

mf = cl.mem_flags
input_opencl = cl.Buffer(cl_ctx, mf.READ_ONLY, image[0:max_slices,:,:].nbytes)
cl.enqueue_copy(cl_queue, input_opencl, image[0:max_slices,:,:]).wait()
output_opencl = cl.Buffer(cl_ctx, mf.WRITE_ONLY, image_out[0:max_slices,:,:].nbytes)
cl.enqueue_copy(cl_queue, input_opencl, image[0:max_slices,:,:]).wait()

code = self._get_cl_code("_le_interpolation_nearest_neighbor_.cl", device['DP'])
prg = cl.Program(cl_ctx, code).build()
Expand Down Expand Up @@ -355,44 +354,56 @@ class ShiftScaleRotate(LiquidEngine):


# tag-start: _le_interpolation_nearest_neighbor.ShiftScaleRotate._run_opencl
def _run_opencl(self, image, shift_row, shift_col, float scale_row, float scale_col, float angle, dict device) -> np.ndarray:
def _run_opencl(self, image, shift_row, shift_col, float scale_row, float scale_col, float angle, dict device, int mem_div=1) -> np.ndarray:

# QUEUE AND CONTEXT
cl_ctx = cl.Context([device['device']])
dc = device["device"]
cl_queue = cl.CommandQueue(cl_ctx)

code = self._get_cl_code("_le_interpolation_nearest_neighbor_.cl", device['DP'])
output_shape = (image.shape[0], int(image.shape[1]), int(image.shape[2]))
image_out = np.zeros(output_shape, dtype=np.float32)

cdef int nFrames = image.shape[0]
cdef int rowsM = image.shape[1]
cdef int colsM = image.shape[2]
max_slices = int((dc.global_mem_size // (image_out[0,:,:].nbytes + image[0,:,:].nbytes))/mem_div)
max_slices = self._check_max_slices(image, max_slices)

image_in = cl_array.to_device(cl_queue, image)
shift_col_in = cl_array.to_device(cl_queue, shift_col)
shift_row_in = cl_array.to_device(cl_queue, shift_row)
image_out = cl_array.zeros(cl_queue, (nFrames, rowsM, colsM), dtype=np.float32)
mf = cl.mem_flags
input_opencl = cl.Buffer(cl_ctx, mf.READ_ONLY, image[0:max_slices,:,:].nbytes)
output_opencl = cl.Buffer(cl_ctx, mf.WRITE_ONLY, image_out[0:max_slices,:,:].nbytes)
cl.enqueue_copy(cl_queue, input_opencl, image[0:max_slices,:,:]).wait()

# Create the program
code = self._get_cl_code("_le_interpolation_nearest_neighbor_.cl", device['DP'])
prg = cl.Program(cl_ctx, code).build()
knl = prg.shiftScaleRotate

for i in range(0, image.shape[0], max_slices):
if image.shape[0] - i >= max_slices:
n_slices = max_slices
else:
n_slices = image.shape[0] - i
knl(
cl_queue,
(n_slices, int(image.shape[1]), int(image.shape[2])),
self.get_work_group(dc, (n_slices, image.shape[1], image.shape[2])),
input_opencl,
output_opencl,
np.float32(shift_row),
np.float32(shift_col),
np.float32(scale_row),
np.float32(scale_col),
np.float32(angle)
).wait()

cl.enqueue_copy(cl_queue, image_out[i:i+n_slices,:,:], output_opencl).wait()
if i<=image.shape[0]-max_slices:
cl.enqueue_copy(cl_queue, input_opencl, image[i+n_slices:i+2*n_slices,:,:]).wait()

cl_queue.finish()

# Run the kernel
prg.shiftScaleRotate(
cl_queue,
image_out.shape,
None,
image_in.data,
image_out.data,
shift_row_in.data,
shift_col_in.data,
np.float32(scale_row),
np.float32(scale_col),
np.float32(angle)
)

# Wait for queue to finish
cl_queue.finish()

return np.asarray(image_out.get(),dtype=np.float32)
input_opencl.release()
output_opencl.release()

return image_out

# tag-end

Expand Down Expand Up @@ -646,44 +657,60 @@ class PolarTransform(LiquidEngine):
# tag-end

# tag-start: _le_interpolation_nearest_neighbor.PolarTransform._run_opencl
def _run_opencl(self, image, int nrow, int ncol, str scale, dict device):
def _run_opencl(self, image, int nrow, int ncol, str scale, dict device, int mem_div=1):

# QUEUE AND CONTEXT
cl_ctx = cl.Context([device['device']])
cl_queue = cl.CommandQueue(cl_ctx)

code = self._get_cl_code("_le_interpolation_nearest_neighbor_.cl", device['DP'])

cdef int nFrames = image.shape[0]
cdef int nRows = image.shape[1]
cdef int nCols = image.shape[2]

image_in = cl_array.to_device(cl_queue, image)
image_out = cl_array.zeros(cl_queue, (nFrames, nrow, ncol), dtype=np.float32)
output = np.zeros((nFrames, nrow, ncol), dtype=np.float32)

max_slices = int((device["device"].global_mem_size // (output[0,:,:].nbytes + image[0,:,:].nbytes))/mem_div)
max_slices = self._check_max_slices(image, max_slices)
image_in = cl.Buffer(cl_ctx, cl.mem_flags.READ_ONLY, image[0:max_slices,:,:].nbytes)
image_out = cl.Buffer(cl_ctx, cl.mem_flags.WRITE_ONLY, output[0:max_slices,:,:].nbytes)
cl.enqueue_copy(cl_queue, image_in, image[0:max_slices,:,:]).wait()

cdef int scale_int = 0
if scale == 'log':
scale_int = 1

# Create the program
# Create the program
code = self._get_cl_code("_le_interpolation_nearest_neighbor_.cl", device['DP'])
prg = cl.Program(cl_ctx, code).build()
knl = prg.PolarTransform

for i in range(0, image.shape[0], max_slices):
if image.shape[0] - i >= max_slices:
n_slices = max_slices
else:
n_slices = image.shape[0] - i

knl(
cl_queue,
(n_slices, nrow, ncol),
self.get_work_group(device["device"], (n_slices, nrow, ncol)),
image_in,
image_out,
np.int32(nRows),
np.int32(nCols),
np.int32(scale_int)
)

cl.enqueue_copy(cl_queue, output[i:i+n_slices,:,:], image_out).wait()
if i<=image.shape[0]-max_slices:
cl.enqueue_copy(cl_queue, image_in, image[i+n_slices:i+2*n_slices,:,:]).wait()

cl_queue.finish()

image_in.release()
image_out.release()

# Run the kernel
prg.PolarTransform(
cl_queue,
image_out.shape,
None,
image_in.data,
image_out.data,
np.int32(nRows),
np.int32(nCols),
np.int32(scale_int)
)

# Wait for queue to finish
cl_queue.finish()

return np.asarray(image_out.get(), dtype=np.float32)
return output
# tag-end

# tag-start: _le_interpolation_nearest_neighbor.PolarTransform._run_unthreaded
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -35,9 +35,9 @@ shiftAndMagnify(__global float *image_in, __global float *image_out,

__kernel void shiftScaleRotate(__global float *image_in,
__global float *image_out,
__global float *shift_row,
__global float *shift_col, float scale_row,
float scale_col, float angle) {
float shift_row,
float shift_col, float scale_row,
float scale_col, float angle) {
// these are the indexes of the loop
int f = get_global_id(0);
int rM = get_global_id(1);
Expand All @@ -58,11 +58,11 @@ __kernel void shiftScaleRotate(__global float *image_in,

int nPixels = rows * cols;

float col = (a * (cM - center_col - shift_col[f]) +
b * (rM - center_row - shift_row[f])) +
float col = (a * (cM - center_col - shift_col) +
b * (rM - center_row - shift_row)) +
center_col;
float row = (c * (cM - center_col - shift_col[f]) +
d * (rM - center_row - shift_row[f])) +
float row = (c * (cM - center_col - shift_col) +
d * (rM - center_row - shift_row)) +
center_row;

image_out[f * nPixels + rM * cols + cM] =
Expand Down

0 comments on commit 6823ab9

Please sign in to comment.