Skip to content

Commit

Permalink
Profiling of neural networks (#26)
Browse files Browse the repository at this point in the history
* Add profiling of neural networks

* Tests

* Refactor profiling

* Cosmetic changes

* Tests

* Cosmetic changes

* Cosmetic changes

* Refactoring profiler

* Documentation

* Add profiling test
  • Loading branch information
bedapisl authored and Adam Blažek committed Feb 9, 2019
1 parent e88f5ec commit 6b2e648
Show file tree
Hide file tree
Showing 7 changed files with 122 additions and 13 deletions.
27 changes: 27 additions & 0 deletions docs/advanced.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
Profiling networks
------------------
Profiling execution of tensorflow graph can be enabled with following setting:

.. code-block:: yaml
:caption config.yaml
model:
profile: True
keep_profiles: 10
This saves profiles of last 10 runs to the log directory (output directory).
Profiles are in JSON format and can be viewed using Google Chrome.
To view them go to address `chrome://tracing/` and load the json file.

Gradient clipping
-----------------
For gradient clipping use following setting:

.. code-block:: yaml
:caption config.yaml
model:
clip_gradient: 5.0
This clips the absolute value of gradient to 5.0.
Note that the clipping is done to raw gradients before they are multiplied by learning rate or processed in other ways.
3 changes: 2 additions & 1 deletion docs/conf.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
# General information about the project.
project = 'emloop-tensorflow'
copyright = '2018, Iterait a.s.'
author = 'Blazek Adam, Belohlavek Petr, Matzner Filip'
author = 'Blazek Adam, Belohlavek Petr, Matzner Filip, Bedrich Pisl'

# The short X.Y version.
version = '.'.join(pkg_resources.get_distribution("emloop-tensorflow").version.split('.')[:2])
Expand All @@ -37,6 +37,7 @@
("Tutorial", "tutorial"),
("Model Regularization", "regularization"),
("Multi GPU models", "multi_gpu"),
("Advanced", "advanced"),
("API Reference", "emloop_tensorflow/index"),
],
})
Expand Down
21 changes: 17 additions & 4 deletions emloop_tensorflow/frozen_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

from .graph_tower import GraphTower
from .model import BaseModel
from .utils import Profiler


