diff --git a/DOTS/feat.py b/DOTS/feat.py index 5ae8a3d..ba30eb0 100644 --- a/DOTS/feat.py +++ b/DOTS/feat.py @@ -35,18 +35,18 @@ # # Load models and tokenizers -# model_name = "distilroberta-base" -# model = AutoModel.from_pretrained(model_name) -# tokenizer = AutoTokenizer.from_pretrained(model_name) -# # !python -m spacy download en_core_web_sm -# nlp = spacy.load('en_core_web_sm') +model_name = "distilroberta-base" +model = AutoModel.from_pretrained(model_name) +tokenizer = AutoTokenizer.from_pretrained(model_name) +# !python -m spacy download en_core_web_sm +nlp = spacy.load('en_core_web_sm') # # Define constants -# n_gram_range = (1, 2) -# stop_words = "english" -# embeddings=[] -# device = torch.device("cuda" if torch.cuda.is_available() else "cpu") -# model.to(device) +n_gram_range = (1, 2) +stop_words = "english" +embeddings=[] +device = torch.device("cuda" if torch.cuda.is_available() else "cpu") +model.to(device) # Define functions def chunk_text(text, max_len):