Skip to content

Commit

Permalink
fix monitor controller impl & unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
yczhang-nv committed Sep 27, 2024
1 parent 547f276 commit 9847952
Show file tree
Hide file tree
Showing 3 changed files with 84 additions and 50 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <rxcpp/rx.hpp> // for trace_activity, decay_t, from

#include <cstddef>
#include <memory>
#include <optional>
#include <string>
#include <type_traits>
Expand All @@ -43,6 +44,7 @@ namespace morpheus {
* @file
*/


/**
* @brief
*/
Expand Down Expand Up @@ -77,6 +79,15 @@ MonitorController<MessageT>::MonitorController(const std::string& description,
m_determine_count_fn(determine_count_fn),
m_count(0)
{
if (!m_determine_count_fn)
{
m_determine_count_fn = auto_count_fn();
if (!m_determine_count_fn)
{
throw std::runtime_error("No count function provided and no default count function available");
}
}

m_progress_bar.set_option(indicators::option::BarWidth{50});
m_progress_bar.set_option(indicators::option::Start{"["});
m_progress_bar.set_option(indicators::option::Fill(""));
Expand All @@ -93,45 +104,37 @@ MonitorController<MessageT>::MonitorController(const std::string& description,
template <typename MessageT>
MessageT MonitorController<MessageT>::progress_sink(MessageT msg)
{
if (m_determine_count_fn == std::nullopt)
{
m_determine_count_fn = auto_count_fn(msg);
}

m_count += (*m_determine_count_fn)(msg);
m_progress_bar.set_progress(m_count);

return msg;
}

template <typename T>
struct is_vector : std::false_type
{};

template <typename T, typename U>
struct is_vector<std::vector<T, U>> : std::true_type
{};

template <typename MessageT>
auto MonitorController<MessageT>::auto_count_fn() -> std::optional<std::function<size_t(MessageT)>>
{
if constexpr (std::is_same_v<MessageT, std::shared_ptr<cudf::table>>)
if constexpr (std::is_same_v<MessageT, std::shared_ptr<MessageMeta>>)
{
return [](MessageT msg) {
return msg->num_rows();
return [](std::shared_ptr<MessageMeta> msg) {
return msg->count();
};
}

if constexpr (std::is_same_v<MessageT, std::shared_ptr<MessageMeta>>)
if constexpr (std::is_same_v<MessageT, std::vector<std::shared_ptr<MessageMeta>>>)
{
return [](MessageT msg) {
return msg->count();
return [](std::vector<std::shared_ptr<MessageMeta>> msg) {
auto item_count_fn = [](std::shared_ptr<MessageMeta> msg) {
return msg->count();
};
return std::accumulate(msg.begin(), msg.end(), 0, [&](int sum, const auto& item) {
return sum + (*item_count_fn)(item);
});
};
}

if constexpr (std::is_same_v<MessageT, std::shared_ptr<ControlMessage>>)
{
return [](MessageT msg) {
return [](std::shared_ptr<ControlMessage> msg) {
if (!msg->payload())
{
return 0;
Expand All @@ -140,26 +143,23 @@ auto MonitorController<MessageT>::auto_count_fn() -> std::optional<std::function
};
}

if constexpr (is_vector<MessageT>::value)
if constexpr (std::is_same_v<MessageT, std::vector<std::shared_ptr<ControlMessage>>>)
{
// if (msg.empty())
// {
// return std::nullopt;
// }

return [this](MessageT msg) {
auto item_count_fn = auto_count_fn<MessageT>(msg[0]);
if (item_count_fn == std::nullopt)
{
return 0;
}
return std::accumulate(msg.begin(), msg.end(), 0, [item_count_fn](int sum, const auto& item) {
return [](std::vector<std::shared_ptr<ControlMessage>> msg) {
auto item_count_fn = [](std::shared_ptr<ControlMessage> msg) {
if (!msg->payload())
{
return 0;
}
return msg->payload()->count();
};
return std::accumulate(msg.begin(), msg.end(), 0, [&](int sum, const auto& item) {
return sum + (*item_count_fn)(item);
});
};
}

throw std::runtime_error("Unsupported message type received for MonitorController");
return std::nullopt;
}

template <typename MessageT>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,15 @@

#include "morpheus/controllers/monitor_controller.hpp"

#include "morpheus/messages/control.hpp"

#include <cudf/column/column.hpp>
#include <cudf/column/column_factories.hpp>
#include <cudf/table/table.hpp>



namespace morpheus {

// Component public implementations
// ****************** MonitorController ************************ //



} // namespace morpheus
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
#include "../test_utils/common.hpp" // for get_morpheus_root, TEST_CLASS_WITH_PYTHON, morpheus

#include "morpheus/controllers/monitor_controller.hpp" // for MonitorController
#include "morpheus/stages/monitor.hpp" // for MonitorStage
#include "morpheus/messages/control.hpp" // for ControlMessage

#include <cudf/column/column.hpp>
#include <cudf/column/column_factories.hpp>
#include <cudf/column/column_view.hpp>
#include <cudf/filling.hpp>
#include <cudf/io/types.hpp>
#include <cudf/scalar/scalar.hpp>
#include <cudf/table/table.hpp>
#include <cudf/types.hpp>
Expand All @@ -16,39 +17,74 @@

#include <memory>
#include <numeric>
#include <stdexcept>
#include <vector>

using namespace morpheus;

TEST_CLASS(MonitorController);
TEST_CLASS_WITH_PYTHON(MonitorController);

std::shared_ptr<cudf::table> create_cudf_table(int rows, int cols)
cudf::io::table_with_metadata create_cudf_table_with_metadata(int rows, int cols)
{
std::vector<std::unique_ptr<cudf::column>> columns;

for (int i = 0; i < cols; ++i)
{
// Create a numeric column of type INT32 with 'rows' elements
auto col = cudf::make_numeric_column(cudf::data_type{cudf::type_id::INT32}, rows);
auto col_view = col->mutable_view();

// Fill the column with range [0, rows - 1]
std::vector<int32_t> data(rows);
std::iota(data.begin(), data.end(), 0);
cudaMemcpy(col_view.data<int32_t>(), data.data(), data.size() * sizeof(int32_t), cudaMemcpyHostToDevice);

// Add the column to the vector
columns.push_back(std::move(col));
}

// Create and return the table
return std::make_shared<cudf::table>(std::move(columns));
auto table = std::make_unique<cudf::table>(std::move(columns));

auto index_info = cudf::io::column_name_info{""};
auto column_names = std::vector<cudf::io::column_name_info>(cols, index_info);
auto metadata = cudf::io::table_metadata{std::move(column_names), {}, {}};

return cudf::io::table_with_metadata{std::move(table), metadata};
}

TEST_F(TestMonitorController, TestAutoCountFn)
{
auto test_mc_cudf = MonitorController<std::shared_ptr<cudf::table>>("test_cudf_table");
auto cudf_auto_count_fn = test_mc_cudf.auto_count_fn();
auto cudf_table = create_cudf_table(10, 2);
assert((*cudf_auto_count_fn)(cudf_table) == 10);
auto message_meta_mc = MonitorController<std::shared_ptr<MessageMeta>>("test_message_meta");
auto message_meta_auto_count_fn = message_meta_mc.auto_count_fn();
auto meta = MessageMeta::create_from_cpp(std::move(create_cudf_table_with_metadata(10, 2)));
EXPECT_EQ((*message_meta_auto_count_fn)(meta), 10);

auto control_message_mc = MonitorController<std::shared_ptr<ControlMessage>>("test_control_message");
auto control_message_auto_count_fn = control_message_mc.auto_count_fn();
auto control_message = std::make_shared<ControlMessage>();
auto cm_meta = MessageMeta::create_from_cpp(std::move(create_cudf_table_with_metadata(20, 3)));
control_message->payload(cm_meta);
EXPECT_EQ((*control_message_auto_count_fn)(control_message), 20);

auto message_meta_vector_mc =
MonitorController<std::vector<std::shared_ptr<MessageMeta>>>("test_message_meta_vector");
auto message_meta_vector_auto_count_fn = message_meta_vector_mc.auto_count_fn();
std::vector<std::shared_ptr<MessageMeta>> meta_vector;
for (int i = 0; i < 5; ++i)
{
meta_vector.emplace_back(MessageMeta::create_from_cpp(std::move(create_cudf_table_with_metadata(5, 2))));
}
EXPECT_EQ((*message_meta_vector_auto_count_fn)(meta_vector), 25);

auto control_message_vector_mc =
MonitorController<std::vector<std::shared_ptr<ControlMessage>>>("test_control_message_vector");
auto control_message_vector_auto_count_fn = control_message_vector_mc.auto_count_fn();
std::vector<std::shared_ptr<ControlMessage>> control_message_vector;
for (int i = 0; i < 5; ++i)
{
auto cm = std::make_shared<ControlMessage>();
cm->payload(MessageMeta::create_from_cpp(std::move(create_cudf_table_with_metadata(6, 2))));
control_message_vector.emplace_back(cm);
}
EXPECT_EQ((*control_message_vector_auto_count_fn)(control_message_vector), 30);

// Test invalid message type
EXPECT_THROW(MonitorController<int>("invalid message type"), std::runtime_error);
}

0 comments on commit 9847952

Please sign in to comment.