Skip to content

Commit

Permalink
Implement quant/dequant partitioning
Browse files Browse the repository at this point in the history
on our way

get clooooooser

clean up (part 1)

clean up (part 2)

clean up (part 3)

clean up (part 4)

clean clean

cleaanaannanaaananaananaananaan

clkjsdflkjlfsjdflkj

revert parser changes

add docs

roll lint

roll lint
  • Loading branch information
weberlo committed Aug 10, 2020
1 parent b29f79e commit 53deebc
Show file tree
Hide file tree
Showing 7 changed files with 640 additions and 2 deletions.
47 changes: 47 additions & 0 deletions python/tvm/relay/analysis/analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,9 @@
from tvm.ir import IRModule
from tvm.relay import transform, build_module
from tvm.runtime.ndarray import cpu
# TODO(weberlo) remove when we port dtype collectors to C++
from tvm.relay.expr_functor import ExprVisitor
from tvm.relay.type_functor import TypeVisitor

from . import _ffi_api
from .feature import Feature
Expand Down Expand Up @@ -236,6 +239,50 @@ def all_type_vars(expr, mod=None):
return _ffi_api.all_type_vars(expr, use_mod)


class TyDtypeCollector(TypeVisitor):
"""Pass that collects data types used in the visited type."""

def __init__(self):
TypeVisitor.__init__(self)
self.dtypes = set()

def visit_tensor_type(self, tt):
self.dtypes.add(tt.dtype)


class ExprDtypeCollector(ExprVisitor):
"""Pass that collects data types used in all types in the visited expression."""

def __init__(self):
ExprVisitor.__init__(self)
self.ty_visitor = TyDtypeCollector()

def visit(self, expr):
if hasattr(expr, 'checked_type'):
self.ty_visitor.visit(expr.checked_type)
elif hasattr(expr, 'type_annotation'):
self.ty_visitor.visit(expr.type_annotation)
ExprVisitor.visit(self, expr)


def all_dtypes(expr):
"""Collect set of all data types used in `expr`.
Parameters
----------
expr : tvm.relay.Expr
The input expression
Returns
-------
ret : Set[String]
Set of data types used in the expression
"""
dtype_collector = ExprDtypeCollector()
dtype_collector.visit(expr)
return dtype_collector.ty_visitor.dtypes


def collect_device_info(expr):
"""Collect the device allocation map for the given expression. The device
ids are propagated from the `device_copy` operators.
Expand Down
Loading

0 comments on commit 53deebc

Please sign in to comment.