Skip to content

Commit

Permalink
fix: fix VowpalWabbit#4669 by handling empty decision scores elements
Browse files Browse the repository at this point in the history
  • Loading branch information
jackgerrits committed Nov 27, 2023
1 parent 2849b3b commit 56d41da
Show file tree
Hide file tree
Showing 9 changed files with 194 additions and 3 deletions.
28 changes: 28 additions & 0 deletions test/core.vwtest.json
Original file line number Diff line number Diff line change
Expand Up @@ -6073,5 +6073,33 @@
"depends_on": [
467
]
},
{
"id": 469,
"desc": "https://github.com/VowpalWabbit/vowpal_wabbit/issues/4669",
"vw_command": "--ccb_explore_adf -q UA --all_slots_loss -f issue4669.model -d train-sets/issue4669.txt",
"diff_files": {
"stderr": "train-sets/ref/issue4669_train.stderr",
"stdout": "train-sets/ref/issue4669_train.stdout"
},
"input_files": [
"train-sets/issue4669.txt"
]
},
{
"id": 470,
"desc": "https://github.com/VowpalWabbit/vowpal_wabbit/issues/4669",
"vw_command": "--ccb_explore_adf -q UA --all_slots_loss -i issue4669.model -t -d train-sets/issue4669.txt",
"diff_files": {
"stderr": "train-sets/ref/issue4669_test.stderr",
"stdout": "train-sets/ref/issue4669_test.stdout"
},
"input_files": [
"train-sets/issue4669.txt",
"issue4669.model"
],
"depends_on": [
469
]
}
]
50 changes: 50 additions & 0 deletions test/train-sets/issue4669.txt
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
ccb shared |User userID='aUser'
ccb action |Action contentId='a'
ccb action |Action contentId='b'
ccb action |Action contentId='c'
ccb action |Action contentId='d'
ccb action |Action contentId='e'
ccb action |Action contentId='f'
ccb action |Action contentId='g'
ccb action |Action contentId='h'
ccb action |Action contentId='i'
ccb action |Action contentId='j'
ccb action |Action contentId='k'
ccb action |Action contentId='l'
ccb action |Action contentId='m'
ccb action |Action contentId='n'
ccb action |Action contentId='o'
ccb action |Action contentId='p'
ccb action |Action contentId='q'
ccb action |Action contentId='r'
ccb action |Action contentId='s'
ccb action |Action contentId='t'
ccb action |Action contentId='x'
ccb action |Action contentId='y'
ccb action |Action contentId='z'
ccb action |Action contentId='aa'
ccb action |Action contentId='ab'
ccb action |Action contentId='ac'
ccb action |Action contentId='ad'
ccb action |Action contentId='ae'
ccb action |Action contentId='af'
ccb action |Action contentId='ag'
ccb slot 7:0:0.2814157009124756,18:0.4986087679862976,0:0.09111946821212769,13:0.034919969737529755,17:0.022231325507164,24:0.021717887371778488,12:0.02091880328953266,22:0.008943566121160984,21:0.0035167494788765907,4:0.003486328525468707,8:0.00306233623996377,14:0.002975673880428076,25:0.002446186961606145,11:0.0013868052046746016,6:0.0007644615834578872,23:0.0007121390081010759,10:0.0006081887986510992,20:0.00046663329703733325,29:0.00024663639487698674,5:0.00020450385636650026,19:6.42634549876675e-05,15:6.378938996931538e-05,9:4.367829387774691e-05,27:2.616703204694204e-05,1:1.8396624000160955e-05,2:1.1368916602805257e-05,26:9.699509064375889e-06,16:5.553145911108004e-06,28:2.4986536573123885e-06,3:2.245711812065565e-06 7,18,0,13,17,24,12,22,21,4,8,14,25,11,6,23,10,20,29,5,19,15,9,27,1,2,26,16,28,3 |
ccb slot 28:0:0.38746780157089233,12:0.4452095031738281,17:0.11129308491945267,13:0.021261485293507576,2:0.010781776160001755,5:0.006623556837439537,14:0.006618801970034838,0:0.00422333087772131,3:0.0029810182750225067,25:0.0011069459142163396,22:0.0010891001438722014,21:0.00043250489397905767,20:0.00032753165578469634,29:0.0002478094829712063,24:0.0001928548444993794,18:3.78785262000747e-05,1:2.8776834369637072e-05,8:2.2285950763034634e-05,10:1.975510713236872e-05,19:1.1510705917316955e-05,15:1.0838572052307427e-05,4:4.2137894524785224e-06,11:2.6999800866178703e-06,16:1.461011038372817e-06,6:1.387319798595854e-06,27:1.0769275604616269e-06,23:9.243367458111607e-07,9:1.2605019605871348e-07,26:6.900691684741389e-10 28,12,17,13,2,5,14,0,3,25,22,21,20,29,24,18,1,8,10,19,15,4,11,16,6,27,23,9,26 |
ccb slot 21:0:0.009457200765609741,17:0.13830232620239258,20:0.06671705096960068,18:0.060711901634931564,13:0.04832380264997482,12:0.02896728552877903,15:0.027813201770186424,1:0.01238782238215208,5:0.5844042301177979,24:0.007916729897260666,22:0.005313423927873373,3:0.0030796348582953215,8:0.0026463039685040712,2:0.001610504579730332,4:0.0004394367279019207,14:0.0003743874840438366,11:0.0003052169340662658,9:0.0002891358162742108,10:0.0002627648937050253,0:0.0002489395847078413,25:0.00015549163799732924,29:0.00012897276610601693,23:7.323167665163055e-05,6:4.6114979340927675e-05,16:1.574967973283492e-05,27:4.545572664937936e-06,26:4.3639242903736886e-06,19:9.780344356613568e-08 21,17,20,18,13,12,15,1,5,24,22,3,8,2,4,14,11,9,10,0,25,29,23,6,16,27,26,19 |
ccb slot 25:0:0.04852084442973137,4:0.18842479586601257,0:0.11689016968011856,11:0.09394704550504684,15:0.07111208140850067,14:0.06413374841213226,1:0.2144840508699417,17:0.04394223168492317,22:0.04030191898345947,6:0.025946179404854774,20:0.02485107257962227,23:0.02306993119418621,8:0.01103772409260273,18:0.009649633429944515,24:0.0051767947152256966,3:0.004639864899218082,5:0.0036481451243162155,13:0.003012213623151183,27:0.0018589160172268748,2:0.0014349337434396148,12:0.0013534734025597572,16:0.0009505416383035481,26:0.0006534871645271778,19:0.0006136561860330403,29:0.00025192913017235696,9:4.8225138016277924e-05,10:4.6369124902412295e-05 25,4,0,11,15,14,1,17,22,6,20,23,8,18,24,3,5,13,27,2,12,16,26,19,29,9,10 |
ccb slot 22:-1:0.2308475524187088,29:0.7056121826171875,24:0.024420326575636864,13:0.017714975401759148,5:0.006036168430000544,15:0.004493136424571276,17:0.002671352354809642,14:0.0016715804813429713,18:0.001555807190015912,12:0.0014371434226632118,19:0.0010492069413885474,10:0.0006786655867472291,0:0.0003868465428240597,2:0.0003725361602846533,3:0.0003565707884263247,4:0.00022809540678281337,6:0.00016483885701745749,11:0.00013846883666701615,1:3.974044739152305e-05,16:3.603336153901182e-05,8:3.055339402635582e-05,27:1.7396419934812002e-05,20:1.5432331565534696e-05,23:1.4172852388583124e-05,9:1.0329640645068139e-05,26:8.735177061680588e-07 22,29,24,13,5,15,17,14,18,12,19,10,0,2,3,4,6,11,1,16,8,27,20,23,9,26 |
ccb slot 12:0:0.9093628525733948,0:0.04226174205541611,17:0.018896808847784996,13:0.009150444529950619,14:0.005060152616351843,24:0.0038528849836438894,11:0.0037376414984464645,1:0.0027163256891071796,20:0.0019985607359558344,5:0.0015763206174597144,6:0.0005092258215881884,18:0.0003184932575095445,2:0.0002330887655261904,8:0.00017497778753750026,23:4.659605838241987e-05,10:4.272087971912697e-05,29:2.479964132362511e-05,3:8.60487580212066e-06,27:6.880749879201176e-06,16:5.5458899623772595e-06,15:4.874308160651708e-06,9:3.708174290295574e-06,26:3.1738302368466975e-06,19:2.801355094561586e-06,4:8.652203291603655e-07 12,0,17,13,14,24,11,1,20,5,6,18,2,8,23,10,29,3,27,16,15,9,26,19,4 |
ccb slot 5:0:0.28719252347946167,17:0.23356057703495026,29:0.17160779237747192,20:0.1328369677066803,0:0.07422535866498947,23:0.050119683146476746,1:0.021847186610102654,13:0.012630262412130833,2:0.004507853649556637,8:0.0024876149836927652,3:0.0022544104140251875,9:0.0015914351679384708,18:0.0013845351058989763,10:0.001123852445743978,4:0.0008636291604489088,15:0.0004619551182258874,16:0.0004587841685861349,6:0.0004078721103724092,11:0.00021510760416276753,14:0.00014530900807585567,24:3.345161894685589e-05,26:2.9229657229734585e-05,27:1.3438097084872425e-05,19:1.2410271210683277e-06 5,17,29,20,0,23,1,13,2,8,3,9,18,10,4,15,16,6,11,14,24,26,27,19 |
ccb slot 20:0:0.028376348316669464,0:0.1662101000547409,1:0.046988535672426224,17:0.6560971140861511,3:0.02463809959590435,2:0.022943779826164246,8:0.02021339163184166,4:0.011480750516057014,6:0.006486969999969006,27:0.005383896175771952,18:0.005046996288001537,10:0.0017316940939053893,16:0.0012900123838335276,9:0.0011327711399644613,13:0.000544899667147547,23:0.000498446635901928,11:0.00048043631250038743,14:0.00018104366608895361,19:0.00012088919174857438,24:6.899452273501083e-05,29:5.0261642172699794e-05,15:3.0958537536207587e-05,26:3.547019559846376e-06 20,0,1,17,3,2,8,4,6,27,18,10,16,9,13,23,11,14,19,24,29,15,26 |
ccb slot 29:0:0.14386186003684998,4:0.18966703116893768,11:0.2729467749595642,1:0.10815022885799408,17:0.10176316648721695,24:0.048249997198581696,3:0.034515734761953354,2:0.0317024290561676,6:0.029643042013049126,14:0.01419538538902998,10:0.008138692937791348,0:0.00682934420183301,13:0.003933712374418974,9:0.0023981353733688593,8:0.0013491454301401973,23:0.000978886615484953,15:0.0006223535747267306,27:0.00042564209434203804,18:0.0003683240502141416,19:0.00018579690367914736,26:5.759065970778465e-05,16:1.6690515622030944e-05 29,4,11,1,17,24,3,2,6,14,10,0,13,9,8,23,15,27,18,19,26,16 |
ccb slot 16:0:0.8358462452888489,18:0.14180654287338257,8:0.006622648332268,0:0.0048446557484567165,13:0.003651235019788146,17:0.002547427313402295,2:0.0023207622580230236,14:0.0007652127533219755,9:0.0003962431219406426,3:0.00039094872772693634,6:0.0003010313375853002,1:0.00015582605556119233,4:0.0001078778732335195,27:8.33657686598599e-05,11:6.325534195639193e-05,26:5.2438019338296726e-05,24:2.5939621991710737e-05,10:1.2438776138878893e-05,15:5.886784038011683e-06,23:1.0229091351732222e-08,19:7.580823080388654e-09 16,18,8,0,13,17,2,14,9,3,6,1,4,27,11,26,24,10,15,23,19 |
ccb slot 17:0:0.5127939581871033,2:0.4746686518192291,19:0.006616346072405577,1:0.001669980469159782,3:0.0010488936677575111,24:0.0009213163866661489,6:0.0007942019728943706,14:0.0006952053518034518,23:0.0005310842534527183,27:0.00010622834088280797,0:5.494915967574343e-05,8:2.4259865313069895e-05,26:2.4085793484118767e-05,10:1.763248656061478e-05,13:1.2499597687565256e-05,4:9.418990885023959e-06,11:6.62406728224596e-06,15:4.250239726388827e-06,9:3.168663909036695e-07 17,2,19,1,3,24,6,14,23,27,0,8,26,10,13,4,11,15,9 |
ccb slot 2:0:0.3009277284145355,11:0.2405196875333786,0:0.12718695402145386,1:0.10248230397701263,4:0.10110407322645187,27:0.0488734170794487,14:0.031138606369495392,23:0.027005093172192574,15:0.01463699247688055,13:0.0021626290399581194,9:0.001511908951215446,24:0.0014617646811529994,26:0.000723238626960665,8:0.00015963710029609501,10:6.141114135971293e-05,3:4.065587563673034e-05,6:2.3371310362563236e-06,19:1.6080829254860873e-06 2,11,0,1,4,27,14,23,15,13,9,24,26,8,10,3,6,19 |
ccb slot 3:0:0.057185981422662735,27:0.14134737849235535,24:0.11761204898357391,11:0.09371144324541092,0:0.5287960171699524,6:0.03242357447743416,8:0.011251688934862614,14:0.007142344955354929,1:0.0070684002712368965,23:0.0012641034554690123,9:0.0005852219182997942,13:0.0005638026632368565,19:0.0004846873343922198,15:0.0002985078317578882,26:0.00022339983843266964,10:3.1262432457879186e-05,4:1.0203940291830804e-05 3,27,24,11,0,6,8,14,1,23,9,13,19,15,26,10,4 |
ccb slot 1:0:0.1249222531914711,24:0.25583717226982117,14:0.12867441773414612,23:0.26537543535232544,13:0.08874479681253433,15:0.04395920783281326,11:0.04107809066772461,26:0.026856984943151474,0:0.01694386824965477,6:0.005332316737622023,9:0.0009168571559712291,8:0.0005701580666936934,10:0.00035485764965415,4:0.00030593250994570553,27:9.87778403214179e-05,19:2.8945323720108718e-05 1,24,14,23,13,15,11,26,0,6,9,8,10,4,27,19 |
ccb slot 0:0:0.7005100846290588,15:0.15935854613780975,26:0.10210220515727997,6:0.01650484837591648,27:0.012556466273963451,14:0.003727053524926305,10:0.0025744284503161907,23:0.0013873651623725891,11:0.0004973539034835994,13:0.00025413508410565555,24:0.0002187014470109716,8:0.00016339073772542179,19:8.715951116755605e-05,4:4.481669020606205e-05,9:1.3443200259644073e-05 0,15,26,6,27,14,10,23,11,13,24,8,19,4,9 |
ccb slot 8:0:0.9874762892723083,6:0.00831417366862297,24:0.0015088269719853997,19:0.0014199139550328255,13:0.0005207555368542671,4:0.0003495477430988103,23:0.00022233030176721513,11:0.0001492148294346407,10:1.3149796359357424e-05,27:9.886987754725851e-06,15:9.333793059340678e-06,9:2.8232302611286286e-06,14:2.0402019345056033e-06,26:1.7364958466714597e-06 8,6,24,19,13,4,23,11,10,27,15,9,14,26 |
ccb slot 14:0:0.12711836397647858,9:0.23751074075698853,10:0.3793737292289734,27:0.08924281597137451,24:0.07040458917617798,15:0.06249556690454483,11:0.02001812309026718,4:0.009080302901566029,26:0.0017807148396968842,23:0.0015691084554418921,19:0.0009121194598264992,13:0.0004805738863069564,6:1.3344042599783279e-05 14,9,10,27,24,15,11,4,26,23,19,13,6 |
ccb slot 24:0:0.9453869462013245,19:0.029324114322662354,23:0.017052331939339638,13:0.004684425424784422,11:0.0012589030666276813,10:0.0009999927133321762,26:0.0009100232855416834,15:0.00019236477965023369,6:0.00014357283362187445,27:4.12082408729475e-05,9:3.6369422105053673e-06,4:2.426173068670323e-06 24,19,23,13,11,10,26,15,6,27,9,4 |
ccb slot 19:0:0.07647673040628433,13:0.26458504796028137,27:0.08122498542070389,26:0.548874020576477,11:0.025309467688202858,15:0.0020130074117332697,6:0.00121505802962929,23:0.00012354919454082847,9:9.067041537491605e-05,4:4.7242716391338035e-05,10:4.006700328318402e-05 19,13,27,26,11,15,6,23,9,4,10 |
23 changes: 23 additions & 0 deletions test/train-sets/ref/issue4669_test.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
creating quadratic features for pairs: UA
only testing
using no cache
Reading datafile = train-sets/issue4669.txt
num sources = 1
Num weight bits = 18
learning rate = 0.5
initial_t = 1
power_t = 0.5
cb_type = mtr
Enabled learners: gd, generate_interactions, scorer-identity, csoaa_ldf-rank, cb_adf, cb_explore_adf_greedy, cb_sample, shared_feature_merger, ccb_explore_adf
Input label = CCB
Output pred = DECISION_PROBS
average since example example current current current
loss last counter weight label predict features
0.000000 0.000000 1 1.0 7:0,28:0,... 22,4,6,25,2... 2059

