Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Made changes for applying ranking loss directly to the TaskNN #48

Merged
merged 14 commits into from
Sep 26, 2021
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Empty file removed data/.gitkeep
Empty file.
Original file line number Diff line number Diff line change
@@ -0,0 +1,181 @@
local test = std.extVar('TEST'); // a test run with small dataset
local data_dir = std.extVar('DATA_DIR');
local cuda_device = std.extVar('CUDA_DEVICE');
local use_wandb = (if test == '1' then false else true);

local dataset_name = std.parseJson(std.extVar('dataset_name'));
// local dataset_name = 'genbase';
local dataset_metadata = (import '../datasets.jsonnet')[dataset_name];
local num_labels = dataset_metadata.num_labels;
local num_input_features = dataset_metadata.input_features;

// model variables
local ff_hidden = std.parseJson(std.extVar('ff_hidden'));
local label_space_dim = ff_hidden;
local ff_dropout = std.parseJson(std.extVar('ff_dropout_10x')) / 10.0;
local ff_activation = 'softplus';
local ff_linear_layers = std.parseJson(std.extVar('ff_linear_layers'));
local ff_weight_decay = std.parseJson(std.extVar('ff_weight_decay'));
local global_score_hidden_dim = std.parseJson(std.extVar('global_score_hidden_dim'));
local gain = (if ff_activation == 'tanh' then 5 / 3 else 1);
local cross_entropy_loss_weight = std.parseJson(std.extVar('cross_entropy_loss_weight'));
local dvn_score_loss_weight = std.parseJson(std.extVar('dvn_score_loss_weight'));
local task_temp = std.parseJson(std.extVar('task_nn_steps')); # variable for task_nn.steps
local task_nn_steps = (if std.toString(task_temp) == '0' then 1 else task_temp);
local num_samples = std.parseJson(std.extVar('num_samples'));
{
[if use_wandb then 'type']: 'train_test_log_to_wandb',
evaluate_on_test: true,
// Data
dataset_reader: {
type: 'arff',
num_labels: num_labels,
},
validation_dataset_reader: {
type: 'arff',
num_labels: num_labels,
},
train_data_path: (data_dir + '/' + dataset_metadata.dir_name + '/' +
dataset_metadata.train_file),
validation_data_path: (data_dir + '/' + dataset_metadata.dir_name + '/' +
dataset_metadata.validation_file),
test_data_path: (data_dir + '/' + dataset_metadata.dir_name + '/' +
dataset_metadata.test_file),

// Model
model: {
type: 'multi-label-classification-with-infnet',
sampler: {
type: 'appending-container',
log_key: 'sampler',
constituent_samplers: [],
},
task_nn: {
type: 'multi-label-classification',
feature_network: {
input_dim: num_input_features,
num_layers: ff_linear_layers,
activations: ([ff_activation for i in std.range(0, ff_linear_layers - 2)] + [ff_activation]),
hidden_dims: ff_hidden,
dropout: ([ff_dropout for i in std.range(0, ff_linear_layers - 2)] + [0]),
},
label_embeddings: {
embedding_dim: ff_hidden,
vocab_namespace: 'labels',
},
},
inference_module: {
type: 'multi-label-inference-net-normalized',
log_key: 'inference_module',
loss_fn: {
type: 'combination-loss',
log_key: 'loss',
constituent_losses: [
{
type: 'multi-label-nce-ranking-with-discrete-sampling',
log_key: 'nce-on-tasknn',
num_samples: num_samples,
sign: '+',
use_scorenn: false,
normalize_y: true,
},
{
type: 'multi-label-bce',
reduction: 'none',
log_key: 'bce',
},
],
loss_weights: [dvn_score_loss_weight, cross_entropy_loss_weight],
reduction: 'mean',
},
},
oracle_value_function: { type: 'per-instance-f1', differentiable: false },
score_nn: {
type: 'multi-label-classification',
task_nn: {
type: 'multi-label-classification',
feature_network: {
input_dim: num_input_features,
num_layers: ff_linear_layers,
activations: ([ff_activation for i in std.range(0, ff_linear_layers - 2)] + [ff_activation]),
hidden_dims: ff_hidden,
dropout: ([ff_dropout for i in std.range(0, ff_linear_layers - 2)] + [0]),
},
label_embeddings: {
embedding_dim: ff_hidden,
vocab_namespace: 'labels',
},
},
global_score: {
type: 'multi-label-feedforward',
feedforward: {
input_dim: num_labels,
num_layers: 1,
activations: ff_activation,
hidden_dims: global_score_hidden_dim,
},
},
},
loss_fn: { type: 'multi-label-dvn-bce', log_key: 'dvn_bce' },
initializer: {
regexes: [
//[@'.*_feedforward._linear_layers.0.weight', {type: 'normal'}],
[@'.*_linear_layers.*weight', (if std.member(['tanh', 'sigmoid'], ff_activation) then { type: 'xavier_uniform', gain: gain } else { type: 'kaiming_uniform', nonlinearity: 'relu' })],
[@'.*linear_layers.*bias', { type: 'zero' }],
],
},
},
data_loader: {
shuffle: true,
batch_size: 32,
},
trainer: {
type: 'gradient_descent_minimax',
num_epochs: if test == '1' then 10 else 300,
grad_norm: { task_nn: 10.0 },
patience: 20,
validation_metric: '+fixed_f1',
cuda_device: std.parseInt(cuda_device),
learning_rate_schedulers: {
task_nn: {
type: 'reduce_on_plateau',
factor: 0.5,
mode: 'max',
patience: 5,
verbose: true,
},
},
optimizer: {
optimizers: {
task_nn:
{
lr: 0.001,
weight_decay: ff_weight_decay,
type: 'adamw',
},
score_nn: {
lr: 0.005,
weight_decay: ff_weight_decay,
type: 'adamw',
},
},
},
checkpointer: {
keep_most_recent_by_count: 1,
},
callbacks: [
'track_epoch_callback',
'slurm',
] + (
if use_wandb then [
{
type: 'wandb_allennlp',
sub_callbacks: [{ type: 'log_best_validation_metrics', priority: 100 }],
},
]
else []
),
inner_mode: 'score_nn',
num_steps: { task_nn: task_nn_steps, score_nn: 1 },
},
}
83 changes: 83 additions & 0 deletions scripts/wandb_run_NCE_infnet.sh
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
#!/bin/sh
#SBATCH --job-name=structured_prediction
#SBATCH --output=../logs/infnet_wNCE-%j.out
#SBATCH --partition=1080ti-long
#SBATCH --gres=gpu:1
#SBATCH --cpus-per-task=4
#SBATCH --mem=20GB
#SBATCH --exclude=node072,node035,node030

