Skip to content

Commit 8fbebca

Browse files
committed
Multiplex scalar fetch: one tag, many runs
1 parent e7140fe commit 8fbebca

File tree

4 files changed

+84
-38
lines changed

4 files changed

+84
-38
lines changed

tensorboard/components/tf_dashboard_common/data-loader-behavior.ts

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -81,13 +81,22 @@ namespace tf_dashboard_common {
8181
requestData: {
8282
type: Function,
8383
value: function() {
84-
return (datum) =>
85-
this.requestManager.request(this.getDataLoadUrl(datum));
84+
return (datum) => {
85+
const dataLoadUrl = this.getDataLoadUrl(datum);
86+
var url;
87+
var postdata = {};
88+
if (Array.isArray(dataLoadUrl)) {
89+
[url, postdata] = dataLoadUrl;
90+
} else {
91+
url = dataLoadUrl;
92+
}
93+
return this.requestManager.request(url, postdata);
94+
};
8695
},
8796
},
8897

8998
// A function that takes a datum and returns a string URL for fetching
90-
// data.
99+
// data. Optionally, returns a tuple of (URL, postdata).
91100
getDataLoadUrl: Function,
92101

93102
dataLoading: {

tensorboard/plugins/scalar/scalars_plugin.py

Lines changed: 40 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424

2525
import collections
2626
import csv
27+
import json
2728

2829
import six
2930
from six import StringIO
@@ -95,29 +96,37 @@ def index_impl(self, ctx, experiment=None):
9596
}
9697
return result
9798

98-
def scalars_impl(self, ctx, tag, run, experiment, output_format):
99+
def scalars_impl(self, ctx, tag, runs, experiment, output_format):
99100
"""Result of the form `(body, mime_type)`."""
100101
all_scalars = self._data_provider.read_scalars(
101102
ctx,
102103
experiment_id=experiment,
103104
plugin_name=metadata.PLUGIN_NAME,
104105
downsample=self._downsample_to,
105-
run_tag_filter=provider.RunTagFilter(runs=[run], tags=[tag]),
106+
run_tag_filter=provider.RunTagFilter(runs=runs, tags=[tag]),
106107
)
107-
scalars = all_scalars.get(run, {}).get(tag, None)
108-
if scalars is None:
109-
raise errors.NotFoundError(
110-
"No scalar data for run=%r, tag=%r" % (run, tag)
111-
)
112-
values = [(x.wall_time, x.step, x.value) for x in scalars]
108+
result = {}
109+
for run in all_scalars:
110+
scalars = all_scalars.get(run, {}).get(tag, [])
111+
# if scalars is None:
112+
# raise errors.NotFoundError(
113+
# "No scalar data for run=%r, tag=%r" % (run, tag)
114+
# )
115+
values = [(x.wall_time, x.step, x.value) for x in scalars]
116+
result[run] = values
117+
113118
if output_format == OutputFormat.CSV:
119+
if(runs.length > 1):
120+
raise errors.InvalidArgumentError(
121+
"Not implemented: Return CSV data for more than one run "
122+
"at a time.")
114123
string_io = StringIO()
115124
writer = csv.writer(string_io)
116125
writer.writerow(["Wall time", "Step", "Value"])
117126
writer.writerows(values)
118127
return (string_io.getvalue(), "text/csv")
119128
else:
120-
return (values, "application/json")
129+
return (json.dumps(result), "application/json")
121130

122131
@wrappers.Request.application
123132
def tags_route(self, request):
@@ -129,12 +138,31 @@ def tags_route(self, request):
129138
@wrappers.Request.application
130139
def scalars_route(self, request):
131140
"""Given a tag and single run, return array of ScalarEvents."""
132-
tag = request.args.get("tag")
133-
run = request.args.get("run")
141+
if request.method == "GET":
142+
tag = request.args.get("tag")
143+
run = request.args.get("run")
144+
runs = [run]
145+
else:
146+
tag = request.form["tag"]
147+
json_runs = request.form["runs"]
148+
try:
149+
runs = json.loads(json_runs)
150+
except Exception as e: # pylint: disable=broad-except
151+
# Different JSON libs raise different exceptions, so we just do a
152+
# catch-all here. This problem is complicated by how Tensorboard might be
153+
# run in many different environments, as it is open-source.
154+
# TODO(@caisq, @chihuahua): Create platform-dependent adapter to catch
155+
# specific types of exceptions, instead of the broad catching here.
156+
response = ("Could not decode runs JSON string %r: %s") % (
157+
json_runs,
158+
e,
159+
)
160+
return http_util.Respond(request, response, "text/plain", code=400)
161+
134162
ctx = plugin_util.context(request.environ)
135163
experiment = plugin_util.experiment_id(request.environ)
136164
output_format = request.args.get("format")
137165
(body, mime_type) = self.scalars_impl(
138-
ctx, tag, run, experiment, output_format
166+
ctx, tag, runs, experiment, output_format
139167
)
140168
return http_util.Respond(request, body, mime_type)

tensorboard/plugins/scalar/tf_scalar_dashboard/tf-scalar-card.html

Lines changed: 30 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -45,8 +45,8 @@
4545
<tf-line-chart-data-loader
4646
active="[[active]]"
4747
color-scale="[[_getColorScale(colorScale)]]"
48-
data-series="[[_getDataSeries(dataToLoad.*)]]"
49-
data-to-load="[[dataToLoad]]"
48+
data-series="[[_getDataSeries(tagAndRuns.series)]]"
49+
data-to-load="[[_asSingletonArray(tagAndRuns)]]"
5050
get-data-load-name="[[_getDataLoadName]]"
5151
get-data-load-url="[[getDataLoadUrl]]"
5252
request-data="[[requestData]]"
@@ -102,7 +102,7 @@
102102
<template is="dom-if" if="[[showDownloadLinks]]">
103103
<div class="download-links">
104104
<tf-downloader
105-
runs="[[_runsFromData(dataToLoad)]]"
105+
runs="[[_runsFromData(tagAndRuns.series)]]"
106106
tag="[[tag]]"
107107
url-fn="[[_downloadUrlFn]]"
108108
></tf-downloader>
@@ -191,7 +191,9 @@
191191
/**
192192
* @type {Array<Object>}
193193
*/
194-
dataToLoad: Array,
194+
// dataToLoad: Array,
195+
196+
tagAndRuns: Object,
195197

196198
/**
197199
* @type {vz_chart_helpers.XType}
@@ -218,14 +220,17 @@
218220
type: Object,
219221
value: function() {
220222
return (scalarChart, datum, data) => {
221-
const formattedData = data.map((datum) => ({
222-
wall_time: new Date(datum[0] * 1000),
223-
step: datum[1],
224-
scalar: datum[2],
225-
}));
226-
const name = this._getSeriesNameFromDatum(datum);
227-
scalarChart.setSeriesMetadata(name, datum);
228-
scalarChart.setSeriesData(name, formattedData);
223+
for (const run in data) {
224+
const points = data[run];
225+
const formattedData = points.map((point) => ({
226+
wall_time: new Date(point[0] * 1000),
227+
step: point[1],
228+
scalar: point[2],
229+
}));
230+
const name = this._getSeriesNameForRun(run, datum.experiment);
231+
scalarChart.setSeriesMetadata(name, {run, tag: datum.tag});
232+
scalarChart.setSeriesData(name, formattedData);
233+
}
229234
scalarChart.commitChanges();
230235
};
231236
},
@@ -235,14 +240,12 @@
235240
getDataLoadUrl: {
236241
type: Function,
237242
value: function() {
238-
return ({tag, run}) => {
239-
return tf_backend
243+
return ({tag, runs}) => {
244+
const url = tf_backend
240245
.getRouter()
241-
.pluginRoute(
242-
'scalars',
243-
'/scalars',
244-
new URLSearchParams({tag, run})
245-
);
246+
.pluginRoute('scalars', '/scalars');
247+
const postdata = {tag, runs: JSON.stringify(runs)};
248+
return [url, postdata];
246249
};
247250
},
248251
},
@@ -266,7 +269,7 @@
266269
_getDataLoadName: {
267270
type: Function,
268271
value: function() {
269-
return (datum) => this._getSeriesNameFromDatum(datum);
272+
return (datum) => datum.tag;
270273
},
271274
},
272275

@@ -325,10 +328,12 @@
325328
return data.map((datum) => datum.run);
326329
},
327330
_getDataSeries() {
328-
return this.dataToLoad.map((d) => this._getSeriesNameFromDatum(d));
331+
return this.tagAndRuns.series.map((d) =>
332+
this._getSeriesNameForRun(d.run, d.experiment)
333+
);
329334
},
330335
// name is a stable identifier for a series.
331-
_getSeriesNameFromDatum({run, experiment = {name: '_default'}}) {
336+
_getSeriesNameForRun(run, experiment = {name: '_default'}) {
332337
return JSON.stringify([experiment.name, run]);
333338
},
334339
// title is a visible string of a series for the UI.
@@ -348,6 +353,9 @@
348353
},
349354
};
350355
},
356+
_asSingletonArray(x) {
357+
return [x];
358+
},
351359
});
352360
</script>
353361
</dom-module>

tensorboard/plugins/scalar/tf_scalar_dashboard/tf-scalar-dashboard.html

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ <h3>No scalar data was found.</h3>
162162
<template>
163163
<tf-scalar-card
164164
active="[[active]]"
165-
data-to-load="[[item.series]]"
165+
tag-and-runs="[[item]]"
166166
ignore-y-outliers="[[_ignoreYOutliers]]"
167167
multi-experiments="[[_getMultiExperiments(dataSelection)]]"
168168
request-manager="[[_requestManager]]"
@@ -361,6 +361,7 @@ <h3>No scalar data was found.</h3>
361361
categories.forEach((category) => {
362362
category.items = category.items.map((item) => ({
363363
tag: item.tag,
364+
runs: item.runs,
364365
series: item.runs.map((run) => ({run, tag: item.tag})),
365366
}));
366367
});

0 commit comments

Comments
 (0)