Skip to content

Commit

Permalink
support pooling to be compatible with ST.
Browse files Browse the repository at this point in the history
  • Loading branch information
SeanLee97 committed Oct 19, 2024
1 parent 9acf28f commit c371424
Showing 1 changed file with 6 additions and 6 deletions.
12 changes: 6 additions & 6 deletions angle_emb/angle.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ def get_pooling(outputs: torch.Tensor,
:param outputs: torch.Tensor. Model outputs (without pooling)
:param inputs: Dict. Model inputs
:param pooling_strategy: str. Pooling strategy ['cls', 'cls_avg', 'cls_max', 'last', 'avg', 'max', 'all', index]
:param pooling_strategy: str. Pooling strategy ['cls', 'cls_avg', 'cls_max', 'last', 'avg', 'mean', 'max', 'all', index]
:param padding_side: str. Padding strategy of tokenizers (`left` or `right`).
It can be obtained by `tokenizer.padding_side`.
"""
Expand All @@ -271,7 +271,7 @@ def get_pooling(outputs: torch.Tensor,
batch_size = inputs['input_ids'].shape[0]
sequence_lengths = -1 if padding_side == 'left' else inputs["attention_mask"].sum(dim=1) - 1
outputs = outputs[torch.arange(batch_size, device=outputs.device), sequence_lengths]
elif pooling_strategy == 'avg':
elif pooling_strategy in ['avg', 'mean']:
outputs = torch.sum(outputs * inputs["attention_mask"][:, :, None], dim=1) / inputs["attention_mask"].sum(dim=1).unsqueeze(1)
elif pooling_strategy == 'max':
outputs, _ = torch.max(outputs * inputs["attention_mask"][:, :, None], dim=1)
Expand All @@ -283,7 +283,7 @@ def get_pooling(outputs: torch.Tensor,
outputs = outputs[:, int(pooling_strategy)]
else:
raise NotImplementedError(
'please specify pooling_strategy from [`cls`, `last`, `avg`, `max`, `last_avg`, `all`, int]')
'please specify pooling_strategy from [`cls`, `cls_avg`, `cls_max`, `last`, `avg`, `mean`, `max`, `all`, int]')
return outputs


Expand Down Expand Up @@ -689,7 +689,7 @@ class Pooler:
Using Pooler to obtain sentence embeddings.
:param model: PreTrainedModel
:param pooling_strategy: Optional[str]. Currently support [`cls`, `last`, `avg`, `cls_avg`, `max`]. Default None.
:param pooling_strategy: Optional[str]. Currently support [`cls`, `cls_avg`, `cls_max`, `last`, `avg`, `mean`, `max`, `all`, int]. Default None.
:param padding_side: Optional[str]. `left` or `right`. Default None.
:param is_llm: bool. Default False
"""
Expand Down Expand Up @@ -717,7 +717,7 @@ def __call__(self,
:param embedding_size: int. Set embedding size for sentence embeddings.
:param return_all_layer_outputs: bool. Return all layer outputs or not. Default False.
:param pooling_strategy: Optional[str].
Currently support [`cls`, `last`, `avg`, `cls_avg`, `max`]. Default None.
Currently support [`cls`, `cls_avg`, `cls_max`, `last`, `avg`, `mean`, `max`, `all`, int]. Default None.
:param return_mlm_logits: bool. Return logits or not. Default False.
"""
ret = self.model(output_hidden_states=True, return_dict=True, **inputs)
Expand Down Expand Up @@ -1093,7 +1093,7 @@ class AnglE(AngleBase):
:param lora_config_kwargs: Optional[Dict]. kwargs for peft lora_config.
details refer to: https://huggingface.co/docs/peft/tutorial/peft_model_config
:param pooling_strategy: Optional[str]. Pooling strategy.
Currently support [`cls`, `last`, `avg`, `cls_avg`, `max`]
Currently support [`cls`, `cls_avg`, `cls_max`, `last`, `avg`, `mean`, `max`, `all`, int]
:param apply_lora: Optional[bool]. Whether apply lora. Default None.
:param train_mode: bool. Whether load for training. Default True.
:param load_kbit: Optional[int]. Specify kbit training from [4, 8, 16]. Default None.
Expand Down

0 comments on commit c371424

Please sign in to comment.