-
-
Notifications
You must be signed in to change notification settings - Fork 439
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
12 changed files
with
326 additions
and
8 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,95 @@ | ||
# encoding: utf-8 | ||
|
||
# author: BrikerMan | ||
# contact: eliyar917@gmail.com | ||
# blog: https://eliyar.biz | ||
|
||
# file: scoring_processor.py | ||
# time: 11:10 上午 | ||
|
||
from typing import List, Optional | ||
|
||
import numpy as np | ||
|
||
import kashgari | ||
from kashgari import utils | ||
from kashgari.processors.base_processor import BaseProcessor | ||
|
||
|
||
class ScoringProcessor(BaseProcessor): | ||
""" | ||
Corpus Pre Processor class | ||
""" | ||
|
||
def __init__(self, output_dim=None, **kwargs): | ||
super(ScoringProcessor, self).__init__(**kwargs) | ||
self.output_dim = output_dim | ||
|
||
def info(self): | ||
info = super(ScoringProcessor, self).info() | ||
info['task'] = kashgari.SCORING | ||
return info | ||
|
||
def _build_label_dict(self, | ||
label_list: List[List[float]]): | ||
""" | ||
Build label2idx dict for sequence labeling task | ||
Args: | ||
label_list: corpus label list | ||
""" | ||
if self.output_dim is None: | ||
label_sample = label_list[0] | ||
if isinstance(label_sample, (float, int)): | ||
self.output_dim = 1 | ||
elif isinstance(label_sample, list): | ||
self.output_dim = len(label_sample) | ||
elif isinstance(label_sample, np.ndarray) and len(label_sample.shape) == 1: | ||
self.output_dim = label_sample.shape[0] | ||
else: | ||
raise ValueError('Scoring Label Sample must be a float, float array or 1D numpy array') | ||
# np_labels = np.array(label_list) | ||
# if np_labels.max() > 1 or np_labels.min() < 0: | ||
# raise ValueError('Scoring Label Sample must be in range[0,1]') | ||
|
||
def process_y_dataset(self, | ||
data: List[List[str]], | ||
max_len: Optional[int] = None, | ||
subset: Optional[List[int]] = None) -> np.ndarray: | ||
if subset is not None: | ||
target = utils.get_list_subset(data, subset) | ||
else: | ||
target = data[:] | ||
y = np.array(target) | ||
return y | ||
|
||
def numerize_token_sequences(self, | ||
sequences: List[List[str]]): | ||
|
||
result = [] | ||
for seq in sequences: | ||
if self.add_bos_eos: | ||
seq = [self.token_bos] + seq + [self.token_eos] | ||
unk_index = self.token2idx[self.token_unk] | ||
result.append([self.token2idx.get(token, unk_index) for token in seq]) | ||
return result | ||
|
||
def numerize_label_sequences(self, | ||
sequences: List[List[str]]) -> List[List[int]]: | ||
return sequences | ||
|
||
def reverse_numerize_label_sequences(self, | ||
sequences, | ||
lengths=None): | ||
return sequences | ||
|
||
|
||
if __name__ == "__main__": | ||
from kashgari.corpus import SMP2018ECDTCorpus | ||
|
||
x, y = SMP2018ECDTCorpus.load_data() | ||
x = x[:3] | ||
y = [0.2, 0.3, 0.2] | ||
p = ScoringProcessor() | ||
p.analyze_corpus(x, y) | ||
print(p.process_y_dataset(y)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
# encoding: utf-8 | ||
|
||
# author: BrikerMan | ||
# contact: eliyar917@gmail.com | ||
# blog: https://eliyar.biz | ||
|
||
# file: __init__.py | ||
# time: 11:36 上午 | ||
|
||
|
||
from kashgari.tasks.scoring.models import BiLSTM_Model | ||
|
||
if __name__ == "__main__": | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# encoding: utf-8 | ||
|
||
# author: BrikerMan | ||
# contact: eliyar917@gmail.com | ||
# blog: https://eliyar.biz | ||
|
||
# file: base_model.py | ||
# time: 11:36 上午 | ||
|
||
|
||
from typing import Dict, Any, Tuple | ||
|
||
import random | ||
import logging | ||
from seqeval.metrics import classification_report | ||
from seqeval.metrics.sequence_labeling import get_entities | ||
|
||
from kashgari.tasks.base_model import BaseModel | ||
|
||
|
||
class BaseScoringModel(BaseModel): | ||
"""Base Sequence Labeling Model""" | ||
|
||
__task__ = 'scoring' | ||
|
||
@classmethod | ||
def get_default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]: | ||
raise NotImplementedError | ||
|
||
def compile_model(self, **kwargs): | ||
if kwargs.get('loss') is None: | ||
kwargs['loss'] = 'mse' | ||
if kwargs.get('optimizer') is None: | ||
kwargs['optimizer'] = 'rmsprop' | ||
if kwargs.get('metrics') is None: | ||
kwargs['metrics'] = ['mae'] | ||
super(BaseScoringModel, self).compile_model(**kwargs) | ||
|
||
def evaluate(self, | ||
x_data, | ||
y_data, | ||
batch_size=None, | ||
digits=4, | ||
debug_info=False) -> Tuple[float, float, Dict]: | ||
""" | ||
Build a text report showing the main classification metrics. | ||
Args: | ||
x_data: | ||
y_data: | ||
batch_size: | ||
digits: | ||
debug_info: | ||
Returns: | ||
""" | ||
pass | ||
return {} | ||
|
||
|
||
if __name__ == "__main__": | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,57 @@ | ||
# encoding: utf-8 | ||
|
||
# author: BrikerMan | ||
# contact: eliyar917@gmail.com | ||
# blog: https://eliyar.biz | ||
|
||
# file: models.py | ||
# time: 11:38 上午 | ||
|
||
|
||
import logging | ||
from typing import Dict, Any | ||
|
||
from tensorflow import keras | ||
|
||
from kashgari.tasks.scoring.base_model import BaseScoringModel | ||
from kashgari.layers import L | ||
|
||
|
||
class BiLSTM_Model(BaseScoringModel): | ||
|
||
@classmethod | ||
def get_default_hyper_parameters(cls) -> Dict[str, Dict[str, Any]]: | ||
return { | ||
'layer_bi_lstm': { | ||
'units': 128, | ||
'return_sequences': False | ||
}, | ||
'layer_dense': { | ||
'activation': 'linear' | ||
} | ||
} | ||
|
||
def build_model_arc(self): | ||
output_dim = self.processor.output_dim | ||
config = self.hyper_parameters | ||
embed_model = self.embedding.embed_model | ||
|
||
layer_bi_lstm = L.Bidirectional(L.LSTM(**config['layer_bi_lstm'])) | ||
layer_dense = L.Dense(output_dim, **config['layer_dense']) | ||
|
||
tensor = layer_bi_lstm(embed_model.output) | ||
output_tensor = layer_dense(tensor) | ||
|
||
self.tf_model = keras.Model(embed_model.inputs, output_tensor) | ||
|
||
|
||
if __name__ == "__main__": | ||
from kashgari.corpus import SMP2018ECDTCorpus | ||
import numpy as np | ||
|
||
x, y = SMP2018ECDTCorpus.load_data('valid') | ||
y = np.random.random((len(x), 4)) | ||
model = BiLSTM_Model() | ||
model.fit(x, y) | ||
print(model.predict(x[:10])) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
# encoding: utf-8 | ||
|
||
# author: BrikerMan | ||
# contact: eliyar917@gmail.com | ||
# blog: https://eliyar.biz | ||
|
||
# file: blstm_model.py | ||
# time: 12:17 下午 | ||
import os | ||
import tempfile | ||
import time | ||
import unittest | ||
import kashgari | ||
import numpy as np | ||
|
||
from tests.corpus import NERCorpus | ||
from kashgari.tasks.scoring import BiLSTM_Model | ||
|
||
|
||
class TestBiLSTM_Model(unittest.TestCase): | ||
@classmethod | ||
def setUpClass(cls): | ||
cls.model_class = BiLSTM_Model | ||
|
||
def test_basic_use_build(self): | ||
x, _ = NERCorpus.load_corpus() | ||
y = np.random.random((len(x), 4)) | ||
model = self.model_class() | ||
model.fit(x, y) | ||
res = model.predict(x[:20]) | ||
model_path = os.path.join(tempfile.gettempdir(), str(time.time())) | ||
model.save(model_path) | ||
|
||
pd_model_path = os.path.join(tempfile.gettempdir(), str(time.time())) | ||
kashgari.utils.convert_to_saved_model(model, | ||
pd_model_path) | ||
|
||
new_model = kashgari.utils.load_model(model_path) | ||
new_res = new_model.predict(x[:20]) | ||
assert np.array_equal(new_res, res) | ||
|
||
new_model.compile_model() | ||
model.fit(x, y, x, y, epochs=1) | ||
|
||
|
||
if __name__ == "__main__": | ||
pass |
Oops, something went wrong.