From 6c64da708d632894a25a8d613665f9294b77e638 Mon Sep 17 00:00:00 2001 From: Ruan Chaves Rodrigues Date: Wed, 24 May 2023 22:38:07 +0200 Subject: [PATCH 01/11] feat: add minicons backend --- src/hashformers/beamsearch/bert_lm.py | 45 ++------ src/hashformers/beamsearch/gpt2_lm.py | 120 ++-------------------- src/hashformers/beamsearch/minicons_lm.py | 28 +++++ src/hashformers/beamsearch/model_lm.py | 20 +++- src/hashformers/segmenter/segmenter.py | 30 ++++++ 5 files changed, 90 insertions(+), 153 deletions(-) create mode 100644 src/hashformers/beamsearch/minicons_lm.py diff --git a/src/hashformers/beamsearch/bert_lm.py b/src/hashformers/beamsearch/bert_lm.py index 63e0328..99281a3 100644 --- a/src/hashformers/beamsearch/bert_lm.py +++ b/src/hashformers/beamsearch/bert_lm.py @@ -1,10 +1,6 @@ -import mxnet as mx -import numpy as np -import pandas as pd -from mlm.models import get_pretrained -from mlm.scorers import MLMScorerPT +from hashformers.beamsearch.minicons_lm import MiniconsLM -class BertLM(object): +class BertLM(MiniconsLM): """ Implements a BERT-based language model scorer, to compute sentence probabilities. This class uses a transformer-based Masked Language Model (MLM) for scoring. @@ -18,34 +14,9 @@ class BertLM(object): """ def __init__(self, model_name_or_path, gpu_batch_size=1, gpu_id=0): - mx_device = [mx.gpu(gpu_id)] - self.scorer = MLMScorerPT(*get_pretrained(mx_device, model_name_or_path), mx_device) - self.gpu_batch_size = gpu_batch_size - - def get_probs(self, list_of_candidates): - """ - Returns probabilities for a list of candidate sentences. - - Args: - list_of_candidates (list): A list of sentences for which the probability is to be - calculated. Each sentence should be a string. - - Returns: - list: A list of probabilities corresponding to the input sentences. If an exception is encountered - while computing the probability for a sentence (e.g., if the sentence is not a string or - is NaN), the corresponding score in the output list is NaN. - """ - scores = [] - try: - scores = self.scorer.score_sentences(list_of_candidates, split_size=self.gpu_batch_size) - scores = [ x * -1 for x in scores ] - return scores - except: - for candidate in list_of_candidates: - try: - score = self.scorer.score_sentences([candidate])[0] if not pd.isna(candidate) else np.nan - score = score * -1 - except IndexError: - score = np.nan - scores.append(score) - return scores \ No newline at end of file + super().__init__( + model_name_or_path=model_name_or_path, + device='cuda', + gpu_batch_size=gpu_batch_size, + model_type='MaskedLMScorer' + ) \ No newline at end of file diff --git a/src/hashformers/beamsearch/gpt2_lm.py b/src/hashformers/beamsearch/gpt2_lm.py index 0b2fdd0..b234ab7 100644 --- a/src/hashformers/beamsearch/gpt2_lm.py +++ b/src/hashformers/beamsearch/gpt2_lm.py @@ -1,102 +1,6 @@ -from lm_scorer.models.auto import GPT2LMScorer -from typing import * # pylint: disable=wildcard-import,unused-wildcard-import -import torch -from transformers import AutoTokenizer, GPT2LMHeadModel -from transformers.tokenization_utils import BatchEncoding +from hashformers.beamsearch.minicons_lm import MiniconsLM -class PaddedGPT2LMScorer(GPT2LMScorer): - """A Language Model (LM) scorer using GPT2 and supporting token padding. - - Inherits from the GPT2LMScorer class to score sentences. Additionally, supports padding of tokens for - consistent length sequences. - - Args: - *args: Variable length argument list. - **kwargs: Arbitrary keyword arguments. - """ - def __init__(self, *args, **kwargs): - super().__init__(*args, **kwargs) - - def _build(self, model_name: str, options: Dict[str, Any]) -> None: - """Initializes the tokenizer and model with the given model_name and options. - - Args: - model_name (str): Name of the model to be used for tokenization and generation. - options (Dict[str, Any]): Additional options for the model, such as 'device'. - """ - super()._build(model_name, options) - - # pylint: disable=attribute-defined-outside-init - self.tokenizer = AutoTokenizer.from_pretrained( - model_name, use_fast=True, add_special_tokens=False - ) - # Add the pad token to GPT2 dictionary. - # len(tokenizer) = vocab_size + 1 - self.tokenizer.add_special_tokens({"additional_special_tokens": ["<|pad|>"]}) - self.tokenizer.pad_token = "<|pad|>" - - self.model = GPT2LMHeadModel.from_pretrained(model_name) - # We need to resize the embedding layer because we added the pad token. - self.model.resize_token_embeddings(len(self.tokenizer)) - self.model.eval() - if "device" in options: - self.model.to(options["device"]) - - def _tokens_log_prob_for_batch( - self, text: List[str] - ) -> List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]: - """Calculates the log probability of tokens for a batch of sentences. - - Args: - text (List[str]): List of sentences to calculate log probabilities for. - - Returns: - List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]]: List of tuples, each containing the - log probabilities, ids, and tokens for a sentence. - """ - outputs: List[Tuple[torch.DoubleTensor, torch.LongTensor, List[str]]] = [] - if len(text) == 0: - return outputs - - # TODO: Handle overflowing elements for long sentences - text = list(map(self._add_special_tokens, text)) - encoding: BatchEncoding = self.tokenizer.batch_encode_plus( - text, return_tensors="pt", padding=True, truncation=True - ) - with torch.no_grad(): - ids = encoding["input_ids"].to(self.model.device) - attention_mask = encoding["attention_mask"].to(self.model.device) - nopad_mask = ids != self.tokenizer.pad_token_id - logits: torch.Tensor = self.model(ids, attention_mask=attention_mask)[0] - - for sent_index in range(len(text)): - sent_nopad_mask = nopad_mask[sent_index] - # len(tokens) = len(text[sent_index]) + 1 - sent_tokens = [ - tok - for i, tok in enumerate(encoding.tokens(sent_index)) - if sent_nopad_mask[i] and i != 0 - ] - - # sent_ids.shape = [len(text[sent_index]) + 1] - sent_ids = ids[sent_index, sent_nopad_mask][1:] - # logits.shape = [len(text[sent_index]) + 1, vocab_size] - sent_logits = logits[sent_index, sent_nopad_mask][:-1, :] - sent_logits[:, self.tokenizer.pad_token_id] = float("-inf") - # ids_scores.shape = [seq_len + 1] - sent_ids_scores = sent_logits.gather(1, sent_ids.unsqueeze(1)).squeeze(1) - # log_prob.shape = [seq_len + 1] - sent_log_probs = sent_ids_scores - sent_logits.logsumexp(1) - - sent_log_probs = cast(torch.DoubleTensor, sent_log_probs) - sent_ids = cast(torch.LongTensor, sent_ids) - - output = (sent_log_probs, sent_ids, sent_tokens) - outputs.append(output) - - return outputs - -class GPT2LM(object): +class GPT2LM(MiniconsLM): """A Language Model (LM) scorer using GPT2. This class utilizes the PaddedGPT2LMScorer for scoring sentences. @@ -107,17 +11,9 @@ class GPT2LM(object): gpu_batch_size (int): The batch size for GPU processing. Default is 20. """ def __init__(self, model_name_or_path, device='cuda', gpu_batch_size=20): - self.scorer = PaddedGPT2LMScorer(model_name_or_path, device=device, batch_size=gpu_batch_size) - - def get_probs(self, list_of_candidates): - """Calculates the probabilities for a list of candidate sentences. - - Args: - list_of_candidates (List[str]): List of candidate sentences to calculate probabilities for. - - Returns: - List[float]: List of probabilities corresponding to the input sentences. - """ - scores = self.scorer.sentence_score(list_of_candidates, log=True) - scores = [ 1-x for x in scores ] - return scores \ No newline at end of file + super().__init__( + model_name_or_path=model_name_or_path, + device=device, + gpu_batch_size=gpu_batch_size, + model_type='IncrementalLMScorer' + ) \ No newline at end of file diff --git a/src/hashformers/beamsearch/minicons_lm.py b/src/hashformers/beamsearch/minicons_lm.py new file mode 100644 index 0000000..a24668a --- /dev/null +++ b/src/hashformers/beamsearch/minicons_lm.py @@ -0,0 +1,28 @@ +from minicons import scorer +from torch.utils.data import DataLoader +import warnings + +class MiniconsLM(object): + + def __init__(self, model_name_or_path, device='cuda', gpu_batch_size=20, model_type='IncrementalLMScorer'): + self.scorer = getattr(scorer, model_type)(model_name_or_path, device) + self.gpu_batch_size = gpu_batch_size + self.model_type = model_type + + def get_probs(self, list_of_candidates): + probs = [] + dl = DataLoader(list_of_candidates, batch_size=self.gpu_batch_size) + for batch in dl: + probs.extend(self.get_batch_scores(batch)) + return probs + + def get_batch_scores(self, batch): + if self.model_type == 'IncrementalLMScorer': + return self.scorer.sentence_score(batch, reduction = lambda x: -x.sum(0).item()) + elif self.model_type == 'MaskedLMScorer': + return self.scorer.sentence_score(batch, reduction = lambda x: -x.sum(0).item()) + elif self.model_type == 'Seq2SeqScorer': + return self.scorer.sentence_score(batch, source_format = 'blank') + else: + warnings.warn(f"Model type {self.model_type} not implemented. Assuming reduction = lambda x: -x.sum(0).item()") + return self.scorer.sentence_score(batch, reduction = lambda x: -x.sum(0).item()) \ No newline at end of file diff --git a/src/hashformers/beamsearch/model_lm.py b/src/hashformers/beamsearch/model_lm.py index 7cc847f..a444a47 100644 --- a/src/hashformers/beamsearch/model_lm.py +++ b/src/hashformers/beamsearch/model_lm.py @@ -1,3 +1,7 @@ +from hashformers.beamsearch.gpt2_lm import GPT2LM +from hashformers.beamsearch.minicons_lm import MiniconsLM +from hashformers.beamsearch.bert_lm import BertLM + class ModelLM(object): """ A Language Model (LM) class that supports both GPT2 and BERT models. @@ -19,9 +23,17 @@ class ModelLM(object): """ def __init__(self, model_name_or_path=None, model_type=None, device=None, gpu_batch_size=None, gpu_id=0): self.gpu_batch_size = gpu_batch_size - if model_type == 'gpt2': - from hashformers.beamsearch.gpt2_lm import GPT2LM + if model_type is None: + self.model = None + elif model_type == 'gpt2': self.model = GPT2LM(model_name_or_path, device=device, gpu_batch_size=gpu_batch_size) elif model_type == 'bert': - from hashformers.beamsearch.bert_lm import BertLM - self.model = BertLM(model_name_or_path, gpu_batch_size=gpu_batch_size, gpu_id=gpu_id) \ No newline at end of file + self.model = BertLM(model_name_or_path, gpu_batch_size=gpu_batch_size, gpu_id=gpu_id) + elif model_type == 'seq2seq': + self.model = MiniconsLM(model_name_or_path, device=device, gpu_batch_size=gpu_batch_size, model_type='Seq2SeqScorer') + elif model_type == 'masked': + self.model = MiniconsLM(model_name_or_path, device=device, gpu_batch_size=gpu_batch_size, model_type='MaskedLMScorer') + elif model_type == 'incremental': + self.model = MiniconsLM(model_name_or_path, device=device, gpu_batch_size=gpu_batch_size, model_type='IncrementalLMScorer') + else: + self.model = MiniconsLM(model_name_or_path, device=device, gpu_batch_size=gpu_batch_size, model_type=model_type) \ No newline at end of file diff --git a/src/hashformers/segmenter/segmenter.py b/src/hashformers/segmenter/segmenter.py index c92e4ed..3087d72 100644 --- a/src/hashformers/segmenter/segmenter.py +++ b/src/hashformers/segmenter/segmenter.py @@ -33,6 +33,36 @@ def __init__( self.reranker_model = reranker self.ensembler = ensembler + def get_segmenter(self): + """ + Returns the segmenter model. + """ + return self.segmenter_model.model + + def get_reranker(self): + """ + Returns the reranker model. + """ + return self.reranker_model.model + + def set_segmenter(self, segmenter): + """ + Sets the segmenter model. + + Args: + segmenter: The model used for initial hashtag segmentation. + """ + self.segmenter_model.model = segmenter + + def set_reranker(self, reranker): + """ + Sets the reranker model. + + Args: + reranker: The model used for reranking the segmented hashtags. + """ + self.reranker_model.model = reranker + def segment( self, word_list: List[str], From 7838acf6113d99c80263935df17bcfb3ade5877a Mon Sep 17 00:00:00 2001 From: Ruan Chaves Rodrigues Date: Thu, 25 May 2023 09:51:35 +0200 Subject: [PATCH 02/11] change versioning and requirements --- requirements.txt | 5 +---- setup.py | 5 ++--- 2 files changed, 3 insertions(+), 7 deletions(-) diff --git a/requirements.txt b/requirements.txt index 42130bc..a211729 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,7 +1,4 @@ -sphinx_markdown_tables -recommonmark -mlm-hashformers -lm-scorer-hashformers +minicons twitter-text-python ekphrasis pandas \ No newline at end of file diff --git a/setup.py b/setup.py index 55a7b64..7665ae1 100644 --- a/setup.py +++ b/setup.py @@ -2,15 +2,14 @@ setup( name='hashformers', - version='1.2.8', + version='2.0.0', author='Ruan Chaves Rodrigues', author_email='ruanchave93@gmail.com', description='Word segmentation with transformers', packages=find_packages('src'), package_dir={'': 'src'}, install_requires=[ - "mlm-hashformers", - "lm-scorer-hashformers", + "minicons", "twitter-text-python", "pandas" ] From b11020cea247655485b731544a36dd4ff8a7ae59 Mon Sep 17 00:00:00 2001 From: Ruan Chaves Rodrigues Date: Thu, 25 May 2023 21:50:20 +0200 Subject: [PATCH 03/11] feat: minor changes to improve usability --- src/hashformers/segmenter/auto.py | 6 ++++-- src/hashformers/segmenter/segmenter.py | 17 ++++++++++++++++- 2 files changed, 20 insertions(+), 3 deletions(-) diff --git a/src/hashformers/segmenter/auto.py b/src/hashformers/segmenter/auto.py index 580079a..3096780 100644 --- a/src/hashformers/segmenter/auto.py +++ b/src/hashformers/segmenter/auto.py @@ -16,7 +16,8 @@ def __init__( segmenter_gpu_batch_size = 1000, reranker_gpu_batch_size = 1000, reranker_model_name_or_path = None, - reranker_model_type = "bert" + reranker_model_type = "bert", + reranker_device = "cuda" ): """Word segmentation API initialization. A GPT-2 model must be passed to `segmenter_model_name_or_path`, and optionally a BERT model to `reranker_model_name_or_path`. @@ -43,7 +44,8 @@ def __init__( reranker_model = Reranker( model_name_or_path=reranker_model_name_or_path, model_type=reranker_model_type, - gpu_batch_size=reranker_gpu_batch_size + gpu_batch_size=reranker_gpu_batch_size, + device=reranker_device ) else: reranker_model = None diff --git a/src/hashformers/segmenter/segmenter.py b/src/hashformers/segmenter/segmenter.py index 3087d72..7c83b1f 100644 --- a/src/hashformers/segmenter/segmenter.py +++ b/src/hashformers/segmenter/segmenter.py @@ -45,6 +45,12 @@ def get_reranker(self): """ return self.reranker_model.model + def get_ensembler(self): + """ + Returns the ensembler model. + """ + return self.ensembler + def set_segmenter(self, segmenter): """ Sets the segmenter model. @@ -63,6 +69,15 @@ def set_reranker(self, reranker): """ self.reranker_model.model = reranker + def set_ensembler(self, ensembler): + """ + Sets the ensembler model. + + Args: + ensembler: The model used for ensemble operations over the segmenter and reranker models. + """ + self.ensembler = ensembler + def segment( self, word_list: List[str], @@ -163,7 +178,7 @@ def __call__(self, tweets): tweets: A list of strings, where each string is a tweet. Returns: - A list of tags for each tweet. + A list of hashtags for each tweet. """ return [ self.parser.parse(x).tags for x in tweets ] From 92b0d8485de73ee8ecf4f35c400ecfcee03c490b Mon Sep 17 00:00:00 2001 From: Ruan Chaves Rodrigues Date: Thu, 25 May 2023 22:26:04 +0200 Subject: [PATCH 04/11] fix: typo - replace sentence score by sequence score method call --- src/hashformers/beamsearch/minicons_lm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/hashformers/beamsearch/minicons_lm.py b/src/hashformers/beamsearch/minicons_lm.py index a24668a..5579e45 100644 --- a/src/hashformers/beamsearch/minicons_lm.py +++ b/src/hashformers/beamsearch/minicons_lm.py @@ -18,11 +18,11 @@ def get_probs(self, list_of_candidates): def get_batch_scores(self, batch): if self.model_type == 'IncrementalLMScorer': - return self.scorer.sentence_score(batch, reduction = lambda x: -x.sum(0).item()) + return self.scorer.sequence_score(batch, reduction = lambda x: -x.sum(0).item()) elif self.model_type == 'MaskedLMScorer': - return self.scorer.sentence_score(batch, reduction = lambda x: -x.sum(0).item()) + return self.scorer.sequence_score(batch, reduction = lambda x: -x.sum(0).item()) elif self.model_type == 'Seq2SeqScorer': - return self.scorer.sentence_score(batch, source_format = 'blank') + return self.scorer.sequence_score(batch, source_format = 'blank') else: warnings.warn(f"Model type {self.model_type} not implemented. Assuming reduction = lambda x: -x.sum(0).item()") - return self.scorer.sentence_score(batch, reduction = lambda x: -x.sum(0).item()) \ No newline at end of file + return self.scorer.sequence_score(batch, reduction = lambda x: -x.sum(0).item()) \ No newline at end of file From 1d5d178883cb0dae5c585acd15ca45dfc38d1f5c Mon Sep 17 00:00:00 2001 From: Ruan Chaves Rodrigues Date: Thu, 25 May 2023 23:47:17 +0200 Subject: [PATCH 05/11] fix: replace sum by prod in incremental --- src/hashformers/beamsearch/minicons_lm.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/hashformers/beamsearch/minicons_lm.py b/src/hashformers/beamsearch/minicons_lm.py index 5579e45..53ca83b 100644 --- a/src/hashformers/beamsearch/minicons_lm.py +++ b/src/hashformers/beamsearch/minicons_lm.py @@ -18,11 +18,11 @@ def get_probs(self, list_of_candidates): def get_batch_scores(self, batch): if self.model_type == 'IncrementalLMScorer': - return self.scorer.sequence_score(batch, reduction = lambda x: -x.sum(0).item()) + return self.scorer.sequence_score(batch, reduction = lambda x: x.prod(0).item()) elif self.model_type == 'MaskedLMScorer': - return self.scorer.sequence_score(batch, reduction = lambda x: -x.sum(0).item()) + return self.scorer.sequence_score(batch, reduction = lambda x: x.sum(0).item()) elif self.model_type == 'Seq2SeqScorer': return self.scorer.sequence_score(batch, source_format = 'blank') else: - warnings.warn(f"Model type {self.model_type} not implemented. Assuming reduction = lambda x: -x.sum(0).item()") - return self.scorer.sequence_score(batch, reduction = lambda x: -x.sum(0).item()) \ No newline at end of file + warnings.warn(f"Model type {self.model_type} not implemented. Assuming reduction = lambda x: x.sum(0).item()") + return self.scorer.sequence_score(batch, reduction = lambda x: x.sum(0).item()) \ No newline at end of file From fb28f61454b2552fac0c687c3fddf017b96e4eb2 Mon Sep 17 00:00:00 2001 From: Ruan Chaves Rodrigues Date: Thu, 25 May 2023 23:57:04 +0200 Subject: [PATCH 06/11] fix: correct sum function --- src/hashformers/beamsearch/minicons_lm.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/hashformers/beamsearch/minicons_lm.py b/src/hashformers/beamsearch/minicons_lm.py index 53ca83b..679b6ac 100644 --- a/src/hashformers/beamsearch/minicons_lm.py +++ b/src/hashformers/beamsearch/minicons_lm.py @@ -18,7 +18,7 @@ def get_probs(self, list_of_candidates): def get_batch_scores(self, batch): if self.model_type == 'IncrementalLMScorer': - return self.scorer.sequence_score(batch, reduction = lambda x: x.prod(0).item()) + return self.scorer.sequence_score(batch, reduction = lambda x: x.sum(0).item()) elif self.model_type == 'MaskedLMScorer': return self.scorer.sequence_score(batch, reduction = lambda x: x.sum(0).item()) elif self.model_type == 'Seq2SeqScorer': From 8e8b76ce5cd36ae78cf66eda474bd56453e14421 Mon Sep 17 00:00:00 2001 From: Ruan Chaves Rodrigues Date: Fri, 2 Jun 2023 15:18:21 +0200 Subject: [PATCH 07/11] fix: minicons incremental lm scorer --- src/hashformers/beamsearch/minicons_lm.py | 5 ++++- 1 file changed, 4 insertions(+), 1 deletion(-) diff --git a/src/hashformers/beamsearch/minicons_lm.py b/src/hashformers/beamsearch/minicons_lm.py index 679b6ac..75b904c 100644 --- a/src/hashformers/beamsearch/minicons_lm.py +++ b/src/hashformers/beamsearch/minicons_lm.py @@ -1,6 +1,7 @@ from minicons import scorer from torch.utils.data import DataLoader import warnings +import numpy as np class MiniconsLM(object): @@ -18,7 +19,9 @@ def get_probs(self, list_of_candidates): def get_batch_scores(self, batch): if self.model_type == 'IncrementalLMScorer': - return self.scorer.sequence_score(batch, reduction = lambda x: x.sum(0).item()) + tokens = self.scorer.prepare_text(batch, bos_token=True, eos_token=True) + stats = self.scorer.compute_stats(tokens, prob=True) + return np.array(stats).sum(axis=1) elif self.model_type == 'MaskedLMScorer': return self.scorer.sequence_score(batch, reduction = lambda x: x.sum(0).item()) elif self.model_type == 'Seq2SeqScorer': From da9f1a139280d5cc339847f0f5b7d761754870d0 Mon Sep 17 00:00:00 2001 From: Ruan Chaves Rodrigues Date: Sat, 3 Jun 2023 13:00:15 +0200 Subject: [PATCH 08/11] feat: add incremental sequence scores function --- src/hashformers/beamsearch/minicons_lm.py | 14 ++++++++++---- 1 file changed, 10 insertions(+), 4 deletions(-) diff --git a/src/hashformers/beamsearch/minicons_lm.py b/src/hashformers/beamsearch/minicons_lm.py index 75b904c..3e0a72c 100644 --- a/src/hashformers/beamsearch/minicons_lm.py +++ b/src/hashformers/beamsearch/minicons_lm.py @@ -1,7 +1,7 @@ from minicons import scorer from torch.utils.data import DataLoader import warnings -import numpy as np +import math class MiniconsLM(object): @@ -17,11 +17,17 @@ def get_probs(self, list_of_candidates): probs.extend(self.get_batch_scores(batch)) return probs + def incremental_sequence_score(self, batch): + tokens = self.scorer.prepare_text(batch, bos_token=True, eos_token=True) + stats = self.scorer.compute_stats(tokens, prob=True) + log_stats = [ [ math.log(x) for x in sequence ] for sequence in stats ] + sum_log_stats = [ sum(x) for x in log_stats ] + pos_sum_log_stats = [ 1 - x for x in sum_log_stats ] + return pos_sum_log_stats + def get_batch_scores(self, batch): if self.model_type == 'IncrementalLMScorer': - tokens = self.scorer.prepare_text(batch, bos_token=True, eos_token=True) - stats = self.scorer.compute_stats(tokens, prob=True) - return np.array(stats).sum(axis=1) + return self.incremental_sequence_score(batch) elif self.model_type == 'MaskedLMScorer': return self.scorer.sequence_score(batch, reduction = lambda x: x.sum(0).item()) elif self.model_type == 'Seq2SeqScorer': From 8c1a8007eaf7c0a73a6b6478292bbcf88753e1ee Mon Sep 17 00:00:00 2001 From: Ruan Chaves Rodrigues Date: Sat, 3 Jun 2023 16:05:06 +0200 Subject: [PATCH 09/11] feat: include new models on tutorial --- README.md | 35 ++++++++++++----------------------- hashformers.ipynb | 34 +++++++++++++++++++++++++--------- 2 files changed, 37 insertions(+), 32 deletions(-) diff --git a/README.md b/README.md index 844ea39..7eb1856 100644 --- a/README.md +++ b/README.md @@ -5,9 +5,9 @@ Hashtag segmentation is the task of automatically adding spaces between the words on a hashtag. -[Hashformers](https://github.com/ruanchaves/hashformers) is the current **state-of-the-art** for hashtag segmentation. On average, hashformers is **10% more accurate** than the second best hashtag segmentation library ( [Learn More](https://github.com/ruanchaves/hashformers/blob/master/tutorials/EVALUATION.md) ). +[Hashformers](https://github.com/ruanchaves/hashformers) is the current **state-of-the-art** for hashtag segmentation, as demonstrated on [this paper accepted at LREC 2022](https://aclanthology.org/2022.lrec-1.782.pdf). -Hashformers is also **language-agnostic**: you can use it to segment hashtags not just in English, but also in any language with a GPT-2 model on the [Hugging Face Model Hub](https://huggingface.co/models). +Hashformers is also **language-agnostic**: you can use it to segment hashtags not just with English models, but also using any language model available on the [Hugging Face Model Hub](https://huggingface.co/models).