class FrozenModel(el.AbstractModel):
Expand All @@ -28,17 +29,19 @@ class FrozenModel(el.AbstractModel):
"""

def __init__(self,
inputs: List[str], outputs: List[str], restore_from: str,
session_config: Optional[dict]=None, n_gpus: int=0, **_):
def __init__(self, inputs: List[str], outputs: List[str], restore_from: str, log_dir: Optional[str]=None,
session_config: Optional[dict]=None, n_gpus: int=0, profile: bool=False, keep_profiles: int=5, **_):
"""
Initialize new :py:class:`FrozenModel` instance.
:param log_dir: output directory
:param inputs: model input names
:param outputs: model output names
:param restore_from: restore model path (either a dir or a .pb file)
:param session_config: TF session configuration dict
:param n_gpus: number of GPUs to use (either 0 or 1)
:param profile: if true, profile the speed of model inference and save profiles to the specified log_dir
:param keep_profiles: how many profiles are saved
"""
super().__init__(None, '', restore_from)
assert 0 <= n_gpus <= 1, 'FrozenModel can be used only with n_gpus=0 or n_gpus=1'
Expand All @@ -60,6 +63,13 @@ def __init__(self,
except KeyError:
self._is_training = tf.placeholder(tf.bool, [], BaseModel.TRAINING_FLAG_NAME)

if profile and not log_dir:
raise ValueError('log_dir has to be specified with profile set to True')

self._profile = profile
if profile:
self._profiler = Profiler(log_dir, keep_profiles, self._session)

def run(self, batch: el.Batch, train: bool=False, stream: el.datasets.StreamWrapper=None) -> Mapping[str, object]:
"""
Run the model with the given ``batch``.
Expand All @@ -83,7 +93,10 @@ def run(self, batch: el.Batch, train: bool=False, stream: el.datasets.StreamWrap
for output_name in self.output_names:
fetches.append(self._tower[output_name])

outputs = self._session.run(fetches=fetches, feed_dict=feed_dict)
if self._profile:
outputs = self._profiler.run(fetches=fetches, feed_dict=feed_dict)
else:
outputs = self._session.run(fetches=fetches, feed_dict=feed_dict)

return dict(zip(self.output_names, outputs))

Expand Down
25 changes: 18 additions & 7 deletions emloop_tensorflow/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import logging
from os import path
from abc import ABCMeta
from typing import List, Mapping, Optional
from typing import List, Mapping, Optional, Dict
from glob import glob

import numpy as np
Expand All @@ -11,7 +11,7 @@

from .third_party.tensorflow.freeze_graph import freeze_graph
from .third_party.tensorflow.average_gradients import average_gradients
from .utils import create_optimizer
from .utils import create_optimizer, Profiler
from .graph_tower import GraphTower

DEFAULT_LOSS_NAME = 'loss'
Expand Down Expand Up @@ -44,8 +44,8 @@ def __init__(self, # pylint: disable=too-many-arguments
dataset: Optional[el.AbstractDataset], log_dir: Optional[str], inputs: List[str], outputs: List[str],
session_config: Optional[dict]=None, n_gpus: int=0, restore_from: Optional[str]=None,
optimizer=None, freeze=False, loss_name: str=DEFAULT_LOSS_NAME, monitor: Optional[str]=None,
restore_fallback: Optional[str]=None, clip_gradient: Optional[float]=None,
**kwargs):
restore_fallback: Optional[str]=None, clip_gradient: Optional[float]=None, profile: bool=False,
keep_profiles: int=5, **kwargs):
"""
Create new emloop trainable TensorFlow model.
Expand Down Expand Up @@ -82,6 +82,8 @@ def __init__(self, # pylint: disable=too-many-arguments
:param monitor: monitor signal mean and variance of the tensors which names contain the specified value
:param restore_fallback: ignored arg. (allows training from configs saved by emloop where it is added)
:param clip_gradient: limit the absolute value of the gradient; set to None for no clipping
:param profile: if true, profile the speed of model inference and save profiles to the specified log_dir
:param keep_profiles: if true, profile the speed of model inference and save profiles to the specified log_dir
:param kwargs: additional kwargs forwarded to :py:meth:`_create_model`
"""
super().__init__(dataset=dataset, log_dir=log_dir, restore_from=restore_from)
Expand All @@ -97,10 +99,17 @@ def __init__(self, # pylint: disable=too-many-arguments
self._towers = [GraphTower(i, inputs, outputs, loss_name) for i in range(n_gpus)]
if n_gpus == 0:
self._towers.append(GraphTower(-1, inputs, outputs, loss_name))

logging.info('\tCreating TF model on %s GPU devices', n_gpus)
self._graph = tf.Graph()
self._session = self._create_session(session_config)

if profile and not log_dir:
raise ValueError('log_dir has to be specified with profile set to True')

self._profile = profile
if profile:
self._profiler = Profiler(log_dir, keep_profiles, self._session)

dependencies = []
with self._graph.as_default():
if restore_from is None:
Expand Down Expand Up @@ -223,12 +232,14 @@ def run(self, batch: el.Batch, train: bool=False, stream: el.datasets.StreamWrap
for output_name in self.output_names:
fetches.append(tower[output_name])

run_fn = self._profiler.run if self._profile else self._session.run

# run the computational graph for one batch and allow buffering in the meanwhile
if stream is not None:
with stream.allow_buffering:
outputs = self._session.run(fetches=fetches, feed_dict=feed_dict)
outputs = run_fn(fetches, feed_dict)
else:
outputs = self._session.run(fetches=fetches, feed_dict=feed_dict)
outputs = run_fn(fetches, feed_dict)

if train:
outputs = outputs[1:]
Expand Down
15 changes: 15 additions & 0 deletions emloop_tensorflow/tests/model_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -419,6 +419,21 @@ def test_regularization():
regularized_model2.run(good_batch, train=True)


def test_profiling(tmpdir):
"""Test whether profile is created."""
model = TrainableModel(dataset=None, log_dir=tmpdir, **_IO, optimizer=_OPTIMIZER, profile=True, keep_profiles=10)
batch = {'input': [[1]*10], 'target': [[0]*10]}

# test if one can train one model while the other remains intact
for _ in range(1000):
model.run(batch, train=True)

for i in range(10):
assert path.exists(f"{tmpdir}/profile_{i}.json")

assert not path.exists(f"{tmpdir}/profile_11.json")


#######################
# TF Base Model Saver #
#######################
Expand Down
3 changes: 2 additions & 1 deletion emloop_tensorflow/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,5 +2,6 @@
Module with TensorFlow util functions.
"""
from .reflection import create_activation, create_optimizer
from .profiler import Profiler

__all__ = ['create_activation', 'create_optimizer']
__all__ = ['create_activation', 'create_optimizer', 'Profiler']
41 changes: 41 additions & 0 deletions emloop_tensorflow/utils/profiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
import tensorflow as tf
from tensorflow.python.client import timeline
from typing import Dict
import os


class Profiler:
"""
Profiles tensorflow graphs and saves the profiles.
"""

def __init__(self, log_dir: str, keep_profiles: int, session: tf.Session):
"""
:param log_dir: directory where profiles will be saved
:param keep_profiles: how many profiles are saved
"""
self._log_dir = log_dir
self._profile_counter = 0
self._keep_profiles = keep_profiles
self._run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
self._session = session

def run(self, fetches: Dict, feed_dict: Dict):
"""
Evaluates the tensorflow graph with profiling, saves profile and returns outputs.
:param session: tensorflow session
:param fetches: names of output tensors
:param feed_dict: input tensors
"""
run_metadata = tf.RunMetadata()
outputs = self._session.run(fetches=fetches, feed_dict=feed_dict,
options=self._run_options, run_metadata=run_metadata)

with open(os.path.join(self._log_dir, f'profile_{self._profile_counter}.json'), 'w') as ofile:
tl = timeline.Timeline(run_metadata.step_stats)
ofile.write(tl.generate_chrome_trace_format())

self._profile_counter = (self._profile_counter + 1) % self._keep_profiles

return outputs

0 comments on commit 6b2e648

Please sign in to comment.