diff --git a/examples/example_c_api.c b/examples/example_c_api.c index e5dc91c..e378b71 100644 --- a/examples/example_c_api.c +++ b/examples/example_c_api.c @@ -3,7 +3,14 @@ int main() { - tensor_t *pt = new_tensor(dtypes.u8, 1, 1); - del_tensor(pt); + { + tensor_t *pt = new_tensor(dtypes.u8, 1, 1); + del_tensor(pt); + } + { + int dims[1] = {1}; + tensor_t *pt = new_tensor1(dtypes.u8, 1, dims); + del_tensor(pt); + } return 0; } diff --git a/include/tensor.h b/include/tensor.h index 4574615..13a1700 100644 --- a/include/tensor.h +++ b/include/tensor.h @@ -23,8 +23,13 @@ typedef struct tensor_s tensor_t; extern tensor_t *new_tensor(uint8_t /*! value_type */, int /*! rank */, ...); +extern tensor_t *new_tensor1(uint8_t /*! value_type */, int /*! rank */, + const int * /* dims */); + extern void del_tensor(const tensor_t * /*! p_tensor_t */); +extern void *tensor_data(tensor_t * /*! p_tensor_t */); + #ifdef __cplusplus } #endif diff --git a/src/tensor.cpp b/src/tensor.cpp index d552820..770f3ea 100644 --- a/src/tensor.cpp +++ b/src/tensor.cpp @@ -37,4 +37,16 @@ tensor_t *new_tensor(uint8_t value_type, int rank, ...) return new tensor_s(value_type, shape); } +tensor_t *new_tensor1(uint8_t value_type, int rank, const int *_dims) +{ + using raw_shape = raw_tensor::shape_type; + using dim_t = raw_shape::dimension_type; + std::vector dims(rank); + std::copy(_dims, _dims + rank, dims.begin()); + raw_shape shape(dims); + return new tensor_s(value_type, shape); +} + void del_tensor(const tensor_t *pt) { delete pt; } + +void *tensor_data(tensor_t *pt) { return pt->data(); } diff --git a/tests/test_tensor_c_api.cpp b/tests/test_tensor_c_api.cpp index 408c08b..a5f1bc5 100644 --- a/tests/test_tensor_c_api.cpp +++ b/tests/test_tensor_c_api.cpp @@ -17,8 +17,21 @@ TEST(c_api_test, test1) tensor_t *pt = new_tensor(dt, 2, 2, 3); del_tensor(pt); } + { + int dims[1] = {1}; + tensor_t *pt = new_tensor1(dt, 1, dims); + del_tensor(pt); + } + { + int dims[2] = {2, 3}; + tensor_t *pt = new_tensor1(dt, 2, dims); + del_tensor(pt); + } } +} +TEST(c_api_test, test2) +{ using scalar_encoding = raw_tensor::encoder_type; ASSERT_EQ(scalar_encoding::value(), dtypes.u8);