Skip to content

Commit

Permalink
Make multi label prediction scores visible in the data table column.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 570758611
  • Loading branch information
bdu91 authored and LIT team committed Oct 4, 2023
1 parent e7115ab commit 8a3f366
Showing 1 changed file with 35 additions and 2 deletions.
37 changes: 35 additions & 2 deletions lit_nlp/client/services/data_service.ts
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
import {action, computed, observable, reaction} from 'mobx';

import {BINARY_NEG_POS, ColorRange} from '../lib/colors';
import {BooleanLitType, CategoryLabel, GeneratedText, GeneratedTextCandidates, LitType, MulticlassPreds, RegressionScore, Scalar} from '../lib/lit_types';
import {BooleanLitType, CategoryLabel, GeneratedText, GeneratedTextCandidates, LitType, MulticlassPreds, RegressionScore, Scalar, SparseMultilabelPreds} from '../lib/lit_types';
import {ClassificationResults, IndexedInput, RegressionResults} from '../lib/types';
import {createLitType, findSpecKeys, isLitSubtype, mapsContainSame} from '../lib/utils';

Expand Down Expand Up @@ -68,6 +68,8 @@ export const GEN_TEXT_CANDS_SOURCE_PREFIX = 'GeneratedTextCandidates';
export const REGRESSION_SOURCE_PREFIX = 'Regression';
/** Column source prefix for columns from scalar model outputs. */
export const SCALAR_SOURCE_PREFIX = 'Scalar';
/** Column source prefix for columns from multilabel model outputs. */
export const MULTILABEL_SOURCE_PREFIX = 'Multilabel';

/**
* Data service singleton, responsible for maintaining columns of computed data
Expand Down Expand Up @@ -109,7 +111,7 @@ export class DataService extends LitService {
}
}, {fireImmediately: true});

// Run other preiction interpreters when necessary.
// Run other prediction interpreters when necessary.
const getPredictionInputs =
() => [this.appState.currentInputData, this.appState.currentModels];
reaction(getPredictionInputs, () => {
Expand All @@ -124,6 +126,7 @@ export class DataService extends LitService {
this.runGeneratedTextPreds(model, this.appState.currentInputData);
this.runRegression(model, this.appState.currentInputData);
this.runScalarPreds(model, this.appState.currentInputData);
this.runMultiLabelPreds(model, this.appState.currentInputData);
}
}, {fireImmediately: true});

Expand Down Expand Up @@ -301,6 +304,36 @@ export class DataService extends LitService {
}
}

/**
* Run multi label predictions and store results in data service.
*/
private async runMultiLabelPreds(model: string, data: IndexedInput[]) {
const {output} = this.appState.getModelSpec(model);
if (findSpecKeys(output, SparseMultilabelPreds).length === 0) {
return;
}

const multiLabelPredsPromise = this.apiService.getPreds(
data, model, this.appState.currentDataset, [SparseMultilabelPreds]);
const preds = await multiLabelPredsPromise;

// Add multi label prediction results as new column to the data service.
if (preds == null || preds.length === 0) {
return;
}
const multiLabelPredKeys = Object.keys(preds[0]);
for (const key of multiLabelPredKeys) {
const scoreFeatName = this.getColumnName(model, key);
const scores = preds.map(pred => pred[key]);
// TODO(b/303457849): maybe possible to directly use the data type from
// the output spec rather than creating a new one.
const dataType = createLitType(SparseMultilabelPreds);
const source = `${MULTILABEL_SOURCE_PREFIX}:${model}`;
this.addColumnFromList(
scores, data, key, scoreFeatName, dataType, source);
}
}

@action
async setValuesForNewDatapoints(datapoints: IndexedInput[]) {
// When new datapoints are created, set their data values for each
Expand Down

0 comments on commit 8a3f366

Please sign in to comment.