Skip to content

Commit

Permalink
Fixing preprocessing C++ shape
Browse files Browse the repository at this point in the history
  • Loading branch information
mdemoret-nv committed Nov 29, 2023
1 parent 1f5718c commit a69a10c
Showing 1 changed file with 21 additions and 8 deletions.
29 changes: 21 additions & 8 deletions morpheus/_lib/src/stages/preprocess_fil.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -101,13 +101,15 @@ PreprocessFILStage::subscribe_fn_t PreprocessFILStage::build_operator()
}
}

// Need to do a transpose here
// Need to convert from row major to column major
// Easiest way to do this is to transpose the data from [fea_len, row_count] to [row_count, fea_len]
auto transposed_data =
MatxUtil::transpose(DevMemInfo{packed_data,
TypeId::FLOAT32,
{x->mess_count, static_cast<TensorIndex>(m_fea_cols.size())},
{1, x->mess_count}});
{static_cast<TensorIndex>(m_fea_cols.size()), x->mess_count},
{x->mess_count, 1}});

// Create the tensor which will be row-major and size [row_count, fea_len]
auto input__0 = Tensor::create(transposed_data,
DType::create<float>(),
{x->mess_count, static_cast<TensorIndex>(m_fea_cols.size())},
Expand All @@ -121,8 +123,8 @@ PreprocessFILStage::subscribe_fn_t PreprocessFILStage::build_operator()
input__0.get_memory(),
x->mess_offset),
seq_id_dtype,
{x->mess_count, 3},
{},
{x->mess_count, 3},
{},
0);

// Build the results
Expand Down Expand Up @@ -152,11 +154,22 @@ TableInfo PreprocessFILStage::fix_bad_columns(sink_type_t x)
auto mutable_info = x->meta->get_mutable_info();
auto df_meta_col_names = mutable_info.get_column_names();

for (size_t i = 0; i < mutable_info.num_columns(); ++i)
// Only check the feature columns. Leave the rest unchanged
for (auto& fea_col : m_fea_cols)
{
if (mutable_info.get_column(i).type().id() == cudf::type_id::STRING)
// Find the index of the column in the dataframe
auto col_idx =
std::find(df_meta_col_names.begin(), df_meta_col_names.end(), fea_col) - df_meta_col_names.begin();

if (col_idx == df_meta_col_names.size())
{
// This feature was not found. Ignore it.
continue;
}

if (mutable_info.get_column(col_idx).type().id() == cudf::type_id::STRING)
{
bad_cols.push_back(df_meta_col_names[i]);
bad_cols.push_back(fea_col);
}
}

Expand Down

0 comments on commit a69a10c

Please sign in to comment.