This repository has been archived by the owner on Sep 3, 2022. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 79
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Move TensorBoard and TensorFlow Events UI rendering to Python functio…
…n to deprecate magic. (#163) * Update feature slice view UI. Added Slices Overview. * Move TensorBoard and TensorFlow Events UI rendering to Python function to deprecate magic. Use matplotlib for tf events plotting so it can display well in static HTML pages (such as github). Improve TensorFlow Events list/get APIs. * Follow up on CR comments.
- Loading branch information
Showing
3 changed files
with
111 additions
and
95 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,107 +1,127 @@ | ||
# Copyright 2016 Google Inc. All rights reserved. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except | ||
# in compliance with the License. You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software distributed under the License | ||
# is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express | ||
# or implied. See the License for the specific language governing permissions and limitations under | ||
# the License. | ||
|
||
"""Implements Cloud ML Summary wrapper.""" | ||
|
||
import datetime | ||
import fnmatch | ||
import glob | ||
import google.cloud.ml as ml | ||
import matplotlib.pyplot as plt | ||
import os | ||
import pandas as pd | ||
from tensorflow.core.util import event_pb2 | ||
from tensorflow.python.lib.io import tf_record | ||
|
||
import datalab.storage as storage | ||
|
||
|
||
class Summary(object): | ||
"""Represents TensorFlow summary events from files under a directory.""" | ||
"""Represents TensorFlow summary events from files under specified directories.""" | ||
|
||
def __init__(self, path): | ||
def __init__(self, paths): | ||
"""Initializes an instance of a Summary. | ||
Args: | ||
path: the path of the directory which holds TensorFlow events files. | ||
Can be local path or GCS path. | ||
path: a list of paths to directories which hold TensorFlow events files. | ||
Can be local path or GCS paths. Wild cards allowed. | ||
""" | ||
self._path = path | ||
|
||
def _get_events_files(self): | ||
if self._path.startswith('gs://'): | ||
storage._api.Api.verify_permitted_to_read(self._path) | ||
bucket, prefix = storage._bucket.parse_name(self._path) | ||
items = storage.Items(bucket, prefix, None) | ||
filtered_list = [item.uri for item in items if os.path.basename(item.uri).find('tfevents')] | ||
return filtered_list | ||
else: | ||
path_pattern = os.path.join(self._path, '*tfevents*') | ||
return glob.glob(path_pattern) | ||
self._paths = [paths] if isinstance(paths, basestring) else paths | ||
|
||
def _glob_events_files(self, paths): | ||
event_files = [] | ||
for path in paths: | ||
if path.startswith('gs://'): | ||
event_files += ml.util._file.glob_files(os.path.join(path, '*.tfevents.*')) | ||
else: | ||
dirs = ml.util._file.glob_files(path) | ||
for dir in dirs: | ||
for root, _, filenames in os.walk(dir): | ||
for filename in fnmatch.filter(filenames, '*.tfevents.*'): | ||
event_files.append(os.path.join(root, filename)) | ||
return event_files | ||
|
||
def list_events(self): | ||
"""List all scalar events in the directory. | ||
Returns: | ||
A set of unique event tags. | ||
A dictionary. Key is the name of a event. Value is a set of dirs that contain that event. | ||
""" | ||
event_tags = set() | ||
for event_file in self._get_events_files(): | ||
event_dir_dict = {} | ||
for event_file in self._glob_events_files(self._paths): | ||
dir = os.path.dirname(event_file) | ||
for record in tf_record.tf_record_iterator(event_file): | ||
event = event_pb2.Event.FromString(record) | ||
if event.summary is None or event.summary.value is None: | ||
continue | ||
for value in event.summary.value: | ||
if value.simple_value is None: | ||
if value.simple_value is None or value.tag is None: | ||
continue | ||
if value.tag is not None and value.tag not in event_tags: | ||
event_tags.add(value.tag) | ||
return event_tags | ||
if not value.tag in event_dir_dict: | ||
event_dir_dict[value.tag] = set() | ||
event_dir_dict[value.tag].add(dir) | ||
return event_dir_dict | ||
|
||
|
||
def get_events(self, event_name): | ||
"""Get all events of a certain tag. | ||
def get_events(self, event_names): | ||
"""Get all events as pandas DataFrames given a list of names. | ||
Args: | ||
event_name: the tag of event to look for. | ||
event_names: A list of events to get. | ||
Returns: | ||
A tuple. First is a list of {time_span, event_name}. Second is a list of {step, event_name}. | ||
Raises: | ||
Exception if event start time cannot be found | ||
A list with the same length as event_names. Each element is a dictionary | ||
{dir1: DataFrame1, dir2: DataFrame2, ...}. | ||
Multiple directories may contain events with the same name, but they are different | ||
events (i.e. 'loss' under trains_set/, and 'loss' under eval_set/.) | ||
""" | ||
events_time = [] | ||
events_step = [] | ||
event_start_time = None | ||
for event_file in self._get_events_files(): | ||
for record in tf_record.tf_record_iterator(event_file): | ||
event = event_pb2.Event.FromString(record) | ||
if event.file_version is not None: | ||
# first event in the file. | ||
time = datetime.datetime.fromtimestamp(event.wall_time) | ||
if event_start_time is None or event_start_time > time: | ||
event_start_time = time | ||
event_names = [event_names] if isinstance(event_names, basestring) else event_names | ||
|
||
if event.summary is None or event.summary.value is None: | ||
continue | ||
for value in event.summary.value: | ||
if value.simple_value is None or value.tag is None: | ||
all_events = self.list_events() | ||
dirs_to_look = set() | ||
for event, dirs in all_events.iteritems(): | ||
if event in event_names: | ||
dirs_to_look.update(dirs) | ||
|
||
ret_events = [dict() for i in range(len(event_names))] | ||
for dir in dirs_to_look: | ||
for event_file in self._glob_events_files([dir]): | ||
for record in tf_record.tf_record_iterator(event_file): | ||
event = event_pb2.Event.FromString(record) | ||
if event.summary is None or event.wall_time is None or event.summary.value is None: | ||
continue | ||
if value.tag == event_name: | ||
if event.wall_time is not None: | ||
time = datetime.datetime.fromtimestamp(event.wall_time) | ||
events_time.append({'time': time, event_name: value.simple_value}) | ||
if event.step is not None: | ||
events_step.append({'step': event.step, event_name: value.simple_value}) | ||
if event_start_time is None: | ||
raise Exception('Empty or invalid TF events file. Cannot find event start time.') | ||
for event in events_time: | ||
event['time'] = event['time'] - event_start_time # convert time to timespan | ||
events_time = sorted(events_time, key=lambda k: k['time']) | ||
events_step = sorted(events_step, key=lambda k: k['step']) | ||
return events_time, events_step | ||
|
||
event_time = datetime.datetime.fromtimestamp(event.wall_time) | ||
for value in event.summary.value: | ||
if value.tag not in event_names or value.simple_value is None: | ||
continue | ||
|
||
index = event_names.index(value.tag) | ||
dir_event_dict = ret_events[index] | ||
if dir not in dir_event_dict: | ||
dir_event_dict[dir] = pd.DataFrame( | ||
[[event_time, event.step, value.simple_value]], | ||
columns=['time', 'step', 'value']) | ||
else: | ||
df = dir_event_dict[dir] | ||
# Append a row. | ||
df.loc[len(df)] = [event_time, event.step, value.simple_value] | ||
|
||
for dir_event_dict in ret_events: | ||
for df in dir_event_dict.values(): | ||
df.sort_values(by=['time'], inplace=True) | ||
|
||
return ret_events | ||
|
||
def plot(self, event_names, x_axis='step'): | ||
"""Plots a list of events. Each event (a dir+event_name) is represetented as a line | ||
in the graph. | ||
Args: | ||
event_names: A list of events to plot. Each event_name may correspond to multiple events, | ||
each in a different directory. | ||
x_axis: whether to use step or time as x axis. | ||
""" | ||
event_names = [event_names] if isinstance(event_names, basestring) else event_names | ||
events_list = self.get_events(event_names) | ||
for event_name, dir_event_dict in zip(event_names, events_list): | ||
for dir, df in dir_event_dict.iteritems(): | ||
label = event_name + ':' + dir | ||
x_column = df['step'] if x_axis == 'step' else df['time'] | ||
plt.plot(x_column, df['value'], label=label) | ||
plt.legend(loc='best') | ||
plt.show() | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters