From 9cfb252694acfcbff12cbeb683ea9e4a65ad03d0 Mon Sep 17 00:00:00 2001 From: Eliyar Eziz Date: Wed, 11 Dec 2019 13:27:15 +0800 Subject: [PATCH] :sparkles: Add Scoring task. (#303) * :sparkles: Add Scoring task #301. * :sparkles: Add evaluate function for scoring task. * :bug: Fixing numeric check function. * :sparkles: Add round function to scoring model evaluation. * :green_heart: Fixing CI Build. * :memo: Writing docs. * :green_heart: Fixing CI Build. --- .gitignore | 3 +- .travis.yml | 7 +- README.md | 3 +- kashgari/__init__.py | 1 + kashgari/embeddings/base_embedding.py | 6 +- kashgari/macros.py | 1 + kashgari/processors/__init__.py | 1 + kashgari/processors/scoring_processor.py | 100 ++++++++++ kashgari/tasks/base_model.py | 12 +- kashgari/tasks/scoring/__init__.py | 14 ++ kashgari/tasks/scoring/base_model.py | 92 +++++++++ kashgari/tasks/scoring/models.py | 57 ++++++ mkdocs/docs/index.md | 166 ---------------- mkdocs/docs/tutorial/text-scoring.md | 232 +++++++++++++++++++++++ mkdocs/mkdocs.yml | 1 + mkdocs/readme.md | 2 + tests/scoring/__init__.py | 12 ++ tests/scoring/test_bi_lstm_model.py | 59 ++++++ tests/test_processor.py | 25 ++- 19 files changed, 618 insertions(+), 176 deletions(-) create mode 100644 kashgari/processors/scoring_processor.py create mode 100644 kashgari/tasks/scoring/__init__.py create mode 100644 kashgari/tasks/scoring/base_model.py create mode 100644 kashgari/tasks/scoring/models.py delete mode 100644 mkdocs/docs/index.md create mode 100644 mkdocs/docs/tutorial/text-scoring.md create mode 100644 tests/scoring/__init__.py create mode 100644 tests/scoring/test_bi_lstm_model.py diff --git a/.gitignore b/.gitignore index 7e7958bd..298fe66a 100644 --- a/.gitignore +++ b/.gitignore @@ -110,4 +110,5 @@ venv.bak/ .vscode venv-tf/* .pytype/ -mkdocs/site \ No newline at end of file +mkdocs/site +node_modules diff --git a/.travis.yml b/.travis.yml index e2c33089..929f595b 100644 --- a/.travis.yml +++ b/.travis.yml @@ -5,6 +5,8 @@ env: global: - COVERALLS_PARALLEL=true matrix: + # Scoring + - TEST_FILE=tests/scoring # Labeling - TEST_FILE=tests/labeling/ # classification part 1 @@ -17,6 +19,7 @@ env: - TEST_FILE=tests/test_custom_multi_output_classification.py # Embedding - TEST_FILE=tests/embedding/ + # Tokenizer - TEST_FILE=tests/test_tokenizer.py python: @@ -44,6 +47,7 @@ install: - pip install nose - python -c "import kashgari;print(f'kashgari version {kashgari.__version__}')" - git fetch --unshallow --quiet + - export PYTHONPATH=`pwd` script: nosetests --with-coverage --cover-html --cover-html-dir=htmlcov --cover-xml --cover-xml-file=coverage.xml --with-xunit @@ -68,8 +72,9 @@ jobs: - stage: Document python: "3.6" install: - - echo -e "machine github.com\n login ${GITHUB_TOKEN}" > ~/.netrc + - echo -e "machine github.com\n login ${GITHUB_TOKEN}" > ~/.netrc - pip install mkdocs mkdocs-material pymdown-extensions script: + - cp README.md mkdocs/docs/index.md - cd mkdocs - mkdocs gh-deploy --force --clean diff --git a/README.md b/README.md index d1d73dc8..1e89d0fd 100644 --- a/README.md +++ b/README.md @@ -67,7 +67,8 @@ Here is a set of quick tutorials to get you started with the library: - [Tutorial 1: Text Classification](https://kashgari.bmio.net/tutorial/text-classification/) - [Tutorial 2: Text Labeling](https://kashgari.bmio.net/tutorial/text-labeling/) -- [Tutorial 3: Language Embedding](https://kashgari.bmio.net/embeddings/) +- [Tutorial 3: Text Scoring](https://kashgari.bmio.net/tutorial/text-scoring/) +- [Tutorial 4: Language Embedding](https://kashgari.bmio.net/embeddings/) There are also articles and posts that illustrate how to use Kashgari: diff --git a/kashgari/__init__.py b/kashgari/__init__.py index 6368e17f..1cd711a9 100644 --- a/kashgari/__init__.py +++ b/kashgari/__init__.py @@ -23,6 +23,7 @@ custom_objects = keras_bert.get_custom_objects() CLASSIFICATION = TaskType.CLASSIFICATION LABELING = TaskType.LABELING +SCORING = TaskType.SCORING from kashgari.version import __version__ diff --git a/kashgari/embeddings/base_embedding.py b/kashgari/embeddings/base_embedding.py index b3d3bdc1..effe0aa4 100644 --- a/kashgari/embeddings/base_embedding.py +++ b/kashgari/embeddings/base_embedding.py @@ -16,7 +16,7 @@ from tensorflow import keras import kashgari -from kashgari.processors import ClassificationProcessor, LabelingProcessor +from kashgari.processors import ClassificationProcessor, LabelingProcessor, ScoringProcessor from kashgari.processors.base_processor import BaseProcessor L = keras.layers @@ -74,8 +74,10 @@ def __init__(self, self.processor = ClassificationProcessor() elif task == kashgari.LABELING: self.processor = LabelingProcessor() + elif task == kashgari.SCORING: + self.processor = ScoringProcessor() else: - raise ValueError() + raise ValueError('Need to set the processor param, value: {labeling, classification, scoring}') else: self.processor = processor diff --git a/kashgari/macros.py b/kashgari/macros.py index 898f6977..9888297c 100644 --- a/kashgari/macros.py +++ b/kashgari/macros.py @@ -23,6 +23,7 @@ class TaskType(object): CLASSIFICATION = 'classification' LABELING = 'labeling' + SCORING = 'scoring' class Config(object): diff --git a/kashgari/processors/__init__.py b/kashgari/processors/__init__.py index fc0abc08..4f72d3e6 100644 --- a/kashgari/processors/__init__.py +++ b/kashgari/processors/__init__.py @@ -10,3 +10,4 @@ from kashgari.processors.classification_processor import ClassificationProcessor from kashgari.processors.labeling_processor import LabelingProcessor +from kashgari.processors.scoring_processor import ScoringProcessor diff --git a/kashgari/processors/scoring_processor.py b/kashgari/processors/scoring_processor.py new file mode 100644 index 00000000..899f5981 --- /dev/null +++ b/kashgari/processors/scoring_processor.py @@ -0,0 +1,100 @@ +# encoding: utf-8 + +# author: BrikerMan +# contact: eliyar917@gmail.com +# blog: https://eliyar.biz + +# file: scoring_processor.py +# time: 11:10 上午 + +from typing import List, Optional + +import numpy as np + +import kashgari +from kashgari import utils +from kashgari.processors.base_processor import BaseProcessor + + +def is_numeric(obj): + attrs = ['__add__', '__sub__', '__mul__', '__truediv__', '__pow__'] + return all(hasattr(obj, attr) for attr in attrs) + + +class ScoringProcessor(BaseProcessor): + """ + Corpus Pre Processor class + """ + + def __init__(self, output_dim=None, **kwargs): + super(ScoringProcessor, self).__init__(**kwargs) + self.output_dim = output_dim + + def info(self): + info = super(ScoringProcessor, self).info() + info['task'] = kashgari.SCORING + return info + + def _build_label_dict(self, + label_list: List[List[float]]): + """ + Build label2idx dict for sequence labeling task + + Args: + label_list: corpus label list + """ + if self.output_dim is None: + label_sample = label_list[0] + if isinstance(label_sample, np.ndarray) and len(label_sample.shape) == 1: + self.output_dim = label_sample.shape[0] + elif is_numeric(label_sample): + self.output_dim = 1 + elif isinstance(label_sample, list): + self.output_dim = len(label_sample) + else: + raise ValueError('Scoring Label Sample must be a float, float array or 1D numpy array') + # np_labels = np.array(label_list) + # if np_labels.max() > 1 or np_labels.min() < 0: + # raise ValueError('Scoring Label Sample must be in range[0,1]') + + def process_y_dataset(self, + data: List[List[str]], + max_len: Optional[int] = None, + subset: Optional[List[int]] = None) -> np.ndarray: + if subset is not None: + target = utils.get_list_subset(data, subset) + else: + target = data[:] + y = np.array(target) + return y + + def numerize_token_sequences(self, + sequences: List[List[str]]): + + result = [] + for seq in sequences: + if self.add_bos_eos: + seq = [self.token_bos] + seq + [self.token_eos] + unk_index = self.token2idx[self.token_unk] + result.append([self.token2idx.get(token, unk_index) for token in seq]) + return result + + def numerize_label_sequences(self, + sequences: List[List[str]]) -> List[List[int]]: + return sequences + + def reverse_numerize_label_sequences(self, + sequences, + lengths=None): + return sequences + + +if __name__ == "__main__": + from kashgari.corpus import SMP2018ECDTCorpus + + x, y = SMP2018ECDTCorpus.load_data() + x = x[:3] + y = [0.2, 0.3, 0.2] + p = ScoringProcessor() + p.analyze_corpus(x, y) + print(p.process_y_dataset(y)) diff --git a/kashgari/tasks/base_model.py b/kashgari/tasks/base_model.py index 64d5abf1..3234e693 100644 --- a/kashgari/tasks/base_model.py +++ b/kashgari/tasks/base_model.py @@ -414,12 +414,16 @@ def predict(self, lengths = [len(sen) for sen in x_data] tensor = self.embedding.process_x_dataset(x_data) pred = self.tf_model.predict(tensor, batch_size=batch_size, **predict_kwargs) - res = self.embedding.reverse_numerize_label_sequences(pred.argmax(-1), + if self.task == 'scoring': + t_pred = pred + else: + t_pred = pred.argmax(-1) + res = self.embedding.reverse_numerize_label_sequences(t_pred, lengths) if debug_info: - logging.info('input: {}'.format(tensor)) - logging.info('output: {}'.format(pred)) - logging.info('output argmax: {}'.format(pred.argmax(-1))) + print('input: {}'.format(tensor)) + print('output: {}'.format(pred)) + print('output argmax: {}'.format(t_pred)) return res def evaluate(self, diff --git a/kashgari/tasks/scoring/__init__.py b/kashgari/tasks/scoring/__init__.py new file mode 100644 index 00000000..9e64f574 --- /dev/null +++ b/kashgari/tasks/scoring/__init__.py @@ -0,0 +1,14 @@ +# encoding: utf-8 + +# author: BrikerMan +# contact: eliyar917@gmail.com +# blog: https://eliyar.biz + +# file: __init__.py +# time: 11:36 上午 + + +from kashgari.tasks.scoring.models import BiLSTM_Model + +if __name__ == "__main__": + pass diff --git a/kashgari/tasks/scoring/base_model.py b/kashgari/tasks/scoring/base_model.py new file mode 100644 index 00000000..fbe0b5d3 --- /dev/null +++ b/kashgari/tasks/scoring/base_model.py @@ -0,0 +1,92 @@ +# encoding: utf-8 + +# author: BrikerMan +# contact: eliyar917@gmail.com +# blog: https://eliyar.biz + +# file: base_model.py +# time: 11:36 上午 + + +from typing import Callable +from typing import Dict, Any + +import numpy as np +from sklearn import metrics + +from kashgari.tasks.base_model import BaseModel + + +class BaseScoringModel(BaseModel): + """Base Sequence Labeling Model""" + + __task__ = 'scoring' + + @classmethod + def get_default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]: + raise NotImplementedError + + def compile_model(self, **kwargs): + if kwargs.get('loss') is None: + kwargs['loss'] = 'mse' + if kwargs.get('optimizer') is None: + kwargs['optimizer'] = 'rmsprop' + if kwargs.get('metrics') is None: + kwargs['metrics'] = ['mae'] + super(BaseScoringModel, self).compile_model(**kwargs) + + def evaluate(self, + x_data, + y_data, + batch_size=None, + should_round: bool = False, + round_func: Callable = None, + digits=4, + debug_info=False) -> Dict: + """ + Build a text report showing the main classification metrics. + + Args: + x_data: + y_data: + batch_size: + should_round: + round_func: + digits: + debug_info: + + Returns: + + """ + y_pred = self.predict(x_data, batch_size=batch_size) + + if should_round: + if round_func is None: + round_func = np.round + print(self.processor.output_dim) + if self.processor.output_dim != 1: + raise ValueError('Evaluate with round function only accept 1D output') + y_pred = [round_func(i) for i in y_pred] + report = metrics.classification_report(y_data, + y_pred, + digits=digits) + + report_dic = metrics.classification_report(y_data, + y_pred, + output_dict=True, + digits=digits) + print(report) + else: + mean_squared_error = metrics.mean_squared_error(y_data, y_pred) + r2_score = metrics.r2_score(y_data, y_pred) + report_dic = { + 'mean_squared_error': mean_squared_error, + 'r2_score': r2_score + } + print(f"mean_squared_error : {mean_squared_error}\n" + f"r2_score : {r2_score}") + return report_dic + + +if __name__ == "__main__": + pass diff --git a/kashgari/tasks/scoring/models.py b/kashgari/tasks/scoring/models.py new file mode 100644 index 00000000..fcd8b398 --- /dev/null +++ b/kashgari/tasks/scoring/models.py @@ -0,0 +1,57 @@ +# encoding: utf-8 + +# author: BrikerMan +# contact: eliyar917@gmail.com +# blog: https://eliyar.biz + +# file: models.py +# time: 11:38 上午 + + +import logging +from typing import Dict, Any + +from tensorflow import keras + +from kashgari.tasks.scoring.base_model import BaseScoringModel +from kashgari.layers import L + + +class BiLSTM_Model(BaseScoringModel): + + @classmethod + def get_default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]: + return { + 'layer_bi_lstm': { + 'units': 128, + 'return_sequences': False + }, + 'layer_dense': { + 'activation': 'linear' + } + } + + def build_model_arc(self): + output_dim = self.processor.output_dim + config = self.hyper_parameters + embed_model = self.embedding.embed_model + + layer_bi_lstm = L.Bidirectional(L.LSTM(**config['layer_bi_lstm'])) + layer_dense = L.Dense(output_dim, **config['layer_dense']) + + tensor = layer_bi_lstm(embed_model.output) + output_tensor = layer_dense(tensor) + + self.tf_model = keras.Model(embed_model.inputs, output_tensor) + + +if __name__ == "__main__": + from kashgari.corpus import SMP2018ECDTCorpus + import numpy as np + + x, y = SMP2018ECDTCorpus.load_data('valid') + y = np.random.random((len(x), 4)) + model = BiLSTM_Model() + model.fit(x, y) + print(model.predict(x[:10])) + diff --git a/mkdocs/docs/index.md b/mkdocs/docs/index.md deleted file mode 100644 index cea77c65..00000000 --- a/mkdocs/docs/index.md +++ /dev/null @@ -1,166 +0,0 @@ -

- Kashgari -

- -

- - GitHub - - - - - - Coverage Status - - - - - - PyPI - -

- -🎉🎉🎉 We are proud to announce that we entirely rewritten Kashgari with tf.keras, now Kashgari comes with easier to understand API and is faster! 🎉🎉🎉 - -## Overview - -Kashgari is a simple and powerful NLP Transfer learning framework, build a state-of-art model in 5 minutes for named entity recognition (NER), part-of-speech tagging (PoS), and text classification tasks. - -- **Human-friendly**. Kashgari's code is straightforward, well documented and tested, which makes it very easy to understand and modify. -- **Powerful and simple**. Kashgari allows you to apply state-of-the-art natural language processing (NLP) models to your text, such as named entity recognition (NER), part-of-speech tagging (PoS) and classification. -- **Built-in transfer learning**. Kashgari built-in pre-trained BERT and Word2vec embedding models, which makes it very simple to transfer learning to train your model. -- **Fully scalable**. Kashgari provide a simple, fast, and scalable environment for fast experimentation, train your models and experiment with new approaches using different embeddings and model structure. -- **Production Ready**. Kashgari could export model with `SavedModel` format for tensorflow serving, you could directly deploy it on cloud. - -## Our Goal - -- **Academic users** Easier Experimentation to prove their hypothesis without coding from scratch. -- **NLP beginners** Learn how to build an NLP project with production level code quality. -- **NLP developers** Build a production level classification/labeling model within minutes. - -## Performance - -| Task | Language | Dataset | Score | Detail | -| ------------------------ | -------- | ------------------------- | -------------- | ------------------------------------------------------------------------------------------------------------------ | -| Named Entity Recognition | Chinese | People's Daily Ner Corpus | **94.46** (F1) | [Text Labeling Performance Report](https://kashgari.bmio.net/tutorial/text-labeling/#performance-report) | - -## Tutorials - -Here is a set of quick tutorials to get you started with the library: - -- [Tutorial 1: Text Classification](https://kashgari.bmio.net/tutorial/text-classification/) -- [Tutorial 2: Text Labeling](https://kashgari.bmio.net/tutorial/text-labeling/) -- [Tutorial 3: Language Embedding](https://kashgari.bmio.net/embeddings/) - -There are also articles and posts that illustrate how to use Kashgari: - -- [15 分钟搭建中文文本分类模型](https://eliyar.biz/nlp_chinese_text_classification_in_15mins/) -- [基于 BERT 的中文命名实体识别(NER)](https://eliyar.biz/nlp_chinese_bert_ner/) -- [BERT/ERNIE 文本分类和部署](https://eliyar.biz/nlp_train_and_deploy_bert_text_classification/) -- [五分钟搭建一个基于BERT的NER模型](https://www.jianshu.com/p/1d6689851622) -- [Multi-Class Text Classification with Kashgari in 15 minutes](https://medium.com/@BrikerMan/multi-class-text-classification-with-kashgari-in-15mins-c3e744ce971d) - -## Quick start - -### Requirements and Installation - -🎉🎉🎉 We renamed again for consistency and clarity. From now on, it is all `kashgari`. 🎉🎉🎉 - -The project is based on Python 3.6+, because it is 2019 and type hinting is cool. - -| Backend | pypi version | desc | -| ---------------- | -------------------------------------- | --------------- | -| TensorFlow 2.x | `pip install 'kashgari>=2.0.0'` | coming soon | -| TensorFlow 1.14+ | `pip install 'kashgari>=1.0.0,<2.0.0'` | current version | -| Keras | `pip install 'kashgari<1.0.0'` | legacy version | - -[Find more info about the name changing.](https://github.com/BrikerMan/Kashgari/releases/tag/v1.0.0) - -### Example Usage - -lets run a NER labeling model with Bi_LSTM Model. - -```python -from kashgari.corpus import ChineseDailyNerCorpus -from kashgari.tasks.labeling import BiLSTM_Model - -train_x, train_y = ChineseDailyNerCorpus.load_data('train') -test_x, test_y = ChineseDailyNerCorpus.load_data('test') -valid_x, valid_y = ChineseDailyNerCorpus.load_data('valid') - -model = BiLSTM_Model() -model.fit(train_x, train_y, valid_x, valid_y, epochs=50) - -""" -_________________________________________________________________ -Layer (type) Output Shape Param # -================================================================= -input (InputLayer) (None, 97) 0 -_________________________________________________________________ -layer_embedding (Embedding) (None, 97, 100) 320600 -_________________________________________________________________ -layer_blstm (Bidirectional) (None, 97, 256) 235520 -_________________________________________________________________ -layer_dropout (Dropout) (None, 97, 256) 0 -_________________________________________________________________ -layer_time_distributed (Time (None, 97, 8) 2056 -_________________________________________________________________ -activation_7 (Activation) (None, 97, 8) 0 -================================================================= -Total params: 558,176 -Trainable params: 558,176 -Non-trainable params: 0 -_________________________________________________________________ -Train on 20864 samples, validate on 2318 samples -Epoch 1/50 -20864/20864 [==============================] - 9s 417us/sample - loss: 0.2508 - acc: 0.9333 - val_loss: 0.1240 - val_acc: 0.9607 - -""" -``` - -### Run with GPT-2 Embedding - -```python -import kashgari -from kashgari.embeddings import GPT2Embedding -from kashgari.corpus import ChineseDailyNerCorpus -from kashgari.tasks.labeling import BiGRU_Model - -train_x, train_y = ChineseDailyNerCorpus.load_data('train') -valid_x, valid_y = ChineseDailyNerCorpus.load_data('valid') - -gpt2_embedding = GPT2Embedding('', - task=kashgari.LABELING, - sequence_length=30) -model = BiGRU_Model(gpt2_embedding) -model.fit(train_x, train_y, valid_x, valid_y, epochs=50) -``` - -### Run with Bert Embedding - -```python -import kashgari -from kashgari.embeddings import BERTEmbedding -from kashgari.tasks.labeling import BiGRU_Model -from kashgari.corpus import ChineseDailyNerCorpus - -bert_embedding = BERTEmbedding('', - task=kashgari.LABELING, - sequence_length=30) -model = BiGRU_Model(bert_embedding) - -train_x, train_y = ChineseDailyNerCorpus.load_data() -model.fit(train_x, train_y) -``` - -## Contributing - -Thanks for your interest in contributing! There are many ways to get involved; start with the [contributor guidelines](https://kashgari.bmio.net/about/contributing/) and then check these open issues for specific tasks. - -## Reference - -This library is inspired by and references following frameworks and papers. - -- [flair - A very simple framework for state-of-the-art Natural Language Processing (NLP)](https://github.com/zalandoresearch/flair) -- [anago - Bidirectional LSTM-CRF and ELMo for Named-Entity Recognition, Part-of-Speech Tagging](https://github.com/Hironsan/anago) -- [Chinese-Word-Vectors](https://github.com/Embedding/Chinese-Word-Vectors) diff --git a/mkdocs/docs/tutorial/text-scoring.md b/mkdocs/docs/tutorial/text-scoring.md new file mode 100644 index 00000000..d3c9774c --- /dev/null +++ b/mkdocs/docs/tutorial/text-scoring.md @@ -0,0 +1,232 @@ +# Text Scoring Model + +Kashgari provides several models for text scoring, which could be use for Sentiment analysis tasks. Model input is text and output is continuous float value. +All labeling models inherit from the `BaseScoringModel`. +You could easily switch from one model to another just by changing one line of code. + +## Available Models + +| Name | info | +| --------------------- | ---- | +| BiLSTM\_Model | | + +## Train basic scoring model + + + +```python +# Load build-in corpus. +from kashgari.corpus import SMP2018ECDTCorpus + +# Sample x is tokenized text, y is float value +train_x = [['Hello', 'world'], ['Hello', 'Kashgari'], ['I', 'hate', 'you']] +train_y = [5.0, 5.0, 1.2] + +valid_x, valid_y = train_x, train_y +test_x, test_x = train_x, train_y +``` + +Then train our first model. All models provided some APIs, so you could use any scoring model here. + +```python +import kashgari +from kashgari.tasks.scoring import BiLSTM_Model + +import logging +logging.basicConfig(level='DEBUG') + +model = BiLSTM_Model() +model.fit(train_x, train_y, valid_x, valid_y) + +# Evaluate the model +model.evaluate(test_x, test_y) + +# Evaluate the model with round funcion +model.evaluate(test_x, test_y, should_round=True) + +# Model data will save to `saved_scoring_model` folder +model.save('saved_scoring_model') + +# Load saved model +loaded_model = kashgari.utils.load_model('saved_scoring_model') +loaded_model.predict(test_x[:10]) + +# To continue training, compile the newly loaded model first +loaded_model.compile_model() +model.fit(train_x, train_y, valid_x, valid_y) +``` + +That's all your need to do. Easy right. + +## Text scoring with transfer learning + +Kashgari provides varies Language model Embeddings for transfer learning. Here is the example for BERT Embedding. + +```python +import kashgari +from kashgari.tasks.scoring import BiGRU_Model +from kashgari.embeddings import BERTEmbedding + +import logging +logging.basicConfig(level='DEBUG') + +bert_embed = BERTEmbedding('', + task=kashgari.SCORING, + sequence_length=100) +model = BiGRU_Model(bert_embed) +model.fit(train_x, train_y, valid_x, valid_y) +``` + +You could replace bert_embedding with any Embedding class in `kashgari.embeddings`. More info about Embedding: LINK THIS. + +## Adjust model's hyper-parameters + +You could easily change model's hyper-parameters. For example, we change the lstm unit in `BiLSTM_Model` from 128 to 32. + +```python +from kashgari.tasks.scoring import BiLSTM_Model + +hyper = BiLSTM_Model.get_default_hyper_parameters() +print(hyper) +# {'layer_bi_lstm': {'units': 128, 'return_sequences': False}, 'layer_dense': {'activation': 'softmax'}} + +hyper['layer_bi_lstm']['units'] = 32 + +model = BiLSTM_Model(hyper_parameters=hyper) +``` + +## Use custom optimizer + +Kashgari already supports using customized optimizer, like RAdam. + +```python +from kashgari.corpus import SMP2018ECDTCorpus +from kashgari.tasks.scoring import BiLSTM_Model +# Remember to import kashgari before than RAdam +from keras_radam import RAdam + +model = BiLSTM_Model() +# This step will build token dict, label dict and model structure +model.build_model(train_x, train_y, valid_x, valid_y) +# Compile model with custom optimizer, you can also customize loss and metrics. +optimizer = RAdam() +model.compile_model(optimizer=optimizer) + +# Train model +model.fit(train_x, train_y, valid_x, valid_y) +``` + +## Use callbacks + +Kashgari is based on keras so that you could use all of the [tf.keras callbacks](https://www.tensorflow.org/api_docs/python/tf/keras/callbacks) directly with +Kashgari model. For example, here is how to visualize training with tensorboard. + +```python +from tensorflow.python import keras +from kashgari.tasks.scoring import BiGRU_Model +from kashgari.callbacks import EvalCallBack + +import logging +logging.basicConfig(level='DEBUG') + +model = BiGRU_Model() + +tf_board_callback = keras.callbacks.TensorBoard(log_dir='./logs', update_freq=1000) + +model.fit(train_x, + train_y, + valid_x, + valid_y, + batch_size=100, + callbacks=[tf_board_callback]) +``` + +## Customize your own model + +It is very easy and straightforward to build your own customized model, +just inherit the `BaseScoringModel` and implement the `get_default_hyper_parameters()` function and `build_model_arc()` function. + +```python +from typing import Dict, Any + +from tensorflow import keras + +from kashgari.tasks.scoring.base_model import BaseScoringModel +from kashgari.layers import L + +import logging +logging.basicConfig(level='DEBUG') + + +class DoubleBiLSTMModel(BaseScoringModel): + """Bidirectional LSTM Sequence Labeling Model""" + + @classmethod + def get_default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]: + """ + Get hyper parameters of model + Returns: + hyper parameters dict + """ + return { + 'layer_blstm1': { + 'units': 128, + 'return_sequences': True + }, + 'layer_blstm2': { + 'units': 128, + 'return_sequences': False + }, + 'layer_dropout': { + 'rate': 0.4 + }, + 'layer_time_distributed': {}, + 'layer_activation': { + 'activation': 'linear' + } + } + + def build_model_arc(self): + """ + build model architectural + """ + output_dim = self.processor.output_dim + config = self.hyper_parameters + embed_model = self.embedding.embed_model + + # Define your layers + layer_blstm1 = L.Bidirectional(L.LSTM(**config['layer_blstm1']), + name='layer_blstm1') + layer_blstm2 = L.Bidirectional(L.LSTM(**config['layer_blstm2']), + name='layer_blstm2') + + layer_dropout = L.Dropout(**config['layer_dropout'], + name='layer_dropout') + + layer_time_distributed = L.TimeDistributed(L.Dense(output_dim, + **config['layer_time_distributed']), + name='layer_time_distributed') + layer_activation = L.Activation(**config['layer_activation']) + + # Define tensor flow + tensor = layer_blstm1(embed_model.output) + tensor = layer_blstm2(tensor) + tensor = layer_dropout(tensor) + tensor = layer_time_distributed(tensor) + output_tensor = layer_activation(tensor) + + # Init model + self.tf_model = keras.Model(embed_model.inputs, output_tensor) + +model = DoubleBLSTMModel() +model.fit(train_x, train_y, valid_x, valid_y) +``` + +## Speed up with CuDNN cell + +You can speed up training and inferencing process using [CuDNN cell](https://stackoverflow.com/questions/46767001/what-is-cudnn-implementation-of-rnn-cells-in-tensorflow). CuDNNLSTM and CuDNNGRU layers are much faster than LSTM and GRU layer, but they must be used on GPU. If you want to train on GPU and inferencing on CPU, you cannot use CuDNN cells. + +```python +# Enable use cudnn cell +kashgari.config.use_cudnn_cell = True +``` diff --git a/mkdocs/mkdocs.yml b/mkdocs/mkdocs.yml index 16f70db3..241d83ff 100644 --- a/mkdocs/mkdocs.yml +++ b/mkdocs/mkdocs.yml @@ -69,6 +69,7 @@ nav: - Tutorials: - tutorial/text-classification.md - tutorial/text-labeling.md + - tutorial/text-scoring.md - Embeddings: - embeddings/index.md diff --git a/mkdocs/readme.md b/mkdocs/readme.md index e646237e..89398bfb 100644 --- a/mkdocs/readme.md +++ b/mkdocs/readme.md @@ -3,7 +3,9 @@ Run with command. ```bash +cd mkdocs pip install mkdocs mkdocs-material pymdown-extensions +cp ../README.md ./docs/index.md mkdocs serve ``` diff --git a/tests/scoring/__init__.py b/tests/scoring/__init__.py new file mode 100644 index 00000000..8d3c6103 --- /dev/null +++ b/tests/scoring/__init__.py @@ -0,0 +1,12 @@ +# encoding: utf-8 + +# author: BrikerMan +# contact: eliyar917@gmail.com +# blog: https://eliyar.biz + +# file: __init__.py +# time: 10:50 上午 + + +if __name__ == "__main__": + pass diff --git a/tests/scoring/test_bi_lstm_model.py b/tests/scoring/test_bi_lstm_model.py new file mode 100644 index 00000000..8b351bb0 --- /dev/null +++ b/tests/scoring/test_bi_lstm_model.py @@ -0,0 +1,59 @@ +# encoding: utf-8 + +# author: BrikerMan +# contact: eliyar917@gmail.com +# blog: https://eliyar.biz + +# file: blstm_model.py +# time: 12:17 下午 +import os +import tempfile +import time +import unittest +import kashgari +import numpy as np + +from tests.corpus import NERCorpus +from kashgari.tasks.scoring import BiLSTM_Model + + +class TestBiLSTM_Model(unittest.TestCase): + @classmethod + def setUpClass(cls): + cls.model_class = BiLSTM_Model + + def test_basic_use_build(self): + x, _ = NERCorpus.load_corpus() + y = np.random.random((len(x),)) + model = self.model_class() + model.fit(x, y, epochs=1) + res = model.predict(x[:20]) + model_path = os.path.join(tempfile.gettempdir(), str(time.time())) + model.save(model_path) + + pd_model_path = os.path.join(tempfile.gettempdir(), str(time.time())) + kashgari.utils.convert_to_saved_model(model, + pd_model_path) + + new_model = kashgari.utils.load_model(model_path) + new_res = new_model.predict(x[:20]) + assert np.array_equal(new_res, res) + + new_model.compile_model() + model.fit(x, y, x, y, epochs=1) + model.evaluate(x, y) + + rounded_y = np.round(y) + model.evaluate(x, rounded_y, should_round=True) + + def test_multi_output(self): + x, _ = NERCorpus.load_corpus() + y = np.random.random((len(x), 4)) + model = self.model_class() + model.fit(x, y, x, y, epochs=1) + with self.assertRaises(ValueError): + model.evaluate(x, y, should_round=True) + + +if __name__ == "__main__": + pass diff --git a/tests/test_processor.py b/tests/test_processor.py index 09a4e014..b62dd960 100644 --- a/tests/test_processor.py +++ b/tests/test_processor.py @@ -12,8 +12,9 @@ import tempfile import unittest import numpy as np +import random from kashgari import utils -from kashgari.processors import ClassificationProcessor, LabelingProcessor +from kashgari.processors import ClassificationProcessor, LabelingProcessor, ScoringProcessor from kashgari.corpus import SMP2018ECDTCorpus, ChineseDailyNerCorpus from kashgari.tasks.classification import BiGRU_Model @@ -120,5 +121,27 @@ def test_load(self): assert np.array_equal(process_x_0, process_x_1) +class TestScoringProcessor(unittest.TestCase): + + def test_build_dict(self): + x = sample_train_x + y = [random.random() for _ in range(len(x))] + p = ScoringProcessor() + p.analyze_corpus(x, y) + assert p.output_dim == 1 + + y = [[random.random(), random.random(), random.random()] for _ in range(len(x))] + p = ScoringProcessor() + p.analyze_corpus(x, y) + + assert p.output_dim == 3 + + y = np.random.random((len(x), 3)) + p = ScoringProcessor() + p.analyze_corpus(x, y) + print(p.output_dim) + assert p.output_dim == 3 + + if __name__ == "__main__": print("Hello world")