Skip to content

Commit

Permalink
Merge pull request #34 from WojciechMigda/GH33-data-processing-diet
Browse files Browse the repository at this point in the history
Calculate clause output and train automata only using data for the two selected labels
  • Loading branch information
WojciechMigda authored Jun 29, 2020
2 parents f7ca6e2 + 582ffab commit c48cf77
Show file tree
Hide file tree
Showing 3 changed files with 93 additions and 33 deletions.
88 changes: 66 additions & 22 deletions lib/src/tsetlini.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -208,15 +208,35 @@ void classifier_update_impl(
int clause_output_tile_size
)
{
calculate_clause_output(
X,
cache.clause_output,
number_of_clauses,
number_of_features,
ta_state,
n_jobs,
clause_output_tile_size
);
{
auto const [clause_ix_begin, clause_ix_end] = clause_range_for_label(target_label, number_of_pos_neg_clauses_per_label);

calculate_clause_output(
X,
cache.clause_output,
clause_ix_begin,
clause_ix_end,
number_of_features,
ta_state,
n_jobs,
clause_output_tile_size
);
}

{
auto const [clause_ix_begin, clause_ix_end] = clause_range_for_label(opposite_label, number_of_pos_neg_clauses_per_label);

calculate_clause_output(
X,
cache.clause_output,
clause_ix_begin,
clause_ix_end,
number_of_features,
ta_state,
n_jobs,
clause_output_tile_size
);
}

sum_up_label_votes(
cache.clause_output,
Expand Down Expand Up @@ -246,19 +266,43 @@ void classifier_update_impl(

const auto S_inv = ONE / s;

train_classifier_automata(
ta_state,
number_of_clauses,
cache.feedback_to_clauses.data(),
cache.clause_output.data(),
number_of_features,
number_of_states,
S_inv,
X.data(),
boost_true_positive_feedback,
fgen,
cache.fcache
);
{
auto const [clause_ix_begin, clause_ix_end] = clause_range_for_label(target_label, number_of_pos_neg_clauses_per_label);

train_classifier_automata(
ta_state,
clause_ix_begin,
clause_ix_end,
cache.feedback_to_clauses.data(),
cache.clause_output.data(),
number_of_features,
number_of_states,
S_inv,
X.data(),
boost_true_positive_feedback,
fgen,
cache.fcache
);
}

{
auto const [clause_ix_begin, clause_ix_end] = clause_range_for_label(opposite_label, number_of_pos_neg_clauses_per_label);

train_classifier_automata(
ta_state,
clause_ix_begin,
clause_ix_end,
cache.feedback_to_clauses.data(),
cache.clause_output.data(),
number_of_features,
number_of_states,
S_inv,
X.data(),
boost_true_positive_feedback,
fgen,
cache.fcache
);
}
}


Expand Down
36 changes: 26 additions & 10 deletions lib/src/tsetlini_algo.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,15 @@ int neg_clause_index(int target_label, int j, int number_of_pos_neg_clauses_per_
}


inline
auto clause_range_for_label(int label, int number_of_pos_neg_clauses_per_label) -> std::pair<int, int>
{
auto const begin = pos_clause_index(label, 0, number_of_pos_neg_clauses_per_label);

return std::make_pair(begin, begin + 2 * number_of_pos_neg_clauses_per_label);
}


inline
void sum_up_label_votes(
aligned_vector_char const & clause_output,
Expand Down Expand Up @@ -254,7 +263,8 @@ inline
void calculate_clause_output_T(
aligned_vector_char const & X,
aligned_vector_char & clause_output,
int const number_of_clauses,
int const clause_begin_ix,
int const clause_end_ix,
int const number_of_features,
numeric_matrix<state_type> const & ta_state,
int const n_jobs)
Expand All @@ -263,7 +273,7 @@ void calculate_clause_output_T(

if (number_of_features < (int)BATCH_SZ)
{
for (int j = 0; j < number_of_clauses; ++j)
for (int j = clause_begin_ix; j < clause_end_ix; ++j)
{
bool output = true;

Expand All @@ -284,7 +294,7 @@ void calculate_clause_output_T(
else
{
#pragma omp parallel for if (n_jobs > 1) num_threads(n_jobs)
for (int j = 0; j < number_of_clauses; ++j)
for (int j = clause_begin_ix; j < clause_end_ix; ++j)
{
char toggle_output = 0;

Expand Down Expand Up @@ -350,7 +360,8 @@ inline
void calculate_clause_output(
RowType const & X,
aligned_vector_char & clause_output,
int const number_of_clauses,
int const clause_begin_ix,
int const clause_end_ix,
int const number_of_features,
numeric_matrix<state_type> const & ta_state,
int const n_jobs,
Expand All @@ -362,7 +373,8 @@ void calculate_clause_output(
calculate_clause_output_T<state_type, 128>(
X,
clause_output,
number_of_clauses,
clause_begin_ix,
clause_end_ix,
number_of_features,
ta_state,
n_jobs
Expand All @@ -372,7 +384,8 @@ void calculate_clause_output(
calculate_clause_output_T<state_type, 64>(
X,
clause_output,
number_of_clauses,
clause_begin_ix,
clause_end_ix,
number_of_features,
ta_state,
n_jobs
Expand All @@ -382,7 +395,8 @@ void calculate_clause_output(
calculate_clause_output_T<state_type, 32>(
X,
clause_output,
number_of_clauses,
clause_begin_ix,
clause_end_ix,
number_of_features,
ta_state,
n_jobs
Expand All @@ -395,7 +409,8 @@ void calculate_clause_output(
calculate_clause_output_T<state_type, 16>(
X,
clause_output,
number_of_clauses,
clause_begin_ix,
clause_end_ix,
number_of_features,
ta_state,
n_jobs
Expand Down Expand Up @@ -542,7 +557,8 @@ void block3(
template<typename state_type>
void train_classifier_automata(
numeric_matrix<state_type> & ta_state,
int const number_of_clauses,
int const clause_begin_ix,
int const clause_end_ix,
feedback_vector_type::value_type const * __restrict feedback_to_clauses,
char const * __restrict clause_output,
int const number_of_features,
Expand All @@ -556,7 +572,7 @@ void train_classifier_automata(
{
float const * fcache_ = assume_aligned<alignment>(fcache.m_fcache.data());

for (int j = 0; j < number_of_clauses; ++j)
for (int j = clause_begin_ix; j < clause_end_ix; ++j)
{
state_type * ta_state_pos_j = ::assume_aligned<alignment>(ta_state.row_data(2 * j + 0));
state_type * ta_state_neg_j = ::assume_aligned<alignment>(ta_state.row_data(2 * j + 1));
Expand Down
2 changes: 1 addition & 1 deletion lib/tests/src/test_algo_classic.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -314,7 +314,7 @@ TEST(CalculateClauseOutput, replicates_result_of_CAIR_code)
Tsetlini::aligned_vector_char clause_output(number_of_clauses);

CAIR::calculate_clause_output(X, clause_output_CAIR, number_of_clauses, number_of_features, ta_state, false);
Tsetlini::calculate_clause_output(X, clause_output, number_of_clauses, number_of_features, ta_state, 1, 16);
Tsetlini::calculate_clause_output(X, clause_output, 0, number_of_clauses, number_of_features, ta_state, 1, 16);

if (0 != std::accumulate(clause_output_CAIR.cbegin(), clause_output_CAIR.cend(), 0u))
{
Expand Down

0 comments on commit c48cf77

Please sign in to comment.