Skip to content

Commit

Permalink
Merge pull request #1 from Hoiy/v0.2
Browse files Browse the repository at this point in the history
v0.2.1
  • Loading branch information
Hoiy authored Feb 9, 2019
2 parents bcaf504 + 2642833 commit 42bedc5
Show file tree
Hide file tree
Showing 21 changed files with 1,760 additions and 878 deletions.
43 changes: 24 additions & 19 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,30 +14,35 @@ berserker.load_model() # An one-off installation
berserker.tokenize('姑姑想過過過兒過過的生活。') # ['姑姑', '想', '過', '過', '過兒', '過過', '的', '生活', '。']
```

## Training
Berserker is fine-tuned over TPU with [pretrained Chinese BERT model](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip). It is connected with a single dense layer which is applied to all tokens to produce a sequence of [0, 1] output.

## Benchmark
A quick test shows that Berserker achieved state-of-the-art F1 measure on the [SIGHAN 2005](http://sighan.cs.uchicago.edu/bakeoff2005/) [dataset](http://sighan.cs.uchicago.edu/bakeoff2005/data/icwb2-data.zip).

The result below is trained with 15 epoches on each dataset with a batch size of 64 and a cutoff threshold 0.1.

| | PKU | CITYU | MSR |
|--------------------|----------|----------|----------|
| Liu et al. (2016) | **96.8** | -- | 97.3 |
| Yang et al. (2017) | 96.3 | 96.9 | 97.5 |
| Zhou et al. (2017) | 96.0 | -- | 97.8 |
| Cai et al. (2017) | 95.8 | 95.6 | 97.1 |
| Chen et al. (2017) | 94.3 | 95.6 | 96.0 |
| Wang and Xu (2017) | 96.5 | -- | 98.0 |
| Ma et al. (2018) | 96.1 | **97.2** | 98.1 |
|--------------------|----------|----------|----------|
| Berserker | 96.6 | 97.1 | **98.4** |
The table below shows that Berserker achieved state-of-the-art F1 measure on the [SIGHAN 2005](http://sighan.cs.uchicago.edu/bakeoff2005/) [dataset](http://sighan.cs.uchicago.edu/bakeoff2005/data/icwb2-data.zip).

The result below is trained with 15 epoches on each dataset with a batch size of 64.

| | PKU | CITYU | MSR | AS |
|--------------------|----------|----------|----------|----------|
| Liu et al. (2016) | **96.8** | -- | 97.3 | -- |
| Yang et al. (2017) | 96.3 | 96.9 | 97.5 | 95.7 |
| Zhou et al. (2017) | 96.0 | -- | 97.8 | -- |
| Cai et al. (2017) | 95.8 | 95.6 | 97.1 | -- |
| Chen et al. (2017) | 94.3 | 95.6 | 96.0 | 94.6 |
| Wang and Xu (2017) | 96.5 | -- | 98.0 | -- |
| Ma et al. (2018) | 96.1 | **97.2** | 98.1 | 96.2 |
|--------------------|----------|----------|----------|----------|
| Berserker | 96.6 | 97.1 | **98.4** | **96.5** |

Reference: [Ji Ma, Kuzman Ganchev, David Weiss - State-of-the-art Chinese Word Segmentation with Bi-LSTMs](https://arxiv.org/pdf/1808.06511.pdf)

More to come...
## Limitation
Since Berserker ~~is muscular~~ is based on BERT, it has a large model size (~300MB) and run slowly on CPU. Berserker is just a proof of concept on what could be achieved with BERT.

Currently the default model provided is trained with [SIGHAN 2005](http://sighan.cs.uchicago.edu/bakeoff2005/) [PKU dataset](http://sighan.cs.uchicago.edu/bakeoff2005/data/icwb2-data.zip). We plan to release more pretrained model in the future.

## Architecture
Berserker is fine-tuned over TPU with [pretrained Chinese BERT model](https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip). It is connected with a single dense layer which is applied to all tokens to produce a sequence of [0, 1] output, where 1 denote a split.

## Training
We provided the source code for training under the `trainer` subdirectory. Feel free to contact me if you need any help reproducing the result.

## Bonus Video
[<img src="https://img.youtube.com/vi/H_xmyvABZnE/maxres1.jpg" alt="Yachae!! BERSERKER!!"/>](https://www.youtube.com/watch?v=H_xmyvABZnE)
53 changes: 9 additions & 44 deletions berserker/__init__.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,16 @@
from .utils import maybe_download_unzip
from .tokenization import FullTokenizer
from berserker.utils import maybe_download_unzip
from pathlib import Path
from .transform import text_to_bert_inputs, preprocess, postprocess
import tensorflow as tf
import numpy as np


_assets_path = Path(__file__).parent / 'assets'
ASSETS_PATH = str(Path(__file__).parent / 'assets')
_models_path = Path(__file__).parent / 'models'

from berserker.transform import batch_preprocess, batch_postprocess

MAX_SEQ_LENGTH = 512
SEQ_LENGTH = MAX_SEQ_LENGTH - 2

BATCH_SIZE = 8

def load_model(model_name=None, verbose=True, force_download=False):
maybe_download_unzip(
Expand All @@ -25,46 +23,13 @@ def load_model(model_name=None, verbose=True, force_download=False):

def tokenize(text):
load_model()
tokenizer = FullTokenizer(
vocab_file=_assets_path / 'vocab.txt',
do_lower_case=False
)

bert_inputs_lens = []
bert_inputs = []
temp = text
while len(temp) > 0:
bert_input = text_to_bert_inputs(temp[:SEQ_LENGTH], MAX_SEQ_LENGTH, tokenizer)
bert_inputs_lens.append(len(preprocess(temp[:SEQ_LENGTH], tokenizer)[0]))
bert_inputs.append(bert_input)
temp = temp[SEQ_LENGTH:]
texts = [text]
bert_inputs, mappings, sizes = batch_preprocess(texts, MAX_SEQ_LENGTH, BATCH_SIZE)

berserker = tf.contrib.predictor.from_saved_model(
str(_models_path / '1547563491')
)
output = berserker({
'input_ids': [bi[0] for bi in bert_inputs],
'input_mask': [bi[1] for bi in bert_inputs],
'segment_ids': [bi[2] for bi in bert_inputs],
'truths': [bi[3] for bi in bert_inputs]
})

results = output['predictions']

results_itr = iter(results)
bert_inputs_itr = iter(bert_inputs)
bert_inputs_lens_itr = iter(bert_inputs_lens)

prediction = np.array([])
bert_tokens = []
temp = text

while len(temp) > 0:
(input_ids, _, _, _) = next(bert_inputs_itr)
bert_inputs_len = next(bert_inputs_lens_itr)
result = next(results_itr)
prediction = np.concatenate((prediction, result[1:1+bert_inputs_len]))
bert_tokens += tokenizer.convert_ids_to_tokens(input_ids[1:1+bert_inputs_len])
temp = temp[SEQ_LENGTH:]
bert_outputs = berserker(bert_inputs)
bert_outputs = [{'predictions': bo} for bo in bert_outputs['predictions']]

return postprocess(text, bert_tokens, prediction, threshold=0.5)
return batch_postprocess(texts, mappings, sizes, bert_inputs, bert_outputs, MAX_SEQ_LENGTH)[0]
File renamed without changes.
Loading

0 comments on commit 42bedc5

Please sign in to comment.