Skip to content

Commit

Permalink
[Refactor] 升级ark-nlp代码结构 (#80)
Browse files Browse the repository at this point in the history
* refactor(task): 修改变量名称和注释

* refactor(*): 修改dataset和task注释, 删去pure多余的代码

* fix(model):  修复指针类模型在评估阶段的计算问题

* refactor(*): 调整大部分注释

* refactor(task): 调整Task类中方法和参数的命名

* refactor(task): 规范化bio_named_entity_recognition和prompt_masked_language_model的task

* refactor(task): 更改evaluate_logs的input_length为sequence_length

* refactor(task): 同步修改

* refactor(*): 按规范对task变量命名进行修正

* refactor(*): 代码变量规范化

* refactor(test): 格式调整成适合规范形式

* refactor(task): 更改方法的结构

* refactor(task): 持续代码结构变更

* refactor(task): 将目前库中的task都改为新形式

* refactor(task): 持续代码结构变更

* refactor(task): 对casrel task格式进行规范化

* refactor(task): 持续代码结构变更

* feature(early_stopping): 新增early_stopping,目前只支持global_point

* Revert "feature(early_stopping): 新增early_stopping,目前只支持global_point"

This reverts commit 3697dcf.

* refactor(dataset): 更新dataset格式,并添加处理进度显示

* fix(logs): 修复训练过程中评估指标累计未置0的问题

* feat(task): 持续更新-新增step评估和按最优指标保存的功能

* feat(*): 使用handler去控制全局的过程和状态

* feat(*): 新增metric逻辑,并将TCTask更改为适配该逻辑的结构

* refactor(task): 根据新增的metric更改task结构

* feat(crf task): 更新crf task的代码结构

* refactor(span task): 更新span task的代码结构

* refactor(global pointer  task): 更新global pointer  task的代码结构

* refactor(task): 更新w2ner和pure的代码结构

* fix(*): 冲突解决

* refactor(casrel task): 更新casrel task代码格式

* refactor(simcse task): 更新simcse task代码结构

* refactor(prgc task): 更新prgc task代码结构

* refactor(attack): 更新攻击的代码结构

* refactor(uie task): 更新uie task的代码结构

* fix(w2ner): 修复w2ner更新产生的bug

* feat(simcse): 添加有监督版本

* feat(early_stopping): 新增early_stopping

* feat(*): 新增Sentence-BERT(sbert)模型

* feat(*): 新增tensorboar和wandb进行记录

Co-authored-by: jimme <jimme.shen123@gmail.com>
  • Loading branch information
xiangking and jimme0421 authored Sep 15, 2022
1 parent 4472a33 commit e2f20d3
Show file tree
Hide file tree
Showing 109 changed files with 5,169 additions and 3,264 deletions.
110 changes: 56 additions & 54 deletions ark_nlp/dataset/base/_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
# Author: Xiang Wang, xiangking1995@163.com
# Status: Active


import json
import copy
import codecs
Expand All @@ -31,28 +30,30 @@ class BaseDataset(Dataset):
Dataset基类
Args:
data (:obj:`DataFrame` or :obj:`string`): 数据或者数据地址
categories (:obj:`list`, optional, defaults to `None`): 数据类别
is_retain_df (:obj:`bool`, optional, defaults to False): 是否将DataFrame格式的原始数据复制到属性retain_df中
is_retain_dataset (:obj:`bool`, optional, defaults to False): 是否将处理成dataset格式的原始数据复制到属性retain_dataset中
is_train (:obj:`bool`, optional, defaults to True): 数据集是否为训练集数据
is_test (:obj:`bool`, optional, defaults to False): 数据集是否为测试集数据
data (DataFrame or string): 数据或者数据地址
categories (list or None, optional, defaults to `None`): 数据类别
do_retain_df (bool, optional, defaults to False): 是否将DataFrame格式的原始数据复制到属性retain_df中
do_retain_dataset (bool, optional, defaults to False): 是否将处理成dataset格式的原始数据复制到属性retain_dataset中
is_train (bool, optional, defaults to True): 数据集是否为训练集数据
is_test (bool, optional, defaults to False): 数据集是否为测试集数据
progress_verbose (bool, optional): 是否显示数据进度, 默认值为: True
""" # noqa: ignore flake8"

def __init__(
self,
data,
categories=None,
is_retain_df=False,
is_retain_dataset=False,
is_train=True,
is_test=False
):
def __init__(self,
data,
categories=None,
do_retain_df=False,
do_retain_dataset=False,
is_train=True,
is_test=False,
progress_verbose=True):

self.is_test = is_test
self.is_train = is_train
self.is_retain_df = is_retain_df
self.is_retain_dataset = is_retain_dataset
self.do_retain_df = do_retain_df
self.do_retain_dataset = do_retain_dataset

self.progress_verbose = progress_verbose

if self.is_test is True:
self.is_train = False
Expand All @@ -61,7 +62,7 @@ def __init__(
if 'label' in data.columns:
data['label'] = data['label'].apply(lambda x: str(x))

if self.is_retain_df:
if self.do_retain_df:
self.df = data

self.dataset = self._convert_to_dataset(data)
Expand All @@ -87,32 +88,27 @@ def _load_dataset(self, data_path):
加载数据集
Args:
data_path (:obj:`string`): 数据地址
data_path (string): 数据地址
""" # noqa: ignore flake8"

data_df = self._read_data(data_path)

if self.is_retain_df:
if self.do_retain_df:
self.df = data_df

return self._convert_to_dataset(data_df)

def _convert_to_dataset(self, data_df):
pass

def _read_data(
self,
data_path,
data_format=None,
skiprows=-1
):
def _read_data(self, data_path, data_format=None, skiprows=-1):
"""
读取所需数据
Args:
data_path (:obj:`string`): 数据地址
data_format (:obj:`string`, defaults to `None`): 数据存储格式
skiprows (:obj:`int`, defaults to -1): 读取跳过指定行数,默认为不跳过
data_path (string): 数据地址
data_format (string or None, optional): 数据存储格式, 默认值为None
skiprows (int, optional): 读取跳过指定行数,默认为-1, 不跳过
""" # noqa: ignore flake8"

if data_format is None:
Expand All @@ -123,7 +119,7 @@ def _read_data(
elif data_format == 'json':
try:
data_df = pd.read_json(data_path, dtype={'label': str})
except:
except Exception:
data_df = self.read_line_json(data_path)
elif data_format == 'tsv':
data_df = pd.read_csv(data_path, sep='\t', dtype={'label': str})
Expand All @@ -134,17 +130,13 @@ def _read_data(

return data_df

def read_line_json(
self,
data_path,
skiprows=-1
):
def read_line_json(self, data_path, skiprows=-1):
"""
读取所需数据
Args:
data_path (:obj:`string`): 数据所在路径
skiprows (:obj:`int`, defaults to -1): 读取跳过指定行数,默认为不跳过
data_path (string): 数据地址
skiprows (int, optional): 读取跳过指定行数,默认为-1, 不跳过
"""
datasets = []

Expand All @@ -153,10 +145,7 @@ def read_line_json(
for index, line in enumerate(reader):
if index == skiprows:
continue
line = json.loads(line)
tokens = line['text']
label = line['label']
datasets.append({'text': tokens.strip(), 'label': label})
datasets.append(json.loads(line))

return pd.DataFrame(datasets)

Expand All @@ -170,18 +159,18 @@ def convert_to_ids(self, tokenizer):
if tokenizer.tokenizer_type == 'vanilla':
features = self._convert_to_vanilla_ids(tokenizer)
elif tokenizer.tokenizer_type == 'transformer':
features = self._convert_to_transfomer_ids(tokenizer)
features = self._convert_to_transformer_ids(tokenizer)
elif tokenizer.tokenizer_type == 'customized':
features = self._convert_to_customized_ids(tokenizer)
else:
raise ValueError("The tokenizer type does not exist")

if self.is_retain_dataset:
if self.do_retain_dataset:
self.retain_dataset = copy.deepcopy(self.dataset)

self.dataset = features

def _convert_to_transfomer_ids(self, bert_tokenizer):
def _convert_to_transformer_ids(self, bert_tokenizer):
pass

def _convert_to_vanilla_ids(self, vanilla_tokenizer):
Expand All @@ -190,7 +179,7 @@ def _convert_to_vanilla_ids(self, vanilla_tokenizer):
def _convert_to_customized_ids(self, customized_tokenizer):
pass

def _get_input_length(self, text, bert_tokenizer):
def _get_sequence_length(self, text, bert_tokenizer):
pass

@property
Expand All @@ -206,17 +195,30 @@ def sample_num(self):
return len(self.dataset)

@property
def dataset_analysis(self):
def dataset_report(self):

result = defaultdict(list)
for row in self.dataset:
for col_name in self.dataset_cols:
if type(row[col_name]) == str:
result[col_name].append(len(row[col_name]))

_result = defaultdict(list)
for _row in self.dataset:
for _col in self.dataset_cols:
if type(_row[_col]) == str:
_result[_col].append(len(_row[_col]))
report_df = pd.DataFrame(result).describe()

return report_df

@property
def max_text_length(self):

_report = pd.DataFrame(_result).describe()
records = dict()
if 'text' in self.dataset[0]:
records['text'] = max([len(row['text']) for row in self.dataset])
if 'text_a' in self.dataset[0]:
records['text_a'] = max([len(row['text_a']) for row in self.dataset])
if 'text_b' in self.dataset[0]:
records['text_b'] = max([len(row['text_b']) for row in self.dataset])

return _report
return records

def __getitem__(self, index):
return self.dataset[index]
Expand Down
Loading

0 comments on commit e2f20d3

Please sign in to comment.