-
Notifications
You must be signed in to change notification settings - Fork 24
/
word_freq.py
27 lines (18 loc) · 719 Bytes
/
word_freq.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
import os
from transformers import BertTokenizer, RobertaTokenizer
import torch
from dataloader import DiffusionLoader
import numpy as np
import diffusion_word_freq
import math
from tqdm import tqdm
# tokenizer = RobertaTokenizer.from_pretrained('roberta-base')
tokenizer = BertTokenizer.from_pretrained('bert-base-uncased')
train_data = DiffusionLoader(tokenizer=tokenizer).my_load(task_name='lm1b', splits=['train'])[0]
word_freq = torch.zeros((tokenizer.vocab_size,), dtype=torch.int64)
for data in tqdm(train_data):
for iid in data['input_ids']:
word_freq[iid] += 1
if not os.path.exists('./word_freq'):
os.mkdir('word_freq')
torch.save(word_freq, f'./word_freq/bert-base-uncased_lm1b.pt')