Skip to content

Commit 821244b

Browse files
author
hhsecond
committed
supporting np dtypes
1 parent 8a9b28a commit 821244b

File tree

4 files changed

+20
-10
lines changed

4 files changed

+20
-10
lines changed

example.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import numpy as np
12
from redisai import Client, DType, Device, Backend
23
import ml2rt
34

@@ -7,15 +8,16 @@
78
print(t.value)
89

910
model = ml2rt.load_model('test/testdata/graph.pb')
10-
client.tensorset('a', (2, 3), dtype=DType.float)
11-
client.tensorset('b', (12, 10), dtype=DType.float)
11+
tensor1 = np.array([2, 3], dtype=np.float)
12+
client.tensorset('a', tensor1)
13+
client.tensorset('b', (12, 10), dtype=np.float)
1214
client.modelset('m', Backend.tf,
1315
Device.cpu,
1416
inputs=['a', 'b'],
1517
outputs='mul',
1618
data=model)
1719
client.modelrun('m', ['a', 'b'], ['mul'])
18-
print(client.tensorget('mul').value)
20+
print(client.tensorget('mul'))
1921

2022
# Try with a script
2123
script = ml2rt.load_script('test/testdata/script.txt')

redisai/client.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99

1010
from .constants import Backend, Device, DType
1111
from .utils import str_or_strsequence, to_string
12-
from . import tensorize
12+
from . import convert
1313

1414

1515
class Client(StrictRedis):
@@ -70,7 +70,7 @@ def tensorset(self,
7070
key: AnyStr,
7171
tensor: Union[np.ndarray, list, tuple],
7272
shape: Union[Sequence[int], None] = None,
73-
dtype: Union[DType, None] = None) -> Any:
73+
dtype: Union[DType, type, None] = None) -> Any:
7474
"""
7575
Set the values of the tensor on the server using the provided Tensor object
7676
:param key: The name of the tensor
@@ -79,12 +79,14 @@ def tensorset(self,
7979
:param dtype: data type of the tensor. Required if `tensor` is list or tuple
8080
"""
8181
if np and isinstance(tensor, np.ndarray):
82-
tensor = tensorize.from_numpy(tensor)
82+
tensor = convert.from_numpy(tensor)
8383
args = ['AI.TENSORSET', key, tensor.dtype.value, *tensor.shape, tensor.argname, tensor.value]
8484
elif isinstance(tensor, (list, tuple)):
8585
if shape is None:
8686
shape = (len(tensor),)
87-
tensor = tensorize.from_sequence(tensor, shape, dtype)
87+
if not isinstance(dtype, DType):
88+
dtype = DType.__members__[np.dtype(dtype).name]
89+
tensor = convert.from_sequence(tensor, shape, dtype)
8890
args = ['AI.TENSORSET', key, tensor.dtype.value, *tensor.shape, tensor.argname, *tensor.value]
8991
return self.execute_command(*args)
9092

@@ -112,11 +114,11 @@ def tensorget(self,
112114
res = self.execute_command('AI.TENSORGET', key, argname)
113115
dtype, shape = to_string(res[0]), res[1]
114116
if meta_only:
115-
return tensorize.to_sequence([], shape, dtype)
117+
return convert.to_sequence([], shape, dtype)
116118
if as_numpy is True:
117-
return tensorize.to_numpy(res[2], shape, dtype)
119+
return convert.to_numpy(res[2], shape, dtype)
118120
else:
119-
return tensorize.to_sequence(res[2], shape, dtype)
121+
return convert.to_sequence(res[2], shape, dtype)
120122

121123
def scriptset(self, name: AnyStr, device: Device, script: AnyStr) -> AnyStr:
122124
return self.execute_command('AI.SCRIPTSET', name, device.value, script)
File renamed without changes.

test/test.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,12 @@ def test_set_non_numpy_tensor(self):
2929
self.assertEqual([2, 3, 4, 5], result.value)
3030
self.assertEqual((2, 2), result.shape)
3131

32+
con.tensorset('x', (2, 3, 4, 5), dtype=np.int16, shape=(2, 2))
33+
result = con.tensorget('x', as_numpy=False)
34+
self.assertEqual(DType.int16, result.dtype)
35+
self.assertEqual([2, 3, 4, 5], result.value)
36+
self.assertEqual((2, 2), result.shape)
37+
3238
with self.assertRaises(AttributeError):
3339
con.tensorset('x', (2, 3, 4), dtype=DType.int)
3440

0 commit comments

Comments
 (0)