finished run
number of examples = 1
weighted example sum = 1.000000
weighted label sum = 0.000000
average loss = 0.000000
total feature number = 2059
1 change: 1 addition & 0 deletions test/train-sets/ref/issue4669_test.stdout
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
[warning] model file has set of {-q, --cubic, --interactions} settings stored, but they'll be OVERRIDDEN by set of {-q, --cubic, --interactions} settings from command line.
23 changes: 23 additions & 0 deletions test/train-sets/ref/issue4669_train.stderr
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
creating quadratic features for pairs: UA
final_regressor = issue4669.model
using no cache
Reading datafile = train-sets/issue4669.txt
num sources = 1
Num weight bits = 18
learning rate = 0.5
initial_t = 0
power_t = 0.5
cb_type = mtr
Enabled learners: gd, generate_interactions, scorer-identity, csoaa_ldf-rank, cb_adf, cb_explore_adf_greedy, cb_sample, shared_feature_merger, ccb_explore_adf
Input label = CCB
Output pred = DECISION_PROBS
average since example example current current current
loss last counter weight label predict features
-0.16661 -0.16661 1 1.0 7:0,28:0,... 7,28,21,25,... 3120

finished run
number of examples = 1
weighted example sum = 1.000000
weighted label sum = 0.000000
average loss = -0.166610
total feature number = 3120
Empty file.
3 changes: 2 additions & 1 deletion vowpalwabbit/core/src/decision_scores.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,8 @@ void print_update(VW::workspace& all, const VW::multi_ex& slots, const VW::decis
std::string delim;
for (const auto& slot : decision_scores)
{
pred_ss << delim << slot[0].action;
if (slot.empty()) { pred_ss << delim << "None"; }
else { pred_ss << delim << slot[0].action; }
delim = ",";
}
all.sd->print_update(*all.output_runtime.trace_message, all.passes_config.holdout_set_off,
Expand Down
13 changes: 11 additions & 2 deletions vowpalwabbit/core/src/reductions/conditional_contextual_bandit.cc
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include "vw/core/reductions/conditional_contextual_bandit.h"

#include "vw/config/options.h"
#include "vw/core/cb.h"
#include "vw/core/ccb_label.h"
#include "vw/core/ccb_reduction_features.h"
#include "vw/core/constant.h"
Expand Down Expand Up @@ -213,8 +214,12 @@ void clear_pred_and_label(ccb_data& data)
data.actions[data.action_with_label]->l.cb.costs.clear();
}

// true if there exists at least 1 action in the cb multi-example
bool has_action(VW::multi_ex& cb_ex) { return !cb_ex.empty(); }
// true if there exists at least 2 examples (since there can only be up to 1
// shared example), or the 0th example is not shared.
bool has_action(VW::multi_ex& cb_ex)
{
return cb_ex.size() > 1 || (!cb_ex.empty() && !VW::ec_is_example_header_cb(*cb_ex[0]));
}

// This function intentionally does not handle increasing the num_features of the example because
// the output_example function has special logic to ensure the number of features is correctly calculated.
Expand Down Expand Up @@ -547,6 +552,10 @@ void update_stats_ccb(const VW::workspace& /* all */, shared_data& sd, const ccb
num_labeled++;
if (i == 0 || data.all_slots_loss_report)
{
// It is possible for the prediction to be empty if there were no actions available at the time of taking the
// slot decision. In this case it does not contribute to loss.
if (preds[i].empty()) { continue; }

const float l = VW::get_cost_estimate(outcome->probabilities[VW::details::TOP_ACTION_INDEX], outcome->cost,
preds[i][VW::details::TOP_ACTION_INDEX].action);
loss += l * preds[i][VW::details::TOP_ACTION_INDEX].score * ec_seq[VW::details::SHARED_EX_INDEX]->weight;
Expand Down
56 changes: 56 additions & 0 deletions vowpalwabbit/core/tests/ccb_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -145,3 +145,59 @@ TEST(Ccb, InsertInteractionsImplTest)

EXPECT_THAT(result, testing::ContainerEq(expected_after));
}

TEST(Ccb, ExplicitIncludedActionsNonExistentAction)
{
auto vw = VW::initialize(vwtest::make_args("--ccb_explore_adf", "--quiet"));
VW::multi_ex examples;
examples.push_back(VW::read_example(*vw, "ccb shared |"));
examples.push_back(VW::read_example(*vw, "ccb action |"));
examples.push_back(VW::read_example(*vw, "ccb slot 0:10:10 10 |"));

vw->learn(examples);

auto& decision_scores = examples[0]->pred.decision_scores;
EXPECT_EQ(decision_scores.size(), 1);
EXPECT_EQ(decision_scores[0].size(), 0);
vw->finish_example(examples);
}

TEST(Ccb, NoAvailableActions)
{
auto vw = VW::initialize(vwtest::make_args("--ccb_explore_adf", "--quiet", "--all_slots_loss"));
{
VW::multi_ex examples;
examples.push_back(VW::read_example(*vw, "ccb shared |"));
examples.push_back(VW::read_example(*vw, "ccb action | a"));
examples.push_back(VW::read_example(*vw, "ccb action | b"));
examples.push_back(VW::read_example(*vw, "ccb slot 0:-1:0.5 0,1 |"));
examples.push_back(VW::read_example(*vw, "ccb slot |"));

vw->learn(examples);

auto& decision_scores = examples[0]->pred.decision_scores;
EXPECT_EQ(decision_scores.size(), 2);
vw->finish_example(examples);
}

{
VW::multi_ex examples;
examples.push_back(VW::read_example(*vw, "ccb shared |"));
examples.push_back(VW::read_example(*vw, "ccb action | a"));
examples.push_back(VW::read_example(*vw, "ccb action | b"));
examples.push_back(VW::read_example(*vw, "ccb slot 0:-1:0.5 0,1 |"));
// This time restrict slot 1 to only have action 0 available
examples.push_back(VW::read_example(*vw, "ccb slot 0:-1:0.5 0 |"));

vw->predict(examples);

auto& decision_scores = examples[0]->pred.decision_scores;
EXPECT_EQ(decision_scores.size(), 2);
EXPECT_EQ(decision_scores[0].size(), 2);
EXPECT_EQ(decision_scores[0][0].action, 0);
EXPECT_EQ(decision_scores[0][1].action, 1);
EXPECT_EQ(decision_scores[1].size(), 0);

vw->finish_example(examples);
}
}

0 comments on commit 56d41da

Please sign in to comment.