Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Unity][Frontend][NN] Add Timesteps layer to NN Module API #15603

Merged
merged 12 commits into from
Aug 25, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 19 additions & 1 deletion include/tvm/relax/attrs/nn.h
Original file line number Diff line number Diff line change
Expand Up @@ -375,7 +375,7 @@ struct DropoutAttrs : public tvm::AttrsNode<DropoutAttrs> {
}
}; // struct DropoutAttrs

/*! \brief Attributes used in dropout operator */
/*! \brief Attributes used in Attention operator */
struct AttentionAttrs : public tvm::AttrsNode<AttentionAttrs> {
Optional<FloatImm> scale;
Optional<String> causal_mask;
Expand All @@ -388,6 +388,24 @@ struct AttentionAttrs : public tvm::AttrsNode<AttentionAttrs> {
}
}; // struct AttentionAttrs

/*! \brief Attributes used for the padding operator */
struct PadAttrs : public tvm::AttrsNode<PadAttrs> {
Array<Integer> pad_width;
tvm::String pad_mode;

TVM_DECLARE_ATTRS(PadAttrs, "relay.attrs.PadAttrs") {
TVM_ATTR_FIELD(pad_width).describe(
"Number of values padded to the edges of each axis, "
"in the format of (before_1, after_1, ..., before_N, after_N)");
TVM_ATTR_FIELD(pad_mode)
.set_default("constant")
.describe(
"Padding type to use. \"constant\" pads with constant_value, "
"\"edge\" pads using the edge values of the input array, "
"\"reflect\" pads by reflecting values with respect to the edges.");
}
};

} // namespace relax
} // namespace tvm

Expand Down
97 changes: 97 additions & 0 deletions python/tvm/relax/frontend/nn/modules.py
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,15 @@ def _print(_, array: NDArray) -> None:
print(f"effect.print: shape = {array.shape}, dtype = {array.dtype}, data =\n{array}")


class SiLU(Module):
"""
Module for SiLU activation layer.
"""

def forward(self, x: Tensor):
return op.silu(x)


class Linear(Module):
"""
Module for linear layer.
Expand Down Expand Up @@ -363,3 +372,91 @@ def forward(self, x: Tensor):
),
shape=[*x.shape, self.dim], # TODO(@junrushao): revisit and remove self.dim
)


class TimestepEmbedding(Module):
"""
Module for HF TimestepEmbedding layer.
"""

def __init__(
self,
in_channels: int,
time_embed_dim: int,
act_fn: str = "silu",
out_dim: int = None,
post_act_fn: Optional[str] = None,
cond_proj_dim: Optional[int] = None,
):
self.linear_1 = Linear(in_channels, time_embed_dim)

if cond_proj_dim is not None:
self.cond_proj = Linear(cond_proj_dim, in_channels, bias=False)
else:
self.cond_proj = None

assert act_fn == "silu", "Only SiLU activations are supported."
self.act = SiLU()

if out_dim is not None:
time_embed_dim_out = out_dim
else:
time_embed_dim_out = time_embed_dim

self.linear_2 = Linear(time_embed_dim, time_embed_dim_out)

if post_act_fn is None:
self.post_act = None
else:
assert self.post_act == "silu", "Only SiLU post-activation supported."
self.post_act = SiLU()

def forward(self, sample: Tensor, condition: Optional[Tensor] = None):
"""
Forward method for TimestepEmbedding layer.

Parameters
----------
sample : Tensor
The input timestep that should be looked up.
condition : Optional[Tensor]
Optional additional projection matrix.

