From 69d75533b56fc4338354c7e2e8e5c4d6faabeddb Mon Sep 17 00:00:00 2001 From: Ilya Matiach Date: Thu, 10 Nov 2022 13:27:10 -0500 Subject: [PATCH] add multilabel support to RAI dashboard and multilabel text classification covid events dataset --- .../__mock_data__/covidevents.ts | 684 ++++++++++++++++++ libs/core-ui/src/lib/Interfaces/IDataset.ts | 9 +- .../lib/Interfaces/IModelExplanationData.ts | 2 +- .../Interfaces/IVisionModelExplanationData.ts | 6 +- .../Interfaces/VisionExplanationInterfaces.ts | 4 +- libs/core-ui/src/lib/util/DatasetUtils.ts | 47 +- libs/core-ui/src/lib/util/JointDataset.ts | 133 ++-- .../core-ui/src/lib/util/JointDatasetUtils.ts | 5 +- .../src/util/getOriginalData.ts | 5 +- .../Controls/DataCharacteristics.tsx | 6 +- .../Controls/DataCharacteristicsRow.tsx | 4 +- .../Controls/Flyout.tsx | 12 +- .../Controls/ImageList.tsx | 4 +- .../VisionExplanationDashboardHelper.ts | 21 +- .../utils/getFilteredData.ts | 23 +- .../Context/buildModelAssessmentContext.ts | 1 + 16 files changed, 886 insertions(+), 80 deletions(-) create mode 100644 apps/dashboard/src/model-assessment-text/__mock_data__/covidevents.ts diff --git a/apps/dashboard/src/model-assessment-text/__mock_data__/covidevents.ts b/apps/dashboard/src/model-assessment-text/__mock_data__/covidevents.ts new file mode 100644 index 0000000000..f1fc8f7e7e --- /dev/null +++ b/apps/dashboard/src/model-assessment-text/__mock_data__/covidevents.ts @@ -0,0 +1,684 @@ +// Copyright (c) Microsoft Corporation. +// Licensed under the MIT License. + +import { DatasetTaskType, IDataset } from "@responsible-ai/core-ui"; + +export const covid19events: IDataset = { + categorical_features: [], + class_names: [], + feature_names: [ + "positive_words", + "negative_words", + "negation_words", + "negated_entities", + "named_persons", + "sentence_length" + ], + features: [ + [9, 0, 0, 0, 0, 122], + [3, 0, 0, 0, 0, 64], + [7, 0, 0, 0, 0, 206], + [16, 2, 0, 0, 0, 296], + [16, 0, 0, 0, 0, 350], + [14, 5, 0, 0, 0, 338], + [17, 1, 0, 0, 0, 324], + [26, 4, 0, 0, 0, 524], + [4, 1, 0, 0, 0, 120], + [5, 1, 1, 0, 0, 93], + [11, 2, 0, 0, 0, 231], + [8, 2, 0, 0, 0, 156], + [16, 0, 0, 0, 0, 309], + [6, 1, 1, 0, 0, 167], + [15, 1, 0, 1, 2, 258], + [4, 2, 0, 0, 1, 122], + [7, 2, 0, 0, 0, 158], + [6, 0, 0, 0, 0, 130], + [12, 0, 0, 0, 0, 149], + [24, 2, 0, 0, 0, 385], + [13, 0, 0, 0, 2, 288], + [13, 0, 0, 0, 0, 150], + [24, 7, 1, 0, 0, 378], + [7, 0, 0, 0, 0, 144], + [28, 5, 1, 0, 0, 417], + [11, 2, 1, 0, 0, 287], + [8, 0, 0, 0, 0, 132], + [21, 1, 0, 0, 1, 300], + [23, 2, 0, 0, 0, 282], + [18, 1, 1, 0, 0, 286], + [13, 2, 1, 1, 0, 272], + [11, 0, 0, 0, 0, 145], + [15, 3, 0, 0, 1, 275], + [9, 0, 0, 0, 1, 193], + [4, 0, 0, 0, 0, 106], + [5, 2, 2, 0, 0, 70], + [12, 2, 0, 0, 0, 189], + [5, 1, 0, 0, 0, 317], + [2, 0, 0, 0, 0, 81], + [17, 0, 0, 0, 0, 304], + [5, 1, 0, 0, 0, 142], + [9, 0, 0, 0, 0, 210], + [22, 3, 2, 2, 0, 565], + [8, 2, 0, 0, 0, 237], + [10, 0, 0, 0, 0, 137], + [7, 0, 0, 0, 0, 178], + [11, 0, 0, 0, 0, 226], + [17, 0, 0, 0, 0, 324], + [8, 0, 0, 0, 0, 150], + [5, 0, 0, 0, 0, 124], + [3, 0, 0, 0, 1, 110], + [9, 1, 0, 0, 0, 197], + [15, 0, 0, 0, 1, 246], + [17, 1, 0, 0, 0, 205], + [30, 9, 1, 1, 0, 701], + [9, 3, 0, 1, 0, 230], + [11, 3, 2, 1, 1, 253], + [30, 1, 0, 0, 0, 465], + [18, 2, 0, 0, 1, 382], + [3, 0, 0, 0, 0, 89], + [0, 0, 0, 0, 0, 43], + [2, 0, 0, 0, 0, 135], + [23, 1, 0, 0, 0, 274], + [32, 1, 1, 1, 0, 342], + [12, 1, 0, 0, 0, 222], + [16, 0, 0, 0, 2, 249], + [16, 1, 2, 0, 0, 274], + [5, 0, 0, 0, 0, 105], + [13, 0, 0, 0, 0, 394], + [12, 0, 0, 0, 0, 341], + [5, 0, 0, 0, 0, 179], + [7, 0, 0, 0, 0, 255], + [7, 0, 0, 0, 0, 208], + [2, 1, 1, 2, 0, 101], + [16, 0, 0, 0, 0, 277], + [2, 1, 1, 1, 0, 103], + [10, 3, 0, 0, 0, 281], + [11, 1, 0, 0, 0, 215], + [15, 0, 0, 0, 0, 260], + [13, 1, 0, 0, 0, 308], + [8, 0, 0, 0, 0, 175] + ], + predicted_y: [ + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 1], + [0, 0, 0, 1, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1], + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 1], + [1, 0, 0, 0, 0, 0, 0, 0] + ], + probability_y: [ + [ + 0.004053614102303982, 0.002606572350487113, 0.03770056739449501, + 0.01188106182962656, 0.003459201194345951, 0.0015037631383165717, + 0.001331277540884912, 0.21803493797779083 + ], + [ + 0.01080469973385334, 0.06398747861385345, 0.9224351644515991, + 0.037625785917043686, 0.06488275527954102, 0.02870604768395424, + 0.021452177315950394, 0.9044150710105896 + ], + [ + 0.016751259565353394, 0.012714984826743603, 0.030967459082603455, + 0.9194455146789551, 0.16859136521816254, 0.03698199242353439, + 0.034791648387908936, 0.9330119490623474 + ], + [ + 0.019421033561229706, 0.003612867556512356, 0.008314190432429314, + 0.004650872200727463, 0.0021483059972524643, 0.001040966366417706, + 0.000974209513515234, 0.059485021978616714 + ], + [ + 0.0030330675654113293, 0.0029852748848497868, 0.026549631729722023, + 0.02181229181587696, 0.005618046037852764, 0.0016587598947808146, + 0.0014594105305150151, 0.2606494724750519 + ], + [ + 0.006474900059401989, 0.0033059900160878897, 0.009029039181768894, + 0.013276848942041397, 0.002718358300626278, 0.0010431399568915367, + 0.001009888481348753, 0.060823600739240646 + ], + [ + 0.00636912090703845, 0.0045023574493825436, 0.013715151697397232, + 0.035727258771657944, 0.0018532754620537162, 0.0008890776662155986, + 0.0010821651667356491, 0.013800323009490967 + ], + [ + 0.03933095559477806, 0.1984478086233139, 0.8884686231613159, + 0.007201211992651224, 0.007626084610819817, 0.00954109150916338, + 0.009848043322563171, 0.013710269704461098 + ], + [ + 0.09486576914787292, 0.8029305934906006, 0.2078256905078888, + 0.037898894399404526, 0.01797395944595337, 0.028242554515600204, + 0.01896023564040661, 0.004749095067381859 + ], + [ + 0.01595112681388855, 0.007182489614933729, 0.009365618228912354, + 0.012222557328641415, 0.0016373330727219582, 0.0009629192645661533, + 0.0009885681793093681, 0.008730828762054443 + ], + [ + 0.8572545647621155, 0.824332594871521, 0.14179246127605438, + 0.0200976375490427, 0.012936611659824848, 0.02951071783900261, + 0.022895444184541702, 0.004638078156858683 + ], + [ + 0.9630569219589233, 0.4214213490486145, 0.048789650201797485, + 0.012131775729358196, 0.007620903663337231, 0.015260433778166771, + 0.013192555867135525, 0.004479472525417805 + ], + [ + 0.6588174700737, 0.89195716381073, 0.31339535117149353, + 0.02704426646232605, 0.018460256978869438, 0.04036659374833107, + 0.031334638595581055, 0.004008607938885689 + ], + [ + 0.07224246114492416, 0.7733100652694702, 0.29067757725715637, + 0.04611900448799133, 0.01312231458723545, 0.018750926479697227, + 0.013437841087579727, 0.003730364143848419 + ], + [ + 0.004964345600455999, 0.015528012067079544, 0.5673052668571472, + 0.014631686732172966, 0.004219789523631334, 0.0019882586784660816, + 0.0022909576073288918, 0.012187712825834751 + ], + [ + 0.014854734763503075, 0.002776058856397867, 0.013288215734064579, + 0.008608302101492882, 0.004614434204995632, 0.0025636397767812014, + 0.0020366390235722065, 0.668488085269928 + ], + [ + 0.005452004261314869, 0.05184265226125717, 0.20518416166305542, + 0.10911096632480621, 0.002271457342430949, 0.002440890297293663, + 0.0029052274767309427, 0.003225722350180149 + ], + [ + 0.00816359743475914, 0.0037822388112545013, 0.03452426195144653, + 0.006227799691259861, 0.006432953290641308, 0.002913115546107292, + 0.0021951044909656048, 0.5802102088928223 + ], + [ + 0.8807785511016846, 0.010355561971664429, 0.003747768234461546, + 0.038442689925432205, 0.0026625001337379217, 0.002585115609690547, + 0.0032945089042186737, 0.014006773941218853 + ], + [ + 0.005408243741840124, 0.0028393203392624855, 0.018540285527706146, + 0.011183302849531174, 0.0018795254873111844, 0.0009552990668453276, + 0.0009099083254113793, 0.03446265682578087 + ], + [ + 0.021801862865686417, 0.004676208831369877, 0.012078095227479935, + 0.01185069140046835, 0.016176607459783554, 0.007096666842699051, + 0.005083728116005659, 0.9373388290405273 + ], + [ + 0.0072930739261209965, 0.002496061846613884, 0.008354422636330128, + 0.008719664998352528, 0.0032808054238557816, 0.001497624907642603, + 0.0012436261167749763, 0.2242753654718399 + ], + [ + 0.0045114136300981045, 0.0024979072622954845, 0.02105729840695858, + 0.00813086237758398, 0.0037036617286503315, 0.0016966235125437379, + 0.0013959328643977642, 0.21806898713111877 + ], + [ + 0.010847983881831169, 0.008680322207510471, 0.022197311744093895, + 0.8772158622741699, 0.006153744645416737, 0.004052757751196623, + 0.004840007517486811, 0.014863566495478153 + ], + [ + 0.006133050192147493, 0.0031470158137381077, 0.008532358333468437, + 0.02629619464278221, 0.001844844315201044, 0.000900165643543005, + 0.0010252236388623714, 0.030216652899980545 + ], + [ + 0.015955885872244835, 0.008804835379123688, 0.008258480578660965, + 0.45750343799591064, 0.002109039342030883, 0.0015894134994596243, + 0.002052376978099346, 0.01139446534216404 + ], + [ + 0.005888455547392368, 0.0037404082249850035, 0.021594777703285217, + 0.004686504602432251, 0.002653959207236767, 0.001325606950558722, + 0.0011151438811793923, 0.06581959873437881 + ], + [ + 0.03390899673104286, 0.013195629231631756, 0.012328003533184528, + 0.48629847168922424, 0.0019344737520441413, 0.0017799859633669257, + 0.0023830749560147524, 0.0038987318985164165 + ], + [ + 0.004476228728890419, 0.0028032732661813498, 0.02048170007765293, + 0.014296414330601692, 0.0020108588505536318, 0.0009604321094229817, + 0.0009910142980515957, 0.028767826035618782 + ], + [ + 0.008351492695510387, 0.0027173233684152365, 0.006832093931734562, + 0.010775466449558735, 0.002076599281281233, 0.0009824051521718502, + 0.000980760552920401, 0.04596298933029175 + ], + [ + 0.0034246055874973536, 0.006964144762605429, 0.06431429088115692, + 0.16197563707828522, 0.002055288292467594, 0.0014627790078520775, + 0.0017995978705585003, 0.006710992194712162 + ], + [ + 0.008549017831683159, 0.0023692194372415543, 0.011592290364205837, + 0.01252521201968193, 0.005836385302245617, 0.0023176344111561775, + 0.0019478322938084602, 0.574681282043457 + ], + [ + 0.0037562220823019743, 0.004723121412098408, 0.28035104274749756, + 0.012803664430975914, 0.002808540826663375, 0.0016030226834118366, + 0.0016077299369499087, 0.026799745857715607 + ], + [ + 0.010778078809380531, 0.0033077062107622623, 0.010019906796514988, + 0.007183510344475508, 0.0016579522052779794, 0.0009568877285346389, + 0.0009443693561479449, 0.026288291439414024 + ], + [ + 0.04733191803097725, 0.018970994278788567, 0.01071301568299532, + 0.9060934782028198, 0.005917802918702364, 0.005035545211285353, + 0.006228873040527105, 0.008966443128883839 + ], + [ + 0.002796222921460867, 0.012405790388584137, 0.5131822824478149, + 0.07165513932704926, 0.0037333376239985228, 0.0023874433245509863, + 0.002634379081428051, 0.009200328961014748 + ], + [ + 0.008300071582198143, 0.003618015442043543, 0.012185129337012768, + 0.004840482491999865, 0.0025029657408595085, 0.0012542438926175237, + 0.0010433616116642952, 0.06829141825437546 + ], + [ + 0.011342223733663559, 0.004387284629046917, 0.007927621714770794, + 0.009962920099496841, 0.001706104027107358, 0.0008959750412032008, + 0.0009249066351912916, 0.017002930864691734 + ], + [ + 0.01529914140701294, 0.003765893168747425, 0.007152734324336052, + 0.007018750999122858, 0.0020147659815847874, 0.000946531246881932, + 0.0009364398429170251, 0.027150370180606842 + ], + [ + 0.9773837327957153, 0.015268188901245594, 0.009549940936267376, + 0.009744508191943169, 0.00498786149546504, 0.005123118869960308, + 0.006575216539204121, 0.020418094471096992 + ], + [ + 0.9767888784408569, 0.019299056380987167, 0.009953146800398827, + 0.008699998259544373, 0.004664226435124874, 0.004985435865819454, + 0.006263888906687498, 0.015779754146933556 + ], + [ + 0.9763739109039307, 0.019300034269690514, 0.010850567370653152, + 0.007862312719225883, 0.004617104772478342, 0.005025084596127272, + 0.006361216306686401, 0.01565021462738514 + ], + [ + 0.09475114941596985, 0.004144375678151846, 0.003442079294472933, + 0.008787229657173157, 0.001502313301898539, 0.0009589873952791095, + 0.0011371158761903644, 0.020854368805885315 + ], + [ + 0.03847936913371086, 0.004373945761471987, 0.004726479761302471, + 0.006207616534084082, 0.0016334792599081993, 0.0009732409962452948, + 0.0010101464577019215, 0.02155589684844017 + ], + [ + 0.9824935793876648, 0.017293009907007217, 0.009633757174015045, + 0.010956677608191967, 0.005564797669649124, 0.005668932106345892, + 0.007087952923029661, 0.025238078087568283 + ], + [ + 0.9751385450363159, 0.01883637346327305, 0.009107490070164204, + 0.00974255707114935, 0.0045736124739050865, 0.004684543237090111, + 0.005823461338877678, 0.015655402094125748 + ], + [ + 0.972524106502533, 0.01674414612352848, 0.009109711274504662, + 0.008387683890759945, 0.004344087094068527, 0.0044397697784006596, + 0.005560746416449547, 0.017042065039277077 + ], + [ + 0.9005312323570251, 0.0060760825872421265, 0.004219627007842064, + 0.007590696215629578, 0.004334719385951757, 0.0033609310630708933, + 0.0037694047205150127, 0.06674294173717499 + ], + [ + 0.20959758758544922, 0.009987308643758297, 0.006036617327481508, + 0.9178768992424011, 0.01012398675084114, 0.00710502453148365, + 0.008404775522649288, 0.05668610706925392 + ], + [ + 0.11886875331401825, 0.010674898512661457, 0.008322533220052719, + 0.9459357857704163, 0.01293387170881033, 0.008369955234229565, + 0.010107753798365593, 0.0824631080031395 + ], + [ + 0.023857099935412407, 0.005824489053338766, 0.006021629553288221, + 0.007940581999719143, 0.0015797278610989451, 0.0009298609802499413, + 0.0009641240467317402, 0.011920041404664516 + ], + [ + 0.022636914625763893, 0.06863640248775482, 0.03288957476615906, + 0.9349145889282227, 0.007641573436558247, 0.007351301610469818, + 0.009204176254570484, 0.006709030829370022 + ], + [ + 0.004267312120646238, 0.0023081921972334385, 0.010186562314629555, + 0.10128005594015121, 0.004626833833754063, 0.0020957940723747015, + 0.0018772734329104424, 0.40996336936950684 + ], + [ + 0.8485772013664246, 0.005407100543379784, 0.003526245476678014, + 0.026944028213620186, 0.005907413549721241, 0.004630755167454481, + 0.006144389975816011, 0.39433664083480835 + ], + [ + 0.008236004039645195, 0.015242429450154305, 0.0284435898065567, + 0.9290887117385864, 0.007697242312133312, 0.004852084908634424, + 0.005398884881287813, 0.015870744362473488 + ], + [ + 0.9775904417037964, 0.8073777556419373, 0.0689464583992958, + 0.04831147938966751, 0.021140847355127335, 0.044133368879556656, + 0.035297691822052, 0.006722880993038416 + ], + [ + 0.006988622713834047, 0.003685500705614686, 0.028757771477103233, + 0.008969159796833992, 0.0015175525331869721, 0.0008602043963037431, + 0.0009046993218362331, 0.01761375181376934 + ], + [ + 0.007170221768319607, 0.0028206217102706432, 0.008279886096715927, + 0.00733738299459219, 0.0029358824249356985, 0.0013125627301633358, + 0.0011354121379554272, 0.12913256883621216 + ], + [ + 0.05700590834021568, 0.002624162705615163, 0.00503351679071784, + 0.004811878316104412, 0.001987897092476487, 0.0010967416455969214, + 0.00103966414462775, 0.05905309319496155 + ], + [ + 0.008587258867919445, 0.004960223101079464, 0.011483282782137394, + 0.007672236766666174, 0.0021473723463714123, 0.0010993445757776499, + 0.0010156948119401932, 0.022881396114826202 + ], + [ + 0.0048813181929290295, 0.004922970663756132, 0.02174382470548153, + 0.012835115194320679, 0.0018289090367034078, 0.0009726088028401136, + 0.0009479201398789883, 0.014979247935116291 + ], + [ + 0.010713278315961361, 0.004812181461602449, 0.01098900567740202, + 0.0074848695658147335, 0.0017532043857499957, 0.0008975921082310379, + 0.0009350860491394997, 0.01609320007264614 + ], + [ + 0.0189472958445549, 0.0022920204792171717, 0.0052499533630907536, + 0.012288747355341911, 0.004157240968197584, 0.0018124673515558243, + 0.0017559187253937125, 0.45695117115974426 + ], + [ + 0.005030542612075806, 0.008470925502479076, 0.7023911476135254, + 0.009625325910747051, 0.0053747063502669334, 0.0033622460905462503, + 0.0032059194054454565, 0.08403091132640839 + ], + [ + 0.01117105595767498, 0.002374058123677969, 0.007107346318662167, + 0.007442558649927378, 0.0027034305967390537, 0.001224913401529193, + 0.0010362836765125394, 0.12144136428833008 + ], + [ + 0.011302579194307327, 0.007459386717528105, 0.01954708807170391, + 0.006509932689368725, 0.001633140491321683, 0.0009798504179343581, + 0.0010501056676730514, 0.008953449316322803 + ], + [ + 0.010446273721754551, 0.005010900087654591, 0.011206495575606823, + 0.005570677574723959, 0.001929469290189445, 0.00108762935269624, + 0.0010165354469791055, 0.018636176362633705 + ], + [ + 0.0333007276058197, 0.005092688836157322, 0.00455334410071373, + 0.010138079524040222, 0.0016301381401717663, 0.0009418704430572689, + 0.0010034292936325073, 0.013947085477411747 + ], + [ + 0.9821280241012573, 0.017449883744120598, 0.00992298312485218, + 0.008777274750173092, 0.0055213021114468575, 0.005468924529850483, + 0.007189790718257427, 0.03100375458598137 + ], + [ + 0.08579986542463303, 0.45395219326019287, 0.04066558554768562, + 0.01531267911195755, 0.005848083645105362, 0.007405467331409454, + 0.006066540721803904, 0.004517185967415571 + ], + [ + 0.13787204027175903, 0.8431532979011536, 0.1284618079662323, + 0.055224061012268066, 0.01815311796963215, 0.030115319415926933, + 0.020605934783816338, 0.006293881684541702 + ], + [ + 0.009955096989870071, 0.00963742844760418, 0.009823950007557869, + 0.01756494678556919, 0.0015651412541046739, 0.000988426967523992, + 0.0010414187563583255, 0.00985727272927761 + ], + [ + 0.008801661431789398, 0.012020917609333992, 0.013056713156402111, + 0.13104365766048431, 0.0016467644600197673, 0.001227130414918065, + 0.0014970960328355432, 0.0053172544576227665 + ], + [ + 0.003562038531526923, 0.010120008140802383, 0.04526464268565178, + 0.024078993126749992, 0.0020359810441732407, 0.001278961542993784, + 0.001246760250069201, 0.008677229285240173 + ], + [ + 0.13742423057556152, 0.7908906936645508, 0.12005621194839478, + 0.041446276009082794, 0.01247428823262453, 0.019188880920410156, + 0.014208252541720867, 0.004945479333400726 + ], + [ + 0.003901640186086297, 0.009056136943399906, 0.04611537605524063, + 0.018624842166900635, 0.001964427763596177, 0.0012003777083009481, + 0.0011739704059436917, 0.008579351007938385 + ], + [ + 0.9783485531806946, 0.014377448707818985, 0.006595549173653126, + 0.016069328412413597, 0.0049976990558207035, 0.004660132806748152, + 0.006281949579715729, 0.03527028486132622 + ], + [ + 0.264612078666687, 0.004771120380610228, 0.008079776540398598, + 0.012648681178689003, 0.012639557011425495, 0.007618315052241087, + 0.006590433418750763, 0.9206035733222961 + ], + [ + 0.08223392814397812, 0.004600840155035257, 0.00430525466799736, + 0.008834690786898136, 0.0014234376139938831, 0.0008598905988037586, + 0.0010183859849348664, 0.012251179665327072 + ], + [ + 0.013059987686574459, 0.008842402137815952, 0.030529087409377098, + 0.9055529832839966, 0.05444278568029404, 0.020777687430381775, + 0.019403981044888496, 0.8245904445648193 + ], + [ + 0.9881116151809692, 0.033392533659935, 0.01161368377506733, + 0.012724206782877445, 0.0062301279976964, 0.006983584724366665, + 0.008607178926467896, 0.02086910419166088 + ] + ], + target_column: [ + "event1", + "event2", + "event3", + "event4", + "event5", + "event6", + "event7", + "event8" + ], + task_type: DatasetTaskType.TextClassification, + true_y: [ + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 1, 0, 0, 0, 0, 1], + [0, 0, 0, 1, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 1, 0, 0, 0, 0, 0], + [0, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0], + [1, 1, 1, 0, 0, 0, 0, 0], + [0, 1, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1], + [1, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1], + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 1, 0, 0, 0, 0], + [1, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 1, 0, 0, 0, 0, 0, 0], + [0, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 0], + [1, 0, 0, 0, 0, 0, 0, 1], + [0, 0, 0, 0, 0, 0, 0, 0], + [0, 0, 1, 0, 0, 0, 0, 1], + [1, 0, 0, 0, 0, 0, 0, 0] + ] +}; diff --git a/libs/core-ui/src/lib/Interfaces/IDataset.ts b/libs/core-ui/src/lib/Interfaces/IDataset.ts index 04d75674ed..eaddc4c56c 100644 --- a/libs/core-ui/src/lib/Interfaces/IDataset.ts +++ b/libs/core-ui/src/lib/Interfaces/IDataset.ts @@ -8,13 +8,14 @@ export enum DatasetTaskType { Regression = "regression", Classification = "classification", ImageClassification = "image_classification", - TextClassification = "text_classification" + TextClassification = "text_classification", + MultilabelTextClassification = "multilabel_text_classification" } export interface IDataset { task_type: DatasetTaskType; - true_y: number[]; - predicted_y?: number[]; + true_y: number[] | number[][]; + predicted_y?: number[] | number[][]; probability_y?: number[][]; features: unknown[][]; feature_names: string[]; @@ -22,7 +23,7 @@ export interface IDataset { is_large_data_scenario?: boolean; use_entire_test_data?: boolean; class_names?: string[]; - target_column?: string; + target_column?: string | string[]; data_balance_measures?: IDataBalanceMeasures; feature_metadata?: IFeatureMetaData; images?: string[]; diff --git a/libs/core-ui/src/lib/Interfaces/IModelExplanationData.ts b/libs/core-ui/src/lib/Interfaces/IModelExplanationData.ts index 276317ce3c..dc6021f810 100644 --- a/libs/core-ui/src/lib/Interfaces/IModelExplanationData.ts +++ b/libs/core-ui/src/lib/Interfaces/IModelExplanationData.ts @@ -6,7 +6,7 @@ import { IPrecomputedExplanations } from "./ExplanationInterfaces"; export interface IModelExplanationData { modelClass?: ModelClass; method?: Method; - predictedY?: number[]; + predictedY?: number[] | number[][]; probabilityY?: number[][]; explanationMethod?: string; precomputedExplanations?: IPrecomputedExplanations; diff --git a/libs/core-ui/src/lib/Interfaces/IVisionModelExplanationData.ts b/libs/core-ui/src/lib/Interfaces/IVisionModelExplanationData.ts index a227f387b6..ccae251ec0 100644 --- a/libs/core-ui/src/lib/Interfaces/IVisionModelExplanationData.ts +++ b/libs/core-ui/src/lib/Interfaces/IVisionModelExplanationData.ts @@ -2,9 +2,9 @@ // Licensed under the MIT License. export interface IVisionListItem { - [key: string]: string | number | boolean; + [key: string]: string | number | boolean | string[]; image: string; - predictedY: string; - trueY: string; + predictedY: string | string[]; + trueY: string | string[]; index: number; } diff --git a/libs/core-ui/src/lib/Interfaces/VisionExplanationInterfaces.ts b/libs/core-ui/src/lib/Interfaces/VisionExplanationInterfaces.ts index 3605fbff6c..67ed3540e8 100644 --- a/libs/core-ui/src/lib/Interfaces/VisionExplanationInterfaces.ts +++ b/libs/core-ui/src/lib/Interfaces/VisionExplanationInterfaces.ts @@ -4,8 +4,8 @@ import { DatasetTaskType } from "./IDataset"; export interface IVisionExplanationDashboardData { task_type: DatasetTaskType; - true_y: number[]; - predicted_y: number[]; + true_y: number[] | number[][]; + predicted_y: number[] | number[][]; features?: unknown[][]; feature_names?: string[]; class_names: string[]; diff --git a/libs/core-ui/src/lib/util/DatasetUtils.ts b/libs/core-ui/src/lib/util/DatasetUtils.ts index 5d02f02124..ca8e1d42d1 100644 --- a/libs/core-ui/src/lib/util/DatasetUtils.ts +++ b/libs/core-ui/src/lib/util/DatasetUtils.ts @@ -41,10 +41,28 @@ export function constructRows( tableRow.push(colors[i]); } if (jointDataset.hasTrueY) { - pushRowData(tableRow, JointDataset.TrueYLabel, jointDataset, row); + if (jointDataset.numLabels > 1) { + pushMultilabelRowData( + tableRow, + JointDataset.TrueYLabel, + jointDataset, + row + ); + } else { + pushRowData(tableRow, JointDataset.TrueYLabel, jointDataset, row); + } } if (jointDataset.hasPredictedY) { - pushRowData(tableRow, JointDataset.PredictedYLabel, jointDataset, row); + if (jointDataset.numLabels > 1) { + pushMultilabelRowData( + tableRow, + JointDataset.PredictedYLabel, + jointDataset, + row + ); + } else { + pushRowData(tableRow, JointDataset.PredictedYLabel, jointDataset, row); + } } tableRow.push(...data); rows.push(tableRow); @@ -152,3 +170,28 @@ function pushRowData( tableRow.push(row[property]); } } + +function pushMultilabelRowData( + tableRow: any[], + property: string, + jointDataset: JointDataset, + row: { [key: string]: number } +): void { + const values = []; + for (let i = 0; i < jointDataset.numLabels; i++) { + const labelProp = property + i.toString(); + const categories = jointDataset.metaDict[labelProp].sortedCategoricalValues; + if (jointDataset.metaDict[labelProp].isCategorical && categories) { + const value = categories[row[labelProp]]; + if (value) { + values.push(value); + } + } else { + const value = row[labelProp]; + if (value) { + values.push(value); + } + } + } + tableRow.push(values.join(",")); +} diff --git a/libs/core-ui/src/lib/util/JointDataset.ts b/libs/core-ui/src/lib/util/JointDataset.ts index 17b26ad1dd..f206080ced 100644 --- a/libs/core-ui/src/lib/util/JointDataset.ts +++ b/libs/core-ui/src/lib/util/JointDataset.ts @@ -64,6 +64,7 @@ export class JointDataset { public predictionClassCount = 0; public datasetRowCount = 0; public localExplanationFeatureCount = 0; + public numLabels = 1; // these properties should only be accessed by Cohort class, // which enables independent filtered views of this data @@ -137,30 +138,14 @@ export class JointDataset { this.hasDataset = true; } if (args.predictedY) { - this.initializeDataDictIfNeeded(args.predictedY); - args.predictedY.forEach((val, index) => { - if (this.dataDict) { - this.dataDict[index][JointDataset.PredictedYLabel] = val; - } - }); - this.metaDict[JointDataset.PredictedYLabel] = { - abbridgedLabel: localization.Interpret.ExplanationScatter.predictedY, - category: ColumnCategories.Outcome, - isCategorical: args.metadata.modelType !== ModelTypes.Regression, - label: localization.Interpret.ExplanationScatter.predictedY, - sortedCategoricalValues: - args.metadata.modelType !== ModelTypes.Regression - ? args.metadata.classNames - : undefined, - treatAsCategorical: args.metadata.modelType !== ModelTypes.Regression - }; - if (args.metadata.modelType === ModelTypes.Regression) { - this.metaDict[JointDataset.PredictedYLabel].featureRange = { - max: _.max(args.predictedY) || 0, - min: _.min(args.predictedY) || 0, - rangeType: RangeTypes.Numeric - }; - } + this.updateMetaDataDict( + args.predictedY, + args.metadata, + JointDataset.PredictedYLabel, + localization.Interpret.ExplanationScatter.predictedY, + localization.Interpret.ExplanationScatter.predictedY, + args.targetColumn + ); this.hasPredictedY = true; } if (args.predictedProbabilities) { @@ -204,30 +189,14 @@ export class JointDataset { } } if (args.trueY) { - this.initializeDataDictIfNeeded(args.trueY); - args.trueY.forEach((val, index) => { - if (this.dataDict) { - this.dataDict[index][JointDataset.TrueYLabel] = val; - } - }); - this.metaDict[JointDataset.TrueYLabel] = { - abbridgedLabel: localization.Interpret.ExplanationScatter.trueY, - category: ColumnCategories.Outcome, - isCategorical: args.metadata.modelType !== ModelTypes.Regression, - label: localization.Interpret.ExplanationScatter.trueY, - sortedCategoricalValues: - args.metadata.modelType !== ModelTypes.Regression - ? args.metadata.classNames - : undefined, - treatAsCategorical: args.metadata.modelType !== ModelTypes.Regression - }; - if (args.metadata.modelType === ModelTypes.Regression) { - this.metaDict[JointDataset.TrueYLabel].featureRange = { - max: _.max(args.trueY) || 0, - min: _.min(args.trueY) || 0, - rangeType: RangeTypes.Numeric - }; - } + this.updateMetaDataDict( + args.trueY, + args.metadata, + JointDataset.TrueYLabel, + localization.Interpret.ExplanationScatter.trueY, + localization.Interpret.ExplanationScatter.trueY, + args.targetColumn + ); this.hasTrueY = true; } // include error columns if applicable @@ -677,6 +646,74 @@ export class JointDataset { return undefined; } + private updateMetaDataDict( + values: number[] | number[][], + metadata: IExplanationModelMetadata, + labelColName: string, + abbridgedLabel: string, + label: string, + targetColumn?: string | string[] + ): void { + this.initializeDataDictIfNeeded(values); + values.forEach((val, index) => { + if (Array.isArray(val)) { + this.numLabels = val.length; + val.forEach((subVal, subIndex) => { + if (this.dataDict) { + this.dataDict[index][labelColName + subIndex.toString()] = subVal; + } + }); + } else { + if (this.dataDict) { + this.dataDict[index][labelColName] = val; + } + } + }); + for (let i = 0; i < this.numLabels; i++) { + let labelColNameKey = labelColName; + let abbridgedLabelValue = abbridgedLabel; + let labelValue = label; + let singleLabelValues: number[] = []; + if (this.numLabels > 1) { + const labelIdxStr = i.toString(); + labelColNameKey += labelIdxStr; + abbridgedLabelValue += labelIdxStr; + labelValue += labelIdxStr; + // check if values is a 2d array + const indexedValues = values[i]; + if (Array.isArray(indexedValues)) { + singleLabelValues = indexedValues; + } + } else { + if (!Array.isArray(values)) { + singleLabelValues = values; + } + } + let categoricalValues = + metadata.modelType !== ModelTypes.Regression + ? metadata.classNames + : undefined; + if (this.numLabels > 1 && Array.isArray(targetColumn)) { + categoricalValues = ["", targetColumn[i]]; + } + this.metaDict[labelColNameKey] = { + abbridgedLabel: abbridgedLabelValue, + category: ColumnCategories.Outcome, + isCategorical: metadata.modelType !== ModelTypes.Regression, + label: labelValue, + sortedCategoricalValues: categoricalValues, + treatAsCategorical: metadata.modelType !== ModelTypes.Regression + }; + if (metadata.modelType === ModelTypes.Regression) { + this.metaDict[labelColNameKey].featureRange = { + max: _.max(singleLabelValues) || 0, + min: _.min(singleLabelValues) || 0, + rangeType: RangeTypes.Numeric + }; + } + } + } + private initializeDataDictIfNeeded(arr: any[]): void { if (arr === undefined) { return; diff --git a/libs/core-ui/src/lib/util/JointDatasetUtils.ts b/libs/core-ui/src/lib/util/JointDatasetUtils.ts index 7a1bc8d19e..2b69c7307c 100644 --- a/libs/core-ui/src/lib/util/JointDatasetUtils.ts +++ b/libs/core-ui/src/lib/util/JointDatasetUtils.ts @@ -14,14 +14,15 @@ import { AxisTypes } from "./IGenericChartProps"; export interface IJointDatasetArgs { dataset?: any[][]; - predictedY?: number[]; + predictedY?: number[] | number[][]; predictedProbabilities?: number[][]; - trueY?: number[]; + trueY?: number[] | number[][]; localExplanations?: | IMultiClassLocalFeatureImportance | ISingleClassLocalFeatureImportance; metadata: IExplanationModelMetadata; featureMetaData?: IFeatureMetaData; + targetColumn?: string | string[]; } export enum ColumnCategories { diff --git a/libs/counterfactuals/src/util/getOriginalData.ts b/libs/counterfactuals/src/util/getOriginalData.ts index 70889a8b89..50bb0cea27 100644 --- a/libs/counterfactuals/src/util/getOriginalData.ts +++ b/libs/counterfactuals/src/util/getOriginalData.ts @@ -25,7 +25,10 @@ export function getOriginalData( featureNames.forEach((f, index) => { data[f] = dataPoint[index]; }); - const targetLabel = dataset.target_column || "y"; + const targetColumn = Array.isArray(dataset.target_column) + ? dataset.target_column?.[0] + : dataset.target_column; + const targetLabel = targetColumn || "y"; data[targetLabel] = row[JointDataset.TrueYLabel]; return data; } diff --git a/libs/interpret-vision/src/lib/VisionExplanationDashboard/Controls/DataCharacteristics.tsx b/libs/interpret-vision/src/lib/VisionExplanationDashboard/Controls/DataCharacteristics.tsx index 99390b394d..5e0f02c1d3 100644 --- a/libs/interpret-vision/src/lib/VisionExplanationDashboard/Controls/DataCharacteristics.tsx +++ b/libs/interpret-vision/src/lib/VisionExplanationDashboard/Controls/DataCharacteristics.tsx @@ -176,20 +176,22 @@ export class DataCharacteristics extends React.Component< ): React.ReactElement => { const imageDim = this.props.imageDim; const classNames = dataCharacteristicsStyles(); + const predictedY = item?.predictedY; const indicatorStyle = mergeStyles( classNames.indicator, { width: imageDim }, - item?.predictedY === item?.trueY + predictedY === item?.trueY ? classNames.successIndicator : classNames.errorIndicator ); + const alt = Array.isArray(predictedY) ? predictedY.join(",") : predictedY; return !item ? (
) : ( {item?.predictedY} @@ -93,7 +95,7 @@ export class DataCharacteristicsRow extends React.Component { const fieldNames = this.props.otherMetadataFieldNames; const metadata: Array> = []; fieldNames.forEach((fieldName) => { + const itemField = item[fieldName]; + const itemValue = Array.isArray(itemField) + ? itemField.join(",") + : itemField; if (item[fieldName]) { - metadata.push([fieldName, item[fieldName]]); + metadata.push([fieldName, itemValue]); } }); this.setState({ item, metadata }); @@ -72,8 +76,12 @@ export class Flyout extends React.Component { const fieldNames = this.props.otherMetadataFieldNames; const metadata: Array> = []; fieldNames.forEach((fieldName) => { + const itemField = item[fieldName]; + const itemValue = Array.isArray(itemField) + ? itemField.join(",") + : itemField; if (item[fieldName]) { - metadata.push([fieldName, item[fieldName]]); + metadata.push([fieldName, itemValue]); } }); this.setState({ diff --git a/libs/interpret-vision/src/lib/VisionExplanationDashboard/Controls/ImageList.tsx b/libs/interpret-vision/src/lib/VisionExplanationDashboard/Controls/ImageList.tsx index 1b026a180c..260bcd3718 100644 --- a/libs/interpret-vision/src/lib/VisionExplanationDashboard/Controls/ImageList.tsx +++ b/libs/interpret-vision/src/lib/VisionExplanationDashboard/Controls/ImageList.tsx @@ -98,6 +98,8 @@ export class ImageList extends React.Component< if (!item) { return; } + const predictedY = item?.predictedY; + const alt = Array.isArray(predictedY) ? predictedY.join(",") : predictedY; return ( {item?.predictedY} + row.map((index) => classNames[index]) + ); + } + return (labels as number[]).map((index) => classNames[index]); +} + export function preprocessData( props: IVisionExplanationDashboardProps ): @@ -47,13 +60,9 @@ export function preprocessData( const successInstances: IVisionListItem[] = []; const classNames = props.dataSummary.class_names; - const predictedY = dataSummary.predicted_y.map((index) => { - return classNames[index]; - }); + const predictedY = mapClassNames(dataSummary.predicted_y, classNames); - const trueY = dataSummary.true_y.map((index) => { - return classNames[index]; - }); + const trueY = mapClassNames(dataSummary.true_y, classNames); const features = dataSummary.features?.map((featuresArr) => { return featuresArr[0] as number; diff --git a/libs/interpret-vision/src/lib/VisionExplanationDashboard/utils/getFilteredData.ts b/libs/interpret-vision/src/lib/VisionExplanationDashboard/utils/getFilteredData.ts index 1ef388f8ee..6989fa05df 100644 --- a/libs/interpret-vision/src/lib/VisionExplanationDashboard/utils/getFilteredData.ts +++ b/libs/interpret-vision/src/lib/VisionExplanationDashboard/utils/getFilteredData.ts @@ -7,9 +7,22 @@ export function getFilteredDataFromSearch( searchVal: string, items: IVisionListItem[] ): IVisionListItem[] { - return items.filter( - (item) => - item.predictedY.toLowerCase().includes(searchVal) || - item.trueY.toLowerCase().includes(searchVal) - ); + return items.filter((item) => { + const predYIncludesSearchVal = includesSearchVal( + item.predictedY, + searchVal + ); + const trueYIncludesSearchVal = includesSearchVal(item.trueY, searchVal); + return predYIncludesSearchVal || trueYIncludesSearchVal; + }); +} + +export function includesSearchVal( + labels: string | string[], + searchVal: string +): boolean { + if (Array.isArray(labels)) { + return labels.some((label) => label.toLowerCase().includes(searchVal)); + } + return labels.toLowerCase().includes(searchVal); } diff --git a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Context/buildModelAssessmentContext.ts b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Context/buildModelAssessmentContext.ts index 935af6f1a8..7f7f0e4618 100644 --- a/libs/model-assessment/src/lib/ModelAssessmentDashboard/Context/buildModelAssessmentContext.ts +++ b/libs/model-assessment/src/lib/ModelAssessmentDashboard/Context/buildModelAssessmentContext.ts @@ -56,6 +56,7 @@ export function buildInitialModelAssessmentContext( metadata: modelMetadata, predictedProbabilities: props.dataset.probability_y, predictedY: props.dataset.predicted_y, + targetColumn: props.dataset.target_column, trueY: props.dataset.true_y }); const globalProps = buildGlobalProperties(