Skip to content

Commit

Permalink
Slightly simplify json_dtype
Browse files Browse the repository at this point in the history
Reduce redundant code by extracting dtypes into a table.
  • Loading branch information
zeux committed Apr 24, 2024
1 parent d70d030 commit e90082e
Showing 1 changed file with 23 additions and 30 deletions.
53 changes: 23 additions & 30 deletions src/tensors.c
Original file line number Diff line number Diff line change
Expand Up @@ -67,38 +67,31 @@ static char* json_array(char* json, long long* res, int size) {
}

static int json_dtype(const char* str, enum DType* dtype, int* dsize) {
if (strcmp(str, "F32") == 0) {
*dtype = dt_f32;
*dsize = 4;
} else if (strcmp(str, "F16") == 0) {
*dtype = dt_f16;
*dsize = 2;
} else if (strcmp(str, "BF16") == 0) {
*dtype = dt_bf16;
*dsize = 2;
} else if (strcmp(str, "F8_E5M2") == 0) {
*dtype = dt_f8e5m2;
*dsize = 1;
} else if (strcmp(str, "F8_E4M3") == 0) {
*dtype = dt_f8e4m3;
*dsize = 1;
} else if (strcmp(str, "I32") == 0) {
*dtype = dt_i32;
*dsize = 4;
} else if (strcmp(str, "I16") == 0) {
*dtype = dt_i16;
*dsize = 2;
} else if (strcmp(str, "I8") == 0) {
*dtype = dt_i8;
*dsize = 1;
} else if (strcmp(str, "U8") == 0) {
*dtype = dt_u8;
*dsize = 1;
} else {
return -1;
static const struct {
const char* str;
enum DType dtype;
int dsize;
} dtypes[] = {
{"F32", dt_f32, 4},
{"F16", dt_f16, 2},
{"BF16", dt_bf16, 2},
{"F8_E5M2", dt_f8e5m2, 1},
{"F8_E4M3", dt_f8e4m3, 1},
{"I32", dt_i32, 4},
{"I16", dt_i16, 2},
{"I8", dt_i8, 1},
{"U8", dt_u8, 1},
};

for (size_t i = 0; i < sizeof(dtypes) / sizeof(dtypes[0]); ++i) {
if (strcmp(str, dtypes[i].str) == 0) {
*dtype = dtypes[i].dtype;
*dsize = dtypes[i].dsize;
return 0;
}
}

return 0;
return -1;
}

static bool validate_shape(int dsize, int shape[4], size_t length) {
Expand Down

0 comments on commit e90082e

Please sign in to comment.