Skip to content

Commit 85d0755

Browse files
author
Thilina Rajapakse
committed
Formatting
1 parent 50525b9 commit 85d0755

File tree

8 files changed

+57
-53
lines changed

8 files changed

+57
-53
lines changed

examples/seq2seq/paraphrasing/predict.py

+1-3
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77
transformers_logger = logging.getLogger("transformers")
88
transformers_logger.setLevel(logging.ERROR)
99

10-
model = Seq2SeqModel(
11-
encoder_decoder_type="bart", encoder_decoder_name="outputs"
12-
)
10+
model = Seq2SeqModel(encoder_decoder_type="bart", encoder_decoder_name="outputs")
1311

1412

1513
while True:

examples/seq2seq/paraphrasing/train.py

+7-29
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,8 @@
2020
train_df = train_df.loc[train_df["label"] == "1"]
2121
eval_df = eval_df.loc[eval_df["label"] == "1"]
2222

23-
train_df = train_df.rename(
24-
columns={"sentence1": "input_text", "sentence2": "target_text"}
25-
)
26-
eval_df = eval_df.rename(
27-
columns={"sentence1": "input_text", "sentence2": "target_text"}
28-
)
23+
train_df = train_df.rename(columns={"sentence1": "input_text", "sentence2": "target_text"})
24+
eval_df = eval_df.rename(columns={"sentence1": "input_text", "sentence2": "target_text"})
2925

3026
train_df = train_df[["input_text", "target_text"]]
3127
eval_df = eval_df[["input_text", "target_text"]]
@@ -34,25 +30,13 @@
3430
eval_df["prefix"] = "paraphrase"
3531

3632
# MSRP Data
37-
train_df = pd.concat(
38-
[
39-
train_df,
40-
load_data("data/msr_paraphrase_train.txt", "#1 String", "#2 String", "Quality"),
41-
]
42-
)
43-
eval_df = pd.concat(
44-
[
45-
eval_df,
46-
load_data("data/msr_paraphrase_test.txt", "#1 String", "#2 String", "Quality"),
47-
]
48-
)
33+
train_df = pd.concat([train_df, load_data("data/msr_paraphrase_train.txt", "#1 String", "#2 String", "Quality"),])
34+
eval_df = pd.concat([eval_df, load_data("data/msr_paraphrase_test.txt", "#1 String", "#2 String", "Quality"),])
4935

5036
# Quora Data
5137

5238
# The Quora Dataset is not separated into train/test, so we do it manually the first time.
53-
df = load_data(
54-
"data/quora_duplicate_questions.tsv", "question1", "question2", "is_duplicate"
55-
)
39+
df = load_data("data/quora_duplicate_questions.tsv", "question1", "question2", "is_duplicate")
5640
q_train, q_test = train_test_split(df)
5741

5842
q_train.to_csv("data/quora_train.tsv", sep="\t")
@@ -107,11 +91,7 @@
10791
model_args.wandb_project = "Paraphrasing with BART"
10892

10993

110-
model = Seq2SeqModel(
111-
encoder_decoder_type="bart",
112-
encoder_decoder_name="facebook/bart-large",
113-
args=model_args,
114-
)
94+
model = Seq2SeqModel(encoder_decoder_type="bart", encoder_decoder_name="facebook/bart-large", args=model_args,)
11595

11696
model.train_model(train_df, eval_data=eval_df)
11797

@@ -136,6 +116,4 @@
136116
f.write("Prediction:\n")
137117
for pred in preds[i]:
138118
f.write(str(pred) + "\n")
139-
f.write(
140-
"________________________________________________________________________________\n"
141-
)
119+
f.write("________________________________________________________________________________\n")

examples/seq2seq/paraphrasing/utils.py

+2-6
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,10 @@
33
import pandas as pd
44

55

6-
def load_data(
7-
file_path, input_text_column, target_text_column, label_column, keep_label=1
8-
):
6+
def load_data(file_path, input_text_column, target_text_column, label_column, keep_label=1):
97
df = pd.read_csv(file_path, sep="\t", error_bad_lines=False)
108
df = df.loc[df[label_column] == keep_label]
11-
df = df.rename(
12-
columns={input_text_column: "input_text", target_text_column: "target_text"}
13-
)
9+
df = df.rename(columns={input_text_column: "input_text", target_text_column: "target_text"})
1410
df = df[["input_text", "target_text"]]
1511
df["prefix"] = "paraphrase"
1612

simpletransformers/streamlit/classification_view.py

+7-4
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,12 @@ def get_states(model, session_state=None):
2828
return session_state, model
2929

3030

31-
@st.cache(hash_funcs={ClassificationModel: simple_transformers_model, MultiLabelClassificationModel: simple_transformers_model})
31+
@st.cache(
32+
hash_funcs={
33+
ClassificationModel: simple_transformers_model,
34+
MultiLabelClassificationModel: simple_transformers_model,
35+
}
36+
)
3237
def get_prediction(model, input_text):
3338
prediction, raw_values = model.predict([input_text])
3439

@@ -71,9 +76,7 @@ def classification_viewer(model, model_class):
7176
try:
7277
session_state, model = get_states(model)
7378
except AttributeError:
74-
session_state = get(
75-
max_seq_length=model.args.max_seq_length,
76-
)
79+
session_state = get(max_seq_length=model.args.max_seq_length,)
7780
session_state, model = get_states(model, session_state)
7881

