Skip to content
This repository has been archived by the owner on May 22, 2023. It is now read-only.

[Discuss] Extract Op Info from Primfunc #278

Open
sunggg opened this issue Nov 16, 2022 · 8 comments
Open

[Discuss] Extract Op Info from Primfunc #278

sunggg opened this issue Nov 16, 2022 · 8 comments

Comments

@sunggg
Copy link
Collaborator

sunggg commented Nov 16, 2022

1. Motivation

Problem

Currently, once we lower op to primfunc implementation, it is hard to exploit op-level info (e.g., op name, op kind, op attribute..) although primfunc is supposed to contain them. This has been fine in Relay since the pipeline lowers abstraction strictly in the one-direction allowing one abstraction at a time.

In Relax, we are unlocking interaction between different abstraction-levels. New design of TIR-level layout planning is a good example - by manipulating both graph-level and TIR-level at the same time, we could eliminate the need of InferCorrectLayout that has been source of complexities and issues. However, this makes layout planning require lowering to happen before the planning and the loss of convenient op-level information during lowering makes BYOC mechanism difficult. For instance, the following snippet shows how TensorRT BYOC converts Relay/Relax conv2d op to TensorRT equivalent by using the op-level info (e.g., op name and its attributes, such as data_layout, strides, etc. These info may not be easily accessible in the current primfunc design.

class Conv2DOpConverter : public TensorRTOpConverter {
 public:
  // ....
  void Convert(TensorRTOpConverterParams* params) const {
    auto input_tensor = params->inputs.at(0).tensor;
    auto input_dims = TrtDimsToVector(input_tensor->getDimensions());
    auto weight_shape = params->inputs.at(1).weight_shape;
    ICHECK_EQ(params->node.GetAttr<std::vector<std::string>>("data_layout")[0], "NCHW");
    ICHECK(params->node.GetAttr<std::vector<std::string>>("out_layout")[0] == "" ||
           params->node.GetAttr<std::vector<std::string>>("out_layout")[0] == "NCHW");
    ICHECK_EQ(params->node.GetAttr<std::vector<std::string>>("kernel_layout")[0], "OIHW");
    auto str_strides = params->node.GetAttr<std::vector<std::string>>("strides");
    auto str_dilation = params->node.GetAttr<std::vector<std::string>>("dilation");
    auto str_padding = params->node.GetAttr<std::vector<std::string>>("padding");
    int groups = std::stoi(params->node.GetAttr<std::vector<std::string>>("groups")[0]);
    int channels = weight_shape[0];
    if (params->node.HasAttr("channels") &&
        !params->node.GetAttr<std::vector<std::string>>("channels")[0].empty()) {
      channels = std::stoi(params->node.GetAttr<std::vector<std::string>>("channels")[0]);
    }
    // ...
    const auto kernel_size = nvinfer1::DimsHW(weight_shape[2], weight_shape[3]);
    const nvinfer1::DataType weight_type = params->inputs.at(1).weight.type;
    nvinfer1::Weights bias{weight_type, nullptr, 0};
    auto conv_layer = params->network->addConvolution(*input_tensor, channels, kernel_size,
                                                      params->inputs.at(1).weight, bias);
    //...
  }
};

Goal

To solve such problems, such as achieving benefit from TIR-level planning while supporting BYOC, this doc investigates whether it is possible to access the op-level info in TIR-level in a convenient form. Specifically, this op-level info includes

  • operator name (e.g., conv2d)
  • operator kind (e.g., kElemWise)
  • attributes (e.g., axis, padding, …)

Please note that tir::PatternKindAnalyzer in Relax is already able to deduce operator kind based on the TIR primfunc. This doc examines whether similar approach is achievable for other info.

At the end of the day, we may provide the convenient interface to access those info. Although this doc would not discuss its best design, a couple of options can be:

  • O1: embed op info in the primfunc

    @T.prim_func
    def lowered_primfunc(...):
      T.func_attrs(
        "op_info": {
          # `op name`
          "name": "conv2d",
          # `op kind` is omitted since it can be deduced by tir::PatternKindAnalyzer
          # `op attributes`
          "strides": [1, 1],
          "padding": [1, 1],
          "groups": 1,
        }
      )
  • O2: provide API like tir::PatternKindAnalyzer

2. Findings

Operator Name

This can be obtained during the lowering and easily annotated in primfunc.

Operator Kinds

Already supported by tir::PatternKindAnalyzer in Relax.

Operator Attributes

By using attributes, TVM lowers each operator into its valid implementation. Therefore, this section assumes the primfunc implementation would embed the attribute information in a certain way and examines whether we can extract them. Since layout transformation at TIR-level might affects the attributes (we call it layout-sensitive attribute), we also look into which attributes should be updated accordingly on the layout transformation.

Case Study

Representative Ops w/o Attributes

  • nn.relu
  • add, subtract, maximum, minimum
  • etc.

Representative Ops w/ Attributes

  • Reduction family: sum

    TVM_DECLARE_ATTRS(ReduceAttrs, "relay.attrs.ReduceAttrs") {
        TVM_ATTR_FIELD(axis)
            .set_default(NullValue<Array<Integer>>())
            .describe(R"code(The axis or axes along which to perform the reduction.
    
          The default, `axis=()`, will compute over all elements into a
          scalar array with shape `(1,)`.
    
          If `axis` is int, a reduction is performed on a particular axis.
    
          If `axis` is a tuple of ints, a reduction is performed on all the axes
          specified in the tuple.
    
          If `exclude` is true, reduction will be performed on the axes that are
          NOT in axis instead.)code");
    
        TVM_ATTR_FIELD(keepdims).set_default(false).describe(
            "If this is set to `True`, the reduced axes are left "
            "in the result as dimension with size one.");
        TVM_ATTR_FIELD(exclude).set_default(false).describe(
            "Whether to perform reduction on axis that are NOT in axis instead.");
      }
    };
    @T.prim_func
    def sum(rxplaceholder: T.Buffer[(T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "float32"], rxplaceholder_red: T.Buffer[(T.int64(1), T.int64(64), T.int64(56)), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "sum", "tir.noalias": True})
        # body
        # with T.block("root")
        for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(64), T.int64(56), T.int64(56)):
            with T.block("rxplaceholder_red"):
                ax0 = T.axis.spatial(T.int64(1), i0)
                ax1 = T.axis.spatial(T.int64(64), i1)
                ax2 = T.axis.spatial(T.int64(56), i2)
                k2 = T.axis.reduce(T.int64(56), i3)
                T.reads(rxplaceholder[ax0, ax1, k2, ax2])
                T.writes(rxplaceholder_red[ax0, ax1, ax2])
                with T.init():
                    rxplaceholder_red[ax0, ax1, ax2] = T.float32(0)
                rxplaceholder_red[ax0, ax1, ax2] = rxplaceholder_red[ax0, ax1, ax2] + rxplaceholder[ax0, ax1, k2, ax2]
    • Layout-sensitive attributes
      • axis : find the reduction axis in T.reads. In this example, axis=2
    • Layout-insensitive attributes: keepdims, exclude
  • nn.bias_add

    struct BiasAddAttrs : public tvm::AttrsNode<BiasAddAttrs> {
      int axis;
    
      TVM_DECLARE_ATTRS(BiasAddAttrs, "relay.attrs.BiasAddAttrs") {
        TVM_ATTR_FIELD(axis).describe("The axis to add the bias").set_default(1);
      }
    };
    @T.prim_func
    def expand_dims(rxplaceholder: T.Buffer[T.int64(64), "float32"], T_expand_dims: T.Buffer[(T.int64(64), 1, 1), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "expand_dims", "tir.noalias": True})
        # body
        # with T.block("root")
        for i0, i1, i2 in T.grid(T.int64(64), 1, 1):
            with T.block("T_expand_dims"):
                ax0 = T.axis.spatial(T.int64(64), i0)
                ax1, ax2 = T.axis.remap("SS", [i1, i2])
                T.reads(rxplaceholder[ax0])
                T.writes(T_expand_dims[ax0, ax1, ax2])
                T_expand_dims[ax0, ax1, ax2] = rxplaceholder[ax0]
    
    @T.prim_func
    def add(rxplaceholder: T.Buffer[(T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "float32"], rxplaceholder_1: T.Buffer[(T.int64(64), 1, 1), "float32"], T_add: T.Buffer[(T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "add", "tir.noalias": True})
        # body
        # with T.block("root")
        for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(64), T.int64(56), T.int64(56)):
            with T.block("T_add"):
                ax0 = T.axis.spatial(T.int64(1), i0)
                ax1 = T.axis.spatial(T.int64(64), i1)
                ax2 = T.axis.spatial(T.int64(56), i2)
                ax3 = T.axis.spatial(T.int64(56), i3)
                T.reads(rxplaceholder[ax0, ax1, ax2, ax3], rxplaceholder_1[ax1, T.int64(0), T.int64(0)])
                T.writes(T_add[ax0, ax1, ax2, ax3])
                T_add[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1, ax2, ax3] + rxplaceholder_1[ax1, T.int64(0), T.int64(0)]
    
    @R.function
    def main(x: Tensor((1, 64, 56, 56), "float32"), bias: Tensor((64,), "float32")) -> Tensor(None, "float32", ndim = 4):
        # block 0
        with R.dataflow():
            lv = R.call_tir(expand_dims, (bias,), (64, 1, 1), dtype="float32")
            lv1 = R.call_tir(add, (x, lv), (1, 64, 56, 56), dtype="float32")
            gv: Tensor((1, 64, 56, 56), "float32") = lv1
            R.output(gv)
        return gv
    • Layout-sensitive attributes
      • axis : find the first dim for expand_dims. In this example, axis=1
  • nn.upsampling

    struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
      double scale_h;
      double scale_w;
      tvm::String layout;
      tvm::String method;
      bool align_corners;
    
      TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") {
        TVM_ATTR_FIELD(scale_h).describe("The upsampling factor for height");
        TVM_ATTR_FIELD(scale_w).describe("The upsampling factor for width");
        TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
            "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
            "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
            "dimensions respectively. Upsampling is applied on the 'H' and"
            "'W' dimensions.");
        TVM_ATTR_FIELD(method)
            .set_default("nearest_neighbor")
            .describe(
                "Specify the mode to use for scaling."
                "nearest_neighbor -  Nearest Neighbor"
                "bilinear - Bilinear Interpolation"
                "bicubic - Bicubic Interpolation");
        TVM_ATTR_FIELD(align_corners)
            .set_default(false)
            .describe("Should be true to preserve the values at the corner pixels");
      }
    };
    @T.prim_func
    def upsampling(rxplaceholder: T.Buffer[(T.int64(64), T.int64(512), T.int64(10), T.int64(10)), "float32"], resize: T.Buffer[(T.int64(64), T.int64(512), T.int64(20), T.int64(20)), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "upsampling", "tir.noalias": True})
        # body
        # with T.block("root")
        for i0, i1, i2, i3 in T.grid(T.int64(64), T.int64(512), T.int64(20), T.int64(20)):
            with T.block("resize"):
                i0_1 = T.axis.spatial(T.int64(64), i0)
                i1_1 = T.axis.spatial(T.int64(512), i1)
                i2_1 = T.axis.spatial(T.int64(20), i2)
                i3_1 = T.axis.spatial(T.int64(20), i3)
                T.reads(rxplaceholder[i0_1, i1_1, T.max(T.min(T.cast(T.cast(i2_1 / T.int64(2), "int32"), "int64"), T.int64(9)), T.int64(0)), T.max(T.min(T.cast(T.cast(i3_1 / T.int64(2), "int32"), "int64"), T.int64(9)), T.int64(0))])
                T.writes(resize[i0_1, i1_1, i2_1, i3_1])
                resize[i0_1, i1_1, i2_1, i3_1] = rxplaceholder[i0_1, i1_1, T.max(T.min(T.cast(T.cast(i2_1 / T.int64(2), "int32"), "int64"), T.int64(9)), T.int64(0)), T.max(T.min(T.cast(T.cast(i3_1 / T.int64(2), "int32"), "int64"), T.int64(9)), T.int64(0))]
    • Layout-sensitive attributes
      • layout
    • Layout-insensitive attributes: scale_h, scale_w, method, align_corners
  • nn.conv2d

    struct Conv2DAttrs : public tvm::AttrsNode<Conv2DAttrs> {
      Array<IndexExpr> strides;
      Array<IndexExpr> padding;
      Array<IndexExpr> dilation;
      int groups;
      IndexExpr channels;
      Array<IndexExpr> kernel_size;
      tvm::String data_layout;
      tvm::String kernel_layout;
      tvm::String out_layout;
      tvm::String auto_scheduler_rewritten_layout;   // The layout after auto-scheduler's layout rewrite
      Array<PrimExpr> meta_schedule_original_shape;  // The original shape of the weights
      DataType out_dtype;
    
      TVM_DECLARE_ATTRS(Conv2DAttrs, "relay.attrs.Conv2DAttrs") {
        TVM_ATTR_FIELD(strides)
            .set_default(Array<IndexExpr>({1, 1}))
            .describe("Specifies the strides of the convolution.");
        TVM_ATTR_FIELD(padding)
            .set_default(Array<IndexExpr>({0, 0}))
            .describe(
                "If padding is non-zero, then the input is implicitly zero-padded"
                "Padding support both symmetric and asymmetric as"
                "one int : same padding used on all sides"
                "two int : bottom, right will use same padding as top, left"
                "four int : padding width in the order of (top, left, bottom, right)");
        TVM_ATTR_FIELD(dilation)
            .set_default(Array<IndexExpr>({1, 1}))
            .describe("Specifies the dilation rate to use for dilated convolution.");
        TVM_ATTR_FIELD(groups).set_default(1).describe(
            "Controls the connections between inputs and outputs."
            "At groups=1, all inputs are convolved to all outputs."
            "At groups=2, the operation becomes equivalent to having two convolution"
            "layers side by side, each seeing half the input channels, and producing"
            "half the output channels, and both subsequently concatenated.");
        TVM_ATTR_FIELD(channels)
            .describe(
                "The number of output channels in the convolution."
                " If it is not set, inferred by shape of the weight.")
            .set_default(NullValue<IndexExpr>());
        TVM_ATTR_FIELD(kernel_size)
            .describe("Specifies the dimensions of the convolution window.")
            .set_default(NullValue<Array<IndexExpr>>());
        TVM_ATTR_FIELD(data_layout)
            .set_default("NCHW")
            .describe(
                "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
                "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
                "dimensions respectively. Convolution is applied on the 'H' and"
                "'W' dimensions.");
        TVM_ATTR_FIELD(kernel_layout)
            .set_default("OIHW")
            .describe(
                "Dimension ordering of weight. Can be 'OIHW', 'OIHW16o16i', etc."
                "'O', 'I', 'H', 'W' stands for num_filter, input_channel, height, and width"
                "dimensions respectively.");
        TVM_ATTR_FIELD(out_layout)
            .set_default("")
            .describe(
                "Dimension ordering of output. Can be 'NCHW', 'NHWC', etc."
                "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
                "dimensions respectively. Default to be same as input layout.");
    
        // use 0 bits to indicate none.
        TVM_ATTR_FIELD(out_dtype)
            .set_default(NullValue<DataType>())
            .describe("Output data type, set to explicit type under mixed precision setting");
      }
    };
    @T.prim_func
    def conv2d(rxplaceholder: T.Buffer[(T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "float32"], rxplaceholder_1: T.Buffer[(T.int64(256), T.int64(64), T.int64(5), T.int64(5)), "float32"], conv2d_nchw: T.Buffer[(1, 256, 54, 54), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "conv2d", "tir.noalias": True})
        # body
        # with T.block("root")
        pad_temp = T.alloc_buffer([T.int64(1), T.int64(64), T.int64(58), T.int64(58)], dtype="float32")
        for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(64), T.int64(58), T.int64(58)):
            with T.block("pad_temp"):
                i0_1 = T.axis.spatial(T.int64(1), i0)
                i1_1 = T.axis.spatial(T.int64(64), i1)
                i2_1 = T.axis.spatial(T.int64(58), i2)
                i3_1 = T.axis.spatial(T.int64(58), i3)
                T.reads(rxplaceholder[i0_1, i1_1, i2_1 - T.int64(1), i3_1 - T.int64(1)])
                T.writes(pad_temp[i0_1, i1_1, i2_1, i3_1])
                pad_temp[i0_1, i1_1, i2_1, i3_1] = T.if_then_else(T.int64(1) <= i2_1 and i2_1 < T.int64(57) and T.int64(1) <= i3_1 and i3_1 < T.int64(57), rxplaceholder[i0_1, i1_1, i2_1 - T.int64(1), i3_1 - T.int64(1)], T.float32(0), dtype="float32")
        for i0, i1, i2, i3, i4, i5, i6 in T.grid(1, 256, 54, 54, 64, 5, 5):
            with T.block("conv2d_nchw"):
                nn, ff, yy, xx, rc, ry, rx = T.axis.remap("SSSSRRR", [i0, i1, i2, i3, i4, i5, i6])
                T.reads(pad_temp[nn, rc, yy + ry, xx + rx], rxplaceholder_1[ff, rc, ry, rx])
                T.writes(conv2d_nchw[nn, ff, yy, xx])
                T.block_attr({"workload":["conv2d_nchw.cuda", ["TENSOR", [1, 64, 56, 56], "float32"], ["TENSOR", [256, 64, 5, 5], "float32"], [1, 1], [1, 1, 1, 1], [1, 1], "float32"]})
                with T.init():
                    conv2d_nchw[nn, ff, yy, xx] = T.float32(0)
                conv2d_nchw[nn, ff, yy, xx] = conv2d_nchw[nn, ff, yy, xx] + pad_temp[nn, rc, yy + ry, xx + rx] * rxplaceholder_1[ff, rc, ry, rx]
    • Layout-sensitive attributes
      • data_layout, kernel_layout, out_layout
      • strides, padding, dilation, channels : may be affected by the tiling
    • Layout-insensitive attributes: groups
  • nn.dense

    /*! \brief Attributes for dense operator */
    struct DenseAttrs : public tvm::AttrsNode<DenseAttrs> {
      IndexExpr units;
      tvm::String auto_scheduler_rewritten_layout;   // The layout after auto-scheduler's layout rewrite
      Array<PrimExpr> meta_schedule_original_shape;  // The original shape of the weights
      DataType out_dtype;
    
      TVM_DECLARE_ATTRS(DenseAttrs, "relay.attrs.DenseAttrs") {
        TVM_ATTR_FIELD(units).describe("Number of hidden units of the dense transformation.");
    
        // use 0 bits to indicate none.
        TVM_ATTR_FIELD(out_dtype)
            .set_default(NullValue<DataType>())
            .describe("Output data type, set to explicit type under mixed precision setting");
      }
    };
    @T.prim_func
    def dense(rxplaceholder: T.Buffer[(T.int64(1), T.int64(8)), "float32"], rxplaceholder_1: T.Buffer[(T.int64(16), T.int64(8)), "float32"], T_matmul_NT: T.Buffer[(T.int64(1), T.int64(16)), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "dense", "tir.noalias": True, "layout_free_buffers": [1]})
        # body
        # with T.block("root")
        for i0, i1, i2 in T.grid(T.int64(1), T.int64(16), T.int64(8)):
            with T.block("T_matmul_NT"):
                i = T.axis.spatial(T.int64(1), i0)
                j = T.axis.spatial(T.int64(16), i1)
                k = T.axis.reduce(T.int64(8), i2)
                T.reads(rxplaceholder[i, k], rxplaceholder_1[j, k])
                T.writes(T_matmul_NT[i, j])
                T.block_attr({"layout_free_placeholders":[], "workload":["dense_small_batch.gpu", ["TENSOR", [1, 8], "float32"], ["TENSOR", [16, 8], "float32"], None, "float32"]})
                with T.init():
                    T_matmul_NT[i, j] = T.float32(0)
                T_matmul_NT[i, j] = T_matmul_NT[i, j] + rxplaceholder[i, k] * rxplaceholder_1[j, k]
    • [TODO] units : ???
  • strided_slice

    /*! \brief Attributes for StridedSlice operator */
    struct StridedSliceAttrs : public tvm::AttrsNode<StridedSliceAttrs> {
      Optional<Array<Integer>> begin;
      Optional<Array<Integer>> end;
      Optional<Array<Integer>> strides;
      tvm::String slice_mode;
      Optional<Array<Integer>> axes;
    
      TVM_DECLARE_ATTRS(StridedSliceAttrs, "relay.attrs.StridedSliceAttrs") {
        TVM_ATTR_FIELD(begin).describe("Indices for begin of slice, begin index is also inclusive");
        TVM_ATTR_FIELD(end).describe("Indices for end of slice, end index is exclusive");
        TVM_ATTR_FIELD(strides).describe(
            "Stride values of the slice, a stride can be negative, which causes a reverse slice.");
        TVM_ATTR_FIELD(slice_mode)
            .set_default("end")
            .describe(
                "The slice mode [end, size]."
                "end - The default slice mode, ending indices for the slice."
                "size - The input strides will be ignored, input end in this mode indicates the size"
                "of a slice starting at the location specified by begin. If end[i] is -1,"
                "all remaining elements in that dimension are included in the slice");
        TVM_ATTR_FIELD(axes).describe(
            "Axes along which slicing is applied. When it is specified, the length of begin, end, "
            "strides, and axes must be equal.");
      }
    };
    @T.prim_func
    def strided_slice(rxplaceholder: T.Buffer[(T.int64(1), T.int64(2), T.int64(2), T.int64(4)), "int8"], T_strided_slice: T.Buffer[(T.int64(1), T.int64(2), T.int64(2), T.int64(2)), "int8"]) -> None:
      # function attr dict
      T.func_attr({"global_symbol": "strided_slice", "tir.noalias": True})
      # body
      # with T.block("root")
      for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(2), T.int64(2), T.int64(2)):
          with T.block("T_strided_slice"):
              ax0 = T.axis.spatial(T.int64(1), i0)
              ax1 = T.axis.spatial(T.int64(2), i1)
              ax2 = T.axis.spatial(T.int64(2), i2)
              ax3 = T.axis.spatial(T.int64(2), i3)
              T.reads(rxplaceholder[ax0, ax1, ax2, ax3])
              T.writes(T_strided_slice[ax0, ax1, ax2, ax3])
              T_strided_slice[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1, ax2, ax3]
    • Layout-sensitive attributes
      • begin, end, strides, axes: these four params will decide how to slice each axis. Need to update when layout changes the axis info.
    • Layout-insensitive attributes: slice_mode
  • nn.batch_norm

    /*! \brief Attributes used in batch_norm operator */
    struct BatchNormAttrs : public tvm::AttrsNode<BatchNormAttrs> {
      int axis;
      double epsilon;
      bool center;
      bool scale;
    
      TVM_DECLARE_ATTRS(BatchNormAttrs, "relay.attrs.BatchNormAttrs") {
        TVM_ATTR_FIELD(axis).describe("Specify which shape axis denotes the channel.").set_default(1);
        TVM_ATTR_FIELD(epsilon)
            .describe("Small float added to variance to avoid dividing by zero")
            .set_default(1e-5);
        TVM_ATTR_FIELD(center)
            .describe("If True, add offset of beta to normalized tensor. If False, beta is ignored")
            .set_default(true);
        TVM_ATTR_FIELD(scale)
            .describe(
                "If True, multiply by gamma. If False, gamma is not used. "
                "When the next layer is piecewise linear (also, e.g., nn.relu), "
                "this can be disabled since the scaling will be done by the next layer.")
            .set_default(true);
      }
    };  // struct BatchNormAttrs
    @tvm.script.ir_module
    class Module:
        @T.prim_func
        def negative(rxplaceholder: T.Buffer[T.int64(8), "float32"], T_negative: T.Buffer[T.int64(8), "float32"]) -> None:
            # function attr dict
            T.func_attr({"global_symbol": "negative", "tir.noalias": True})
            # body
            # with T.block("root")
            for i0 in T.serial(T.int64(8)):
                with T.block("T_negative"):
                    ax0 = T.axis.spatial(T.int64(8), i0)
                    T.reads(rxplaceholder[ax0])
                    T.writes(T_negative[ax0])
                    T_negative[ax0] = T.float32(0) - rxplaceholder[ax0]
        
        @T.prim_func
        def add2(rxplaceholder: T.Buffer[(T.int64(1), T.int64(8)), "float32"], rxplaceholder_1: T.Buffer[T.int64(8), "float32"], T_add: T.Buffer[(T.int64(1), T.int64(8)), "float32"]) -> None:
            # function attr dict
            T.func_attr({"global_symbol": "add2", "tir.noalias": True})
            # body
            # with T.block("root")
            for i0, i1 in T.grid(T.int64(1), T.int64(8)):
                with T.block("T_add"):
                    ax0 = T.axis.spatial(T.int64(1), i0)
                    ax1 = T.axis.spatial(T.int64(8), i1)
                    T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax1])
                    T.writes(T_add[ax0, ax1])
                    T_add[ax0, ax1] = rxplaceholder[ax0, ax1] + rxplaceholder_1[ax1]
        
        @T.prim_func
        def sqrt(rxplaceholder: T.Buffer[T.int64(8), "float32"], T_sqrt: T.Buffer[T.int64(8), "float32"]) -> None:
            # function attr dict
            T.func_attr({"global_symbol": "sqrt", "tir.noalias": True})
            # body
            # with T.block("root")
            for i0 in T.serial(T.int64(8)):
                with T.block("T_sqrt"):
                    ax0 = T.axis.spatial(T.int64(8), i0)
                    T.reads(rxplaceholder[ax0])
                    T.writes(T_sqrt[ax0])
                    T_sqrt[ax0] = T.sqrt(rxplaceholder[ax0], dtype="float32")
        
        @T.prim_func
        def divide(rxplaceholder: T.Buffer[(), "float32"], rxplaceholder_1: T.Buffer[T.int64(8), "float32"], T_divide: T.Buffer[T.int64(8), "float32"]) -> None:
            # function attr dict
            T.func_attr({"global_symbol": "divide", "tir.noalias": True})
            # body
            # with T.block("root")
            for i0 in T.serial(T.int64(8)):
                with T.block("T_divide"):
                    ax0 = T.axis.spatial(T.int64(8), i0)
                    T.reads(rxplaceholder[()], rxplaceholder_1[ax0])
                    T.writes(T_divide[ax0])
                    T_divide[ax0] = rxplaceholder[()] / rxplaceholder_1[ax0]
        
        @T.prim_func
        def add(rxplaceholder: T.Buffer[T.int64(8), "float32"], rxplaceholder_1: T.Buffer[(), "float32"], T_add: T.Buffer[T.int64(8), "float32"]) -> None:
            # function attr dict
            T.func_attr({"global_symbol": "add", "tir.noalias": True})
            # body
            # with T.block("root")
            for i0 in T.serial(T.int64(8)):
                with T.block("T_add"):
                    ax0 = T.axis.spatial(T.int64(8), i0)
                    T.reads(rxplaceholder[ax0], rxplaceholder_1[()])
                    T.writes(T_add[ax0])
                    T_add[ax0] = rxplaceholder[ax0] + rxplaceholder_1[()]
        
        @T.prim_func
        def multiply1(rxplaceholder: T.Buffer[(T.int64(1), T.int64(8)), "float32"], rxplaceholder_1: T.Buffer[T.int64(8), "float32"], T_multiply: T.Buffer[(T.int64(1), T.int64(8)), "float32"]) -> None:
            # function attr dict
            T.func_attr({"global_symbol": "multiply1", "tir.noalias": True})
            # body
            # with T.block("root")
            for i0, i1 in T.grid(T.int64(1), T.int64(8)):
                with T.block("T_multiply"):
                    ax0 = T.axis.spatial(T.int64(1), i0)
                    ax1 = T.axis.spatial(T.int64(8), i1)
                    T.reads(rxplaceholder[ax0, ax1], rxplaceholder_1[ax1])
                    T.writes(T_multiply[ax0, ax1])
                    T_multiply[ax0, ax1] = rxplaceholder[ax0, ax1] * rxplaceholder_1[ax1]
        
        @T.prim_func
        def add1(rxplaceholder: T.Buffer[T.int64(8), "float32"], rxplaceholder_1: T.Buffer[T.int64(8), "float32"], T_add: T.Buffer[T.int64(8), "float32"]) -> None:
            # function attr dict
            T.func_attr({"global_symbol": "add1", "tir.noalias": True})
            # body
            # with T.block("root")
            for i0 in T.serial(T.int64(8)):
                with T.block("T_add"):
                    ax0 = T.axis.spatial(T.int64(8), i0)
                    T.reads(rxplaceholder[ax0], rxplaceholder_1[ax0])
                    T.writes(T_add[ax0])
                    T_add[ax0] = rxplaceholder[ax0] + rxplaceholder_1[ax0]
        
        @R.function
        def main(x: Tensor((1, 8), "float32"), gamma: Tensor((8,), "float32"), beta: Tensor((8,), "float32"), moving_mean: Tensor((8,), "float32"), moving_var: Tensor((8,), "float32")) -> Tensor(None, "float32", ndim = 2):
            # block 0
            with R.dataflow():
                lv = R.call_tir(add, (moving_var, 1e-05), (8,), dtype="float32")
                lv1 = R.call_tir(sqrt, (lv,), (8,), dtype="float32")
                lv2 = R.call_tir(divide, (1, lv1), (8,), dtype="float32")
                lv3 = R.call_tir(multiply, (lv2, gamma), (8,), dtype="float32")
                lv4 = R.call_tir(multiply1, (x, lv3), (1, 8), dtype="float32")
                lv5 = R.call_tir(negative, (moving_mean,), (8,), dtype="float32")
                lv6 = R.call_tir(multiply, (lv5, lv3), (8,), dtype="float32")
                lv7 = R.call_tir(add1, (lv6, beta), (8,), dtype="float32")
                lv8 = R.call_tir(add2, (lv4, lv7), (1, 8), dtype="float32")
                gv: Tensor((1, 8), "float32") = lv8
                R.output(gv)
            return gv
        
        @T.prim_func
        def multiply(rxplaceholder: T.Buffer[T.int64(8), "float32"], rxplaceholder_1: T.Buffer[T.int64(8), "float32"], T_multiply: T.Buffer[T.int64(8), "float32"]) -> None:
            # function attr dict
            T.func_attr({"global_symbol": "multiply", "tir.noalias": True})
            # body
            # with T.block("root")
            for i0 in T.serial(T.int64(8)):
                with T.block("T_multiply"):
                    ax0 = T.axis.spatial(T.int64(8), i0)
                    T.reads(rxplaceholder[ax0], rxplaceholder_1[ax0])
                    T.writes(T_multiply[ax0])
                    T_multiply[ax0] = rxplaceholder[ax0] * rxplaceholder_1[ax0]
    • Layout-sensitive attributes
      • axis : find the channel dimension
    • Layout-insensitive attributes: epsilon, center, scale
  • nn.max_pool2d

    /*! \brief Attributes for max pool operator */
    struct MaxPool2DAttrs : public tvm::AttrsNode<MaxPool2DAttrs> {
      Array<IndexExpr> pool_size;
      Array<IndexExpr> strides;
      Array<IndexExpr> padding;
      Array<IndexExpr> dilation;
      tvm::String layout;
      tvm::String out_layout;
      bool ceil_mode;
    
      TVM_DECLARE_ATTRS(MaxPool2DAttrs, "relay.attrs.MaxPool2DAttrs") {
        TVM_ATTR_FIELD(pool_size).describe("Size of the pooling windows.");
        TVM_ATTR_FIELD(strides)
            .set_default(Array<IndexExpr>({1, 1}))
            .describe("Specifies the strides of the convolution.");
        TVM_ATTR_FIELD(dilation)
            .set_default(Array<IndexExpr>({1, 1}))
            .describe("Specifies the dilation of the convolution.");
        TVM_ATTR_FIELD(padding)
            .set_default(Array<IndexExpr>({0, 0}))
            .describe(
                "If padding is non-zero, then the input is implicitly zero-padded"
                "Padding support both symmetric and asymmetric as"
                "one int : same padding used on all sides"
                "two int : bottom, right will use same padding as top, left"
                "four int : padding width in the order of (top, left, bottom, right)");
        TVM_ATTR_FIELD(layout).set_default("NCHW").describe(
            "Dimension ordering of input data. Can be 'NCHW', 'NHWC', etc."
            "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
            "dimensions respectively. Pooling is applied on the 'H' and"
            "'W' dimensions.");
        TVM_ATTR_FIELD(out_layout)
            .set_default("")
            .describe(
                "Dimension ordering of output data. Can be 'NCHW', 'NHWC', etc."
                "'N', 'C', 'H', 'W' stands for batch, channel, height, and width"
                "dimensions respectively. Pooling is applied on the 'H' and"
                "'W' dimensions.");
        TVM_ATTR_FIELD(ceil_mode).set_default(false).describe(
            "When true, will use ceil instead of floor to compute the output shape.");
      }
    };
    @T.prim_func
    def max_pool2d(rxplaceholder: T.Buffer[(T.int64(1), T.int64(2), T.int64(2), T.int64(4)), "float32"], tensor: T.Buffer[(1, 1, 1, 4), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "max_pool2d", "tir.noalias": True})
        # body
        # with T.block("root")
        for i0, i1, i2, i3, i4, i5 in T.grid(1, 1, 1, 4, 2, 2):
            with T.block("tensor"):
                ax0, ax1, ax2, ax3, rv0, rv1 = T.axis.remap("SSSSRR", [i0, i1, i2, i3, i4, i5])
                T.reads(rxplaceholder[ax0, ax1 + rv0, ax2 + rv1, ax3])
                T.writes(tensor[ax0, ax1, ax2, ax3])
                with T.init():
                    tensor[ax0, ax1, ax2, ax3] = T.float32(-3.4028234663852886e+38)
                tensor[ax0, ax1, ax2, ax3] = T.max(tensor[ax0, ax1, ax2, ax3], rxplaceholder[ax0, ax1 + rv0, ax2 + rv1, ax3])
    • Layout-sensitive attributes
      • layout, out_layout
    • Layout-insensitive attributes: pool_size, strides, dilation, padding, ceil_mode
  • transpose

    struct TransposeAttrs : public tvm::AttrsNode<TransposeAttrs> {
      Array<Integer> axes;
      TVM_DECLARE_ATTRS(TransposeAttrs, "relay.attrs.TransposeAttrs") {
        TVM_ATTR_FIELD(axes).describe("The target axes order, reverse order if not specified.");
      }
    };  // struct TransposeAttrs
    @T.prim_func
    def transpose(rxplaceholder: T.Buffer[(T.int64(14), T.int64(15), T.int64(16), T.int64(17)), "float32"], T_transpose: T.Buffer[(T.int64(14), T.int64(16), T.int64(17), T.int64(15)), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "transpose", "tir.noalias": True})
        # body
        # with T.block("root")
        for i0, i1, i2, i3 in T.grid(T.int64(14), T.int64(16), T.int64(17), T.int64(15)):
            with T.block("T_transpose"):
                ax0 = T.axis.spatial(T.int64(14), i0)
                ax1 = T.axis.spatial(T.int64(16), i1)
                ax2 = T.axis.spatial(T.int64(17), i2)
                ax3 = T.axis.spatial(T.int64(15), i3)
                T.reads(rxplaceholder[ax0, ax3, ax1, ax2])
                T.writes(T_transpose[ax0, ax1, ax2, ax3])
                T_transpose[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax3, ax1, ax2]
    • Layout-sensitive attributes
      • axes : compare the input buffer and new axis mapping
  • nn.pad

    /*! \brief Attributes used for the padding operator */
    struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
      Array<Array<Integer>> pad_width;
      tvm::String pad_mode;
    
      TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") {
        TVM_ATTR_FIELD(pad_width).describe(
            "Number of values padded to the edges of each axis, "
            "in the format of ((before_1, after_1), ..., (before_N, after_N))");
        TVM_ATTR_FIELD(pad_mode)
            .set_default("constant")
            .describe(
                "Padding type to use. \"constant\" pads with constant_value, "
                "\"edge\" pads using the edge values of the input array, "
                "\"reflect\" pads by reflecting values with respect to the edges.");
      }
    };
    @T.prim_func
    def pad(rxplaceholder: T.Buffer[(T.int64(1), T.int64(64), T.int64(56), T.int64(56)), "float32"], rxplaceholder_1: T.Buffer[(), "int32"], T_pad: T.Buffer[(T.int64(3), T.int64(68), T.int64(62), T.int64(64)), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "pad", "tir.noalias": True})
        # body
        # with T.block("root")
        T_cast = T.alloc_buffer([], dtype="float32")
        with T.block("T_cast"):
            vi = T.axis.spatial(1, 0)
            T.reads(rxplaceholder_1[()])
            T.writes(T_cast[()])
            T_cast[()] = T.cast(rxplaceholder_1[()], "float32")
        for i0, i1, i2, i3 in T.grid(T.int64(3), T.int64(68), T.int64(62), T.int64(64)):
            with T.block("T_pad"):
                ax0 = T.axis.spatial(T.int64(3), i0)
                ax1 = T.axis.spatial(T.int64(68), i1)
                ax2 = T.axis.spatial(T.int64(62), i2)
                ax3 = T.axis.spatial(T.int64(64), i3)
                T.reads(rxplaceholder[ax0 - T.int64(1), ax1 - T.int64(2), ax2 - T.int64(3), ax3 - T.int64(4)], T_cast[()])
                T.writes(T_pad[ax0, ax1, ax2, ax3])
                T_pad[ax0, ax1, ax2, ax3] = T.if_then_else(T.int64(1) <= ax0 and ax0 < T.int64(2) and T.int64(2) <= ax1 and ax1 < T.int64(66) and T.int64(3) <= ax2 and ax2 < T.int64(59) and T.int64(4) <= ax3 and ax3 < T.int64(60), rxplaceholder[ax0 - T.int64(1), ax1 - T.int64(2), ax2 - T.int64(3), ax3 - T.int64(4)], T_cast[()], dtype="float32")
    • Layout-sensitive attributes
      • pad_width : compare T.grid and new axis mapping
    • Layout-insensitive attributes: pad_mode
  • reshape

    struct ReshapeAttrs : public tvm::AttrsNode<ReshapeAttrs> {
      Array<Integer> newshape;
      bool allowzero;
      TVM_DECLARE_ATTRS(ReshapeAttrs, "relay.attrs.ReshapeAttrs") {
        TVM_ATTR_FIELD(newshape).describe(
            "The new shape. Should be compatible with the original shape.");
        TVM_ATTR_FIELD(allowzero).set_default(0).describe(
            "Whether to honor the value of zero in newshape.");
      }
    };  // struct ReshapeAttrs
    @T.prim_func
    def reshape(rxplaceholder: T.Buffer[(T.int64(1), T.int64(15), T.int64(4), T.int64(1)), "float32"], T_reshape: T.Buffer[(1, 30, 2), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "reshape", "tir.noalias": True})
        # body
        # with T.block("root")
        for i0, i1, i2 in T.grid(1, 30, 2):
            with T.block("T_reshape"):
                ax0, ax1, ax2 = T.axis.remap("SSS", [i0, i1, i2])
                T.reads(rxplaceholder[T.int64(0), (T.cast(ax1, "int64") * T.int64(2) + T.cast(ax2, "int64")) % T.int64(60) // T.int64(4), (T.cast(ax1, "int64") * T.int64(2) + T.cast(ax2, "int64")) % T.int64(4), T.int64(0)])
                T.writes(T_reshape[ax0, ax1, ax2])
                T_reshape[ax0, ax1, ax2] = rxplaceholder[T.int64(0), (T.cast(ax1, "int64") * T.int64(2) + T.cast(ax2, "int64")) % T.int64(60) // T.int64(4), (T.cast(ax1, "int64") * T.int64(2) + T.cast(ax2, "int64")) % T.int64(4), T.int64(0)]
    • Layout-sensitive attributes
      • newshape : see T.grid
    • Layout-insensitive attributes: allowzero
  • nn.split

    struct SplitAttrs : public tvm::AttrsNode<SplitAttrs> {
      ObjectRef indices_or_sections;
      int axis;
    
      TVM_DECLARE_ATTRS(SplitAttrs, "relay.attrs.SplitAttrs") {
        TVM_ATTR_FIELD(indices_or_sections)
            .describe(
                "Indices or sections to split into. Accepts an int or a tuple"
                "If indices_or_sections is an integer, the input will be divided equally"
                "along given axis. If such a split is not possible, an error is raised."
                "If indices_or_sections is a tuple of sorted integers,"
                "the entries indicate where along axis the array is split.");
        TVM_ATTR_FIELD(axis).set_default(0).describe("the axis to be splitted.");
      }
    };
    @T.prim_func
    def split(rxplaceholder: T.Buffer[(T.int64(1), T.int64(50), T.int64(50), T.int64(3)), "float32"], T_split: T.Buffer[(T.int64(1), 5, T.int64(50), T.int64(3)), "float32"], T_split_1: T.Buffer[(T.int64(1), 15, T.int64(50), T.int64(3)), "float32"], T_split_2: T.Buffer[(T.int64(1), 25, T.int64(50), T.int64(3)), "float32"], T_split_3: T.Buffer[(T.int64(1), T.int64(5), T.int64(50), T.int64(3)), "float32"]) -> None:
        # function attr dict
        T.func_attr({"global_symbol": "split", "tir.noalias": True})
        # body
        # with T.block("root")
        for i0, i1, i2, i3 in T.grid(T.int64(1), 5, T.int64(50), T.int64(3)):
            with T.block("T_split"):
                ax0 = T.axis.spatial(T.int64(1), i0)
                ax1 = T.axis.spatial(5, i1)
                ax2 = T.axis.spatial(T.int64(50), i2)
                ax3 = T.axis.spatial(T.int64(3), i3)
                T.reads(rxplaceholder[ax0, ax1, ax2, ax3])
                T.writes(T_split[ax0, ax1, ax2, ax3])
                T_split[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1, ax2, ax3]
        for i0, i1, i2, i3 in T.grid(T.int64(1), 15, T.int64(50), T.int64(3)):
            with T.block("T_split_1"):
                ax0 = T.axis.spatial(T.int64(1), i0)
                ax1 = T.axis.spatial(15, i1)
                ax2 = T.axis.spatial(T.int64(50), i2)
                ax3 = T.axis.spatial(T.int64(3), i3)
                T.reads(rxplaceholder[ax0, ax1 + 5, ax2, ax3])
                T.writes(T_split_1[ax0, ax1, ax2, ax3])
                T_split_1[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1 + 5, ax2, ax3]
        for i0, i1, i2, i3 in T.grid(T.int64(1), 25, T.int64(50), T.int64(3)):
            with T.block("T_split_2"):
                ax0 = T.axis.spatial(T.int64(1), i0)
                ax1 = T.axis.spatial(25, i1)
                ax2 = T.axis.spatial(T.int64(50), i2)
                ax3 = T.axis.spatial(T.int64(3), i3)
                T.reads(rxplaceholder[ax0, ax1 + 20, ax2, ax3])
                T.writes(T_split_2[ax0, ax1, ax2, ax3])
                T_split_2[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1 + 20, ax2, ax3]
        for i0, i1, i2, i3 in T.grid(T.int64(1), T.int64(5), T.int64(50), T.int64(3)):
            with T.block("T_split_3"):
                ax0 = T.axis.spatial(T.int64(1), i0)
                ax1 = T.axis.spatial(T.int64(5), i1)
                ax2 = T.axis.spatial(T.int64(50), i2)
                ax3 = T.axis.spatial(T.int64(3), i3)
                T.reads(rxplaceholder[ax0, ax1 + T.int64(45), ax2, ax3])
                T.writes(T_split_3[ax0, ax1, ax2, ax3])
                T_split_3[ax0, ax1, ax2, ax3] = rxplaceholder[ax0, ax1 + T.int64(45), ax2, ax3]
    • Layout-sensitive attributes
      • axis : Look T.reads and find an axis with extra offset in indices_or_sections . In this example, axis=1
    • Layout-insensitive attributes: indices_or_sections

Summary

  • Operator name seems trivial to achieve.
  • Operator kind is already supported.
  • Operator attributes seems achievable.
    • Layout-sensitive (e.g., layout, axis, padding) they would require the extension in the layout transformation to update them properly. Most of the cases, it seems quite straightforward how to update them.
    • Layout-insensitive (e.g., epsilon, slice_mode) they would not be affected by layout transformation. We can simply keep this information achieved during the lowering.

3. Suggestion for Relax Layout Planner

With access to op-level info in primfunc, there can be two options to make relax layout planner work with BYOC:

  • O1: Implement operator-level layout inference (equivalent of InferCorrectLayout) by peeking primfunc to perform TIR-based analysis
    • Steps:
      1. Loads the simplest PrimFunc implementation for an operator
      2. Transforms the layout as we want at TIR-level
      3. Update the operator based on the transformed primfunc
    • Pros:
      • Pure graph-level approach that can be applied before lowering
      • Manageable complexity
    • Cons: less generic compared to O2
  • O2: Support full raising support for operators
    • Pros: Very generic
    • Cons: Can be many sticky situations that requires users to specify rules when we manipulate primfunc (e.g., merge/split primfunc)
@masahi
Copy link
Contributor

masahi commented Feb 2, 2023

Some thought on this problem:

To guarantee the soundness of graph-level layout transformation, we need to be able to infer the new layout-sensitive attributes for all ops, 100% reliably. That might be difficult, especially at the beginning of development.

To guarantee the soundness while making gradual development possible, the TIR-level transformation pass can materialize transform_layout back to the original layout, when it encounters a TIR primfunc for which we cannot infer the attributes for the corresponding graph-level op. We need this behavior only when the TIR-level pass is invoked for the purpose of graph-level transformation ("graph mode").

This way, we don't have to worry about tricky ops like stride_slice until we commit to implement the attribute inference rule for it.

@psrivas2
Copy link
Contributor

psrivas2 commented Feb 2, 2023

That could be an interesting direction @masahi!
Clarification question: If we materialize the layout in such cases, but are not able to raise the transformed PrimFunc back to operator level, would BYOC backends be able to pattern match such operators?

@masahi
Copy link
Contributor

masahi commented Feb 2, 2023

Interesting question, for backends that can do its own layout transform internally (DNNL), TVM-side layout-transform is always optional (only improves performance). So pattern matching is agnostic to layouts. While other backends (CUTLASS) expects the right layout for pattern matching to succeed, so we need to break the graph there.

But I expect there would be no need to "infer" the new attributes for most compute-intensive ops that we want to offload to BYOC, since their layouts are typically fixed by users. We only need to worry about layout-sensitive ops in-between, like reduction and other shape-changing ops, that might not be offloaded to BYOC anyway.

@quic-sanirudh
Copy link
Contributor

Are there any plans to add support for extracting op info as mentioned here at some point? Was there a final decision on how this is going to be supported?

@psrivas2
Copy link
Contributor

Hello @quic-sanirudh!

Are there any plans to add support for extracting op info as mentioned here at some point?

Yes, extracting op info would be supported. As mentioned in the comments above, @masahi and @sunggg have also laid out some of the possible approaches. A lot of details need to be figured out still.

Was there a final decision on how this is going to be supported?

It is going to be supported, but the design of how exactly this would work has not been decided yet. If you are interested in this problem, please feel free to start discussion on the design here or in a separate thread.

@quic-sanirudh
Copy link
Contributor

quic-sanirudh commented Feb 28, 2023

Thanks @psrivas2 for the quick reply. I was curious on how this would work in the presence of fusion. Basically if we extract the op info before fusion, we have to assume that it'll only be valid until fusion is performed, or some way to iterate through the attributes of each individual op that is part of a fused prim_func.

I'll think a bit more about this and explain a bit more with an example.

@psrivas2
Copy link
Contributor

Looking forward to the example.
However, the use case that we have (graph operator level layout transformation), does not need to preserve this information in presence of fusion because graph operator level layout transformation for BYOC and the pattern matching for BYOC would happen before fusion.

@quic-sanirudh
Copy link
Contributor

@psrivas2, thanks for the reply. Actually my question is not related to just layout transformation. Please correct me if I'm mistaken here, but I thought the point of this op info extraction is to make it so that we have these op specific attributes available during other transformations.

For example, if we have the strides or padding information for conv2d/pooling ops, that might be useful for writing passes that target those specific ops. If the information is available during scheduling, that would help improve the scheduling as well (automated based on rules or manual).

Say for example a user would like to write a new shedule_rule that targets a particular type of op, such that based on its attributes, the number of tiles can be decided, that might turn out to be really useful (just a random thought, I don't have a concrete example yet). My idea was, if we need something like that, we might need to extract op attributes through a pass before fusion, and retain it in some way after fusion, perhaps in the form of attributes to that fused PrimFunc.

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants