diff --git a/delft/applications/softwareContextClassifier.py b/delft/applications/softwareContextClassifier.py index 7db257fe..c0dc8b47 100644 --- a/delft/applications/softwareContextClassifier.py +++ b/delft/applications/softwareContextClassifier.py @@ -40,7 +40,7 @@ def configure(architecture): if architecture == "bert": batch_size = 32 early_stop = False - max_epoch = 5 + max_epoch = 6 maxlen = 100 return batch_size, maxlen, patience, early_stop, max_epoch @@ -99,6 +99,35 @@ def train_and_eval(embeddings_name, fold_count, architecture="gru", transformer= model.save() +def train_binary(embeddings_name, fold_count, architecture="gru", transformer=None): + print('loading multiclass software context dataset...') + x_train, y_train = load_software_context_corpus_json("data/textClassification/software/software-contexts.json.gz") + + report_training_contexts(y_train) + + for class_rank in range(len(list_classes)): + model_name = 'software_context_' + list_classes[class_rank] + '_'+architecture + class_weights = None + + batch_size, maxlen, patience, early_stop, max_epoch = configure(architecture) + + y_train_class_rank = [ [1, 0] if y[class_rank] == 1.0 else [0, 1] for y in y_train ] + y_train_class_rank = np.array(y_train_class_rank) + + list_classes_rank = [list_classes[class_rank], "not_"+list_classes[class_rank]] + + model = Classifier(model_name, architecture=architecture, list_classes=list_classes_rank, max_epoch=max_epoch, fold_number=fold_count, patience=patience, + use_roc_auc=True, embeddings_name=embeddings_name, batch_size=batch_size, maxlen=maxlen, early_stop=early_stop, + class_weights=class_weights, transformer_name=transformer) + + if fold_count == 1: + model.train(x_train, y_train_class_rank) + else: + model.train_nfold(x_train, y_train_class_rank) + # saving the model + model.save() + + def train_and_eval_binary(embeddings_name, fold_count, architecture="gru", transformer=None): print('loading multiclass software context dataset...') xtr, y = load_software_context_corpus_json("data/textClassification/software/software-contexts.json.gz") @@ -203,8 +232,8 @@ def report_training_contexts(y): args = parser.parse_args() - if args.action not in ('train', 'train_eval', 'classify'): - print('action not specified, must be one of [train,train_eval,classify]') + if args.action not in ('train', 'train_eval', 'classify', 'train_binary', 'train_eval_binary'): + print('action not specified, must be one of [train,train_binary,train_eval,train_eval_binary,classify]') embeddings_name = args.embedding transformer = args.transformer @@ -223,6 +252,12 @@ def report_training_contexts(y): train(embeddings_name, args.fold_count, architecture=architecture, transformer=transformer) + if args.action == 'train_binary': + if args.fold_count < 1: + raise ValueError("fold-count should be equal or more than 1") + + train_binary(embeddings_name, args.fold_count, architecture=architecture, transformer=transformer) + if args.action == 'train_eval': if args.fold_count < 1: raise ValueError("fold-count should be equal or more than 1") diff --git a/setup.py b/setup.py index e035e845..cbcfd928 100644 --- a/setup.py +++ b/setup.py @@ -27,7 +27,8 @@ 'truecase', 'requests', 'pandas==1.3.5', - 'pytest' + 'pytest', + 'tensorflow-addons==0.15.0' ], classifiers=[ "Programming Language :: Python :: 3.7",