diff --git a/example.py b/example.py index 75dfceb..cc0f7c8 100644 --- a/example.py +++ b/example.py @@ -1,32 +1,26 @@ -from redisai import Client, Tensor, \ - BlobTensor, DType, Device, Backend +import numpy as np +from redisai import Client, DType, Device, Backend import ml2rt client = Client() -client.tensorset('x', Tensor(DType.float, [2], [2, 3])) +client.tensorset('x', [2, 3], dtype=DType.float) t = client.tensorget('x') print(t.value) model = ml2rt.load_model('test/testdata/graph.pb') -client.tensorset('a', Tensor.scalar(DType.float, 2, 3)) -client.tensorset('b', Tensor.scalar(DType.float, 12, 10)) +tensor1 = np.array([2, 3], dtype=np.float) +client.tensorset('a', tensor1) +client.tensorset('b', (12, 10), dtype=np.float) client.modelset('m', Backend.tf, Device.cpu, - input=['a', 'b'], - output='mul', + inputs=['a', 'b'], + outputs='mul', data=model) client.modelrun('m', ['a', 'b'], ['mul']) -print(client.tensorget('mul').value) +print(client.tensorget('mul')) # Try with a script script = ml2rt.load_script('test/testdata/script.txt') client.scriptset('ket', Device.cpu, script) -client.scriptrun('ket', 'bar', input=['a', 'b'], output='c') +client.scriptrun('ket', 'bar', inputs=['a', 'b'], outputs='c') -b1 = client.tensorget('c', as_type=BlobTensor) -b2 = client.tensorget('c', as_type=BlobTensor) - -client.tensorset('d', BlobTensor(DType.float, b1.shape, b1, b2)) - -tnp = b1.to_numpy() -client.tensorset('e', tnp) diff --git a/redisai/__init__.py b/redisai/__init__.py index 03b6d95..cf7a51b 100644 --- a/redisai/__init__.py +++ b/redisai/__init__.py @@ -1,4 +1,3 @@ from .version import __version__ from .client import Client -from .tensor import Tensor, BlobTensor from .constants import DType, Device, Backend diff --git a/redisai/client.py b/redisai/client.py index f5e80a3..3392e41 100644 --- a/redisai/client.py +++ b/redisai/client.py @@ -1,5 +1,6 @@ from redis import StrictRedis -from typing import Union, Any, AnyStr, ByteString, Sequence, Type +from typing import Union, Any, AnyStr, ByteString, Sequence +from .containers import Script, Model, Tensor try: import numpy as np @@ -8,7 +9,7 @@ from .constants import Backend, Device, DType from .utils import str_or_strsequence, to_string -from .tensor import Tensor, BlobTensor +from . import convert class Client(StrictRedis): @@ -45,13 +46,12 @@ def modelset(self, args += [data] return self.execute_command(*args) - def modelget(self, name: AnyStr) -> dict: + def modelget(self, name: AnyStr) -> Model: rv = self.execute_command('AI.MODELGET', name) - return { - 'backend': Backend(to_string(rv[0])), - 'device': Device(to_string(rv[1])), - 'data': rv[2] - } + return Model( + rv[2], + Device(to_string(rv[1])), + Backend(to_string(rv[0]))) def modeldel(self, name: AnyStr) -> AnyStr: return self.execute_command('AI.MODELDEL', name) @@ -68,71 +68,66 @@ def modelrun(self, def tensorset(self, key: AnyStr, - tensor: Union[Tensor, np.ndarray, list, tuple], + tensor: Union[np.ndarray, list, tuple], shape: Union[Sequence[int], None] = None, - dtype: Union[DType, None] = None) -> Any: + dtype: Union[DType, type, None] = None) -> Any: """ Set the values of the tensor on the server using the provided Tensor object :param key: The name of the tensor - :param tensor: a `Tensor` object - :param shape: Shape of the tensor - :param dtype: data type of the tensor. Required if input is a sequence of ints/floats + :param tensor: a `np.ndarray` object or python list or tuple + :param shape: Shape of the tensor. Required if `tensor` is list or tuple + :param dtype: data type of the tensor. Required if `tensor` is list or tuple """ - # TODO: tensorset will not accept BlobTensor or Tensor object in the future. - # Keeping it in the current version for compatibility with the example repo if np and isinstance(tensor, np.ndarray): - tensor = BlobTensor.from_numpy(tensor) + tensor = convert.from_numpy(tensor) + args = ['AI.TENSORSET', key, tensor.dtype.value, *tensor.shape, tensor.argname, tensor.value] elif isinstance(tensor, (list, tuple)): if shape is None: shape = (len(tensor),) - tensor = Tensor(dtype, shape, tensor) - args = ['AI.TENSORSET', key, tensor.type.value] - args += tensor.shape - args += [tensor.ARGNAME] - args += tensor.value + if not isinstance(dtype, DType): + dtype = DType.__members__[np.dtype(dtype).name] + tensor = convert.from_sequence(tensor, shape, dtype) + args = ['AI.TENSORSET', key, tensor.dtype.value, *tensor.shape, tensor.argname, *tensor.value] return self.execute_command(*args) def tensorget(self, - key: AnyStr, as_type: Type[Tensor] = None, - meta_only: bool = False) -> Union[Tensor, BlobTensor]: + key: AnyStr, as_numpy: bool = True, + meta_only: bool = False) -> Union[Tensor, np.ndarray]: """ Retrieve the value of a tensor from the server. By default it returns the numpy array but it can be controlled using `as_type` argument and `meta_only` argument. :param key: the name of the tensor - :param as_type: the resultant tensor type. Returns numpy array if None + :param as_numpy: Should it return data as numpy.ndarray. + Wraps with namedtuple if False. This flag also decides how to fetch the + value from RedisAI server and could have performance implications :param meta_only: if true, then the value is not retrieved, only the shape and the type :return: an instance of as_type """ - # TODO; We might remove Tensor & BlobTensor in the future and `tensorget` will return - # python list or numpy arrays or a namedtuple if meta_only: argname = 'META' - elif as_type is None: - argname = BlobTensor.ARGNAME + elif as_numpy is True: + argname = 'BLOB' else: - argname = as_type.ARGNAME + argname = 'VALUES' res = self.execute_command('AI.TENSORGET', key, argname) dtype, shape = to_string(res[0]), res[1] - dt = DType.__members__[dtype.lower()] if meta_only: - return Tensor(dt, shape, []) - elif as_type is None: - return BlobTensor.from_resp(dt, shape, res[2]).to_numpy() + return convert.to_sequence([], shape, dtype) + if as_numpy is True: + return convert.to_numpy(res[2], shape, dtype) else: - return as_type.from_resp(dt, shape, res[2]) + return convert.to_sequence(res[2], shape, dtype) def scriptset(self, name: AnyStr, device: Device, script: AnyStr) -> AnyStr: return self.execute_command('AI.SCRIPTSET', name, device.value, script) - def scriptget(self, name: AnyStr) -> dict: + def scriptget(self, name: AnyStr) -> Script: r = self.execute_command('AI.SCRIPTGET', name) - device = Device(to_string(r[0])) - return { - 'device': device, - 'script': to_string(r[1]) - } + return Script( + to_string(r[1]), + Device(to_string(r[0]))) def scriptdel(self, name): return self.execute_command('AI.SCRIPTDEL', name) diff --git a/redisai/constants.py b/redisai/constants.py index 82f9f81..a271136 100644 --- a/redisai/constants.py +++ b/redisai/constants.py @@ -10,6 +10,7 @@ class Backend(Enum): tf = 'TF' torch = 'TORCH' onnx = 'ONNX' + tflite = 'TFLITE' class DType(Enum): diff --git a/redisai/containers.py b/redisai/containers.py new file mode 100644 index 0000000..a5bba64 --- /dev/null +++ b/redisai/containers.py @@ -0,0 +1,5 @@ +from collections import namedtuple + +Tensor = namedtuple('Tensor', field_names=['value', 'shape', 'dtype', 'argname']) +Script = namedtuple('Script', field_names=['script', 'device']) +Model = namedtuple('Model', field_names=['data', 'device', 'backend']) diff --git a/redisai/convert.py b/redisai/convert.py new file mode 100644 index 0000000..5e41cc5 --- /dev/null +++ b/redisai/convert.py @@ -0,0 +1,43 @@ +from typing import Union, ByteString, Sequence +from .utils import convert_to_num +from .constants import DType +from .containers import Tensor +try: + import numpy as np +except (ImportError, ModuleNotFoundError): + np = None + + +def from_numpy(tensor: np.ndarray) -> Tensor: + """ Convert the numpy input from user to `Tensor` """ + dtype = DType.__members__[str(tensor.dtype)] + shape = tensor.shape + blob = bytes(tensor.data) + return Tensor(blob, shape, dtype, 'BLOB') + + +def from_sequence(tensor: Sequence, shape: Union[list, tuple], dtype: DType) -> Tensor: + """ Convert the `list`/`tuple` input from user to `Tensor` """ + return Tensor(tensor, shape, dtype, 'VALUES') + + +def to_numpy(value: ByteString, shape: Union[list, tuple], dtype: DType) -> np.ndarray: + """ Convert `BLOB` result from RedisAI to `np.ndarray` """ + dtype = DType.__members__[dtype.lower()].value + mm = { + 'FLOAT': 'float32', + 'DOUBLE': 'float64' + } + if dtype in mm: + dtype = mm[dtype] + else: + dtype = dtype.lower() + a = np.frombuffer(value, dtype=dtype) + return a.reshape(shape) + + +def to_sequence(value: list, shape: list, dtype: DType) -> Tensor: + """ Convert `VALUES` result from RedisAI to `Tensor` """ + dtype = DType.__members__[dtype.lower()] + convert_to_num(dtype, value) + return Tensor(value, tuple(shape), dtype, 'VALUES') diff --git a/redisai/tensor.py b/redisai/tensor.py deleted file mode 100644 index 7927855..0000000 --- a/redisai/tensor.py +++ /dev/null @@ -1,120 +0,0 @@ -from typing import Union, ByteString, Sequence -import warnings -from .utils import convert_to_num -from .constants import DType -try: - import numpy as np -except (ImportError, ModuleNotFoundError): - np = None - - -class Tensor(object): - ARGNAME = 'VALUES' - - def __init__(self, - dtype: DType, - shape: Sequence[int], - value): - warnings.warn("Tensor APIs are depricated and " - "will be removed from the future release.", UserWarning) - """ - Declare a tensor suitable for passing to tensorset - :param dtype: The type the values should be stored as. - This can be one of Tensor.FLOAT, tensor.DOUBLE, etc. - :param shape: An array describing the shape of the tensor. For an - image 250x250 with three channels, this would be [250, 250, 3] - :param value: The value for the tensor. Can be an array. - The contents must coordinate with the shape, meaning that the - overall length needs to be the product of all figures in the - shape. There is no verification to ensure that each dimension - is correct. Your application must ensure that the ordering - is always consistent. - """ - self.type = dtype - self.shape = list(shape) - self.value = value - if not isinstance(value, (list, tuple)): - self.value = [value] - - def __repr__(self): - return '<{c.__class__.__name__}(shape={s} type={t}) at 0x{id:x}>'.format( - c=self, - s=self.shape, - t=self.type, - id=id(self)) - - @classmethod - def from_resp(cls, dtype: DType, shape: Sequence[int], value) -> 'Tensor': - convert_to_num(dtype, value) - return cls(dtype, shape, value) - - @classmethod - def scalar(cls, dtype: DType, *items) -> 'Tensor': - """ - Create a tensor from a list of numbers - :param dtype: Type to use for storage - :param items: One or more items - :return: Tensor - """ - return cls(dtype, [len(items)], items) - - -class BlobTensor(Tensor): - ARGNAME = 'BLOB' - - def __init__(self, - dtype: DType, - shape: Sequence[int], - *blobs: Union['BlobTensor', ByteString] - ): - """ - Create a tensor from a binary blob - :param dtype: The datatype, one of Tensor.FLOAT, Tensor.DOUBLE, etc. - :param shape: An array - :param blobs: One or more blobs to assign to the tensor. - """ - if len(blobs) > 1: - blobarr = bytearray() - for b in blobs: - if isinstance(b, BlobTensor): - b = b.value[0] - blobarr += b - size = len(blobs) - ret_blobs = bytes(blobarr) - shape = [size] + list(shape) - else: - ret_blobs = bytes(blobs[0]) - - super(BlobTensor, self).__init__(dtype, shape, ret_blobs) - - @classmethod - def from_numpy(cls, *nparrs) -> 'BlobTensor': - blobs = [] - for arr in nparrs: - blobs.append(arr.data) - dt = DType.__members__[str(nparrs[0].dtype)] - return cls(dt, nparrs[0].shape, *blobs) - - @property - def blob(self): - return self.value[0] - - def to_numpy(self) -> np.array: - a = np.frombuffer(self.value[0], dtype=self._to_numpy_type(self.type)) - return a.reshape(self.shape) - - @staticmethod - def _to_numpy_type(t): - if isinstance(t, DType): - t = t.value - mm = { - 'FLOAT': 'float32', - 'DOUBLE': 'float64' - } - if t in mm: - return mm[t] - return t.lower() - - @classmethod - def from_resp(cls, dtype, shape, value) -> 'BlobTensor': - return cls(dtype, shape, value) diff --git a/redisai/version.py b/redisai/version.py index fa88a30..84d9a97 100644 --- a/redisai/version.py +++ b/redisai/version.py @@ -2,4 +2,4 @@ # 1) we don't load dependencies by storing it in __init__.py # 2) we can import it in setup.py for the same reason # 3) we can import it into your module module -__version__ = '0.4.1' +__version__ = '0.5.0' diff --git a/setup.py b/setup.py index 3235d7c..df0b604 100644 --- a/setup.py +++ b/setup.py @@ -1,4 +1,3 @@ - #!/usr/bin/env python from setuptools import setup, find_packages diff --git a/test/test.py b/test/test.py index 33c1c06..a24c412 100644 --- a/test/test.py +++ b/test/test.py @@ -1,7 +1,7 @@ from unittest import TestCase import numpy as np import os.path -from redisai import Client, DType, Backend, Device, Tensor, BlobTensor +from redisai import Client, DType, Backend, Device from ml2rt import load_model from redis.exceptions import ResponseError @@ -9,14 +9,6 @@ MODEL_DIR = os.path.dirname(os.path.abspath(__file__)) + '/testdata' -class TensorTestCase(TestCase): - def testTensorShapes(self): - t = Tensor(DType.float, [4], [1, 2, 3, 4]) - self.assertEqual([4], t.shape) - t = BlobTensor.from_numpy(np.array([[1, 2, 3], [4, 5, 6]])) - self.assertEqual([2, 3], t.shape) - - class ClientTestCase(TestCase): def setUp(self): super(ClientTestCase, self).setUp() @@ -25,35 +17,49 @@ def setUp(self): def get_client(self): return Client() - def test_set_tensor(self): + def test_set_non_numpy_tensor(self): con = self.get_client() - con.tensorset('x', (2, 3), dtype=DType.float) - values = con.tensorget('x', as_type=Tensor) - self.assertEqual([2, 3], values.value) - - con.tensorset('x', Tensor.scalar(DType.int32, 2, 3)) - values = con.tensorget('x', as_type=Tensor).value - self.assertEqual([2, 3], values) - meta = con.tensorget('x', meta_only=True) - self.assertTrue('