Skip to content

Commit

Permalink
학습 코드를 기존 방식으로 회귀 #30
Browse files Browse the repository at this point in the history
  • Loading branch information
krikit committed Jan 29, 2019
1 parent 2744ba5 commit 586fc01
Show file tree
Hide file tree
Showing 2 changed files with 58 additions and 59 deletions.
111 changes: 54 additions & 57 deletions src/main/python/khaiii/train/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,15 @@
import os
import pathlib
import pprint
from typing import Iterator, List, Tuple
from typing import List, Tuple

from tensorboardX import SummaryWriter
import torch
import torch.nn as nn
import torch.nn.functional as F
from tqdm import tqdm

from khaiii.train.dataset import PosDataset, PosSentTensor
from khaiii.train.dataset import PosDataset
from khaiii.train.evaluator import Evaluator
from khaiii.train.models import CnnModel
from khaiii.resource.resource import Resource
Expand Down Expand Up @@ -56,6 +56,7 @@ def __init__(self, cfg: Namespace):
self.evaler = Evaluator()
self._load_dataset()
if 'step' not in cfg.__dict__:
setattr(cfg, 'epoch', 0)
setattr(cfg, 'step', 0)
setattr(cfg, 'best_step', 0)
self.log_file = None # tab separated log file
Expand All @@ -66,6 +67,7 @@ def __init__(self, cfg: Namespace):
self.acc_words = []
self.f_scores = []
self.learning_rates = []
self.batch_sizes = []

@classmethod
def model_id(cls, cfg: Namespace) -> str:
Expand All @@ -86,6 +88,7 @@ def model_id(cls, cfg: Namespace) -> str:
'lrd{}'.format(cfg.lr_decay),
'bs{}'.format(cfg.batch_size),
'cs{}'.format(cfg.check_step),
'bg{}'.format(cfg.batch_grow),
]
return '.'.join(model_cfgs)

Expand Down Expand Up @@ -149,35 +152,25 @@ def _restore_prev_train(self):
line = line.rstrip('\r\n')
if not line:
continue
(step, loss_train, loss_dev, acc_char, acc_word, f_score, learning_rate) \
= line.split('\t')
self.cfg.step = self.cfg.best_step = int(step) * self.cfg.check_step
(epoch, step, loss_train, loss_dev, acc_char, acc_word, f_score, learning_rate,
batch_size) = line.split('\t')
self.cfg.epoch = int(epoch) + 1
self.cfg.step = self.cfg.best_step = (int(step)+1) * self.cfg.check_step
self.loss_trains.append(float(loss_train))
self.loss_devs.append(float(loss_dev))
self.acc_chars.append(float(acc_char))
self.acc_words.append(float(acc_word))
self.f_scores.append(float(f_score))
self.learning_rates.append(float(learning_rate))
self.batch_sizes.append(int(batch_size))
if float(f_score) > f_score_best:
f_score_best = float(f_score)
best_idx = idx
logging.info('---- [%d] los(trn/dev): %.4f / %.4f, acc(chr/wrd): %.4f / %.4f, ' \
'f-score: %.4f, lr: %.8f ----', self.cfg.step // self.cfg.check_step,
self.loss_trains[best_idx], self.loss_devs[best_idx], self.acc_chars[best_idx],
self.acc_words[best_idx], self.f_scores[best_idx], self.learning_rates[-1])

@classmethod
def _inf_data_iterator(cls, dataset: PosDataset) -> Iterator[PosSentTensor]:
"""
데이터셋을 무한히 반복하여 문장을 출력하는 제너레이터
Args:
dataset: 데이터셋
Yields:
PosSentTensor 객체
"""
for _ in range(1000000):
for sent in dataset:
yield sent
logging.info('---- [%d|%d] los(trn/dev): %.4f / %.4f, acc(chr/wrd): %.4f / %.4f, ' \
'f-score: %.4f, lr: %.8f, bs: %d ----', self.cfg.epoch,
self.cfg.step // self.cfg.check_step, self.loss_trains[best_idx],
self.loss_devs[best_idx], self.acc_chars[best_idx], self.acc_words[best_idx],
self.f_scores[best_idx], self.learning_rates[-1], self.batch_sizes[-1])

