Skip to content

Commit

Permalink
fix lint
Browse files Browse the repository at this point in the history
  • Loading branch information
masahi committed Feb 26, 2020
1 parent 06f6ee6 commit b8d334c
Showing 1 changed file with 18 additions and 4 deletions.
22 changes: 18 additions & 4 deletions python/tvm/relay/frontend/pytorch.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import itertools
from packaging import version

import torch
import numpy as np

import tvm
Expand Down Expand Up @@ -735,6 +734,8 @@ def _convert_elemwise_input(data, input_type):


def run_jit_passes(graph):
""" The inline pass is nessary to unwrap prim::CallMethod """
import torch
if version.parse(torch.__version__) >= version.parse("1.4.0"):
torch._C._jit_pass_inline(graph)

Expand Down Expand Up @@ -776,10 +777,11 @@ def update_outputs_from_pairs(name_output_pairs, outputs, output_index_map):

def get_all_op_names(graph):
nodes = list(graph.nodes())
return set([node.kind() for node in nodes])
return set(node.kind() for node in nodes)


def report_missing_conversion(op_names):
"""Check if all ops in an input graph are supported by TVM"""
known_ops = ["prim::Constant", "prim::GetAttr",
"prim::ListConstruct", "prim::ListUnpack",
"prim::TupleConstruct", "prim::TupleUnpack"]
Expand All @@ -795,7 +797,7 @@ def report_missing_conversion(op_names):

def getattr_attr_name(node):
attribute_names = node.attributeNames()
assert(len(attribute_names) == 1)
assert len(attribute_names) == 1
attr_name = node.s(attribute_names[0])
return attr_name

Expand All @@ -805,6 +807,10 @@ def get_full_attr_name(getattrs):


def get_use_chains(root_node, terminate=lambda _: False):
"""
Track a chain of users of this node forward, returning a list of chains
See get_attr_chains below for its usage
"""
def concat_lists(lists):
return itertools.chain.from_iterable(lists)

Expand Down Expand Up @@ -841,6 +847,7 @@ def terminate(users):


def get_input_types(op_node):
"""Returns a torch type for each input nodes"""
input_list_types = []
for input_node in op_node.inputs():
in_ty = input_node.type()
Expand Down Expand Up @@ -868,14 +875,15 @@ def get_input_types(op_node):


def get_constant(node):
""" Retrive a constant associated with this prim::Constant node"""
attribute_names = node.attributeNames()
num_attributes = len(attribute_names)

if num_attributes == 1:
attr_name = attribute_names[0]
ty = node.output().type().kind()

if ty == "IntType" or ty == "BoolType":
if ty in ["IntType", "BoolType"]:
return node.i(attr_name)
elif ty in ["FloatType", "LongType"]:
return node.f(attr_name)
Expand All @@ -896,6 +904,7 @@ def get_constant(node):


def parse_inputs(graph_inputs, input_shapes):
""" Return Relay vars from torch input vars"""
ir_inputs = list(graph_inputs)
input_vars = {}

Expand All @@ -907,6 +916,10 @@ def parse_inputs(graph_inputs, input_shapes):


def parse_params(graph, state_dict):
"""
Return Relay vars and TVM NDArrays for input parameters
A chain of prim::GetAttr nodes is processed one at a time
"""
getattr_nodes = graph.findAllNodes("prim::GetAttr", recurse=True)
params = {}
param_tensors = {}
Expand All @@ -933,6 +946,7 @@ def parse_params(graph, state_dict):


def parse_ops(nodes):
""" Returns torch IR nodes that need conversion to Relay """
ops = {}
# Traverse nodes and add to graph
for node in nodes:
Expand Down

0 comments on commit b8d334c

Please sign in to comment.