Skip to content

Commit

Permalink
Merge pull request #73 from decile-team/selcon_sahasra
Browse files Browse the repository at this point in the history
Selcon sahasra
  • Loading branch information
sahasrarjn authored Mar 22, 2022
2 parents 41b449e + 9363ff8 commit db43c30
Show file tree
Hide file tree
Showing 24 changed files with 2,271 additions and 76 deletions.
3 changes: 2 additions & 1 deletion .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ __pycache__
/data/
/log/
/RayLogs/
.vscode
*.tar.gz
.vscode
2 changes: 1 addition & 1 deletion configs/SL/config_glister-warm_cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@
selection_type='Supervised',
greedy='Stochastic'),

train_args=dict(num_epochs=300,
train_args=dict(num_epochs=100,
device="cuda",
print_every=10,
results_dir='results/',
Expand Down
50 changes: 50 additions & 0 deletions configs/SL/config_selcon_lawschool.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
# Learning setting
config = dict(setting="SL",

dataset=dict(name="LawSchool_selcon",
datadir="../data",
feature="dss",
type="pre-defined"),

dataloader=dict(shuffle=True,
batch_size=100,
pin_memory=False),

model=dict(architecture='RegressionNet',
type='pre-defined',
input_dim=10,
numclasses=10),

ckpt=dict(is_load=False,
is_save=True,
dir='results/',
save_every=20),

loss=dict(type='MeanSquaredLoss',
use_sigmoid=False),

optimizer=dict(type="adam",
lr=0.01),

scheduler=dict(type="StepLR", # added this new scheduler type
step_size=1,
gamma=0.1),

dss_args=dict(type="SELCON",
fraction=0.01,
select_every=35,
kappa=0,
delta=0.04,
linear_layer=False,
lam=1e-5,
batch_sampler='sequential',
selection_type='Supervised'),

train_args=dict(num_epochs=200,
device_selcon="cuda",
print_every=1,
results_dir='results/',
print_args=["val_loss", "tst_loss", "trn_loss", "time"],
return_args=[]
)
)
3 changes: 2 additions & 1 deletion cords/selectionstrategies/SL/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,4 +5,5 @@
from .submodularselectionstrategy import SubmodularSelectionStrategy
from .gradmatchstrategy import GradMatchStrategy
from .fixedweightstrategy import FixedWeightStrategy
from .adapweightsstrategy import AdapWeightsStrategy
from .selconstrategy import SELCONstrategy
from .adapweightsstrategy import AdapWeightsStrategy
2 changes: 2 additions & 0 deletions cords/selectionstrategies/SL/glisterstrategy.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ def _update_grads_val(self, grads_curr=None, first_init=False):
if first_init:
for batch_idx, (inputs, targets) in enumerate(valloader):
inputs, targets = inputs.to(self.device), targets.to(self.device, non_blocking=True)
print(inputs.shape, targets.shape)
exit(0)
if batch_idx == 0:
out, l1 = self.model(inputs, last=True, freeze=True)
loss = self.loss(out, targets).sum()
Expand Down
Loading

0 comments on commit db43c30

Please sign in to comment.