export TEST=0
export CUDA_DEVICE=0
export DATA_DIR=./data/

# Model Variables
# export dataset_name=genbase
export ff_hidden=400
export ff_dropout=0.5
export ff_dropout_10x=5
export ff_linear_layers=2
export ff_weight_decay=0.0001
export global_score_hidden_dim=200
export cross_entropy_loss_weight=1.0
export inference_score_weight=0
export dvn_score_loss_weight=0
export num_samples=10
export stopping_criteria=0
export task_nn_steps=1
export score_nn_steps=1

#wandb_allennlp --subcommand=train \
# --config_file=model_configs/<path_to_config_file> \
# --include-package=structured_prediction_baselines \
# --wandb_run_name=<some_informative_name_for_run> \
# --wandb_project structure_prediction_baselines \
# --wandb_entity score-based-learning \
# --wandb_tags=baselines,as_reported

## running with jsonnet
# --config_file=./model_configs/multilabel_classification/bibtex_inference_net_wDVN.jsonnet \
# --config_file=model_configs/multilabel_classification/Infnet_wDVN.config \
# --config_file=./model_configs/multilabel_classification/bibtex_inference_net_wDVN_fscratch.jsonnet \
## running with config file.

# infnet + NCE
# allennlp train_with_wandb \
# model_configs/multilabel_classification/bibtex_infnet_wNCE_infscore.jsonnet \
# --include-package=structured_prediction_baselines \
# --wandb_name=bibtex_inference_net_wNCE_test \
# --wandb_project structured_prediction_baselines \
# --wandb_entity score-based-learning \
# --wandb_tags="bibtex,infnet_wDVN,without_sampling"

# # revNCE with task-NN zero loss + pre-trained task-NN
# allennlp train_with_wandb \
# model_configs/multilabel_classification/bibtex_revNCE_zerotasknn.jsonnet \
# --include-package=structured_prediction_baselines \
# --wandb_name="optimizer=adam (lr!=0, wd!=0) & typical loss & main_wd!=0" \
# --wandb_project structured_prediction_baselines \
# --wandb_entity score-based-learning \
# --wandb_tags="bibtex,revNCE testing"
# # --wandb_name=revNCE_test_with_pretrainedTaskNN_lossAsMetric \

# NCE with task-NN zero loss + pre-trained task-NN
# allennlp train_with_wandb \
# model_configs/multilabel_classification/bibtex_NCE_zerotasknn.jsonnet \
# --include-package=structured_prediction_baselines \
# --wandb_name=NCE_test_with_pretrainedTaskNN_lossAsMetric \
# --wandb_project structured_prediction_baselines \
# --wandb_entity score-based-learning \
# --wandb_tags="bibtex,NCE testing"

