99
1010from .constants import Backend , Device , DType
1111from .utils import str_or_strsequence , to_string
12- from . import tensorize
12+ from . import convert
1313
1414
1515class Client (StrictRedis ):
@@ -70,7 +70,7 @@ def tensorset(self,
7070 key : AnyStr ,
7171 tensor : Union [np .ndarray , list , tuple ],
7272 shape : Union [Sequence [int ], None ] = None ,
73- dtype : Union [DType , None ] = None ) -> Any :
73+ dtype : Union [DType , type , None ] = None ) -> Any :
7474 """
7575 Set the values of the tensor on the server using the provided Tensor object
7676 :param key: The name of the tensor
@@ -79,12 +79,14 @@ def tensorset(self,
7979 :param dtype: data type of the tensor. Required if `tensor` is list or tuple
8080 """
8181 if np and isinstance (tensor , np .ndarray ):
82- tensor = tensorize .from_numpy (tensor )
82+ tensor = convert .from_numpy (tensor )
8383 args = ['AI.TENSORSET' , key , tensor .dtype .value , * tensor .shape , tensor .argname , tensor .value ]
8484 elif isinstance (tensor , (list , tuple )):
8585 if shape is None :
8686 shape = (len (tensor ),)
87- tensor = tensorize .from_sequence (tensor , shape , dtype )
87+ if not isinstance (dtype , DType ):
88+ dtype = DType .__members__ [np .dtype (dtype ).name ]
89+ tensor = convert .from_sequence (tensor , shape , dtype )
8890 args = ['AI.TENSORSET' , key , tensor .dtype .value , * tensor .shape , tensor .argname , * tensor .value ]
8991 return self .execute_command (* args )
9092
@@ -112,11 +114,11 @@ def tensorget(self,
112114 res = self .execute_command ('AI.TENSORGET' , key , argname )
113115 dtype , shape = to_string (res [0 ]), res [1 ]
114116 if meta_only :
115- return tensorize .to_sequence ([], shape , dtype )
117+ return convert .to_sequence ([], shape , dtype )
116118 if as_numpy is True :
117- return tensorize .to_numpy (res [2 ], shape , dtype )
119+ return convert .to_numpy (res [2 ], shape , dtype )
118120 else :
119- return tensorize .to_sequence (res [2 ], shape , dtype )
121+ return convert .to_sequence (res [2 ], shape , dtype )
120122
121123 def scriptset (self , name : AnyStr , device : Device , script : AnyStr ) -> AnyStr :
122124 return self .execute_command ('AI.SCRIPTSET' , name , device .value , script )
0 commit comments