33#include "background_workers.h"
44#include "util/string_utils.h"
55
6- void _DAG_SetTensorsInLocalContext (RedisAI_RunInfo * rinfo ) {
7- for (size_t i = 0 ; i < rinfo -> dagOpCount ; i ++ ) {
8- RAI_DagOp * op = rinfo -> dagOps [i ];
9- if (op -> commandType == REDISAI_DAG_CMD_TENSORSET ) {
10- // Insert the tensor with its mangled (unique) name.
11- void * t = (void * )RAI_TensorGetShallowCopy (op -> outTensor );
12- AI_dictReplace (rinfo -> dagTensorsContext , (void * )op -> outkeys [0 ], t );
13- }
14- }
15- }
16-
17- int MangleTensorsNames (RedisAI_RunInfo * rinfo ) {
18-
19- int res = REDISMODULE_ERR ;
20- AI_dict * mangled_tensors = AI_dictCreate (& AI_dictTypeHeapRStrings , NULL );
6+ int ValidatePersistKeys (RedisAI_RunInfo * rinfo , AI_dict * tensorsNamesToInd ,
7+ AI_dict * persistTensorsNames ) {
218
229 {
23- AI_dictIterator * iter = AI_dictGetSafeIterator (rinfo -> dagTensorsContext );
24- AI_dictEntry * entry = AI_dictNext ( iter ) ;
25- while (entry ) {
26- RedisModuleString * key = (RedisModuleString * )AI_dictGetKey (entry );
27- size_t key_len ;
28- const char * key_str = RedisModule_StringPtrLen ( key , & key_len );
29- RedisModuleString * demangled_key = RedisModule_CreateString ( NULL , key_str , key_len - 4 );
30- int * instance = RedisModule_Alloc ( sizeof ( int ) );
31- * instance = 1 ;
32- AI_dictAdd ( mangled_tensors , ( void * ) demangled_key , ( void * ) instance );
33- RedisModule_FreeString ( NULL , demangled_key );
34- entry = AI_dictNext ( iter );
10+ AI_dictIterator * iter = AI_dictGetSafeIterator (persistTensorsNames );
11+ AI_dictEntry * persist_entry ;
12+ while (( persist_entry = AI_dictNext ( iter )) ) {
13+ RedisModuleString * persist_key = (RedisModuleString * )AI_dictGetKey (persist_entry );
14+ AI_dictEntry * entry = AI_dictFind ( tensorsNamesToInd , persist_key ) ;
15+ if (! entry ) {
16+ RAI_SetError ( rinfo -> err , RAI_EDAGRUN , "ERR PERSIST key cannot be found in DAG" );
17+ AI_dictReleaseIterator ( iter );
18+ return REDISMODULE_ERR ;
19+ }
20+ size_t index = ( size_t ) AI_dictGetVal ( entry );
21+ AI_dictReplace ( persistTensorsNames , ( void * ) persist_key , ( void * ) index );
3522 }
3623 AI_dictReleaseIterator (iter );
3724 }
25+ return REDISMODULE_OK ;
26+ }
27+
28+ int MapTensorsKeysToIndices (RedisAI_RunInfo * rinfo , AI_dict * tensorsNamesToInd ) {
3829
3930 for (long long i = 0 ; i < array_len (rinfo -> dagOps ); i ++ ) {
4031 RAI_DagOp * currentOp = rinfo -> dagOps [i ];
4132
42- RedisModuleString * * mangled_inkeys =
43- array_new (RedisModuleString * , array_len (currentOp -> inkeys ));
4433 for (long long j = 0 ; j < array_len (currentOp -> inkeys ); j ++ ) {
4534 RedisModuleString * key = currentOp -> inkeys [j ];
46- AI_dictEntry * entry = AI_dictFind (mangled_tensors , key );
35+ AI_dictEntry * entry = AI_dictFind (tensorsNamesToInd , key );
4736 if (!entry ) {
48- array_free (mangled_inkeys );
4937 RAI_SetError (rinfo -> err , RAI_EDAGRUN , "ERR INPUT key cannot be found in DAG" );
50- goto cleanup ;
38+ return REDISMODULE_ERR ;
5139 }
52- int * instance = AI_dictGetVal (entry );
53- char buf [16 ];
54- sprintf (buf , "%04d" , * instance );
55- RedisModuleString * mangled_key = RedisModule_CreateStringFromString (NULL , key );
56- RedisModule_StringAppendBuffer (NULL , mangled_key , buf , strlen (buf ));
57- mangled_inkeys = array_append (mangled_inkeys , mangled_key );
40+ size_t ind = (size_t )AI_dictGetVal (entry );
41+ currentOp -> inkeys_indices = array_append (currentOp -> inkeys_indices , ind );
5842 }
5943
60- RedisModuleString * * mangled_outkeys =
61- array_new (RedisModuleString * , array_len (currentOp -> outkeys ));
6244 for (long long j = 0 ; j < array_len (currentOp -> outkeys ); j ++ ) {
6345 RedisModuleString * key = currentOp -> outkeys [j ];
64- AI_dictEntry * entry = AI_dictFind (mangled_tensors , key );
65- int * instance = NULL ;
66- if (entry ) {
67- instance = AI_dictGetVal (entry );
68- * instance += 1 ;
69- } else {
70- instance = RedisModule_Alloc (sizeof (int ));
71- * instance = 1 ;
72- AI_dictAdd (mangled_tensors , (void * )key , (void * )instance );
73- }
74- char buf [16 ];
75- sprintf (buf , "%04d" , * instance );
76- RedisModuleString * mangled_key = RedisModule_CreateStringFromString (NULL , key );
77- RedisModule_StringAppendBuffer (NULL , mangled_key , buf , strlen (buf ));
78- mangled_outkeys = array_append (mangled_outkeys , mangled_key );
79- }
80-
81- if (currentOp -> inkeys ) {
82- for (size_t j = 0 ; j < array_len (currentOp -> inkeys ); j ++ ) {
83- RedisModule_FreeString (NULL , currentOp -> inkeys [j ]);
84- }
85- array_free (currentOp -> inkeys );
86- }
87-
88- if (currentOp -> outkeys ) {
89- for (size_t j = 0 ; j < array_len (currentOp -> outkeys ); j ++ ) {
90- RedisModule_FreeString (NULL , currentOp -> outkeys [j ]);
91- }
92- array_free (currentOp -> outkeys );
93- }
94-
95- currentOp -> inkeys = mangled_inkeys ;
96- currentOp -> outkeys = mangled_outkeys ;
97- }
46+ size_t ind = array_len (rinfo -> dagSharedTensors );
9847
99- AI_dict * mangled_persisted = AI_dictCreate (& AI_dictTypeHeapRStrings , NULL );
100- {
101- AI_dictIterator * iter = AI_dictGetSafeIterator (rinfo -> dagTensorsPersistedContext );
102- AI_dictEntry * entry = AI_dictNext (iter );
103- while (entry ) {
104- RedisModuleString * key = (RedisModuleString * )AI_dictGetKey (entry );
105- AI_dictEntry * mangled_entry = AI_dictFind (mangled_tensors , key );
106- if (!mangled_entry ) {
107- AI_dictRelease (mangled_persisted );
108- AI_dictReleaseIterator (iter );
109- RAI_SetError (rinfo -> err , RAI_EDAGRUN , "ERR PERSIST key cannot be found in DAG" );
110- goto cleanup ;
111- }
112- if (AI_dictFind (mangled_persisted , key ) != NULL ) {
113- AI_dictRelease (mangled_persisted );
114- AI_dictReleaseIterator (iter );
115- RAI_SetError (rinfo -> err , RAI_EDAGRUN , "ERR PERSIST keys must be unique" );
116- goto cleanup ;
48+ // Add a new empty place holder in the array for an output tensor.
49+ // If this is a TENSORSET op, the tensor is already realized.
50+ if (currentOp -> commandType == REDISAI_DAG_CMD_TENSORSET ) {
51+ RAI_Tensor * t = RAI_TensorGetShallowCopy (currentOp -> outTensor );
52+ rinfo -> dagSharedTensors = array_append (rinfo -> dagSharedTensors , t );
53+ } else {
54+ rinfo -> dagSharedTensors = array_append (rinfo -> dagSharedTensors , NULL );
11755 }
118- int * instance = AI_dictGetVal (mangled_entry );
119- char buf [16 ];
120- sprintf (buf , "%04d" , * instance );
121- RedisModuleString * mangled_key = RedisModule_CreateStringFromString (NULL , key );
122- RedisModule_StringAppendBuffer (NULL , mangled_key , buf , strlen (buf ));
123- AI_dictAdd (mangled_persisted , (void * )mangled_key , (void * )1 );
124- RedisModule_FreeString (NULL , mangled_key );
125- entry = AI_dictNext (iter );
56+ currentOp -> outkeys_indices = array_append (currentOp -> outkeys_indices , ind );
57+ AI_dictReplace (tensorsNamesToInd , (void * )key , (void * )ind );
12658 }
127- AI_dictReleaseIterator (iter );
12859 }
129-
130- AI_dictRelease (rinfo -> dagTensorsPersistedContext );
131- rinfo -> dagTensorsPersistedContext = mangled_persisted ;
132-
133- for (long long i = 0 ; i < array_len (rinfo -> dagOps ); i ++ ) {
134- if (rinfo -> dagOps [i ]-> devicestr == NULL ) {
135- rinfo -> dagOps [i ]-> devicestr = "CPU" ;
136- }
137- }
138- // Tensors from TENSORSET ops are ready to be put in DAG local context under their mangled
139- // names.
140- _DAG_SetTensorsInLocalContext (rinfo );
141- res = REDISMODULE_OK ;
142-
143- cleanup : {
144- AI_dictIterator * iter = AI_dictGetSafeIterator (mangled_tensors );
145- AI_dictEntry * entry = AI_dictNext (iter );
146- while (entry ) {
147- int * val = (int * )AI_dictGetVal (entry );
148- RedisModule_Free (val );
149- entry = AI_dictNext (iter );
150- }
151- AI_dictReleaseIterator (iter );
152- }
153- AI_dictRelease (mangled_tensors );
154- return res ;
60+ return REDISMODULE_OK ;
15561}
15662
15763// Add Shallow copies of the DAG run info to the devices' queues.
@@ -242,7 +148,7 @@ int RAI_DAGRun(RAI_DAGRunCtx *run_info, RAI_OnFinishCB DAGAsyncFinish, void *pri
242148 }
243149 // Make the inkeys and outkeys of the DAG ops unique, to ensure that the operations
244150 // will be execute in the right order.
245- if (MangleTensorsNames (rinfo ) != REDISMODULE_OK ) {
151+ if (MapTensorsKeysToIndices (rinfo , rinfo -> tensorsNamesToIndices ) != REDISMODULE_OK ) {
246152 RAI_SetError (err , rinfo -> err -> code , rinfo -> err -> detail );
247153 return REDISMODULE_ERR ;
248154 }
@@ -269,16 +175,13 @@ size_t RAI_DAGNumOutputs(RAI_OnFinishCtx *finish_ctx) {
269175const RAI_Tensor * RAI_DAGOutputTensor (RAI_OnFinishCtx * finish_ctx , size_t index ) {
270176 size_t tensor_get_op_ind = -1 ;
271177 RedisAI_RunInfo * rinfo = (RedisAI_RunInfo * )finish_ctx ;
178+
272179 for (size_t i = 0 ; i < rinfo -> dagOpCount ; i ++ ) {
273180 RAI_DagOp * op = rinfo -> dagOps [i ];
274181 if (op -> commandType == REDISAI_DAG_CMD_TENSORGET ) {
275182 tensor_get_op_ind ++ ;
276183 if (tensor_get_op_ind == index ) {
277- RAI_Tensor * t ;
278- int res = RAI_getTensorFromLocalContext (rinfo -> dagTensorsContext , op -> inkeys [0 ], & t ,
279- op -> err );
280- RedisModule_Assert (res == REDISMODULE_OK );
281- return t ;
184+ return Dag_GetTensorFromGlobalCtx (rinfo , op -> inkeys_indices [0 ]);
282185 }
283186 }
284187 }
0 commit comments