Returns
-------
ret : Tensor
The resulting embedding lookup for the input sample.
"""
if condition is not None:
sample = sample + self.cond_proj(condition)
sample = self.linear_1(sample)

if self.act is not None:
sample = self.act(sample)

sample = self.linear_2(sample)

if self.post_act is not None:
sample = self.post_act(sample)
return sample


class Timesteps(Module):
"""
Module for HF timesteps layer.
"""

def __init__(
self, num_channels: int, flip_sin_to_cos: bool = False, downscale_freq_shift: float = 1
):
self.num_channels = num_channels
self.flip_sin_to_cos = flip_sin_to_cos
self.downscale_freq_shift = downscale_freq_shift

def forward(self, x: Tensor):
return op.get_timestep_embedding(
x,
embedding_dim=self.num_channels,
flip_sin_to_cos=self.flip_sin_to_cos,
downscale_freq_shift=self.downscale_freq_shift,
)
67 changes: 63 additions & 4 deletions python/tvm/relax/frontend/nn/op.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
# under the License.
# pylint: disable=too-many-lines,invalid-name,protected-access
"""nn.Tensor operators."""
import math
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union

from tvm import tir as _tir
Expand Down Expand Up @@ -662,12 +663,10 @@ def full(
result : Tensor
The result tensor.
"""
from tvm import relax # pylint: disable=import-outside-toplevel