def train(self):
"""
Expand All @@ -193,12 +186,10 @@ def train(self):
pathlib.Path(self.cfg.out_dir).mkdir(parents=True, exist_ok=True)
self.log_file = open('{}/log.tsv'.format(self.cfg.out_dir), 'at')
self.sum_wrt = SummaryWriter(self.cfg.out_dir)
check_start = (1 if self.cfg.step == 0 else (self.cfg.step // self.cfg.check_step + 1))
patience = self.cfg.patience
train_iter = self._inf_data_iterator(self.dataset_train)
for check_id in range(check_start, 1000000):
is_best = self._train_and_check(check_id, train_iter)
if is_best:
for _ in range(1000000):
has_best = self._train_epoch()
if has_best:
patience = self.cfg.patience
continue
if patience <= 0:
Expand All @@ -210,8 +201,8 @@ def train(self):

train_end = datetime.now()
train_elapsed = self._elapsed(train_end - train_begin)
logging.info('}}}} training end: %s, elapsed: %s, step: %dk }}}}',
self._dt_str(train_end), train_elapsed, self.cfg.step // 1000)
logging.info('}}}} training end: %s, elapsed: %s, epoch: %d, step: %dk }}}}',
self._dt_str(train_end), train_elapsed, self.cfg.epoch, self.cfg.step // 1000)

def _revert_to_best(self, is_decay_lr: bool):
"""
Expand All @@ -224,21 +215,21 @@ def _revert_to_best(self, is_decay_lr: bool):
self.cfg.learning_rate *= self.cfg.lr_decay
self._load_optim('{}/optim.state'.format(self.cfg.out_dir), self.cfg.learning_rate)

def _train_and_check(self, check_id: int, train_iter: Iterator[PosSentTensor]) -> bool:
def _train_epoch(self) -> bool:
"""
cfg.check_step 만큼의 step을 수행하고 evaluation을 수행한다.
Args:
check_id: check ID
train_iter: 학습 데이터 iterator
한 epoch을 학습한다. 배치 단위는 글자 단위
Returns:
best f-score를 기록한 step 여부
"""
start_step = self.cfg.step
has_best = False
loss_batch = torch.tensor(0.0) # pylint: disable=not-callable
batch_size = 0
num_in_batch = 0
batch_size = self.cfg.batch_size
if self.cfg.batch_grow > 0:
batch_size = self.cfg.batch_size + self.cfg.step // self.cfg.batch_grow
loss_trains = []
train_sents = tqdm(train_iter, '[{}]'.format(check_id), mininterval=1, ncols=100)
for train_sent in train_sents:
for train_sent in tqdm(self.dataset_train, 'EPOCH[{}]'.format(self.cfg.epoch),
mininterval=1, ncols=100):
train_labels, train_contexts = train_sent.to_tensor(self.cfg, self.rsc, True)
if torch.cuda.is_available():
train_labels = train_labels.cuda()
Expand All @@ -252,39 +243,41 @@ def _train_and_check(self, check_id: int, train_iter: Iterator[PosSentTensor]) -
loss_train.backward()
loss_trains.append(loss_train.item())
loss_batch += loss_train
batch_size += len(train_labels)
if batch_size < self.cfg.batch_size:
num_in_batch += len(train_labels)
if num_in_batch < batch_size:
continue

self.optimizer.step()
self.optimizer.zero_grad()
self.sum_wrt.add_scalar('loss-batch', loss_batch.item(), self.cfg.step)
self.cfg.step += 1
loss_batch = torch.tensor(0.0) # pylint: disable=not-callable
batch_size = 0
num_in_batch = 0

if (self.cfg.step - start_step) >= self.cfg.check_step:
train_sents.close()
break
if self.cfg.step % self.cfg.check_step == 0:
avg_loss_dev, acc_char, acc_word, f_score = self.evaluate()
has_best |= self._check(loss_trains, avg_loss_dev, acc_char, acc_word, f_score,
batch_size)
if self.cfg.batch_grow > 0 and self.cfg.step % self.cfg.batch_grow == 0:
batch_size = self.cfg.batch_size + self.cfg.step // self.cfg.batch_grow

avg_loss_dev, acc_char, acc_word, f_score = self.evaluate()
return self._check(check_id, loss_trains, avg_loss_dev, acc_char, acc_word, f_score)
self.cfg.epoch += 1
return has_best