7982
model.args.max_seq_length = st.sidebar.slider(

simpletransformers/streamlit/ner_view.py

+9-7
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@
55
from simpletransformers.streamlit.streamlit_utils import get, simple_transformers_model, get_color
66

77

8-
ENTITY_WRAPPER = (
9-
"""<mark style="background: rgba{}; font-weight: 450; border-radius: 0.5rem; margin: 0.1em; padding: 0.25rem; display: inline-block">{} {}</mark>"""
10-
)
8+
ENTITY_WRAPPER = """<mark style="background: rgba{}; font-weight: 450; border-radius: 0.5rem; margin: 0.1em; padding: 0.25rem; display: inline-block">{} {}</mark>"""
119
ENTITY_LABEL_WRAPPER = """<span style="background: #fff; font-size: 0.56em; font-weight: bold; padding: 0.3em 0.3em; vertical-align: middle; margin: 0 0 0.15rem 0.5rem; line-height: 1; display: inline-block">{}</span>"""
1210

1311

@@ -26,9 +24,7 @@ def get_prediction(model, input_text):
2624

2725

2826
def ner_viewer(model):
29-
session_state = get(
30-
max_seq_length=model.args.max_seq_length,
31-
)
27+
session_state = get(max_seq_length=model.args.max_seq_length,)
3228
model.args.max_seq_length = session_state.max_seq_length
3329

3430
entity_list = model.args.labels_list
@@ -47,7 +43,13 @@ def ner_viewer(model):
4743

4844
prediction = get_prediction(model, input_text)[0]
4945

50-
to_write = " ".join([format_word(word, entity, entity_checkboxes, entity_color_map) for pred in prediction for word, entity in pred.items()])
46+
to_write = " ".join(
47+
[
48+
format_word(word, entity, entity_checkboxes, entity_color_map)
49+
for pred in prediction
50+
for word, entity in pred.items()
51+
]
52+
)
5153

5254
st.subheader(f"Predictions")
5355
st.write(to_write, unsafe_allow_html=True)

simpletransformers/streamlit/qa_view.py

+7-2
Original file line numberDiff line numberDiff line change
@@ -77,9 +77,14 @@ def qa_viewer(model):
7777

7878
if answers[0] != "empty":
7979
if len(context_pieces) == 2:
80-
st.write(QA_ANSWER_WRAPPER.format(context_pieces[0], answers[0], context_pieces[-1]), unsafe_allow_html=True)
80+
st.write(
81+
QA_ANSWER_WRAPPER.format(context_pieces[0], answers[0], context_pieces[-1]), unsafe_allow_html=True
82+
)
8183
else:
82-
st.write(QA_ANSWER_WRAPPER.format(context_pieces[0], answers[0], answers[0].join(context_pieces[1:])), unsafe_allow_html=True)
84+
st.write(
85+
QA_ANSWER_WRAPPER.format(context_pieces[0], answers[0], answers[0].join(context_pieces[1:])),
86+
unsafe_allow_html=True,
87+
)
8388
else:
8489
st.write(QA_EMPTY_ANSWER_WRAPPER.format("", answers[0], ""), unsafe_allow_html=True)
8590

simpletransformers/streamlit/streamlit_utils.py

+24-1
Original file line numberDiff line numberDiff line change
@@ -168,7 +168,30 @@ def simple_transformers_model(model):
168168

169169
def get_color(i):
170170
# Colors taken from Sasha Trubetskoy's list of colors - https://sashamaps.net/docs/tools/20-colors/
171-
colors = [(60, 180, 75, 0.4), (255, 225, 25, 0.4), (0, 130, 200, 0.4), (245, 130, 48, 0.4), (145, 30, 180, 0.4), (70, 240, 240, 0.4), (240, 50, 230, 0.4), (210, 245, 60, 0.4), (250, 190, 212, 0.4), (0, 128, 128, 0.4), (220, 190, 255, 0.4), (170, 110, 40, 0.4), (255, 250, 200, 0.4), (128, 0, 0, 0.4), (170, 255, 195, 0.4), (128, 128, 0, 0.4), (255, 215, 180, 0.4), (0, 0, 128, 0.4), (128, 128, 128, 0.4), (255, 255, 255, 0.4), (0, 0, 0, 0.4), (230, 25, 75, 0.4)]
171+
colors = [
172+
(60, 180, 75, 0.4),
173+
(255, 225, 25, 0.4),
174+
(0, 130, 200, 0.4),
175+
(245, 130, 48, 0.4),
176+
(145, 30, 180, 0.4),
177+
(70, 240, 240, 0.4),
178+
(240, 50, 230, 0.4),
179+
(210, 245, 60, 0.4),
180+
(250, 190, 212, 0.4),
181+
(0, 128, 128, 0.4),
182+
(220, 190, 255, 0.4),
183+
(170, 110, 40, 0.4),
184+
(255, 250, 200, 0.4),
185+
(128, 0, 0, 0.4),
186+
(170, 255, 195, 0.4),
187+
(128, 128, 0, 0.4),
188+
(255, 215, 180, 0.4),
189+
(0, 0, 128, 0.4),
190+
(128, 128, 128, 0.4),
191+
(255, 255, 255, 0.4),
192+
(0, 0, 0, 0.4),
193+
(230, 25, 75, 0.4),
194+
]
172195
try:
173196
return str(colors[i])
174197
except IndexError:

simpletransformers/t5/run_simple_transformers_streamlit_app.py

-1
Original file line numberDiff line numberDiff line change
@@ -3,4 +3,3 @@
33

44

55
streamlit_runner()
6-

0 commit comments

Comments
 (0)