if isinstance(fill_value, (_tir.FloatImm, _tir.IntImm)):
fill_value = relax.const(fill_value.value, dtype=dtype)
fill_value = rx.const(fill_value.value, dtype=dtype)
elif isinstance(fill_value, (int, float)):
fill_value = relax.const(fill_value, dtype=dtype)
fill_value = rx.const(fill_value, dtype=dtype)
else:
fill_value = fill_value._expr
return _wrap_nested(_op.full(shape, fill_value, dtype), name)
Expand Down Expand Up @@ -699,6 +698,66 @@ def zeros(
return _wrap_nested(_op.zeros(shape, dtype), name)


def get_timestep_embedding(
x: Tensor,
embedding_dim: int,
flip_sin_to_cos: bool = False,
downscale_freq_shift: float = 1,
scale: float = 1,
max_period: int = 10000,
name: str = "get_timestep_embedding",
) -> Tensor:
"""
Timestep calculation as described in Denoising Diffusion Probabilistic Models.

Parameters
----------
x : Tensor
A 1-D Tensor of N indices.
embedding_dim : int
The dimension of the output.
flip_sin_to_cos : bool
If True, change the order of sine and cosine embeddings.
downscale_freq_shift : float
Adjusts the frequency of the sinusoidal sampling.
scale : float
Weight adjustment for embedding magnitude.
max_period : int
Controls the minimum frequency of the embeddings.
name : str
The name to label this operator with.

Returns
-------
result : Tensor
[N x dim] Tensor of positional embeddings.
"""
timesteps = _op.astype(x._expr, "float32")

half_dim = embedding_dim // 2
exponent = rx.const(-math.log(max_period), "float32") * _op.arange(
start=0, end=half_dim, dtype="float32"
)
exponent = exponent / (rx.const(half_dim - downscale_freq_shift, "float32"))

emb = _op.exp(exponent)
emb = _op.expand_dims(timesteps, 1) * _op.expand_dims(emb, 0)
# Scale embeddings
if scale != 1:
emb = rx.const(scale, "float32") * emb

# Concat sine and cosine embeddings.
if flip_sin_to_cos:
emb = _op.concat([_op.cos(emb), _op.sin(emb)], axis=-1)
else:
emb = _op.concat([_op.sin(emb), _op.cos(emb)], axis=-1)

# Zero pad
if embedding_dim % 2 == 1:
emb = _op.nn.pad(emb, (0, 1, 0, 0))
return _wrap_nested(emb, name)


def tensor_expr_op(
tensor_expr_func: Callable,
name_hint: str,
Expand Down
31 changes: 30 additions & 1 deletion python/tvm/relax/op/nn/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
from tvm import DataType
from tvm.tir import FloatImm

from ...expr import Expr
from ...expr import Expr, const
from . import _ffi_api


Expand Down Expand Up @@ -413,6 +413,35 @@ def conv2d_transpose(
)


def pad(data, pad_width, pad_value=0, pad_mode="constant"):
r"""Padding

This operator takes in a tensor and pads each axis by the specified
widths using the specified value.

Parameters
----------
data: relax.Expr
The input data to the operator
pad_width: tuple of <tuple of <int>>, required
Number of values padded to the edges of each axis, in the format
of ((before_1, after_1), ..., (before_N, after_N))
pad_value: float
The value used for padding
pad_mode: 'constant', 'edge', 'reflect'
'constant' pads with constant_value pad_value
'edge' pads using the edge values of the input array
'reflect' pads by reflecting values with respect to the edge
Returns
-------
result : relax.Expr
The computed result.
"""
if not isinstance(pad_value, Expr):
pad_value = const(pad_value)
return _ffi_api.pad(data, pad_width, pad_value, pad_mode)


def max_pool2d(
data: Expr,
pool_size: Union[int, Tuple[int, int]] = (1, 1),
Expand Down
16 changes: 16 additions & 0 deletions python/tvm/relax/transform/legalize_ops/nn.py
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,22 @@ def _nn_conv2d_transpose(bb: BlockBuilder, call: Call) -> Expr:
)


@register_legalize("relax.nn.pad")
def _nn_pad(bb: BlockBuilder, call: Call) -> Expr:
# Unpack pad_width into two separate lists for topi.
pad_widths = call.attrs.pad_width
pad_before = pad_widths[::2]
pad_after = pad_widths[1::2]
return bb.call_te(
topi.nn.pad,
call.args[0],
pad_before=pad_before,
pad_after=pad_after,
pad_value=float(call.args[1].data.numpy()),
primfunc_name_hint="pad",
)


@register_legalize("relax.nn.max_pool2d")
def _nn_max_pool2d(bb: BlockBuilder, call: Call) -> Expr:
if call.attrs.out_layout != call.attrs.layout:
Expand Down
46 changes: 46 additions & 0 deletions src/relax/op/nn/nn.cc
Original file line number Diff line number Diff line change
Expand Up @@ -123,6 +123,52 @@ TVM_REGISTER_OP("relax.nn.log_softmax")
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoSoftmax)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.nn.pad */
TVM_REGISTER_NODE_TYPE(PadAttrs);

Expr pad(Expr data, Array<Integer> pad_width, Expr pad_value, String pad_mode) {
auto attrs = make_object<PadAttrs>();
attrs->pad_width = std::move(pad_width);
attrs->pad_mode = std::move(pad_mode);
static const Op& op = Op::Get("relax.nn.pad");
return Call(op, {data, pad_value}, Attrs(attrs), {});
}

TVM_REGISTER_GLOBAL("relax.op.nn.pad").set_body_typed(pad);

StructInfo InferStructInfoPad(const Call& call, const BlockBuilder& ctx) {
Array<TensorStructInfo> input_sinfo = GetInputTensorStructInfo(call, ctx);
const auto* attrs = call->attrs.as<PadAttrs>();
int ndim = input_sinfo[0]->ndim;
Array<Integer> pad_width = attrs->pad_width;
ICHECK(static_cast<int>(pad_width.size()) == 2 * ndim) << "Illegal pad_width";

Array<PrimExpr> out_shape;
if (input_sinfo[0]->shape.defined()) {
// Compute output shape by adding corresponding pad width to each axis.
const auto* data_shape = input_sinfo[0]->shape.as<ShapeExprNode>();
for (int i = 0; i < ndim; i++) {
// Sum pad width for this axis.
PrimExpr added_width = pad_width[2 * i] + pad_width[(2 * i) + 1];
const PrimExpr current_width = data_shape->values[i];
out_shape.push_back(current_width + added_width);
}
} else {
// Shape isnt defined, best we can do is return ndim and dtype.
return TensorStructInfo(input_sinfo[0]->dtype, ndim);
}
return TensorStructInfo(ShapeExpr(out_shape), input_sinfo[0]->dtype);
}

TVM_REGISTER_OP("relax.nn.pad")
.set_num_inputs(2)
.add_argument("data", "Tensor", "The input tensor.")
.add_argument("pad_value", "Tensor", "The value to fill in padded area with.")
.set_attrs_type<PadAttrs>()
.set_attr<FInferStructInfo>("FInferStructInfo", InferStructInfoPad)
.set_attr<Bool>("FPurity", Bool(true));

/* relax.nn.batchnorm */
bool NormCheckDtypeAndShape(const Call& call, const BlockBuilder& ctx,
const Array<TensorStructInfo>& input_sinfo, Array<Integer> axes) {
Op op = Downcast<Op>(call->op);
Expand Down
Loading