@@ -401,8 +401,7 @@ def __init__(
401401 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
402402 r"""
403403 Args:
404- hidden_states
405- torch.Tensor of *encoder* input embeddings.
404+ hidden_states: torch.Tensor of *encoder* input embeddings.
406405 Returns:
407406 Encoder layer output torch.Tensor
408407 """
@@ -490,10 +489,8 @@ def forward(
490489 ) -> torch .Tensor :
491490 r"""
492491 Args:
493- decoder_hidden_states
494- torch.Tensor of *decoder* input embeddings.
495- encoder_hidden_states
496- torch.Tensor of *encoder* input embeddings.
492+ decoder_hidden_states: torch.Tensor of *decoder* input embeddings.
493+ encoder_hidden_states: torch.Tensor of *encoder* input embeddings.
497494 Returns:
498495 Decoder layer output torch.Tensor
499496 """
@@ -584,12 +581,10 @@ def forward(
584581 ) -> torch .Tensor :
585582 r"""
586583 Args:
587- input_ids
588- Indices of *encoder* input sequence tokens in the vocabulary.
589- Padding will be ignored by default should you
590- provide it.
591- positions
592- Positions of *encoder* input sequence tokens.
584+ input_ids: Indices of *encoder* input sequence tokens in the
585+ vocabulary.
586+ Padding will be ignored by default should you provide it.
587+ positions: Positions of *encoder* input sequence tokens.
593588 Returns:
594589 Decoder output torch.Tensor
595590 """
@@ -663,14 +658,11 @@ def forward(
663658 ) -> torch .Tensor :
664659 r"""
665660 Args:
666- decoder_input_ids
667- Indices of *decoder* input sequence tokens in the vocabulary.
668- Padding will be ignored by default should you
669- provide it.
670- decoder_positions
671- Positions of *decoder* input sequence tokens.
672- encoder_hidden_states:
673- Tensor of encoder output embeddings
661+ decoder_input_ids: Indices of *decoder* input sequence tokens
662+ in the vocabulary.
663+ Padding will be ignored by default should you provide it.
664+ decoder_positions: Positions of *decoder* input sequence tokens.
665+ encoder_hidden_states: Tensor of encoder output embeddings.
674666 Returns:
675667 Decoder output torch.Tensor
676668 """
@@ -732,16 +724,13 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
732724 encoder_positions : torch .Tensor ) -> torch .Tensor :
733725 r"""
734726 Args:
735- input_ids
736- Indices of *decoder* input sequence tokens in the vocabulary.
737- Padding will be ignored by default should you
738- provide it.
739- positions
740- Positions of *decoder* input sequence tokens.
741- encoder_input_ids
742- Indices of *encoder* input sequence tokens in the vocabulary.
743- encoder_positions:
744- Positions of *encoder* input sequence tokens.
727+ input_ids: Indices of *decoder* input sequence tokens
728+ in the vocabulary.
729+ Padding will be ignored by default should you provide it.
730+ positions: Positions of *decoder* input sequence tokens.
731+ encoder_input_ids: Indices of *encoder* input sequence tokens
732+ in the vocabulary.
733+ encoder_positions: Positions of *encoder* input sequence tokens.
745734 Returns:
746735 Model output torch.Tensor
747736 """
@@ -848,14 +837,10 @@ def forward(
848837 ) -> torch .Tensor :
849838 r"""
850839 Args:
851- input_ids
852- torch.Tensor of *decoder* input token ids.
853- positions
854- torch.Tensor of *decoder* position indices.
855- encoder_input_ids
856- torch.Tensor of *encoder* input token ids.
857- encoder_positions
858- torch.Tensor of *encoder* position indices
840+ input_ids: torch.Tensor of *decoder* input token ids.
841+ positions: torch.Tensor of *decoder* position indices.
842+ encoder_input_ids: torch.Tensor of *encoder* input token ids.
843+ encoder_positions: torch.Tensor of *encoder* position indices.
859844 Returns:
860845 Output torch.Tensor
861846 """
@@ -912,8 +897,7 @@ class MBartEncoderLayer(BartEncoderLayer):
912897 def forward (self , hidden_states : torch .Tensor ) -> torch .Tensor :
913898 r"""
914899 Args:
915- hidden_states
916- torch.Tensor of *encoder* input embeddings.
900+ hidden_states: torch.Tensor of *encoder* input embeddings.
917901 Returns:
918902 Encoder layer output torch.Tensor
919903 """
@@ -1035,12 +1019,10 @@ def forward(
10351019 ) -> torch .Tensor :
10361020 r"""
10371021 Args:
1038- input_ids
1039- Indices of *encoder* input sequence tokens in the vocabulary.
1040- Padding will be ignored by default should you
1041- provide it.
1042- positions
1043- Positions of *encoder* input sequence tokens.
1022+ input_ids: Indices of *encoder* input sequence tokens in the
1023+ vocabulary.
1024+ Padding will be ignored by default should you provide it.
1025+ positions: Positions of *encoder* input sequence tokens.
10441026 Returns:
10451027 Decoder output torch.Tensor
10461028 """
@@ -1116,14 +1098,11 @@ def forward(
11161098 ) -> torch .Tensor :
11171099 r"""
11181100 Args:
1119- decoder_input_ids
1120- Indices of *decoder* input sequence tokens in the vocabulary.
1121- Padding will be ignored by default should you
1122- provide it.
1123- decoder_positions
1124- Positions of *decoder* input sequence tokens.
1125- encoder_hidden_states:
1126- Tensor of encoder output embeddings
1101+ decoder_input_ids: Indices of *decoder* input sequence tokens
1102+ in the vocabulary.
1103+ Padding will be ignored by default should you provide it.
1104+ decoder_positions: Positions of *decoder* input sequence tokens.
1105+ encoder_hidden_states: Tensor of encoder output embeddings.
11271106 Returns:
11281107 Decoder output torch.Tensor
11291108 """
@@ -1185,16 +1164,13 @@ def forward(self, input_ids: torch.Tensor, positions: torch.Tensor,
11851164 encoder_positions : torch .Tensor ) -> torch .Tensor :
11861165 r"""
11871166 Args:
1188- input_ids
1189- Indices of *decoder* input sequence tokens in the vocabulary.
1190- Padding will be ignored by default should you
1191- provide it.
1192- positions
1193- Positions of *decoder* input sequence tokens.
1194- encoder_input_ids
1195- Indices of *encoder* input sequence tokens in the vocabulary.
1196- encoder_positions:
1197- Positions of *encoder* input sequence tokens.
1167+ input_ids: Indices of *decoder* input sequence tokens
1168+ in the vocabulary.
1169+ Padding will be ignored by default should you provide it.
1170+ positions: Positions of *decoder* input sequence tokens.
1171+ encoder_input_ids: Indices of *encoder* input sequence tokens
1172+ in the vocabulary.
1173+ encoder_positions: Positions of *encoder* input sequence tokens.
11981174 Returns:
11991175 Model output torch.Tensor
12001176 """
0 commit comments