From cfb51534da8955e3e5918795a331cc068b9187d3 Mon Sep 17 00:00:00 2001 From: Giovanni De Toni Date: Sun, 9 Jun 2019 12:30:03 +0200 Subject: [PATCH 1/2] Remove CHECK_TYPE_HISTO from TBOutputFormat. --- src/shogun/io/TBOutputFormat.cpp | 53 +++++------------------- tests/unit/io/TBOutputFormat_unittest.cc | 21 +++++----- 2 files changed, 22 insertions(+), 52 deletions(-) diff --git a/src/shogun/io/TBOutputFormat.cpp b/src/shogun/io/TBOutputFormat.cpp index efed1c36b8c..df996f986ad 100644 --- a/src/shogun/io/TBOutputFormat.cpp +++ b/src/shogun/io/TBOutputFormat.cpp @@ -37,28 +37,12 @@ #include #include -#include #include #include #include -#include using namespace shogun; -#define CHECK_TYPE_HISTO(type) \ - else if ( \ - value.first->get_any().type_info().hash_code() == \ - typeid(type).hash_code()) \ - { \ - tensorflow::histogram::Histogram h; \ - tensorflow::HistogramProto* hp = new tensorflow::HistogramProto(); \ - auto v = any_cast(value.first->get_any()); \ - for (auto value_v : v) \ - h.Add(value_v); \ - h.EncodeToProto(hp, true); \ - summaryValue->set_allocated_histo(hp); \ - } - TBOutputFormat::TBOutputFormat(){}; TBOutputFormat::~TBOutputFormat(){}; @@ -98,33 +82,18 @@ tensorflow::Event TBOutputFormat::convert_vector( summaryValue->set_tag(value.first->get("name")); summaryValue->set_node_name(node_name); - if (value.first->get_any().type_info().hash_code() == - typeid(std::vector).hash_code()) - { - tensorflow::histogram::Histogram h; - tensorflow::HistogramProto* hp = new tensorflow::HistogramProto(); - auto v = any_cast>(value.first->get_any()); - for (auto value_v : v) + tensorflow::histogram::Histogram h; + tensorflow::HistogramProto* hp = new tensorflow::HistogramProto(); + + auto write_summary = [&h](auto val) { + for (auto value_v : val) h.Add(value_v); - h.EncodeToProto(hp, true); - summaryValue->set_allocated_histo(hp); - } - CHECK_TYPE_HISTO(std::vector) - CHECK_TYPE_HISTO(std::vector) - CHECK_TYPE_HISTO(std::vector) - CHECK_TYPE_HISTO(std::vector) - CHECK_TYPE_HISTO(std::vector) - CHECK_TYPE_HISTO(std::vector) - CHECK_TYPE_HISTO(std::vector) - CHECK_TYPE_HISTO(std::vector) - CHECK_TYPE_HISTO(std::vector) - CHECK_TYPE_HISTO(std::vector) - CHECK_TYPE_HISTO(std::vector) - else - { - SG_ERROR( - "Unsupported type %s", value.first->get_any().type_info().name()); - } + }; + + sg_any_dispatch(value.first->get_any(), sg_all_typemap, None{}, write_summary); + + h.EncodeToProto(hp, true); + summaryValue->set_allocated_histo(hp); return e; } diff --git a/tests/unit/io/TBOutputFormat_unittest.cc b/tests/unit/io/TBOutputFormat_unittest.cc index c5abd5c5da0..5388b8036f1 100644 --- a/tests/unit/io/TBOutputFormat_unittest.cc +++ b/tests/unit/io/TBOutputFormat_unittest.cc @@ -46,6 +46,7 @@ #include #include #include +#include using namespace shogun; @@ -90,7 +91,7 @@ void test_case_scalar_error(T value_val) } template -void test_case_vector(std::vector v) +void test_case_vector(SGVector v) { tensorflow::Event event_ex; auto summary = event_ex.mutable_summary(); @@ -109,7 +110,7 @@ void test_case_vector(std::vector v) time_point timestamp; Some emitted_value = Some::from_raw( - new ObservedValueTemplated>( + new ObservedValueTemplated>( 1, "test", "test description", v)); std::string node_name = "node"; @@ -125,13 +126,13 @@ void test_case_vector(std::vector v) } template -void test_case_vector_error(std::vector v) +void test_case_vector_error(SGVector v) { TBOutputFormat tmp; time_point timestamp; Some emitted_value = Some::from_raw( - new ObservedValueTemplated>( + new ObservedValueTemplated>( 1, "test", "test_description", v)); std::string node_name = "node"; @@ -159,17 +160,17 @@ TEST(TBOutputFormatTest, fail_convert_scalar) TYPED_TEST(TBOutputFormatTest, convert_all_types_histo) { - std::vector v; - v.push_back((TypeParam)1); - v.push_back((TypeParam)2); + SGVector v(2); + v[0] = ((TypeParam)1); + v[1] = ((TypeParam)2); test_case_vector(v); }; TEST(TBOutputFormat, fail_convert_histo) { - std::vector v; - v.push_back((complex128_t)1); - v.push_back((complex128_t)2); + SGVector v(2); + v[0] = ((complex128_t)1); + v[1] = ((complex128_t)2); test_case_vector_error(v); } From 0b7c5aed895b79e63f17d61902a12c357e920f70 Mon Sep 17 00:00:00 2001 From: Giovanni De Toni Date: Sun, 9 Jun 2019 15:27:05 +0200 Subject: [PATCH 2/2] Add missing types to the sg_any_dispatch method. This is done in order to fully support TBOutputFormat features. --- src/shogun/lib/type_case.h | 59 ++++++++++++++++++++++++++++---------- 1 file changed, 44 insertions(+), 15 deletions(-) diff --git a/src/shogun/lib/type_case.h b/src/shogun/lib/type_case.h index 4c67c0f76ee..298477ac37d 100644 --- a/src/shogun/lib/type_case.h +++ b/src/shogun/lib/type_case.h @@ -21,10 +21,11 @@ namespace shogun { typedef Types< bool, char, int8_t, uint8_t, int16_t, uint16_t, int32_t, uint32_t, - int64_t, uint64_t, float32_t, float64_t, floatmax_t, SGVector, - SGVector, SGVector, SGVector, - SGVector, SGMatrix, SGMatrix, - SGMatrix, SGMatrix, SGMatrix> + int64_t, uint64_t, float32_t, float64_t, floatmax_t, SGVector, + SGVector, SGVector, SGVector, SGVector, + SGVector, SGVector, SGVector, SGVector, + SGVector, SGVector, SGVector, SGMatrix, + SGMatrix, SGMatrix, SGMatrix, SGMatrix> SG_TYPES; enum class TYPE @@ -44,17 +45,24 @@ namespace shogun T_FLOATMAX = 13, T_SGOBJECT = 14, T_COMPLEX128 = 15, - T_SGVECTOR_FLOAT32 = 16, - T_SGVECTOR_FLOAT64 = 17, - T_SGVECTOR_FLOATMAX = 18, - T_SGVECTOR_INT32 = 19, - T_SGVECTOR_INT64 = 20, - T_SGMATRIX_FLOAT32 = 21, - T_SGMATRIX_FLOAT64 = 22, - T_SGMATRIX_FLOATMAX = 23, - T_SGMATRIX_INT32 = 24, - T_SGMATRIX_INT64 = 25, - T_UNDEFINED = 26 + T_SGVECTOR_CHAR = 16, + T_SGVECTOR_FLOAT32 = 17, + T_SGVECTOR_FLOAT64 = 18, + T_SGVECTOR_FLOATMAX = 19, + T_SGVECTOR_UINT8 = 20, + T_SGVECTOR_INT8 = 21, + T_SGVECTOR_INT16 = 22, + T_SGVECTOR_UINT16 = 23, + T_SGVECTOR_INT32 = 24, + T_SGVECTOR_UINT32 = 25, + T_SGVECTOR_INT64 = 26, + T_SGVECTOR_UINT64 = 27, + T_SGMATRIX_FLOAT32 = 28, + T_SGMATRIX_FLOAT64 = 29, + T_SGMATRIX_FLOATMAX = 30, + T_SGMATRIX_INT32 = 31, + T_SGMATRIX_INT64 = 32, + T_UNDEFINED = 33 }; typedef std::unordered_map typemap; namespace type_internal @@ -128,11 +136,18 @@ namespace shogun SG_ADD_PRIMITIVE_TYPE(float64_t, TYPE::T_FLOAT64) SG_ADD_PRIMITIVE_TYPE(floatmax_t, TYPE::T_FLOATMAX) SG_ADD_PRIMITIVE_TYPE(complex128_t, TYPE::T_COMPLEX128) + SG_ADD_SGVECTOR_TYPE(SGVector, TYPE::T_SGVECTOR_CHAR) SG_ADD_SGVECTOR_TYPE(SGVector, TYPE::T_SGVECTOR_FLOAT32) SG_ADD_SGVECTOR_TYPE(SGVector, TYPE::T_SGVECTOR_FLOAT64) SG_ADD_SGVECTOR_TYPE(SGVector, TYPE::T_SGVECTOR_FLOATMAX) + SG_ADD_SGVECTOR_TYPE(SGVector, TYPE::T_SGVECTOR_INT8) + SG_ADD_SGVECTOR_TYPE(SGVector, TYPE::T_SGVECTOR_INT16) + SG_ADD_SGVECTOR_TYPE(SGVector, TYPE::T_SGVECTOR_UINT16) + SG_ADD_SGVECTOR_TYPE(SGVector, TYPE::T_SGVECTOR_UINT8) SG_ADD_SGVECTOR_TYPE(SGVector, TYPE::T_SGVECTOR_INT32) SG_ADD_SGVECTOR_TYPE(SGVector, TYPE::T_SGVECTOR_INT64) + SG_ADD_SGVECTOR_TYPE(SGVector, TYPE::T_SGVECTOR_UINT32) + SG_ADD_SGVECTOR_TYPE(SGVector, TYPE::T_SGVECTOR_UINT64) SG_ADD_SGMATRIX_TYPE(SGMatrix, TYPE::T_SGMATRIX_FLOAT32) SG_ADD_SGMATRIX_TYPE(SGMatrix, TYPE::T_SGMATRIX_FLOAT64) SG_ADD_SGMATRIX_TYPE(SGMatrix, TYPE::T_SGMATRIX_FLOATMAX) @@ -416,11 +431,18 @@ static const typemap sg_all_typemap = { ADD_TYPE_TO_MAP(float64_t , TYPE::T_FLOAT64) ADD_TYPE_TO_MAP(floatmax_t , TYPE::T_FLOATMAX) ADD_TYPE_TO_MAP(complex128_t, TYPE::T_COMPLEX128) + ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_CHAR) ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_FLOAT32) ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_FLOAT64) ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_FLOATMAX) + ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_INT8) + ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_INT16) + ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_UINT8) + ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_UINT16) ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_INT32) ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_INT64) + ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_UINT32) + ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_UINT64) ADD_TYPE_TO_MAP(SGMatrix, TYPE::T_SGMATRIX_FLOAT32) ADD_TYPE_TO_MAP(SGMatrix, TYPE::T_SGMATRIX_FLOAT64) ADD_TYPE_TO_MAP(SGMatrix, TYPE::T_SGMATRIX_FLOATMAX) @@ -428,11 +450,18 @@ static const typemap sg_all_typemap = { ADD_TYPE_TO_MAP(SGMatrix, TYPE::T_SGMATRIX_INT64) }; static const typemap sg_vector_typemap = { + ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_CHAR) ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_FLOAT32) ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_FLOAT64) ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_FLOATMAX) + ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_INT8) + ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_INT16) + ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_UINT8) + ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_UINT16) ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_INT32) ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_INT64) + ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_UINT32) + ADD_TYPE_TO_MAP(SGVector, TYPE::T_SGVECTOR_UINT64) }; static const typemap sg_matrix_typemap = { ADD_TYPE_TO_MAP(SGMatrix, TYPE::T_SGMATRIX_FLOAT32)