Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Rust] More Rust bindings for Attrs #7082

Merged
merged 6 commits into from
Dec 26, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions include/tvm/relay/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -715,7 +715,7 @@ struct AvgPool2DAttrs : public tvm::AttrsNode<AvgPool2DAttrs> {
Array<IndexExpr> pool_size;
Array<IndexExpr> strides;
Array<IndexExpr> padding;
std::string layout;
tvm::String layout;
bool ceil_mode;
bool count_include_pad;

Expand Down Expand Up @@ -959,8 +959,8 @@ struct FIFOBufferAttrs : public tvm::AttrsNode<FIFOBufferAttrs> {
struct UpSamplingAttrs : public tvm::AttrsNode<UpSamplingAttrs> {
double scale_h;
double scale_w;
std::string layout;
std::string method;
tvm::String layout;
tvm::String method;
bool align_corners;

TVM_DECLARE_ATTRS(UpSamplingAttrs, "relay.attrs.UpSamplingAttrs") {
Expand Down
36 changes: 36 additions & 0 deletions rust/tvm/src/ir/relay/attrs/nn.rs
Original file line number Diff line number Diff line change
Expand Up @@ -106,3 +106,39 @@ pub struct BatchNormAttrsNode {
pub center: bool,
pub scale: bool,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "LeakyReluAttrs"]
#[type_key = "relay.attrs.LeakyReluAttrs"]
pub struct LeakyReluAttrsNode {
pub base: BaseAttrsNode,
pub alpha: f64,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "AvgPool2DAttrs"]
#[type_key = "relay.attrs.AvgPool2DAttrs"]
pub struct AvgPool2DAttrsNode {
pub base: BaseAttrsNode,
pub pool_size: Array<IndexExpr>,
pub strides: Array<IndexExpr>,
pub padding: Array<IndexExpr>,
pub layout: TString,
pub ceil_mode: bool,
pub count_include_pad: bool,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "UpSamplingAttrs"]
#[type_key = "relay.attrs.UpSamplingAttrs"]
pub struct UpSamplingAttrsNode {
pub base: BaseAttrsNode,
pub scale_h: f64,
pub scale_w: f64,
pub layout: TString,
pub method: TString,
pub align_corners: bool,
}
52 changes: 52 additions & 0 deletions rust/tvm/src/ir/relay/attrs/transform.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,8 +18,13 @@
*/

use crate::ir::attrs::BaseAttrsNode;
use crate::ir::PrimExpr;
use crate::runtime::array::Array;
use crate::runtime::ObjectRef;
use tvm_macros::Object;

type IndexExpr = PrimExpr;

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "ExpandDimsAttrs"]
Expand All @@ -29,3 +34,50 @@ pub struct ExpandDimsAttrsNode {
pub axis: i32,
pub num_newaxis: i32,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "ConcatenateAttrs"]
#[type_key = "relay.attrs.ConcatenateAttrs"]
pub struct ConcatenateAttrsNode {
pub base: BaseAttrsNode,
pub axis: i32,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "ReshapeAttrs"]
#[type_key = "relay.attrs.ReshapeAttrs"]
pub struct ReshapeAttrsNode {
pub base: BaseAttrsNode,
pub newshape: Array<IndexExpr>,
pub reverse: bool,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "SplitAttrs"]
#[type_key = "relay.attrs.SplitAttrs"]
pub struct SplitAttrsNode {
pub base: BaseAttrsNode,
pub indices_or_sections: ObjectRef,
pub axis: i32,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "TransposeAttrs"]
#[type_key = "relay.attrs.TransposeAttrs"]
pub struct TransposeAttrsNode {
pub base: BaseAttrsNode,
pub axes: Array<IndexExpr>,
}

#[repr(C)]
#[derive(Object, Debug)]
#[ref_name = "SqueezeAttrs"]
#[type_key = "relay.attrs.SqueezeAttrs"]
pub struct SqueezeAttrsNode {
pub base: BaseAttrsNode,
pub axis: Array<IndexExpr>,
}