Skip to content

Commit

Permalink
fix darts search process (PaddlePaddle#353)
Browse files Browse the repository at this point in the history
  • Loading branch information
baiyfbupt authored Jun 15, 2020
1 parent a075e69 commit e6cffba
Showing 1 changed file with 36 additions and 23 deletions.
59 changes: 36 additions & 23 deletions paddleslim/nas/darts/train_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,17 +100,32 @@ def __init__(self,
def train_one_epoch(self, train_loader, valid_loader, architect, optimizer,
epoch):
objs = AvgrageMeter()
ce_losses = AvgrageMeter()
kd_losses = AvgrageMeter()
e_losses = AvgrageMeter()
top1 = AvgrageMeter()
top5 = AvgrageMeter()
self.model.train()

step_id = 0
for train_data, valid_data in zip(train_loader(), valid_loader()):
for step_id, (
train_data,
valid_data) in enumerate(zip(train_loader(), valid_loader())):
train_image, train_label = train_data
valid_image, valid_label = valid_data
train_image = to_variable(train_image)
train_label = to_variable(train_label)
train_label.stop_gradient = True
valid_image = to_variable(valid_image)
valid_label = to_variable(valid_label)
valid_label.stop_gradient = True
n = train_image.shape[0]

if epoch >= self.epochs_no_archopt:
architect.step(train_data, valid_data)
architect.step(train_image, train_label, valid_image,
valid_label)

loss, ce_loss, kd_loss, e_loss = self.model.loss(train_data)
logits = self.model(train_image)
prec1 = fluid.layers.accuracy(input=logits, label=train_label, k=1)
prec5 = fluid.layers.accuracy(input=logits, label=train_label, k=5)
loss = fluid.layers.reduce_mean(
fluid.layers.softmax_with_cross_entropy(logits, train_label))

if self.use_data_parallel:
loss = self.model.scale_loss(loss)
Expand All @@ -122,30 +137,26 @@ def train_one_epoch(self, train_loader, valid_loader, architect, optimizer,
optimizer.minimize(loss)
self.model.clear_gradients()

batch_size = train_data[0].shape[0]
objs.update(loss.numpy(), batch_size)
ce_losses.update(ce_loss.numpy(), batch_size)
kd_losses.update(kd_loss.numpy(), batch_size)
e_losses.update(e_loss.numpy(), batch_size)
objs.update(loss.numpy(), n)
top1.update(prec1.numpy(), n)
top5.update(prec5.numpy(), n)

if step_id % self.log_freq == 0:
#logger.info("Train Epoch {}, Step {}, loss {:.6f}; ce: {:.6f}; kd: {:.6f}; e: {:.6f}".format(
# epoch, step_id, objs.avg[0], ce_losses.avg[0], kd_losses.avg[0], e_losses.avg[0]))
logger.info(
"Train Epoch {}, Step {}, loss {}; ce: {}; kd: {}; e: {}".
format(epoch, step_id,
loss.numpy(),
ce_loss.numpy(), kd_loss.numpy(), e_loss.numpy()))
step_id += 1
return objs.avg[0]
"Train Epoch {}, Step {}, loss {:.6f}, acc_1 {:.6f}, acc_5 {:.6f}".
format(epoch, step_id, objs.avg[0], top1.avg[0], top5.avg[
0]))
return top1.avg[0]

def valid_one_epoch(self, valid_loader, epoch):
objs = AvgrageMeter()
top1 = AvgrageMeter()
top5 = AvgrageMeter()
self.model.eval()

for step_id, valid_data in enumerate(valid_loader):
for step_id, (image, label) in enumerate(valid_loader):
image = to_variable(image)
label = to_variable(label)
n = image.shape[0]
Expand Down Expand Up @@ -235,12 +246,14 @@ def train(self):
genotype = get_genotype(base_model)
logger.info('genotype = %s', genotype)

self.train_one_epoch(train_loader, valid_loader, architect,
optimizer, epoch)
train_top1 = self.train_one_epoch(train_loader, valid_loader,
architect, optimizer, epoch)
logger.info("Epoch {}, train_acc {:.6f}".format(epoch, train_top1))

if epoch == self.num_epochs - 1:
# valid_top1 = self.valid_one_epoch(valid_loader, epoch)
logger.info("Epoch {}, valid_acc {:.6f}".format(epoch, 1))
valid_top1 = self.valid_one_epoch(valid_loader, epoch)
logger.info("Epoch {}, valid_acc {:.6f}".format(epoch,
valid_top1))
if save_parameters:
fluid.save_dygraph(
self.model.state_dict(),
Expand Down

0 comments on commit e6cffba

Please sign in to comment.