Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion src/DAG/dag_builder.c
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ int RAI_DAGAddRunOp(RAI_DAGRunCtx *run_info, RAI_DAGRunOp *DAGop, RAI_Error *err
return REDISMODULE_OK;
}

int RAI_DAGAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err) {
int RAI_DAGAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name) {

RedisAI_RunInfo *rinfo = (RedisAI_RunInfo *)run_info;
RAI_DagOp *op;
Expand Down
2 changes: 1 addition & 1 deletion src/DAG/dag_builder.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ int RAI_DAGAddTensorSet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Tensor
* @param runInfo The DAG to append this op into.
* @param tensor The tensor to set.
*/
int RAI_DAGAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name, RAI_Error *err);
int RAI_DAGAddTensorGet(RAI_DAGRunCtx *run_info, const char *t_name);

/**
* @brief Add ops to a DAG from string (according to the command syntax). In case of a valid
Expand Down
7 changes: 2 additions & 5 deletions src/command_parser.c
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@ static int _ModelRunCommand_ParseArgs(RedisModuleCtx *ctx, int argc, RedisModule
return REDISMODULE_ERR;
}
size_t argpos = 1;
RedisModuleKey *modelKey;
const int status =
RAI_GetModelFromKeyspace(ctx, argv[argpos], &modelKey, model, REDISMODULE_READ, error);
const int status = RAI_GetModelFromKeyspace(ctx, argv[argpos], model, REDISMODULE_READ, error);
if (status == REDISMODULE_ERR) {
return REDISMODULE_ERR;
}
Expand Down Expand Up @@ -172,9 +170,8 @@ static int _ScriptRunCommand_ParseArgs(RedisModuleCtx *ctx, RedisModuleString **
return REDISMODULE_ERR;
}
size_t argpos = 1;
RedisModuleKey *scriptKey;
const int status =
RAI_GetScriptFromKeyspace(ctx, argv[argpos], &scriptKey, script, REDISMODULE_READ, error);
RAI_GetScriptFromKeyspace(ctx, argv[argpos], script, REDISMODULE_READ, error);
if (status == REDISMODULE_ERR) {
return REDISMODULE_ERR;
}
Expand Down
18 changes: 9 additions & 9 deletions src/model.c
Original file line number Diff line number Diff line change
Expand Up @@ -25,21 +25,21 @@
/* Return REDISMODULE_ERR if there was an error getting the Model.
* Return REDISMODULE_OK if the model value stored at key was correctly
* returned and available at *model variable. */
int RAI_GetModelFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RedisModuleKey **key,
RAI_Model **model, int mode, RAI_Error *err) {
*key = RedisModule_OpenKey(ctx, keyName, mode);
if (RedisModule_KeyType(*key) == REDISMODULE_KEYTYPE_EMPTY) {
RedisModule_CloseKey(*key);
int RAI_GetModelFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RAI_Model **model,
int mode, RAI_Error *err) {
RedisModuleKey *key = RedisModule_OpenKey(ctx, keyName, mode);
if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) {
RedisModule_CloseKey(key);
RAI_SetError(err, RAI_EMODELRUN, "ERR model key is empty");
return REDISMODULE_ERR;
}
if (RedisModule_ModuleTypeGetType(*key) != RedisAI_ModelType) {
RedisModule_CloseKey(*key);
if (RedisModule_ModuleTypeGetType(key) != RedisAI_ModelType) {
RedisModule_CloseKey(key);
RAI_SetError(err, RAI_EMODELRUN, REDISMODULE_ERRORMSG_WRONGTYPE);
return REDISMODULE_ERR;
}
*model = RedisModule_ModuleTypeGetValue(*key);
RedisModule_CloseKey(*key);
*model = RedisModule_ModuleTypeGetValue(key);
RedisModule_CloseKey(key);
return REDISMODULE_OK;
}

Expand Down
6 changes: 2 additions & 4 deletions src/model.h
Original file line number Diff line number Diff line change
Expand Up @@ -112,17 +112,15 @@ int RAI_ModelSerialize(RAI_Model *model, char **buffer, size_t *len, RAI_Error *
*
* @param ctx Context in which Redis modules operate
* @param keyName key name
* @param key models's key handle. On success it contains an handle representing
* a Redis key with the requested access mode
* @param model destination model structure
* @param mode key access mode
* @param error contains the error in case of problem with retrival
* @return REDISMODULE_OK if the model value stored at key was correctly
* returned and available at *model variable, or REDISMODULE_ERR if there was
* an error getting the Model
*/
int RAI_GetModelFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RedisModuleKey **key,
RAI_Model **model, int mode, RAI_Error *err);
int RAI_GetModelFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RAI_Model **model,
int mode, RAI_Error *err);