@@ -26,7 +26,9 @@ from hashformers import TransformerWordSegmenter as WordSegmenter ws = WordSegmenter( segmenter_model_name_or_path="gpt2", - reranker_model_name_or_path="bert-base-uncased" + segmenter_model_type="incremental", + reranker_model_name_or_path="google/flan-t5-base", + reranker_model_type="seq2seq" ) segmentations = ws.segment([ @@ -40,28 +42,11 @@ print(segmentations) # 'ice cold' ] ``` -## Installation - -Hashformers is compatible with Python 3.7. - -``` -pip install hashformers -``` - -It is possible to use **hashformers** without a reranker: - -```python -from hashformers import TransformerWordSegmenter as WordSegmenter -ws = WordSegmenter( - segmenter_model_name_or_path="gpt2", - reranker_model_name_or_path=None -) -``` +It is also possible to use hashformers without a reranker by setting the `reranker_model_name_or_path` and the `reranker_model_type` to `None`. -If you want to use a BERT model as a reranker, you must install [mxnet](https://pypi.org/project/mxnet/). Here we install **hashformers** with `mxnet-cu110`, which is compatible with Hugging Face Spaces. If installing in another environment, replace it by the [mxnet package](https://pypi.org/project/mxnet/) compatible with your CUDA version. +## Installation ``` -pip install mxnet-cu110 pip install hashformers ``` @@ -81,6 +66,10 @@ pip install -e . This is a collection of papers that have utilized the *hashformers* library as a tool in their research. +### hashformers v1.3 + +These papers have utilized `hashformers` version 1.3 or below. + * [Zero-shot hashtag segmentation for multilingual sentiment analysis](https://arxiv.org/abs/2112.03213) * [HashSet -- A Dataset For Hashtag Segmentation (LREC 2022)](https://aclanthology.org/2022.lrec-1.782/) @@ -104,4 +93,4 @@ This is a collection of papers that have utilized the *hashformers* library as a archivePrefix={arXiv}, primaryClass={cs.CL} } -``` +``` \ No newline at end of file diff --git a/hashformers.ipynb b/hashformers.ipynb index 4ec93a8..ffd755d 100644 --- a/hashformers.ipynb +++ b/hashformers.ipynb @@ -29,10 +29,7 @@ "id": "geWaMgWXu1f5" }, "source": [ - "Here we install `mxnet-cu110` and `hashformers`. \n", - "\n", - "\n", - "**Deprecation Notice**: Support for `mxnet-cu110` has been discontinued on Google Colab. If you intend to execute cells involving the reranker, please consider using an alternative environment." + "Here we install `hashformers`. " ] }, { @@ -45,7 +42,6 @@ "source": [ "%%capture\n", "\n", - "!pip install mxnet-cu110 \n", "!pip install hashformers" ] }, @@ -81,6 +77,7 @@ "\n", "ws = WordSegmenter(\n", " segmenter_model_name_or_path=\"distilgpt2\",\n", + " segmenter_model_type=\"incremental\",\n", " reranker_model_name_or_path=None\n", ")" ] @@ -137,9 +134,22 @@ "id": "1F0rTjzQWY6q" }, "source": [ - "You can use **hashformers** to segment hashtags in any language, not just English. Visit the [HuggingFace Model Hub](https://huggingface.co/models) and choose any GPT-2 and a BERT models for the WordSegmenter class.\n", + "## What models can I use?\n", + "\n", + "You can use **hashformers** to segment hashtags in any language, not just English. Visit the [HuggingFace Model Hub](https://huggingface.co/models) and choose your language model for the `WordSegmenter` class.\n", + "\n", + "You can use any language model supported by the [minicons](https://github.com/kanishkamisra/minicons) library. Currently it supports the following model types:\n", + "\n", + "* Auto-regressive models like GPT-2 and XLNet. You can load them using the model type `incremental`. \n", "\n", - "The GPT-2 model should be informed as `segmenter_model_name_or_path` and the BERT model as `reranker_model_name_or_path`. A segmenter is required, however a reranker is optional. " + "* Masked language models like BERT. Their model type is `masked`.\n", + "\n", + "* Seq2seq models like T5. Their model type is `seq2seq`.\n", + "\n", + "\n", + "Best results are usually achieved by using an auto-regressive model like GPT-2 as the `segmenter_model_name_or_path` and a BERT-like or seq2seq model as the `reranker_model_name_or_path`. \n", + "\n", + "A segmenter is required, however a reranker is optional. " ] }, { @@ -156,7 +166,9 @@ "\n", "portuguese_ws = WordSegmenter(\n", " segmenter_model_name_or_path=\"pierreguillou/gpt2-small-portuguese\",\n", - " reranker_model_name_or_path=\"neuralmind/bert-base-portuguese-cased\"\n", + " segmenter_model_type=\"incremental\",\n", + " reranker_model_name_or_path=\"neuralmind/bert-base-portuguese-cased\",\n", + " segmenter_model_type=\"masked\"\n", ")" ] }, @@ -235,7 +247,9 @@ "\n", "ws = WordSegmenter(\n", " segmenter_model_name_or_path=\"distilgpt2\",\n", + " segmenter_model_type=\"incremental\",\n", " reranker_model_name_or_path=\"distilbert-base-uncased\",\n", + " reranker_model_type=\"masked\",\n", " segmenter_gpu_batch_size=1,\n", " reranker_gpu_batch_size=2000\n", ")" @@ -1121,7 +1135,9 @@ "\n", "ws = TransformerWordSegmenter(\n", " segmenter_model_name_or_path=\"distilgpt2\",\n", - " reranker_model_name_or_path=None\n", + " segmenter_model_type=\"incremental\",\n", + " reranker_model_name_or_path=None,\n", + " reranker_model_type=None,\n", ")\n", "\n", "def generate_experiments(datasets, splits, samples=100):\n", From b3d16dcec8f0239c908d069bc82fc7e21e79a6bb Mon Sep 17 00:00:00 2001 From: Ruan Chaves Rodrigues Date: Sat, 3 Jun 2023 16:12:33 +0200 Subject: [PATCH 10/11] feat: add information to readme --- README.md | 32 +++++++++++++++++++++++++++++++- hashformers.ipynb | 26 ++++++++++++++++++++------ 2 files changed, 51 insertions(+), 7 deletions(-) diff --git a/README.md b/README.md index 7eb1856..95ee625 100644 --- a/README.md +++ b/README.md @@ -18,7 +18,6 @@ Hashformers is also **language-agnostic**: you can use it to segment hashtags no

- ## Basic usage ```python @@ -50,6 +49,37 @@ It is also possible to use hashformers without a reranker by setting the `rerank pip install hashformers ``` +## What models can I use? + +Visit the [HuggingFace Model Hub](https://huggingface.co/models) and choose your models for the `WordSegmenter` class. + +You can use any model supported by the [minicons](https://github.com/kanishkamisra/minicons) library. Currently `hashformers` supports the following model types as the `segmenter_model_type` or `reranker_model_type`: + +### `incremental` + +Auto-regressive models like GPT-2 and XLNet, or any model that can be loaded with `AutoModelForCausalLM`. This includes large language models (LLMs) such as Alpaca-LoRA ( `chainyo/alpaca-lora-7b` ) and GPT-J ( `EleutherAI/gpt-j-6b` ). + +```python +ws = WordSegmenter( + segmenter_model_name_or_path="EleutherAI/gpt-j-6b", + segmenter_model_type="incremental", + reranker_model_name_or_path=None, + reranker_model_type=None +) +``` + +### `masked` + +Masked language models like BERT, or any model that can be loaded with `AutoModelForMaskedLM`. + +### `seq2seq` + +Seq2Seq models like FLAN-T5 ( `google/flan-t5-base` ), or any model that can be loaded with `AutoModelForSeq2SeqLM`. + +Best results are usually achieved by using an `incremental` model as the `segmenter_model_name_or_path` and a `masked` or `seq2seq` model as the `reranker_model_name_or_path`. + +A segmenter is always required, however a reranker is optional. + ## Contributing Pull requests are welcome! [Read our paper](https://arxiv.org/abs/2112.03213) for more details on the inner workings of our framework. diff --git a/hashformers.ipynb b/hashformers.ipynb index ffd755d..4da28a0 100644 --- a/hashformers.ipynb +++ b/hashformers.ipynb @@ -136,22 +136,36 @@ "source": [ "## What models can I use?\n", "\n", - "You can use **hashformers** to segment hashtags in any language, not just English. Visit the [HuggingFace Model Hub](https://huggingface.co/models) and choose your language model for the `WordSegmenter` class.\n", + "You can use **hashformers** to segment hashtags in any language, not just English. Visit the [HuggingFace Model Hub](https://huggingface.co/models) and choose your models for the `WordSegmenter` class.\n", "\n", - "You can use any language model supported by the [minicons](https://github.com/kanishkamisra/minicons) library. Currently it supports the following model types:\n", + "You can use any model supported by the [minicons](https://github.com/kanishkamisra/minicons) library. Currently `hashformers` supports the following model types as the `segmenter_model_type` or `reranker_model_type`:\n", "\n", - "* Auto-regressive models like GPT-2 and XLNet. You can load them using the model type `incremental`. \n", + "### `incremental`\n", "\n", - "* Masked language models like BERT. Their model type is `masked`.\n", + "Auto-regressive models like GPT-2 and XLNet, or any model that can be loaded with `AutoModelForCausalLM`. This includes recent large language models (LLMs) such as Alpaca-LoRA ( `chainyo/alpaca-lora-7b` ) and GPT-J ( `EleutherAI/gpt-j-6b` ).\n", "\n", - "* Seq2seq models like T5. Their model type is `seq2seq`.\n", + "### `masked`\n", "\n", + "Masked language models like BERT, or any model that can be loaded with `AutoModelForMaskedLM`.\n", "\n", - "Best results are usually achieved by using an auto-regressive model like GPT-2 as the `segmenter_model_name_or_path` and a BERT-like or seq2seq model as the `reranker_model_name_or_path`. \n", + "### `seq2seq`\n", + "\n", + "Seq2Seq models like FLAN-T5 ( `google/flan-t5-base` ), or any model that can be loaded with `AutoModelForSeq2SeqLM`.\n", + "\n", + "\n", + "Best results are usually achieved by using an `incremental` model as the `segmenter_model_name_or_path` and a `masked` or `seq2seq` model as the `reranker_model_name_or_path`. \n", "\n", "A segmenter is required, however a reranker is optional. " ] }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Here we segment hashtags in Portuguese using a GPT-2 model and a BERT model pretrained on Portuguese data." + ] + }, { "cell_type": "code", "execution_count": 6, From c63ccc4ef04754e199919cd9979e826e41e640da Mon Sep 17 00:00:00 2001 From: Ruan Chaves Rodrigues Date: Sat, 3 Jun 2023 16:15:14 +0200 Subject: [PATCH 11/11] feat: read the docs --- README.md | 2 ++ 1 file changed, 2 insertions(+) diff --git a/README.md b/README.md index 95ee625..51f4d8d 100644 --- a/README.md +++ b/README.md @@ -15,6 +15,8 @@ Hashformers is also **language-agnostic**: you can use it to segment hashtags no

✂️ Get started - Google Colab tutorial

+

✂️ Read the Docs

+