Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 3 additions & 5 deletions src/backends/onnxruntime.c
Original file line number Diff line number Diff line change
Expand Up @@ -91,10 +91,6 @@ OrtValue *RAI_OrtValueFromTensors(RAI_Tensor **ts, size_t count, RAI_Error *erro
return NULL;
}

if (count == 0) {
return NULL;
}

size_t batch_size = 0;
size_t batch_byte_size = 0;

Expand Down Expand Up @@ -328,6 +324,9 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
ort->ReleaseSessionOptions(session_options);
goto error;
}
ort->SetIntraOpNumThreads(session_options, (int)opts.backends_intra_op_parallelism);
ort->SetInterOpNumThreads(session_options, (int)opts.backends_inter_op_parallelism);

// TODO: we will need to propose a more dynamic way to request a specific provider,
// e.g. given the name, in ONNXRuntime
#if RAI_ONNXRUNTIME_USE_CUDA
Expand All @@ -344,7 +343,6 @@ RAI_Model *RAI_ModelCreateORT(RAI_Backend backend, const char *devicestr, RAI_Mo
#endif

OrtSession *session;

status = ort->CreateSessionFromArray(env, modeldef, modellen, session_options, &session);

ort->ReleaseSessionOptions(session_options);
Expand Down
38 changes: 38 additions & 0 deletions tests/flow/tests_onnx.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
import subprocess
import redis
from includes import *
from RLTest import Env

'''
python -m RLTest --test tests_onnx.py --module path/to/redisai.so
Expand Down Expand Up @@ -413,3 +414,40 @@ def tests_onnx_info(env):
ret = con.execute_command('AI.INFO')
env.assertEqual(8, len(ret))
env.assertEqual(b'ONNX version', ret[6])


def test_parallelism():
env = Env(moduleArgs='INTRA_OP_PARALLELISM 1 INTER_OP_PARALLELISM 1')
if not TEST_ONNX:
env.debugPrint("skipping {} since TEST_ONNX=0".format(sys._getframe().f_code.co_name), force=True)
return

con = env.getConnection()
test_data_path = os.path.join(os.path.dirname(__file__), 'test_data')
model_filename = os.path.join(test_data_path, 'mnist.onnx')
sample_filename = os.path.join(test_data_path, 'one.raw')
with open(model_filename, 'rb') as f:
model_pb = f.read()
with open(sample_filename, 'rb') as f:
sample_raw = f.read()

ret = con.execute_command('AI.MODELSET', 'm{1}', 'ONNX', DEVICE, 'BLOB', model_pb)
env.assertEqual(ret, b'OK')
con.execute_command('AI.TENSORSET', 'a{1}', 'FLOAT', 1, 1, 28, 28, 'BLOB', sample_raw)

con.execute_command('AI.MODELRUN', 'm{1}', 'INPUTS', 'a{1}', 'OUTPUTS', 'b{1}')
ensureSlaveSynced(con, env)
values = con.execute_command('AI.TENSORGET', 'b{1}', 'VALUES')
argmax = max(range(len(values)), key=lambda i: values[i])
env.assertEqual(argmax, 1)

load_time_config = {k.split(":")[0]: k.split(":")[1]
for k in con.execute_command("INFO MODULES").decode().split("#")[3].split()[1:]}
env.assertEqual(load_time_config["ai_inter_op_parallelism"], "1")
env.assertEqual(load_time_config["ai_intra_op_parallelism"], "1")

env = Env(moduleArgs='INTRA_OP_PARALLELISM 2 INTER_OP_PARALLELISM 2')
load_time_config = {k.split(":")[0]: k.split(":")[1]
for k in con.execute_command("INFO MODULES").decode().split("#")[3].split()[1:]}
env.assertEqual(load_time_config["ai_inter_op_parallelism"], "2")
env.assertEqual(load_time_config["ai_intra_op_parallelism"], "2")