/**
* When a module command is called in order to obtain the position of
Expand Down
18 changes: 7 additions & 11 deletions src/redisai.c
Original file line number Diff line number Diff line change
Expand Up @@ -420,8 +420,7 @@ int RedisAI_ModelGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,

RAI_Error err = {0};
RAI_Model *mto;
RedisModuleKey *key;
const int status = RAI_GetModelFromKeyspace(ctx, argv[1], &key, &mto, REDISMODULE_READ, &err);
const int status = RAI_GetModelFromKeyspace(ctx, argv[1], &mto, REDISMODULE_READ, &err);
if (status == REDISMODULE_ERR) {
RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(&err));
RAI_ClearError(&err);
Expand Down Expand Up @@ -521,17 +520,16 @@ int RedisAI_ModelDel_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv,
return RedisModule_WrongArity(ctx);

RAI_Model *mto;
RedisModuleKey *key;
RAI_Error err = {0};
const int status = RAI_GetModelFromKeyspace(ctx, argv[1], &key, &mto,
REDISMODULE_READ | REDISMODULE_WRITE, &err);
const int status =
RAI_GetModelFromKeyspace(ctx, argv[1], &mto, REDISMODULE_READ | REDISMODULE_WRITE, &err);
if (status == REDISMODULE_ERR) {
RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(&err));
RAI_ClearError(&err);
return REDISMODULE_ERR;
}

key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_WRITE);
RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_WRITE);
RedisModule_DeleteKey(key);
RedisModule_CloseKey(key);
RedisModule_ReplicateVerbatim(ctx);
Expand Down Expand Up @@ -605,9 +603,8 @@ int RedisAI_ScriptGet_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
return RedisModule_WrongArity(ctx);

RAI_Script *sto;
RedisModuleKey *key;
RAI_Error err = {0};
const int status = RAI_GetScriptFromKeyspace(ctx, argv[1], &key, &sto, REDISMODULE_READ, &err);
const int status = RAI_GetScriptFromKeyspace(ctx, argv[1], &sto, REDISMODULE_READ, &err);
if (status == REDISMODULE_ERR) {
RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(&err));
RAI_ClearError(&err);
Expand Down Expand Up @@ -656,15 +653,14 @@ int RedisAI_ScriptDel_RedisCommand(RedisModuleCtx *ctx, RedisModuleString **argv
return RedisModule_WrongArity(ctx);

RAI_Script *sto;
RedisModuleKey *key;
RAI_Error err = {0};
const int status = RAI_GetScriptFromKeyspace(ctx, argv[1], &key, &sto, REDISMODULE_WRITE, &err);
const int status = RAI_GetScriptFromKeyspace(ctx, argv[1], &sto, REDISMODULE_WRITE, &err);
if (status == REDISMODULE_ERR) {
RedisModule_ReplyWithError(ctx, RAI_GetErrorOneLine(&err));
RAI_ClearError(&err);
return REDISMODULE_ERR;
}
key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_WRITE);
RedisModuleKey *key = RedisModule_OpenKey(ctx, argv[1], REDISMODULE_WRITE);
RedisModule_DeleteKey(key);
RedisModule_CloseKey(key);

Expand Down
4 changes: 1 addition & 3 deletions src/redisai.h
Original file line number Diff line number Diff line change
Expand Up @@ -107,7 +107,6 @@ REDISAI_API void MODULE_API_FUNC(RedisAI_ModelFree)(RAI_Model *model, RAI_Error
REDISAI_API RAI_ModelRunCtx *MODULE_API_FUNC(RedisAI_ModelRunCtxCreate)(RAI_Model *model);
REDISAI_API int MODULE_API_FUNC(RedisAI_GetModelFromKeyspace)(RedisModuleCtx *ctx,
RedisModuleString *keyName,
RedisModuleKey **key,
RAI_Model **model, int mode,
RAI_Error *err);
REDISAI_API int MODULE_API_FUNC(RedisAI_ModelRunCtxAddInput)(RAI_ModelRunCtx *mctx,
Expand Down Expand Up @@ -136,7 +135,6 @@ REDISAI_API RAI_Script *MODULE_API_FUNC(RedisAI_ScriptCreate)(char *devicestr, c
RAI_Error *err);
REDISAI_API int MODULE_API_FUNC(RedisAI_GetScriptFromKeyspace)(RedisModuleCtx *ctx,
RedisModuleString *keyName,
RedisModuleKey **key,
RAI_Script **script, int mode,
RAI_Error *err);
REDISAI_API void MODULE_API_FUNC(RedisAI_ScriptFree)(RAI_Script *script, RAI_Error *err);
Expand Down Expand Up @@ -175,7 +173,7 @@ REDISAI_API int MODULE_API_FUNC(RedisAI_DAGLoadTensor)(RAI_DAGRunCtx *run_info,
REDISAI_API int MODULE_API_FUNC(RedisAI_DAGAddTensorSet)(RAI_DAGRunCtx *run_info,
const char *t_name, RAI_Tensor *tensor);
REDISAI_API int MODULE_API_FUNC(RedisAI_DAGAddTensorGet)(RAI_DAGRunCtx *run_info,
const char *t_name, RAI_Error *err);
const char *t_name);
REDISAI_API int MODULE_API_FUNC(RedisAI_DAGAddOpsFromString)(RAI_DAGRunCtx *run_info,
const char *dag, RAI_Error *err);
REDISAI_API size_t MODULE_API_FUNC(RedisAI_DAGNumOps)(RAI_DAGRunCtx *run_info);
Expand Down
18 changes: 9 additions & 9 deletions src/script.c
Original file line number Diff line number Diff line change
Expand Up @@ -154,21 +154,21 @@ RAI_Script *RAI_ScriptGetShallowCopy(RAI_Script *script) {
/* Return REDISMODULE_ERR if there was an error getting the Script.
* Return REDISMODULE_OK if the model value stored at key was correctly
* returned and available at *model variable. */
int RAI_GetScriptFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RedisModuleKey **key,
RAI_Script **script, int mode, RAI_Error *err) {
*key = RedisModule_OpenKey(ctx, keyName, mode);
if (RedisModule_KeyType(*key) == REDISMODULE_KEYTYPE_EMPTY) {
RedisModule_CloseKey(*key);
int RAI_GetScriptFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RAI_Script **script,
int mode, RAI_Error *err) {
RedisModuleKey *key = RedisModule_OpenKey(ctx, keyName, mode);
if (RedisModule_KeyType(key) == REDISMODULE_KEYTYPE_EMPTY) {
RedisModule_CloseKey(key);
RAI_SetError(err, RAI_ESCRIPTRUN, "ERR script key is empty");
return REDISMODULE_ERR;
}
if (RedisModule_ModuleTypeGetType(*key) != RedisAI_ScriptType) {
RedisModule_CloseKey(*key);
if (RedisModule_ModuleTypeGetType(key) != RedisAI_ScriptType) {
RedisModule_CloseKey(key);
RAI_SetError(err, RAI_ESCRIPTRUN, REDISMODULE_ERRORMSG_WRONGTYPE);
return REDISMODULE_ERR;
}
*script = RedisModule_ModuleTypeGetValue(*key);
RedisModule_CloseKey(*key);
*script = RedisModule_ModuleTypeGetValue(key);
RedisModule_CloseKey(key);
return REDISMODULE_OK;
}

Expand Down
6 changes: 2 additions & 4 deletions src/script.h
Original file line number Diff line number Diff line change
Expand Up @@ -153,16 +153,14 @@ RAI_Script *RAI_ScriptGetShallowCopy(RAI_Script *script);
*
* @param ctx Context in which Redis modules operate
* @param keyName key name
* @param key script's key handle. On success it contains an handle representing
* a Redis key with the requested access mode
* @param script destination script structure
* @param mode key access mode
* @return REDISMODULE_OK if the script value stored at key was correctly
* returned and available at *script variable, or REDISMODULE_ERR if there was
* an error getting the Script
*/
int RAI_GetScriptFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RedisModuleKey **key,
RAI_Script **script, int mode, RAI_Error *err);
int RAI_GetScriptFromKeyspace(RedisModuleCtx *ctx, RedisModuleString *keyName, RAI_Script **script,
int mode, RAI_Error *err);

/**
* When a module command is called in order to obtain the position of
Expand Down
12 changes: 6 additions & 6 deletions tests/module/DAG_utils.c
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ int testKeysMismatchError(RedisModuleCtx *ctx) {
RAI_Tensor *t = (RAI_Tensor *)_getFromKeySpace(ctx, "a{1}");
RedisAI_DAGLoadTensor(run_info, "input", t);

RedisAI_DAGAddTensorGet(run_info, "non existing tensor", err);
RedisAI_DAGAddTensorGet(run_info, "non existing tensor");
int status = RedisAI_DAGRun(run_info, _DAGFinishFuncError, NULL, err);
if(!_assertError(err, status, "ERR INPUT key cannot be found in DAG")) {
goto cleanup;
Expand Down Expand Up @@ -183,7 +183,7 @@ int testBuildDAGFromString(RedisModuleCtx *ctx) {
goto cleanup;
}
RedisModule_Assert(RedisAI_DAGNumOps(run_info) == 3);
RedisAI_DAGAddTensorGet(run_info, "input1", results.error);
RedisAI_DAGAddTensorGet(run_info, "input1");
RedisModule_Assert(RedisAI_DAGNumOps(run_info) == 4);

pthread_mutex_lock(&global_lock);
Expand Down Expand Up @@ -227,7 +227,7 @@ int testSimpleDAGRun(RedisModuleCtx *ctx) {
goto cleanup;
}

RedisAI_DAGAddTensorGet(run_info, "output", results.error);
RedisAI_DAGAddTensorGet(run_info, "output");
pthread_mutex_lock(&global_lock);
if (RedisAI_DAGRun(run_info, _DAGFinishFunc, &results, results.error) != REDISMODULE_OK) {
pthread_mutex_unlock(&global_lock);
Expand Down Expand Up @@ -280,7 +280,7 @@ int testSimpleDAGRun2(RedisModuleCtx *ctx) {
goto cleanup;
}

RedisAI_DAGAddTensorGet(run_info, "output", results.error);
RedisAI_DAGAddTensorGet(run_info, "output");
pthread_mutex_lock(&global_lock);
if (RedisAI_DAGRun(run_info, _DAGFinishFunc, &results, results.error) != REDISMODULE_OK) {
pthread_mutex_unlock(&global_lock);
Expand Down Expand Up @@ -330,7 +330,7 @@ int testSimpleDAGRun2Error(RedisModuleCtx *ctx) {
goto cleanup;
}

RedisAI_DAGAddTensorGet(run_info, "output", results.error);
RedisAI_DAGAddTensorGet(run_info, "output");
pthread_mutex_lock(&global_lock);
if (RedisAI_DAGRun(run_info, _DAGFinishFunc, &results, results.error) != REDISMODULE_OK) {
pthread_mutex_unlock(&global_lock);
Expand Down Expand Up @@ -392,7 +392,7 @@ int testDAGResnet(RedisModuleCtx *ctx) {
RedisAI_DAGRunOpAddOutput(script_op, "output:{{1}}");
RedisAI_DAGAddRunOp(run_info, script_op, results.error);

RedisAI_DAGAddTensorGet(run_info, "output:{{1}}", results.error);
RedisAI_DAGAddTensorGet(run_info, "output:{{1}}");

pthread_mutex_lock(&global_lock);
if (RedisAI_DAGRun(run_info, _DAGFinishFunc, &results, results.error) != REDISMODULE_OK) {
Expand Down