@@ -122,10 +122,25 @@ RAI_Tensor *RAI_TensorCreateWithDLDataType(DLDataType dtype, long long *dims, in
122122 strides [i ] *= strides [i + 1 ] * shape [i + 1 ];
123123 }
124124
125- DLContext ctx = (DLContext ){.device_type = kDLCPU , .device_id = 0 };
126125 void * data = RedisModule_Calloc (len , dtypeSize );
127-
128- ret -> tensor = (DLManagedTensor ){.dl_tensor = (DLTensor ){.ctx = ctx ,
126+ DLDevice device = (DLDevice ){.device_type = kDLCPU , .device_id = 0 };
127+ void * data = NULL ;
128+ switch (tensorAllocMode ) {
129+ case TENSORALLOC_ALLOC :
130+ data = RedisModule_Alloc (len * dtypeSize );
131+ break ;
132+ case TENSORALLOC_CALLOC :
133+ data = RedisModule_Calloc (len , dtypeSize );
134+ break ;
135+ case TENSORALLOC_NONE :
136+ /* shallow copy no alloc */
137+ default :
138+ /* assume TENSORALLOC_NONE
139+ shallow copy no alloc */
140+ break ;
141+ }
142+
143+ ret -> tensor = (DLManagedTensor ){.dl_tensor = (DLTensor ){.device = device ,
129144 .data = data ,
130145 .ndim = ndims ,
131146 .dtype = dtype ,
@@ -170,7 +185,7 @@ RAI_Tensor *_TensorCreateWithDLDataTypeAndRString(DLDataType dtype, size_t dtype
170185 strides [i ] *= strides [i + 1 ] * shape [i + 1 ];
171186 }
172187
173- DLContext ctx = (DLContext ){.device_type = kDLCPU , .device_id = 0 };
188+ DLDevice device = (DLDevice ){.device_type = kDLCPU , .device_id = 0 };
174189 size_t nbytes = len * dtypeSize ;
175190
176191 size_t blob_len ;
@@ -186,7 +201,7 @@ RAI_Tensor *_TensorCreateWithDLDataTypeAndRString(DLDataType dtype, size_t dtype
186201 RAI_HoldString (NULL , rstr );
187202
188203 RAI_Tensor * ret = RAI_TensorNew ();
189- ret -> tensor = (DLManagedTensor ){.dl_tensor = (DLTensor ){.ctx = ctx ,
204+ ret -> tensor = (DLManagedTensor ){.dl_tensor = (DLTensor ){.device = device ,
190205 .data = data ,
191206 .ndim = ndims ,
192207 .dtype = dtype ,
@@ -327,7 +342,7 @@ RAI_Tensor *RAI_TensorCreateFromDLTensor(DLManagedTensor *dl_tensor) {
327342 RAI_Tensor * ret = RAI_TensorNew ();
328343
329344 ret -> tensor =
330- (DLManagedTensor ){.dl_tensor = (DLTensor ){.ctx = dl_tensor -> dl_tensor .ctx ,
345+ (DLManagedTensor ){.dl_tensor = (DLTensor ){.device = dl_tensor -> dl_tensor .device ,
331346 .data = dl_tensor -> dl_tensor .data ,
332347 .ndim = dl_tensor -> dl_tensor .ndim ,
333348 .dtype = dl_tensor -> dl_tensor .dtype ,
0 commit comments