-
Notifications
You must be signed in to change notification settings - Fork 47
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[OSPP] Combine GAN and Self-taught Learning to solve small sample problem #40
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Haven't seen how EntGAN is intergrated into Ianvs. If there are changes in sedna lib codes, please place sedna lib in Ianvs example directory.
- There are so many model files uploaded into github. I suggest that all these files be linked in readme and configured in yaml by setting parameters.
- Some files named ".DS_Store" should be removed.
@@ -0,0 +1,33 @@ | |||
# Integrate GAN and Self-taught Learning into ianvs Lifelong Learning to Handle Unknown Tasks |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
.DS_Store file should be removed.
@@ -0,0 +1,33 @@ | |||
# Integrate GAN and Self-taught Learning into ianvs Lifelong Learning to Handle Unknown Tasks | |||
|
|||
## Motivation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Proposal md file should be removd in this pr.
- batch_size: 3 | ||
- lr: 1.0e-4 | ||
- name: "1" | ||
- cityscapes_data_path: "/home/nailtu/data/cityscapes" |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Absolute path should be revised as relative path.
self.var_p1 = Variable(self.input_p1,requires_grad=True) | ||
|
||
def forward_train(self): # run forward pass | ||
# print(self.net.module.scaling_layer.shift) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Annotated codes should be removed.
os.makedirs(saved_model_folder, exist_ok=True) | ||
os.makedirs(saved_image_folder, exist_ok=True) | ||
|
||
# for f in os.listdir('./'): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove annotated codes.
|
||
from skimage import io | ||
|
||
# print(os.getcwd()) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
For cloud-native service, use log to output instead of print. print is not feasible in service.
optimizerG = optim.Adam(netG.parameters(), lr=nlr, betas=(nbeta1, 0.999)) | ||
optimizerD = optim.Adam(netD.parameters(), lr=nlr, betas=(nbeta1, 0.999)) | ||
|
||
# if checkpoint != 'None': |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove annotated codes.
|
||
|
||
if __name__ == "__main__": | ||
# parser = argparse.ArgumentParser(description='region gan') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove annotated codes.
@@ -0,0 +1,18 @@ | |||
epoch,d_loss,g_loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
What are these csv data used for? Remove the csv file if it is no use.
label_img = cv2.imread(label_img_path, -1) | ||
label_img = cv2.resize(label_img, (self.new_img_w, self.new_img_h), | ||
interpolation=cv2.INTER_NEAREST) | ||
# flip = np.random.randint(low=0, high=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove annotated codes.
epoch_losses_train = [] | ||
epoch_losses_val = [] | ||
for epoch in range(num_epochs): | ||
print ("###########################") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Use log instead of print
|
||
class PerceptualLoss(torch.nn.Module): | ||
def __init__(self, model='net-lin', net='alex', colorspace='rgb', spatial=False, use_gpu=True, gpu_ids=[0]): # VGG using our perceptually-learned weights (LPIPS metric) | ||
# def __init__(self, model='net', net='vgg', use_gpu=True): # "default" way of using VGG as a perceptual loss |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
if this comment of code is still useful, please write a more detailed comment
if not useful, please remove it
self.rankLoss = self.rankLoss.to(device=gpu_ids[0]) # just put this on GPU0 | ||
|
||
if(printNet): | ||
print('---------- Networks initialized -------------') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print -> log
for param_group in self.optimizer_net.param_groups: | ||
param_group['lr'] = lr | ||
|
||
print('update lr [%s] decay: %f -> %f' % (type,self.old_lr, lr)) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print -> log
netG = Generator(ngf=ngf, nz=nz, im_size=im_size).to(device) | ||
weights_init(netG) | ||
weights = torch.load(os.getcwd() + '/train_results/test1/models/50000.pth') | ||
# print(weights['g']) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove the annotation.
feat_last = self.se_8_64(feat_8, feat_last) | ||
|
||
# rf_0 = torch.cat([self.rf_big_1(feat_last).view(-1),self.rf_big_2(feat_last).view(-1)]) | ||
# rff_big = torch.sigmoid(self.rf_factor_big) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
not useful code should be deleted instead of using comment
rf_0 = self.rf_big(feat_last).view(-1) | ||
|
||
feat_small = self.down_from_small(imgs[1]) | ||
# rf_1 = torch.cat([self.rf_small_1(feat_small).view(-1),self.rf_small_2(feat_small).view(-1)]) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete abandoned code
# shutil.copy(f, task_name+'/'+f) | ||
|
||
# with open( os.path.join(saved_model_folder, '../args.txt'), 'w') as f: | ||
# json.dump(args.__dict__, f, indent=2) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete abandoned code
# optimizerG.load_state_dict(ckpt['opt_g']) | ||
# optimizerD.load_state_dict(ckpt['opt_d']) | ||
# current_iteration = int(checkpoint.split('_')[-1].split('.')[0]) | ||
# del ckpt |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete abandoned code
# parser.add_argument('--start_iter', type=int, default=0, help='the iteration to start training') | ||
# parser.add_argument('--batch_size', type=int, default=8, help='mini batch number of images') | ||
# parser.add_argument('--im_size', type=int, default=1024, help='image resolution') | ||
# parser.add_argument('--ckpt', type=str, default='None', help='checkpoint weight path if have one') |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete abandoned code
if it is useful, more detailed comment is needed
@@ -0,0 +1,76 @@ | |||
# Differentiable Augmentation for Data-Efficient GAN Training | |||
# Shengyu Zhao, Zhijian Liu, Ji Lin, Jun-Yan Zhu, and Song Han | |||
# https://arxiv.org/pdf/2006.10738 |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If this is an copy of thirdparty source code, please add the original link of repository and make sure the open source license is permit to use this code copy.
# end_y = start_y + 256 | ||
# | ||
# img = img[start_y:end_y, start_x:end_x] | ||
# label_img = label_img[start_y:end_y, start_x:end_x] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete abandoned code
plt.savefig("%s/epoch_losses_train.png" % network.model_dir) | ||
plt.close(1) | ||
|
||
print ("####") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print -> log
epoch_losses_val.append(epoch_loss) | ||
with open("%s/epoch_losses_val.pkl" % network.model_dir, "wb") as file: | ||
pickle.dump(epoch_losses_val, file) | ||
print ("val loss: %g" % epoch_loss) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
print -> log
def load_yaml(path): | ||
with open(path) as f: | ||
data = yaml.load(f, Loader=yaml.FullLoader) | ||
# print(data) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
delete abandoned code
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
- Change the target branch to lifelong-feature-n.
- Integrate the codes in GANwithSelf-taughtLearning into folder curb-detection/lifelong_learning_bench. Configure algorithm modules of unseen_task_processing in rfnet_algorithm.yaml.
- Haven't seen how this project is integrated into Ianvs. For example, the Ianvs yaml files are not configured yet.
- All the model files such as pth files should not be uploaded to github but linked in readme for users to download. Then they can be configured in config.yaml or other configuration file.
- All the outputs using "print" change to the way of logging.
3022d8c
to
f87ed4a
Compare
Signed-off-by: nailtu <nail.tu@outlook.com>
/lgtm |
/approve |
[OSPP] Combine GAN and Self-taught Learning to solve small sample problem
This PR is related with #34.
This PR includes a proposal and pending codes.
Total work is excepted to be done for a few days.