Skip to content

Missing bias terms in TransformerConv equations? #10130

@pauvilasoler

Description

@pauvilasoler

📚 Describe the documentation issue

Hello!

Maybe I am missing something here (that's partially why I am writting this) but shouldn't there be bias terms in the equations for the attention mechanisms in the TransformerConv? (https://pytorch-geometric.readthedocs.io/en/latest/generated/torch_geometric.nn.conv.TransformerConv.html)

As far as I understand, both in the paper as well as in the code implementation the bias terms are included:

Paper:

Image

Code implementation (from what I see in torch_geometric.nn.dense.linear, the default for Linear() is to include a bias term):

Image

(A related but different question would also be why there bias=False for lin_edge but it is True for lin_query, lin_key and lin_value if that is not how it's presented in the original paper)

Suggest a potential alternative/fix

If it is the case that bias terms should be included, a possible fix would be (where I have just added a bias term with subscript 1 for the query and 2 for the key keeping in line with the subscripts of the weight matrices) :

r"""The graph transformer operator from the `"Masked Label Prediction:
Unified Message Passing Model for Semi-Supervised Classification"
<https://arxiv.org/abs/2009.03509>`_ paper

.. math::
    \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
    \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \mathbf{W}_2 \mathbf{x}_{j},

where the attention coefficients :math:`\alpha_{i,j}` are computed via
multi-head dot product attention:

.. math::
    \alpha_{i,j} = \textrm{softmax} \left(
    \frac{(\mathbf{W}_3\mathbf{x}_i + \mathbf{b}_1)^{\top} (\mathbf{W}_4\mathbf{x}_j+\mathbf_{b}_2})}
    {\sqrt{d}} \right)

Args:
    in_channels (int or tuple): Size of each input sample, or :obj:`-1` to
        derive the size from the first input(s) to the forward method.
        A tuple corresponds to the sizes of source and target
        dimensionalities.
    out_channels (int): Size of each output sample.
    heads (int, optional): Number of multi-head-attentions.
        (default: :obj:`1`)
    concat (bool, optional): If set to :obj:`False`, the multi-head
        attentions are averaged instead of concatenated.
        (default: :obj:`True`)
    beta (bool, optional): If set, will combine aggregation and
        skip information via

        .. math::
            \mathbf{x}^{\prime}_i = \beta_i \mathbf{W}_1 \mathbf{x}_i +
            (1 - \beta_i) \underbrace{\left(\sum_{j \in \mathcal{N}(i)}
            \alpha_{i,j} \mathbf{W}_2 \vec{x}_j \right)}_{=\mathbf{m}_i}

        with :math:`\beta_i = \textrm{sigmoid}(\mathbf{w}_5^{\top}
        [ \mathbf{x}_i, \mathbf{m}_i, \mathbf{x}_i - \mathbf{m}_i ])`
        (default: :obj:`False`)
    dropout (float, optional): Dropout probability of the normalized
        attention coefficients which exposes each node to a stochastically
        sampled neighborhood during training. (default: :obj:`0`)
    edge_dim (int, optional): Edge feature dimensionality (in case
        there are any). Edge features are added to the keys after
        linear transformation, that is, prior to computing the
        attention dot product. They are also added to final values
        after the same linear transformation. The model is:

        .. math::
            \mathbf{x}^{\prime}_i = \mathbf{W}_1 \mathbf{x}_i +
            \sum_{j \in \mathcal{N}(i)} \alpha_{i,j} \left(
            \mathbf{W}_2 \mathbf{x}_{j} + \mathbf{W}_6 \mathbf{e}_{ij}
            \right),

        where the attention coefficients :math:`\alpha_{i,j}` are now
        computed via:

        .. math::
            \alpha_{i,j} = \textrm{softmax} \left(
            \frac{(\mathbf{W}_3\mathbf{x}_i)^{\top}
            (\mathbf{W}_4\mathbf{x}_j + \mathbf{W}_6 \mathbf{e}_{ij})}
            {\sqrt{d}} \right)

        (default :obj:`None`)
    bias (bool, optional): If set to :obj:`False`, the layer will not learn
        an additive bias. (default: :obj:`True`)
    root_weight (bool, optional): If set to :obj:`False`, the layer will
        not add the transformed root node features to the output and the
        option  :attr:`beta` is set to :obj:`False`. (default: :obj:`True`)
    **kwargs (optional): Additional arguments of
        :class:`torch_geometric.nn.conv.MessagePassing`.
"""

Thanks!

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions