Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Refactor] 升级ark-nlp代码结构 #80

Merged
merged 42 commits into from
Sep 15, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
42 commits
Select commit Hold shift + click to select a range
60cb70d
refactor(task): 修改变量名称和注释
Jul 31, 2022
4c164ad
refactor(*): 修改dataset和task注释, 删去pure多余的代码
jimme0421 Aug 1, 2022
0de1601
fix(model): 修复指针类模型在评估阶段的计算问题
xiangking Aug 5, 2022
2e0b79a
refactor(*): 调整大部分注释
jimme0421 Aug 5, 2022
fa3ebe1
refactor(task): 调整Task类中方法和参数的命名
xiangking Aug 5, 2022
dbfe935
refactor(task): 规范化bio_named_entity_recognition和prompt_masked_languag…
xiangking Aug 5, 2022
389388d
refactor(task): 更改evaluate_logs的input_length为sequence_length
xiangking Aug 5, 2022
a75bf27
refactor(task): 同步修改
Aug 8, 2022
6827340
refactor(*): 按规范对task变量命名进行修正
xiangking Aug 8, 2022
35b487e
refactor(*): 代码变量规范化
xiangking Aug 10, 2022
1215e67
refactor(test): 格式调整成适合规范形式
xiangking Aug 11, 2022
23d46d3
refactor(task): 更改方法的结构
xiangking Aug 11, 2022
6a11d3c
refactor(task): 持续代码结构变更
xiangking Aug 12, 2022
72de15c
refactor(task): 将目前库中的task都改为新形式
xiangking Aug 12, 2022
356658a
refactor(task): 持续代码结构变更
xiangking Aug 12, 2022
f0752a0
refactor(task): 对casrel task格式进行规范化
xiangking Aug 13, 2022
c0a38ca
refactor(task): 持续代码结构变更
Aug 14, 2022
3697dcf
feature(early_stopping): 新增early_stopping,目前只支持global_point
jimme0421 Aug 15, 2022
2b56dec
Revert "feature(early_stopping): 新增early_stopping,目前只支持global_point"
jimme0421 Aug 16, 2022
7f7946f
refactor(dataset): 更新dataset格式,并添加处理进度显示
Aug 17, 2022
7d746f7
Merge branch 'refactor' of github.com:xiangking/ark-nlp into refactor
Aug 17, 2022
b30bea3
fix(logs): 修复训练过程中评估指标累计未置0的问题
Aug 17, 2022
1bd096e
feat(task): 持续更新-新增step评估和按最优指标保存的功能
xiangking Aug 24, 2022
71c3f0a
feat(*): 使用handler去控制全局的过程和状态
xiangking Aug 25, 2022
325c1a9
feat(*): 新增metric逻辑,并将TCTask更改为适配该逻辑的结构
xiangking Aug 26, 2022
bacb723
refactor(task): 根据新增的metric更改task结构
xiangking Aug 26, 2022
6e0de67
feat(crf task): 更新crf task的代码结构
xiangking Aug 27, 2022
90fca74
refactor(span task): 更新span task的代码结构
xiangking Aug 27, 2022
db9d910
refactor(global pointer task): 更新global pointer task的代码结构
xiangking Aug 27, 2022
0f4f695
refactor(task): 更新w2ner和pure的代码结构
jimme0421 Aug 29, 2022
1f4b861
fix(*): 冲突解决
xiangking Aug 29, 2022
44af93e
refactor(casrel task): 更新casrel task代码格式
xiangking Aug 29, 2022
d950885
refactor(simcse task): 更新simcse task代码结构
xiangking Aug 29, 2022
cec8238
refactor(prgc task): 更新prgc task代码结构
xiangking Aug 29, 2022
cfe60a4
refactor(attack): 更新攻击的代码结构
xiangking Aug 30, 2022
4b0e5ba
refactor(uie task): 更新uie task的代码结构
xiangking Aug 31, 2022
139b93d
fix(w2ner): 修复w2ner更新产生的bug
jimme0421 Sep 1, 2022
be16cff
feat(simcse): 添加有监督版本
xiangking Sep 4, 2022
9b95215
feat(early_stopping): 新增early_stopping
jimme0421 Sep 6, 2022
71b552c
feat(*): 新增Sentence-BERT(sbert)模型
jimme0421 Sep 8, 2022
9af9939
feat(*): 新增tensorboar和wandb进行记录
Sep 12, 2022
a33d310
Merge branch 'refactor' of github.com:xiangking/ark-nlp into refactor
Sep 12, 2022
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
110 changes: 56 additions & 54 deletions ark_nlp/dataset/base/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# Author: Xiang Wang, xiangking1995@163.com
# Status: Active


import json
import copy
import codecs
Expand All @@ -31,28 +30,30 @@ class BaseDataset(Dataset):
Dataset基类