# revNCE: pre-trained score-NN + no update on score-NN (zero loss)
# export ALLENNLP_DEBUG=1
allennlp train_with_wandb \
model_configs/multilabel_classification/v2.5/gendata_nce_discrete_on_tasknn_reverse.jsonnet \
--include-package=structured_prediction_baselines \
--wandb_name="NCE_on_taskNN_test" \
--wandb_project="mlc" \
--wandb_entity="score-based-learning" \
--wandb_tags="bibtex,revNCE testing"
# --wandb_name=revNCE_test_with_pretrainedTaskNN_lossAsMetric \

# /mnt/nfs/scratch1/jaylee/repository/structured_prediction/model_configs/multilabel_classification/bibtex_revNCE_zeroScoreLoss.jsonnet \
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,16 @@ def _normalize(y: torch.Tensor) -> torch.Tensor:


class MultiLabelNCERankingLoss(NCERankingLoss):
def __init__(self, sign: Literal["-", "+"] = "-", **kwargs: Any):
super().__init__(**kwargs)
def __init__(self,
sign: Literal["-", "+"] = "-",
use_scorenn: bool = True,
**kwargs: Any):
super().__init__(use_scorenn, **kwargs)
self.sign = sign
self.mul = -1 if sign == "-" else 1
self.bce = torch.nn.BCELoss(reduction="none")
# when self.use_scorenn=True, the sign should always be +.
assert (sign == "+" if self.use_scorenn else True)

def normalize(self, y: torch.Tensor) -> torch.Tensor:
return _normalize(y)
Expand Down
15 changes: 10 additions & 5 deletions structured_prediction_baselines/modules/loss/nce_loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,8 +84,9 @@ def _forward(


class NCERankingLoss(NCELoss):
def __init__(self, **kwargs: Any) -> None:
def __init__(self, use_scorenn: bool = True, **kwargs: Any) -> None:
super().__init__(**kwargs)
self.use_scorenn = use_scorenn
self.cross_entropy = torch.nn.CrossEntropyLoss(reduction="none")

def compute_loss(
Expand All @@ -104,10 +105,14 @@ def compute_loss(
distance = self.distance(
y, y_hat.expand_as(y)
) # (batch, 1+num_samples) # does the job of Pn
score = self.score_nn(
x, y, buffer
) # type:ignore # (batch, 1+num_samples)
assert not distance.requires_grad
if self.use_scorenn:
score = self.score_nn(
x, y, buffer
) # type:ignore # (batch, 1+num_samples)
assert not distance.requires_grad
else:
score = 0

dhruvdcoder marked this conversation as resolved.
Show resolved Hide resolved
new_score = score - distance # (batch, 1+num_samples)
ranking_loss = self.cross_entropy(
new_score,
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
name: cal500_ranking_discrete_on_tasknn
description: "Train tasknn using cross-entropy and ranking loss. The score-nn will be trained using NCE with a - sign (score-ln Pn). The samples are taken as discrete samples from the tasknn output."
program: allennlp
command:
- ${program}
- train_with_wandb
- model_configs/multilabel_classification/v2.5/gendata_ranking_discrete_on_tasknn.jsonnet
- --include-package=structured_prediction_baselines
- --wandb_tags="task=mlc,model=dvn,sampler=inference_net_continuous_samples,dataset=cal500,inference_module=inference_net,inference_module=tasknn,sampler=tasknn_contiuous_samples"
- ${args}
- --file-friendly-logging
method: bayes
metric:
goal: maximize
name: "validation/best_fixed_f1"

early_terminate:
type: hyperband
min_iter: 20

parameters:
env.dataset_name:
value: 'cal500'
env.cross_entropy_loss_weight:
value: 1.0
env.dvn_score_loss_weight:
distribution: log_uniform
min: -6.9
max: 2.3
env.ff_dropout_10x:
value: 2 # bibtex
env.ff_hidden:
value: 500 # bibtex
env.ff_linear_layers:
value: 5 # bibtex
env.ff_weight_decay:
value: 0.00001
env.global_score_hidden_dim:
value: 400 # bibtex
trainer.optimizer.optimizers.task_nn.lr:
distribution: log_uniform
min: -12.5
max: -4.5
env.task_nn_steps: # instead of trainer.num_steps.task_nn
value: 1
env.num_samples:
distribution: q_uniform
q: 20 # 10, 25, ..., 50
min: 80
max: 100
Loading