Skip to content

Commit

Permalink
add train-only binary classifiers for software contexts
Browse files Browse the repository at this point in the history
  • Loading branch information
kermitt2 committed Mar 29, 2022
1 parent d4fc562 commit 13b8613
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 4 deletions.
41 changes: 38 additions & 3 deletions delft/applications/softwareContextClassifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def configure(architecture):
if architecture == "bert":
batch_size = 32
early_stop = False
max_epoch = 5
max_epoch = 6
maxlen = 100

return batch_size, maxlen, patience, early_stop, max_epoch
Expand Down Expand Up @@ -99,6 +99,35 @@ def train_and_eval(embeddings_name, fold_count, architecture="gru", transformer=
model.save()


def train_binary(embeddings_name, fold_count, architecture="gru", transformer=None):
print('loading multiclass software context dataset...')
x_train, y_train = load_software_context_corpus_json("data/textClassification/software/software-contexts.json.gz")

report_training_contexts(y_train)

for class_rank in range(len(list_classes)):
model_name = 'software_context_' + list_classes[class_rank] + '_'+architecture
class_weights = None

batch_size, maxlen, patience, early_stop, max_epoch = configure(architecture)

y_train_class_rank = [ [1, 0] if y[class_rank] == 1.0 else [0, 1] for y in y_train ]
y_train_class_rank = np.array(y_train_class_rank)

list_classes_rank = [list_classes[class_rank], "not_"+list_classes[class_rank]]

model = Classifier(model_name, architecture=architecture, list_classes=list_classes_rank, max_epoch=max_epoch, fold_number=fold_count, patience=patience,
use_roc_auc=True, embeddings_name=embeddings_name, batch_size=batch_size, maxlen=maxlen, early_stop=early_stop,
class_weights=class_weights, transformer_name=transformer)

if fold_count == 1:
model.train(x_train, y_train_class_rank)
else:
model.train_nfold(x_train, y_train_class_rank)
# saving the model
model.save()


def train_and_eval_binary(embeddings_name, fold_count, architecture="gru", transformer=None):
print('loading multiclass software context dataset...')
xtr, y = load_software_context_corpus_json("data/textClassification/software/software-contexts.json.gz")
Expand Down Expand Up @@ -203,8 +232,8 @@ def report_training_contexts(y):

args = parser.parse_args()

if args.action not in ('train', 'train_eval', 'classify'):
print('action not specified, must be one of [train,train_eval,classify]')
if args.action not in ('train', 'train_eval', 'classify', 'train_binary', 'train_eval_binary'):
print('action not specified, must be one of [train,train_binary,train_eval,train_eval_binary,classify]')

embeddings_name = args.embedding
transformer = args.transformer
Expand All @@ -223,6 +252,12 @@ def report_training_contexts(y):

train(embeddings_name, args.fold_count, architecture=architecture, transformer=transformer)

if args.action == 'train_binary':
if args.fold_count < 1:
raise ValueError("fold-count should be equal or more than 1")

train_binary(embeddings_name, args.fold_count, architecture=architecture, transformer=transformer)

if args.action == 'train_eval':
if args.fold_count < 1:
raise ValueError("fold-count should be equal or more than 1")
Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,8 @@
'truecase',
'requests',
'pandas==1.3.5',
'pytest'
'pytest',
'tensorflow-addons==0.15.0'
],
classifiers=[
"Programming Language :: Python :: 3.7",
Expand Down

0 comments on commit 13b8613

Please sign in to comment.