Args:
data (:obj:`DataFrame` or :obj:`string`): 数据或者数据地址
categories (:obj:`list`, optional, defaults to `None`): 数据类别
is_retain_df (:obj:`bool`, optional, defaults to False): 是否将DataFrame格式的原始数据复制到属性retain_df中
is_retain_dataset (:obj:`bool`, optional, defaults to False): 是否将处理成dataset格式的原始数据复制到属性retain_dataset中
is_train (:obj:`bool`, optional, defaults to True): 数据集是否为训练集数据
is_test (:obj:`bool`, optional, defaults to False): 数据集是否为测试集数据
data (DataFrame or string): 数据或者数据地址
categories (list or None, optional, defaults to `None`): 数据类别
do_retain_df (bool, optional, defaults to False): 是否将DataFrame格式的原始数据复制到属性retain_df中
do_retain_dataset (bool, optional, defaults to False): 是否将处理成dataset格式的原始数据复制到属性retain_dataset中
is_train (bool, optional, defaults to True): 数据集是否为训练集数据
is_test (bool, optional, defaults to False): 数据集是否为测试集数据
progress_verbose (bool, optional): 是否显示数据进度, 默认值为: True
""" # noqa: ignore flake8"

def __init__(
self,
data,
categories=None,
is_retain_df=False,
is_retain_dataset=False,
is_train=True,
is_test=False
):
def __init__(self,
data,
categories=None,
do_retain_df=False,
do_retain_dataset=False,
is_train=True,
is_test=False,
progress_verbose=True):

self.is_test = is_test
self.is_train = is_train
self.is_retain_df = is_retain_df
self.is_retain_dataset = is_retain_dataset
self.do_retain_df = do_retain_df
self.do_retain_dataset = do_retain_dataset

self.progress_verbose = progress_verbose

if self.is_test is True:
self.is_train = False
Expand All @@ -61,7 +62,7 @@ def __init__(
if 'label' in data.columns:
data['label'] = data['label'].apply(lambda x: str(x))

if self.is_retain_df:
if self.do_retain_df:
self.df = data

self.dataset = self._convert_to_dataset(data)
Expand All @@ -87,32 +88,27 @@ def _load_dataset(self, data_path):
加载数据集

Args:
data_path (:obj:`string`): 数据地址
data_path (string): 数据地址
""" # noqa: ignore flake8"

data_df = self._read_data(data_path)

if self.is_retain_df:
if self.do_retain_df:
self.df = data_df

return self._convert_to_dataset(data_df)

def _convert_to_dataset(self, data_df):
pass

def _read_data(
self,
data_path,
data_format=None,
skiprows=-1
):
def _read_data(self, data_path, data_format=None, skiprows=-1):
"""
读取所需数据

Args:
data_path (:obj:`string`): 数据地址
data_format (:obj:`string`, defaults to `None`): 数据存储格式
skiprows (:obj:`int`, defaults to -1): 读取跳过指定行数,默认为不跳过
data_path (string): 数据地址
data_format (string or None, optional): 数据存储格式, 默认值为None
skiprows (int, optional): 读取跳过指定行数,默认为-1, 不跳过
""" # noqa: ignore flake8"

if data_format is None:
Expand All @@ -123,7 +119,7 @@ def _read_data(
elif data_format == 'json':
try:
data_df = pd.read_json(data_path, dtype={'label': str})
except:
except Exception:
data_df = self.read_line_json(data_path)
elif data_format == 'tsv':
data_df = pd.read_csv(data_path, sep='\t', dtype={'label': str})
Expand All @@ -134,17 +130,13 @@ def _read_data(

return data_df

def read_line_json(
self,
data_path,
skiprows=-1
):
def read_line_json(self, data_path, skiprows=-1):
"""
读取所需数据

Args:
data_path (:obj:`string`): 数据所在路径
skiprows (:obj:`int`, defaults to -1): 读取跳过指定行数,默认为不跳过
data_path (string): 数据地址
skiprows (int, optional): 读取跳过指定行数,默认为-1, 不跳过
"""
datasets = []

Expand All @@ -153,10 +145,7 @@ def read_line_json(
for index, line in enumerate(reader):
if index == skiprows:
continue
line = json.loads(line)
tokens = line['text']
label = line['label']
datasets.append({'text': tokens.strip(), 'label': label})
datasets.append(json.loads(line))

return pd.DataFrame(datasets)

Expand All @@ -170,18 +159,18 @@ def convert_to_ids(self, tokenizer):
if tokenizer.tokenizer_type == 'vanilla':
features = self._convert_to_vanilla_ids(tokenizer)
elif tokenizer.tokenizer_type == 'transformer':
features = self._convert_to_transfomer_ids(tokenizer)
features = self._convert_to_transformer_ids(tokenizer)
elif tokenizer.tokenizer_type == 'customized':
features = self._convert_to_customized_ids(tokenizer)
else:
raise ValueError("The tokenizer type does not exist")

if self.is_retain_dataset:
if self.do_retain_dataset:
self.retain_dataset = copy.deepcopy(self.dataset)

self.dataset = features

def _convert_to_transfomer_ids(self, bert_tokenizer):
def _convert_to_transformer_ids(self, bert_tokenizer):
pass

def _convert_to_vanilla_ids(self, vanilla_tokenizer):
Expand All @@ -190,7 +179,7 @@ def _convert_to_vanilla_ids(self, vanilla_tokenizer):
def _convert_to_customized_ids(self, customized_tokenizer):
pass

def _get_input_length(self, text, bert_tokenizer):
def _get_sequence_length(self, text, bert_tokenizer):
pass

@property
Expand All @@ -206,17 +195,30 @@ def sample_num(self):
return len(self.dataset)

@property
def dataset_analysis(self):
def dataset_report(self):

result = defaultdict(list)
for row in self.dataset:
for col_name in self.dataset_cols:
if type(row[col_name]) == str:
result[col_name].append(len(row[col_name]))

_result = defaultdict(list)
for _row in self.dataset:
for _col in self.dataset_cols:
if type(_row[_col]) == str:
_result[_col].append(len(_row[_col]))
report_df = pd.DataFrame(result).describe()

return report_df

@property
def max_text_length(self):

_report = pd.DataFrame(_result).describe()
records = dict()
if 'text' in self.dataset[0]:
records['text'] = max([len(row['text']) for row in self.dataset])
if 'text_a' in self.dataset[0]:
records['text_a'] = max([len(row['text_a']) for row in self.dataset])
if 'text_b' in self.dataset[0]:
records['text_b'] = max([len(row['text_b']) for row in self.dataset])

return _report
return records

def __getitem__(self, index):
return self.dataset[index]
Expand Down
Loading