diff --git a/torch_geometric/nn/glob/attention.py b/torch_geometric/nn/glob/attention.py index b7ea0a48ec43..876c48d97e49 100644 --- a/torch_geometric/nn/glob/attention.py +++ b/torch_geometric/nn/glob/attention.py @@ -38,7 +38,7 @@ class GlobalAttention(torch.nn.Module): node features :math:`(|\mathcal{V}|, F)`, batch vector :math:`(|\mathcal{V}|)` *(optional)* - **output:** - graph features :math:`(|\mathcal{G}|, 2 * F)` where + graph features :math:`(|\mathcal{G}|, F)` where :math:`|\mathcal{G}|` denotes the number of graphs in the batch """ def __init__(self, gate_nn: torch.nn.Module,