-
Notifications
You must be signed in to change notification settings - Fork 58
[Discuss] Extract Op Info from Primfunc #278
Comments
Some thought on this problem: To guarantee the soundness of graph-level layout transformation, we need to be able to infer the new layout-sensitive attributes for all ops, 100% reliably. That might be difficult, especially at the beginning of development. To guarantee the soundness while making gradual development possible, the TIR-level transformation pass can materialize This way, we don't have to worry about tricky ops like |
That could be an interesting direction @masahi! |
Interesting question, for backends that can do its own layout transform internally (DNNL), TVM-side layout-transform is always optional (only improves performance). So pattern matching is agnostic to layouts. While other backends (CUTLASS) expects the right layout for pattern matching to succeed, so we need to break the graph there. But I expect there would be no need to "infer" the new attributes for most compute-intensive ops that we want to offload to BYOC, since their layouts are typically fixed by users. We only need to worry about layout-sensitive ops in-between, like reduction and other shape-changing ops, that might not be offloaded to BYOC anyway. |
Are there any plans to add support for extracting op info as mentioned here at some point? Was there a final decision on how this is going to be supported? |
Hello @quic-sanirudh!
Yes, extracting op info would be supported. As mentioned in the comments above, @masahi and @sunggg have also laid out some of the possible approaches. A lot of details need to be figured out still.
It is going to be supported, but the design of how exactly this would work has not been decided yet. If you are interested in this problem, please feel free to start discussion on the design here or in a separate thread. |
Thanks @psrivas2 for the quick reply. I was curious on how this would work in the presence of fusion. Basically if we extract the op info before fusion, we have to assume that it'll only be valid until fusion is performed, or some way to iterate through the attributes of each individual op that is part of a fused prim_func. I'll think a bit more about this and explain a bit more with an example. |
Looking forward to the example. |
@psrivas2, thanks for the reply. Actually my question is not related to just layout transformation. Please correct me if I'm mistaken here, but I thought the point of this op info extraction is to make it so that we have these op specific attributes available during other transformations. For example, if we have the Say for example a user would like to write a new shedule_rule that targets a particular type of op, such that based on its attributes, the number of tiles can be decided, that might turn out to be really useful (just a random thought, I don't have a concrete example yet). My idea was, if we need something like that, we might need to extract op attributes through a pass before fusion, and retain it in some way after fusion, perhaps in the form of attributes to that fused |
1. Motivation
Problem
Currently, once we lower op to
primfunc
implementation, it is hard to exploit op-level info (e.g., op name, op kind, op attribute..) althoughprimfunc
is supposed to contain them. This has been fine in Relay since the pipeline lowers abstraction strictly in the one-direction allowing one abstraction at a time.In Relax, we are unlocking interaction between different abstraction-levels. New design of TIR-level layout planning is a good example - by manipulating both graph-level and TIR-level at the same time, we could eliminate the need of
InferCorrectLayout
that has been source of complexities and issues. However, this makes layout planning require lowering to happen before the planning and the loss of convenient op-level information during lowering makes BYOC mechanism difficult. For instance, the following snippet shows how TensorRT BYOC converts Relay/Relaxconv2d
op to TensorRT equivalent by using the op-level info (e.g., op name and its attributes, such asdata_layout
,strides
, etc. These info may not be easily accessible in the current primfunc design.Goal
To solve such problems, such as achieving benefit from TIR-level planning while supporting BYOC, this doc investigates whether it is possible to access the op-level info in TIR-level in a convenient form. Specifically, this op-level info includes
conv2d
)kElemWise
)axis
,padding
, …)Please note that
tir::PatternKindAnalyzer
in Relax is already able to deduce operator kind based on the TIR primfunc. This doc examines whether similar approach is achievable for other info.At the end of the day, we may provide the convenient interface to access those info. Although this doc would not discuss its best design, a couple of options can be:
O1: embed op info in the primfunc
O2: provide API like
tir::PatternKindAnalyzer
2. Findings
Operator Name
This can be obtained during the lowering and easily annotated in primfunc.
Operator Kinds
Already supported by
tir::PatternKindAnalyzer
in Relax.Operator Attributes
By using attributes, TVM lowers each operator into its valid implementation. Therefore, this section assumes the primfunc implementation would embed the attribute information in a certain way and examines whether we can extract them. Since layout transformation at TIR-level might affects the attributes (we call it layout-sensitive attribute), we also look into which attributes should be updated accordingly on the layout transformation.
Case Study
Representative Ops w/o Attributes
nn.relu
add
,subtract
,maximum
,minimum
Representative Ops w/ Attributes
Reduction family:
sum
axis
: find the reduction axis inT.reads
. In this example, axis=2keepdims
,exclude
nn.bias_add
axis
: find the first dim forexpand_dims
. In this example, axis=1nn.upsampling
layout
scale_h
,scale_w
,method
,align_corners
nn.conv2d
data_layout
,kernel_layout
,out_layout
strides
,padding
,dilation
,channels
: may be affected by the tilinggroups
nn.dense
units
: ???strided_slice
begin
,end
,strides
,axes
: these four params will decide how to slice each axis. Need to update when layout changes the axis info.slice_mode
nn.batch_norm
axis
: find the channel dimensionepsilon
,center
,scale
nn.max_pool2d
layout
,out_layout
pool_size
,strides
,dilation
,padding
,ceil_mode
transpose
axes
: compare the input buffer and new axis mappingnn.pad
pad_width
: compareT.grid
and new axis mappingpad_mode
reshape
newshape
: seeT.grid
allowzero
nn.split
axis
: LookT.reads
and find an axis with extra offset inindices_or_sections
. In this example, axis=1indices_or_sections
Summary
layout
,axis
,padding
) they would require the extension in the layout transformation to update them properly. Most of the cases, it seems quite straightforward how to update them.epsilon
,slice_mode
) they would not be affected by layout transformation. We can simply keep this information achieved during the lowering.3. Suggestion for Relax Layout Planner
With access to op-level info in primfunc, there can be two options to make relax layout planner work with BYOC:
InferCorrectLayout
) by peeking primfunc to perform TIR-based analysisPrimFunc
implementation for an operatorThe text was updated successfully, but these errors were encountered: