Skip to content

Commit

Permalink
Implement develop version of create_lmdata (#65)
Browse files Browse the repository at this point in the history
  • Loading branch information
lovit committed Oct 12, 2020
1 parent f4e9e30 commit cb97fe9
Show file tree
Hide file tree
Showing 2 changed files with 208 additions and 6 deletions.
5 changes: 4 additions & 1 deletion Korpora/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,8 +44,11 @@ def main():
# create language model train data
parser_lmdata = subparsers.add_parser('lmdata', help='Create language model train data')
parser_lmdata.add_argument('--corpus', type=str, required=True, nargs='+', help='corpus names')
parser_lmdata.add_argument('--root_dir', type=str, default=None, help='path/to/Korpora')
parser_lmdata.add_argument('--output_dir', type=str, required=True, help='output file path')
parser_lambda.add_argument('--save_each', dest='save_each', action='store_true', help='store each corpus as a file')
parser_lmdata.add_argument('--force_download', dest='force_download', action='store_true')
parser_lmdata.add_argument('--multilingual', dest='multilingual', action='store_true', help='If True, make include train data foreign language text')
parser_lmdata.add_argument('--save_each', dest='save_each', action='store_true', help='store each corpus as a file')
parser_lmdata.set_defaults(func=create_lmdata)

# Do task
Expand Down
209 changes: 204 additions & 5 deletions Korpora/task_lmdata.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,205 @@
import os
from tqdm import tqdm

from .loader import Korpora
from .utils import default_korpora_path


def create_lmdata(args):
args.corpus
args.output_dir
args.save_each
# TODO
raise NotImplementedError()
corpus_names = check_corpus(args.corpus)
os.makedirs(os.path.abspath(args.output_dir), exist_ok=True)

root_dir = args.root_dir
if root_dir is None:
root_dir = default_korpora_path
force_download = args.force_download
multilingual = args.multilingual

status = [['', name, ' - ', ''] for name in corpus_names]

for i_corpus, name in enumerate(corpus_names):
if not args.save_each and i_corpus > 0:
mode = 'a'
else:
mode = 'w'

filename = f'{name}.train' if args.save_each else 'all.train'
lmdata_path = f'{args.output_dir}/{filename}'

sent_iterator = tqdm(
ITERATE_TEXTS[name](root_dir, force_download, multilingual),
desc=f'Create train data from {name}'
)
print_status(status)

with open(lmdata_path, mode, encoding='utf-8') as f:
for i_sent, sent in enumerate(sent_iterator):
f.write(f'{sent}\n')
status[i_corpus][0] = ' x '
status[i_corpus][2] = (i_sent + 1)
status[i_corpus][3] = filename
print_status(status)


def check_corpus(corpus_names):
if (corpus_names == 'all') or (corpus_names[0] == 'all'):
corpus_names = list(ITERATE_TEXTS)
if isinstance(corpus_names, str):
corpus_names = [corpus_names]
available = []
for name in corpus_names:
if name not in ITERATE_TEXTS:
print(f'Not provide {name} corpus. Check the `corpus` argument')
continue
available.append(name)
if not available:
raise ValueError('Not found any proper corpus name. Check the `corpus` argument')
return available


def print_status(status):
max_len = max(max(len(row[3]) for row in status), 9)
form = '| {:4} | {:25} | {:10} | {} |'
print('\n\n' + form.format('Done', 'Corpus name', 'Num sents', 'File name' + ' ' * (max_len - 9)))
print(form.format('-' * 4, '-' * 25, '-' * 10, '-' * max_len))
for finish, name, num_sent, filename in status:
if not filename:
filename = ' ' * max_len
else:
filename += ' ' * (max_len -len(filename))
print(form.format(finish, name, num_sent, filename))


def iterate_kcbert(root_dir, force_download, multilingual=False):
Korpora.fetch('kcbert', root_dir, force_download)
with open(f'{root_dir}/kcbert/20190101_20200611_v2.txt', encoding='utf-8') as f:
# for line in f:
for i, line in enumerate(f): # DEVELOP
if i >= 1000: break # DEVELOP
line = line.strip()
if not line:
continue
yield line


def iterate_korean_chatbot_data(root_dir, force_download, multilingual=False):
corpus = Korpora.load('korean_chatbot_data', root_dir, force_download)
for sents in [corpus.train.texts, corpus.train.pairs]:
for sent in sents:
if not sent:
continue
yield sent


def iterate_korean_hate_speech(root_dir, force_download, multilingual=False):
corpus = Korpora.load('korean_hate_speech', root_dir, force_download)
for sents in [corpus.train.texts, corpus.dev.texts, corpus.unlabeled.texts]:
for sent in sents:
yield sent


def iterate_korean_parallel_koen_news(root_dir, force_download, multilingual):
corpus = Korpora.load('korean_parallel_koen_news', root_dir, force_download)
data = [corpus.train.texts, corpus.dev.texts, corpus.test.texts]
if multilingual:
data += [corpus.train.pairs, corpus.dev.pairs, corpus.test.pairs]
for sents in data:
for sent in sents:
yield sent


def iterate_korean_petitions(root_dir, force_download, multilingual=False):
corpus = Korpora.load('korean_petitions', root_dir, force_download)
for example in corpus.train:
yield example.title
yield example.text


def iterate_kornli(root_dir, force_download, multilingual=False):
corpus = Korpora.load('kornli', root_dir, force_download)
for data in [corpus.multinli_train, corpus.snli_train, corpus.xnli_dev, corpus.xnli_test]:
for sent in data.texts:
yield sent
for sent in data.pairs:
yield sent


def iterate_korsts(root_dir, force_download, multilingual=False):
corpus = Korpora.load('korsts', root_dir, force_download)
for data in [corpus.train, corpus.dev, corpus.test]:
for sent in data.texts:
yield sent
for sent in data.pairs:
yield sent


def iterate_kowikitext(root_dir, force_download, multilingual=False):
Korpora.fetch('kowikitext', root_dir, force_download)
paths = [
f'{root_dir}/kowiki/kowikitext_20200920.train',
f'{root_dir}/kowiki/kowikitext_20200920.dev',
f'{root_dir}/kowiki/kowikitext_20200920.test'
]
for path in paths:
with open(path, encoding='utf-8') as f:
# for line in f:
for i, line in enumerate(f): # DEVELOP
if i >= 1000: break # DEVELOP
line = line.strip()
if not line or (line[0] == '=' and line[-1] == '='):
continue
yield line


def iterate_namuwikitext(root_dir, force_download, multilingual=False):
Korpora.fetch('namuwikitext', root_dir, force_download)
paths = [
f'{root_dir}/namiwiki/namuwikitext_20200302.train',
f'{root_dir}/namiwiki/namuwikitext_20200302.dev',
f'{root_dir}/namiwiki/namuwikitext_20200302.test'
]
for path in paths:
with open(path, encoding='utf-8') as f:
# for line in f:
for i, line in enumerate(f): # DEVELOP
if i >= 1000: break # DEVELOP
line = line.strip()
if not line or (line[0] == '=' and line[-1] == '='):
continue
yield line


def iterate_naver_changwon_ner(root_dir, force_download, multilingual=False):
corpus = Korpora.load('naver_changwon_ner', root_dir, force_download)
for sent in corpus.train.texts:
yield sent


def iterate_nsmc(root_dir, force_download, multilingual=False):
corpus = Korpora.load('nsmc', root_dir, force_download)
for sents in [corpus.train.texts, corpus.test.texts]:
for sent in sents:
yield sent


def iterate_question_pair(root_dir, force_download, multilingual=False):
corpus = Korpora.load('question_pair', root_dir, force_download)
for sents in [corpus.train.texts, corpus.train.pairs]:
for sent in sents:
yield sent


ITERATE_TEXTS = {
'kcbert': iterate_kcbert,
'korean_chatbot_data': iterate_korean_chatbot_data,
'korean_hate_speech': iterate_korean_hate_speech,
'korean_parallel_koen_news': iterate_korean_parallel_koen_news,
'korean_petitions': iterate_korean_petitions,
'kornli': iterate_kornli,
'korsts': iterate_korsts,
'kowikitext': iterate_kowikitext,
'namuwikitext': iterate_namuwikitext,
'naver_changwon_ner': iterate_naver_changwon_ner,
'nsmc': iterate_nsmc,
'question_pair': iterate_question_pair
}

0 comments on commit cb97fe9

Please sign in to comment.