def _check(self, check_id: int, loss_trains: List[float], avg_loss_dev: float, acc_char: float,
acc_word: float, f_score: float) -> bool:
def _check(self, loss_trains: List[float], avg_loss_dev: float, acc_char: float,
acc_word: float, f_score: float, batch_size: int) -> bool:
"""
cfg.check_step번의 step마다 수행하는 체크
Args:
check_id: check ID
loss_trains: train 코퍼스에서 각 배치별 loss 리스트
avg_loss_dev: dev 코퍼스 문장 별 평균 loss
acc_char: 음절 정확도
acc_word: 어절 정확도
f_score: f-score
batch_size: 배치 크기
Returns:
현재 step이 best 성능을 나타냈는 지 여부
"""
assert check_id == self.cfg.step // self.cfg.check_step
avg_loss_train = sum(loss_trains) / len(loss_trains)
loss_trains.clear()
self.loss_trains.append(avg_loss_train)
Expand All @@ -293,22 +286,26 @@ def _check(self, check_id: int, loss_trains: List[float], avg_loss_dev: float, a
self.acc_words.append(acc_word)
self.f_scores.append(f_score)
self.learning_rates.append(self.cfg.learning_rate)
self.batch_sizes.append(batch_size)
check_id = self.cfg.step // self.cfg.check_step - 1
is_best = self._is_best()
is_best_str = 'BEST' if is_best else '< {:.4f}'.format(max(self.f_scores))
tqdm.write(' [Los trn] [Los dev] [Acc chr] [Acc wrd] [F-score] [LR]')
tqdm.write(' {:9.4f} {:9.4f} {:9.4f} {:9.4f} {:9.4f} {:8} {:.8f}'.format( \
avg_loss_train, avg_loss_dev, acc_char, acc_word, f_score, is_best_str,
self.cfg.learning_rate))
print('{}\t{}\t{}\t{}\t{}\t{}\t{}'.format(check_id, avg_loss_train, avg_loss_dev,
acc_char, acc_word, f_score,
self.cfg.learning_rate), file=self.log_file)
tqdm.write(' [{:3d}|{:5d}] [Los trn] [Los dev] [Acc chr] [Acc wrd] [F-score]'
' [LR] [BS]'.format(self.cfg.epoch, check_id))
tqdm.write(' {:9.4f} {:9.4f} {:9.4f} {:9.4f} {:9.4f} {:8} {:.8f} {}' \
.format(avg_loss_train, avg_loss_dev, acc_char, acc_word, f_score, is_best_str,
self.cfg.learning_rate, batch_size))
print('{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}\t{}'.format( \
self.cfg.epoch, check_id, avg_loss_train, avg_loss_dev, acc_char, acc_word, f_score,
self.cfg.learning_rate, batch_size), file=self.log_file)
self.log_file.flush()
self.sum_wrt.add_scalar('loss-train', avg_loss_train, check_id)
self.sum_wrt.add_scalar('loss-dev', avg_loss_dev, check_id)
self.sum_wrt.add_scalar('acc-char', acc_char, check_id)
self.sum_wrt.add_scalar('acc-word', acc_word, check_id)
self.sum_wrt.add_scalar('f-score', f_score, check_id)
self.sum_wrt.add_scalar('learning-rate', self.cfg.learning_rate, check_id)
self.sum_wrt.add_scalar('batch-size', batch_size, check_id)
return is_best

def _is_best(self) -> bool:
Expand Down
6 changes: 4 additions & 2 deletions train/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,12 @@ def main():
type=float, default=0.001)
parser.add_argument('--lr-decay', help='learning rate decay <default: 0.9>', metavar='REAL',
type=float, default=0.9)
parser.add_argument('--batch-size', help='batch size <default: 1000>', metavar='INT', type=int,
default=1000)
parser.add_argument('--batch-size', help='batch size <default: 500>', metavar='INT', type=int,
default=500)
parser.add_argument('--check-step', help='check every N step <default: 10000>', metavar='INT',
type=int, default=10000)
parser.add_argument('--batch-grow', help='grow batch size by 1 per N step <default: 10000>',
metavar='INT', type=int, default=10000)
parser.add_argument('--patience', help='maximum patience count to revert model <default: 10>',
metavar='INT', type=int, default=10)
parser.add_argument('--gpu-num', help='GPU number to use <default: 0>', metavar='INT', type=int,
Expand Down

0 comments on commit 586fc01

Please sign in to comment.