diff --git a/src/tensors.c b/src/tensors.c index b2e54ac..3b97bc9 100644 --- a/src/tensors.c +++ b/src/tensors.c @@ -67,38 +67,31 @@ static char* json_array(char* json, long long* res, int size) { } static int json_dtype(const char* str, enum DType* dtype, int* dsize) { - if (strcmp(str, "F32") == 0) { - *dtype = dt_f32; - *dsize = 4; - } else if (strcmp(str, "F16") == 0) { - *dtype = dt_f16; - *dsize = 2; - } else if (strcmp(str, "BF16") == 0) { - *dtype = dt_bf16; - *dsize = 2; - } else if (strcmp(str, "F8_E5M2") == 0) { - *dtype = dt_f8e5m2; - *dsize = 1; - } else if (strcmp(str, "F8_E4M3") == 0) { - *dtype = dt_f8e4m3; - *dsize = 1; - } else if (strcmp(str, "I32") == 0) { - *dtype = dt_i32; - *dsize = 4; - } else if (strcmp(str, "I16") == 0) { - *dtype = dt_i16; - *dsize = 2; - } else if (strcmp(str, "I8") == 0) { - *dtype = dt_i8; - *dsize = 1; - } else if (strcmp(str, "U8") == 0) { - *dtype = dt_u8; - *dsize = 1; - } else { - return -1; + static const struct { + const char* str; + enum DType dtype; + int dsize; + } dtypes[] = { + {"F32", dt_f32, 4}, + {"F16", dt_f16, 2}, + {"BF16", dt_bf16, 2}, + {"F8_E5M2", dt_f8e5m2, 1}, + {"F8_E4M3", dt_f8e4m3, 1}, + {"I32", dt_i32, 4}, + {"I16", dt_i16, 2}, + {"I8", dt_i8, 1}, + {"U8", dt_u8, 1}, + }; + + for (size_t i = 0; i < sizeof(dtypes) / sizeof(dtypes[0]); ++i) { + if (strcmp(str, dtypes[i].str) == 0) { + *dtype = dtypes[i].dtype; + *dsize = dtypes[i].dsize; + return 0; + } } - return 0; + return -1; } static bool validate_shape(int dsize, int shape[4], size_t length) {