Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Single path tan loading #6

Open
wants to merge 7 commits into
base: master
Choose a base branch
from
Open

Conversation

ajratner
Copy link
Contributor

@ajratner ajratner commented Dec 6, 2017

Still needs to be tested

@ajratner ajratner requested a review from henryre December 6, 2017 02:39
Copy link
Member

@henryre henryre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely needs to be tested as well. For next rev, can use the MNIST small example script

keras/utils.py Outdated

# Load TFs
# Assume they are present in config['train_module'] as list called tfs
tfs = import_module(config['train_module']).tfs
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This seems like way too strong of an assumption. Any other ideas on how to load in TFs?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed in latest commit- we just pickle them using the cloud lib (a good logging / reproducibility step regardless), then just load them in; much simpler. And all handled in the train_scripts.py file

keras/utils.py Outdated
with open(config_path, 'r') as f:
def load_pretrained_tan(path):
# Load config dictionary from run log
with open(os.path.join(path, 'run_log.json'), 'r') as f:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this where this file is always located? I think it's usually in logs, just like this for the pretrained dir

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@@ -36,7 +37,7 @@ class TANDAImageDataGenerator(ImageDataGenerator):
"""

def __init__(self,
tan,
tan_path,
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Definitely need to allow user to pass in TAN object. Can use isinstance(tan, str)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@@ -45,6 +47,12 @@
if FLAGS.n_folds > 0:
X_train, Y_train = select_fold(X_train, Y_train)

# Make sure dims and current module name is included in the run log
# Note: this is currently kind of hackey, should clean up...
FLAGS.__flags['train_module'] = re.sub(r'\/', '.',
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Breaks if any experiment subdir contains "tanda" right?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also shouldnt this be somewhere in train_scripts.py?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Fixed

@ajratner
Copy link
Contributor Author

ajratner commented Dec 7, 2017

@henryre ready for re-review! Can be tested when training new TANs

@henryre
Copy link
Member

henryre commented Dec 7, 2017

Ok will look soon!

Copy link
Member

@henryre henryre left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's run a basic test here to make sure these changes work

@@ -3,7 +3,10 @@
from __future__ import print_function
from __future__ import unicode_literals

from dataset import load_cifar10_data
import sys
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove these

import sys
import re

from .dataset import load_cifar10_data
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Revert to from dataset

@@ -8,6 +8,8 @@
import re
import tensorflow as tf
import tensorflow.contrib.slim as slim
import sys
import cloud
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needs to be added to package requirement

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also why this instead of e.g. dill?

keras/utils.py Outdated
@@ -4,6 +4,7 @@
from __future__ import unicode_literals

import json
import pickle
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: from six import cPickle

@@ -0,0 +1,61 @@
{
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

@@ -8,6 +8,8 @@
import re
import tensorflow as tf
import tensorflow.contrib.slim as slim
import sys
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Remove

@ajratner
Copy link
Contributor Author

@henryre Addressed all inline comments

@@ -3,6 +3,8 @@
from __future__ import print_function
from __future__ import unicode_literals

import re
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Still need to remove

keras/utils.py Outdated
@@ -4,6 +4,7 @@
from __future__ import unicode_literals

import json
from six import pickle
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's cPickle not pickle

@@ -6,3 +6,4 @@ scikit-image>=0.13
scipy>=0.18
six
tensorflow>=1.2
cloud
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Alpha order

@ajratner
Copy link
Contributor Author

@henryre changes made

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants