diff --git a/tests/test_nn_layers.py b/tests/test_nn_layers.py new file mode 100644 index 0000000..e7f3620 --- /dev/null +++ b/tests/test_nn_layers.py @@ -0,0 +1,26 @@ +# +# Copyright 2020 NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import torch + +from variantworks.layers.attention import Attention + + +def test_attention_layer(): + input_tensor = torch.zeros((10, 10, 5), dtype=torch.float32) + attn_layer = Attention(5) + out, _ = attn_layer(input_tensor, input_tensor) + assert(torch.all(input_tensor.eq(out))) diff --git a/variantworks/layers/attention.py b/variantworks/layers/attention.py new file mode 100644 index 0000000..3efc523 --- /dev/null +++ b/variantworks/layers/attention.py @@ -0,0 +1,126 @@ +# +# Copyright 2020 NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +# The implementation in this file is adopted from a 3rd party repository with BSD 3-Clause License. +# BSD 3-Clause License +# +# Copyright (c) James Bradbury and Soumith Chintala 2016, +# All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# * Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# * Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# * Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +"""Attention related layers.""" + +import torch +import torch.nn as nn + + +class Attention(nn.Module): + """Applies attention mechanism on the `context` using the `query`. + + Implementation from: https://pytorchnlp.readthedocs.io/en/latest/_modules/torchnlp/nn/attention.html + """ + + def __init__(self, dimensions, attention_type='general'): + """Construct an Attention layer. + + Args: + dimensions (int): Dimensionality of the query and context. + attention_type (str, optional): How to compute the attention score: + + * dot: :math:`score(H_j,q) = H_j^T q` + * general: :math:`score(H_j, q) = H_j^T W_a q` + """ + super(Attention, self).__init__() + + if attention_type not in ['dot', 'general']: + raise ValueError('Invalid attention type selected.') + + self.attention_type = attention_type + if self.attention_type == 'general': + self.linear_in = nn.Linear(dimensions, dimensions, bias=False) + + self.linear_out = nn.Linear(dimensions * 2, dimensions, bias=False) + self.softmax = nn.Softmax(dim=-1) + self.tanh = nn.Tanh() + + def forward(self, query, context): + """Forward method. + + Args: + query : Sequence of queries to query the \ + context [batch size, output length, dimensions]. + context : Data over which to apply the attention \ + mechanism [batch size, query length, dimensions]. + + Returns: + Tuple with output and weights: + * output : Tensor containing the attended features [batch size, output length, dimensions]. + * weights : Tensor containing attention weights [batch size, output length, query length]. + """ + batch_size, output_len, dimensions = query.size() + query_len = context.size(1) + + if self.attention_type == "general": + query = query.reshape(batch_size * output_len, dimensions) + query = self.linear_in(query) + query = query.reshape(batch_size, output_len, dimensions) + + # (batch_size, output_len, dimensions) * (batch_size, query_len, dimensions) -> + # (batch_size, output_len, query_len) + attention_scores = torch.bmm(query, context.transpose(1, 2).contiguous()) + + # Compute weights across every context sequence + attention_scores = attention_scores.view(batch_size * output_len, query_len) + attention_weights = self.softmax(attention_scores) + attention_weights = attention_weights.view(batch_size, output_len, query_len) + + # (batch_size, output_len, query_len) * (batch_size, query_len, dimensions) -> + # (batch_size, output_len, dimensions) + mix = torch.bmm(attention_weights, context) + + # concat -> (batch_size * output_len, 2*dimensions) + combined = torch.cat((mix, query), dim=2) + combined = combined.view(batch_size * output_len, 2 * dimensions) + + # Apply linear_out on every 2nd dimension of concat + # output -> (batch_size, output_len, dimensions) + output = self.linear_out(combined).view(batch_size, output_len, dimensions) + output = self.tanh(output) + + return output, attention_weights diff --git a/variantworks/networks.py b/variantworks/networks.py index 5b138d7..53e7fc4 100644 --- a/variantworks/networks.py +++ b/variantworks/networks.py @@ -24,6 +24,8 @@ from nemo.core.neural_types import NeuralType, ChannelType, LogitsType from nemo.core.neural_factory import DeviceType +from variantworks.layers.attention import Attention + class AlexNet(TrainableNM): """A Neural Module for AlexNet.""" @@ -254,3 +256,68 @@ def forward(self, encoding): encoding = self.classifier(encoding) encoding = F.softmax(encoding, dim=2) return encoding + + +class ConsensusAttention(TrainableNM): + """A Neural Module for training a Consensus Attention Model.""" + + @property + @add_port_docs() + def input_ports(self): + """Return definitions of module input ports. + + Returns: + Module input ports. + """ + return { + "encoding": NeuralType(('B', 'W', 'C'), ChannelType()), + } + + @property + @add_port_docs() + def output_ports(self): + """Return definitions of module output ports. + + Returns: + Module output ports. + """ + return { + # Variant type + 'output_logit': NeuralType(('B', 'W', 'D'), LogitsType()), + } + + def __init__(self, sequence_length, input_feature_size, num_output_logits): + """Construct an Consensus RNN NeMo instance. + + Args: + sequence_length : Length of sequence to feed into RNN. + input_feature_size : Length of input feature set. + num_output_logits : Number of output classes of classifier. + + Returns: + Instance of class. + """ + super().__init__() + self.num_output_logits = num_output_logits + + self.attn = Attention(input_feature_size) + self.gru = nn.GRU(input_feature_size, 16, 1, batch_first=True, bidirectional=True) + self.classifier = nn.Linear(32, self.num_output_logits) + + self._device = torch.device( + "cuda" if self.placement == DeviceType.GPU else "cpu") + self.to(self._device) + + def forward(self, encoding): + """Abstract function to run the network. + + Args: + encoding : Input sequence to run network on. + + Returns: + Output of forward pass. + """ + encoding, _ = self.attn(encoding, encoding) + encoding, _ = self.gru(encoding) + encoding = self.classifier(encoding) + return encoding