Skip to content

torchtrail: trace the graph of torch functions and modules for visualization, reports, etc

License

Notifications You must be signed in to change notification settings

arakhmati/torchtrail

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

23 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

torchtrail

PyPI version Build Status GitHub license

torchtrail provides an external API to trace pytorch models and extract the graph of torch functions and modules that were executed. The graphs can then be visualized or used for other purposes.

Installation Instructions

On MacOs

brew install graphviz
pip install torchtrail

On Ubuntu

sudo apt-get install graphviz
pip install torchtrail

Examples

Tracing a function

import torch
import torchtrail

with torchtrail.trace():
    input_tensor = torch.rand(1, 64)
    output_tensor = torch.exp(input_tensor)
torchtrail.visualize(output_tensor, file_name="exp.svg")

The graph could be obtained as a networkx.MultiDiGraph using torchtrail.get_graph:

graph: "networkx.MultiDiGraph" = torchtrail.get_graph(output_tensor)

Tracing a module

import torch
import transformers

import torchtrail

model_name = "google/bert_uncased_L-4_H-256_A-4"
config = transformers.BertConfig.from_pretrained(model_name)
config.num_hidden_layers = 1
model = transformers.BertModel.from_pretrained(model_name, config=config).eval()

with torchtrail.trace():
    input_tensor = torch.randint(0, model.config.vocab_size, (1, 64))
    output = model(input_tensor).last_hidden_state

torchtrail.visualize(output, max_depth=1, file_name="bert_max_depth_1.svg")

torchtrail.visualize(output, max_depth=2, file_name="bert_max_depth_2.svg")

The graph of the full module can be visualized by omitting max_depth argument

torchtrail.visualize(output, file_name="bert.svg")

The graph could be obtained as a networkx.MultiDiGraph using torchtrail.get_graph:

graph: "networkx.MultiDiGraph" = torchtrail.get_graph(output_tensor)

Alternatively, visualization of the modules can be turned off completely using show_modules=False

torchtrail.visualize(output, show_modules=False, file_name="bert_show_modules_False.svg")

The flattened graph could be obtained as a networkx.MultiDiGraph using torchtrail.get_graph:

graph: "networkx.MultiDiGraph" = torchtrail.get_graph(output_tensor, flatten=True)

Reference

  • torchtrail was inspired by torchview. mert-kurttutan did an amazing job with displaying torch graphs. However, one of the goals of torchtrail included producing networkx-compatible graph, therefore torchtrail was written.
  • The idea to use persistent MultiDiGraph to trace torch operations was taken from composit

About

torchtrail: trace the graph of torch functions and modules for visualization, reports, etc

Resources

License

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages