Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fixing senteval-probing preprocessing #951

Merged
merged 31 commits into from
Nov 5, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
fb7ad55
Copying configs from superglue
zphang Oct 16, 2019
5eee543
adding senteval probing config commands
Oct 23, 2019
f5925d8
adding meta-script for transfer and probing exps
Oct 23, 2019
501d7a1
Adding meta bash script fixed
Oct 23, 2019
c7ac87f
give_permissions script
zphang Oct 24, 2019
2516e6e
Merge branch 'master' of https://github.com/nyu-mll/jiant into taskma…
Oct 26, 2019
dc4f69d
Merge remote-tracking branch 'origin/master' into taskmaster
zphang Oct 27, 2019
c07efbb
small fix transfer_analysis.sh (#946)
yzpang Oct 27, 2019
ddf94e4
lr_patience fix
HaokunLiu Oct 27, 2019
e2f9e07
Merge remote-tracking branch 'origin/master' into taskmaster
zphang Oct 28, 2019
b74252a
target_task training -> pretrain training
Oct 28, 2019
fe1b6af
adding edgeprobing configs and command
Oct 29, 2019
d472094
Merge branch 'taskmaster' of https://github.com/nyu-mll/jiant into ta…
Oct 29, 2019
d600f3e
adding edge probing conf
Oct 30, 2019
50479ba
fix load_target_train bug
Oct 30, 2019
d9a4546
add hyperparameter sweeping
HaokunLiu Nov 1, 2019
70749c9
val_interval change
HaokunLiu Nov 1, 2019
5595bb0
adding sweep function
Nov 1, 2019
b687778
Task specific val_intervals
zphang Nov 1, 2019
24f78d8
add reload_vocab to hyperparameter sweep
Nov 2, 2019
86be6e9
adding batch_size specification
Nov 3, 2019
61bf126
fixing senteval-word-content
Nov 3, 2019
aad9672
fixing senteval preprocess script
Nov 4, 2019
55cfd81
revert extra delete
Nov 4, 2019
0d697c1
remove extra files
Nov 4, 2019
8f6a10f
black format
Nov 4, 2019
40d11a9
black formatting trainer.py
Nov 4, 2019
ae1dcda
Merge branch 'master' into fix_senteval
Nov 4, 2019
87657dc
remove load_data()
Nov 4, 2019
2169c68
Merge branch 'fix_senteval' of https://github.com/nyu-mll/jiant into …
Nov 4, 2019
485fbdc
removing extra changes
Nov 5, 2019
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
20 changes: 16 additions & 4 deletions jiant/tasks/senteval_probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -272,28 +272,40 @@ def __init__(self, path, max_seq_len, name, **kw):
self.train_data_text = None
self.val_data_text = None
self.test_data_text = None
self.labels = []

def get_all_labels(self):
return list(set(self.labels))

def get_sentences(self):
return self.sentences

def process_split(self, split, indexers, model_preprocessing_interface):
return process_single_pair_task_split(
split,
indexers,
model_preprocessing_interface,
label_namespace=self._label_namespace,
is_pair=False,
skip_indexing=False,
)

def load_data(self):
""" Load data """

def load_csv(data_file):
pruksmhc marked this conversation as resolved.
Show resolved Hide resolved
rows = pd.read_csv(data_file, encoding="utf-8")
rows["s1"] = rows["2"].apply(
labels = rows["1"].apply(lambda x: x.split("\t")[0])
s1 = rows["1"].apply(lambda x: x.split("\t")[1])
s1 = s1.apply(
lambda x: tokenize_and_truncate(self._tokenizer_name, x, self.max_seq_len)
)
self.labels.append(rows["1"].tolist())
return rows["s1"].tolist(), [], rows["1"].tolist(), list(range(len(rows)))
self.labels = list(set(labels.tolist()))
return s1.tolist(), [], labels.tolist(), list(range(len(rows)))

self.train_data_text = load_csv(os.path.join(self.path, "train.csv"))
self.val_data_text = load_csv(os.path.join(self.path, "val.csv"))
self.test_data_text = load_csv(os.path.join(self.path, "test.csv"))

sentences = []
for split in ["train", "val", "test"]:
split_data = getattr(self, "%s_data_text" % split)
Expand Down
6 changes: 5 additions & 1 deletion jiant/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -951,7 +951,11 @@ def _validate(self, val_pass, tasks, batch_size, periodic_save=True):

# Get validation numbers for each task
for task in tasks:
n_examples_overall, task_infos, all_val_metrics = self._calculate_validation_performance( # noqa
(
n_examples_overall,
task_infos,
all_val_metrics,
) = self._calculate_validation_performance(
task, task_infos, tasks, batch_size, all_val_metrics, n_examples_overall
)
# scale the micro avg contributions w/ total size of validation set.
Expand Down
14 changes: 7 additions & 7 deletions scripts/preprocess_senteval_probing.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,21 +9,21 @@


def parse_senteval_probing(args):
files = [x for x in os.listdir(args.senteval_path) if "txt" in x]
files = [x for x in os.listdir(args.senteval_probing_path) if "txt" in x]
for file in files:
file_pd = pd.read_fwf(os.path.join(args.senteval_path, file), header=None)
file_pd = pd.read_fwf(os.path.join(args.senteval_probing_path, file), header=None)
files_train = file_pd[file_pd[0] == "tr"]
task_name = file.split(".")[0]
if not os.path.exists(os.path.join(args.senteval_path, task_name)):
os.mkdir(os.path.join(args.senteval_path, task_name))
files_train.to_csv(os.path.join(args.senteval_path, task_name, "train.tsv"))
if not os.path.exists(os.path.join(args.senteval_probing_path, task_name)):
os.mkdir(os.path.join(args.senteval_probing_path, task_name))
files_train.to_csv(os.path.join(args.senteval_probing_path, task_name, "train.csv"))
files_val = file_pd[file_pd[0] == "va"]
task_name = file.split(".")[0]
files_val.to_csv(os.path.join(args.senteval_path, task_name, "val.tsv"))
files_val.to_csv(os.path.join(args.senteval_probing_path, task_name, "val.csv"))

files_test = file_pd[file_pd[0] == "te"]
task_name = file.split(".")[0]
files_test.to_csv(os.path.join(args.senteval_path, task_name, "test.tsv"))
files_test.to_csv(os.path.join(args.senteval_probing_path, task_name, "test.csv"))


parser = argparse.ArgumentParser()
Expand Down