Skip to content

Commit

Permalink
Merge pull request #136 from mrava87/feature-simplexcuda
Browse files Browse the repository at this point in the history
Feature: added cuda version of Simplex proximal
  • Loading branch information
mrava87 authored Sep 10, 2023
2 parents 516afb8 + 32eae63 commit 488dcbd
Show file tree
Hide file tree
Showing 2 changed files with 134 additions and 7 deletions.
65 changes: 58 additions & 7 deletions pyproximal/proximal/Simplex.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
import logging
import numpy as np

from pylops.utils.backend import get_array_module, to_cupy_conditional
from pyproximal.ProxOperator import _check_tau
from pyproximal import ProxOperator
from pyproximal.projection import SimplexProj

try:
from numba import jit
from ._Simplex_numba import bisect_jit, simplex_jit, fun_jit
from ._Simplex_cuda import bisect_jit_cuda, simplex_jit_cuda, fun_jit_cuda
except ModuleNotFoundError:
jit = None
jit_message = 'Numba not available, reverting to numpy.'
Expand All @@ -20,8 +23,8 @@
class _Simplex(ProxOperator):
"""Simplex operator (numpy version)
"""
def __init__(self, n, radius, dims=None, axis=-1, maxiter=100, xtol=1e-8,
call=True):
def __init__(self, n, radius, dims=None, axis=-1,
maxiter=100, xtol=1e-8, call=True):
super().__init__(None, False)
if dims is not None and len(dims) != 2:
raise ValueError('provide only 2 dimensions, or None')
Expand Down Expand Up @@ -90,6 +93,7 @@ def __init__(self, n, radius, dims=None, axis=-1,
self.xtol = xtol
self.call = call

@_check_tau
def prox(self, x, tau):
if self.dims is None:
bisect_lower = -1
Expand All @@ -113,6 +117,50 @@ def prox(self, x, tau):
return y.ravel()


class _Simplex_cuda(_Simplex):
"""Simplex operator (cuda version)
This implementation is adapted from https://github.com/DIG-Kaust/HPC_Hackathon_DIG.
"""
def __init__(self, n, radius, dims=None, axis=-1,
maxiter=100, ftol=1e-8, xtol=1e-8, call=False,
num_threads_per_blocks=32):
super().__init__(None, False)
if dims is not None and len(dims) != 2:
raise ValueError('provide only 2 dimensions, or None')
self.n = n
# self.coeffs = cuda.to_device(np.ones(self.n if dims is None else dims[axis]))
self.coeffs = np.ones(self.n if dims is None else dims[axis])
self.radius = radius
self.dims = dims
self.axis = axis
self.otheraxis = 1 if axis == 0 else 0
self.maxiter = maxiter
self.ftol = ftol
self.xtol = xtol
self.call = call
self.num_threads_per_blocks = num_threads_per_blocks

@_check_tau
def prox(self, x, tau):
ncp = get_array_module(x)
x = x.reshape(self.dims)
if self.axis == 0:
x = x.T
if type(self.coeffs) != type(x):
self.coeffs = to_cupy_conditional(x, self.coeffs)

y = ncp.empty_like(x)
num_blocks = (x.shape[0] + self.num_threads_per_blocks - 1) // self.num_threads_per_blocks
simplex_jit_cuda[num_blocks, self.num_threads_per_blocks](x, self.coeffs, self.radius,
0, 10000000000, self.maxiter,
self.ftol, self.xtol, y)
if self.axis == 0:
y = y.T
return y.ravel()


def Simplex(n, radius, dims=None, axis=-1, maxiter=100,
ftol=1e-8, xtol=1e-8, call=True, engine='numpy'):
r"""Simplex proximal operator.
Expand All @@ -137,18 +185,18 @@ def Simplex(n, radius, dims=None, axis=-1, maxiter=100,
maxiter : :obj:`int`, optional
Maximum number of iterations used by bisection
ftol : :obj:`float`, optional
Function tolerance in bisection (only with ``engine='numba'``)
Function tolerance in bisection (only with ``engine='numba'`` or ``engine='cuda'``)
xtol : :obj:`float`, optional
Solution absolute tolerance in bisection
call : :obj:`bool`, optional
Evalutate call method (``True``) or not (``False``)
engine : :obj:`str`, optional
Engine used for simplex computation (``numpy`` or ``numba``).
Engine used for simplex computation (``numpy``, ``numba``or ``cuda``).
Raises
------
KeyError
If ``engine`` is neither ``numpy`` nor ``numba``
If ``engine`` is neither ``numpy`` nor ``numba`` nor ``cuda``
ValueError
If ``dims`` is provided as a list (or tuple) with more or less than
2 elements
Expand All @@ -163,12 +211,15 @@ def Simplex(n, radius, dims=None, axis=-1, maxiter=100,
positive number can be provided.
"""
if not engine in ['numpy', 'numba']:
raise KeyError('engine must be numpy or numba')
if not engine in ['numpy', 'numba', 'cuda']:
raise KeyError('engine must be numpy or numba or cuda')

if engine == 'numba' and jit is not None:
s = _Simplex_numba(n, radius, dims=dims, axis=axis,
maxiter=maxiter, ftol=ftol, xtol=xtol, call=call)
elif engine == 'cuda' and jit is not None:
s = _Simplex_cuda(n, radius, dims=dims, axis=axis,
maxiter=maxiter, ftol=ftol, xtol=xtol, call=call)
else:
if engine == 'numba' and jit is None:
logging.warning(jit_message)
Expand Down
76 changes: 76 additions & 0 deletions pyproximal/proximal/_Simplex_cuda.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
from numba import cuda


@cuda.jit(device=True)
def fun_jit_cuda(mu, x, coeffs, scalar, lower, upper):
"""Bisection function"""
p = 0
for i in range(coeffs.shape[0]):
p += coeffs[i] * min(max(x[i] - mu * coeffs[i], lower), upper)
return p - scalar


@cuda.jit(device=True)
def bisect_jit_cuda(x, coeffs, scalar, lower, upper, bisect_lower, bisect_upper,
maxiter, ftol, xtol):
"""Bisection method (See _Simplex_numba for details).
"""
a, b = bisect_lower, bisect_upper
fa = fun_jit_cuda(a, x, coeffs, scalar, lower, upper)
for iiter in range(maxiter):
c = (a + b) / 2.
if (b - a) / 2 < xtol:
return c
fc = fun_jit_cuda(c, x, coeffs, scalar, lower, upper)
if abs(fc) < ftol:
return c
if fc / abs(fc) == fa / abs(fa):
a = c
fa = fc
else:
b = c
return c


@cuda.jit
def simplex_jit_cuda(x, coeffs, scalar, lower, upper, maxiter, ftol, xtol, y):
"""Simplex proximal
Parameters
----------
x : :obj:`np.ndarray`
Input vector
coeffs : :obj:`np.ndarray`
Vector of coefficients used in the definition of the hyperplane
scalar : :obj:`float`
Scalar used in the definition of the hyperplane
lower : :obj:`float` or :obj:`np.ndarray`, optional
Lower bound of Box
upper : :obj:`float` or :obj:`np.ndarray`, optional
Upper bound of Box
maxiter : :obj:`int`, optional
Maximum number of iterations
ftol : :obj:`float`, optional
Function tolerance
xtol : :obj:`float`, optional
Solution absolute tolerance
y : :obj:`np.ndarray`
Output vector
"""
i = cuda.threadIdx.x + cuda.blockIdx.x * cuda.blockDim.x

if i < x.shape[0]:
bisect_lower = -1
while fun_jit_cuda(bisect_lower, x[i], coeffs, scalar, lower, upper) < 0:
bisect_lower *= 2
bisect_upper = 1
while fun_jit_cuda(bisect_upper, x[i], coeffs, scalar, lower, upper) > 0:
bisect_upper *= 2

c = bisect_jit_cuda(x[i], coeffs, scalar, lower, upper,
bisect_lower, bisect_upper, maxiter, ftol, xtol)

for j in range(coeffs.shape[0]):
y[i][j] = min(max(x[i][j] - c * coeffs[j], lower), upper)

0 comments on commit 488dcbd

Please sign in to comment.