-
Notifications
You must be signed in to change notification settings - Fork 29
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Single path tan loading #6
base: master
Are you sure you want to change the base?
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely needs to be tested as well. For next rev, can use the MNIST small example script
keras/utils.py
Outdated
|
||
# Load TFs | ||
# Assume they are present in config['train_module'] as list called tfs | ||
tfs = import_module(config['train_module']).tfs |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems like way too strong of an assumption. Any other ideas on how to load in TFs?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed in latest commit- we just pickle them using the cloud
lib (a good logging / reproducibility step regardless), then just load them in; much simpler. And all handled in the train_scripts.py
file
keras/utils.py
Outdated
with open(config_path, 'r') as f: | ||
def load_pretrained_tan(path): | ||
# Load config dictionary from run log | ||
with open(os.path.join(path, 'run_log.json'), 'r') as f: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is this where this file is always located? I think it's usually in logs
, just like this for the pretrained dir
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
keras/tanda_keras.py
Outdated
@@ -36,7 +37,7 @@ class TANDAImageDataGenerator(ImageDataGenerator): | |||
""" | |||
|
|||
def __init__(self, | |||
tan, | |||
tan_path, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Definitely need to allow user to pass in TAN object. Can use isinstance(tan, str)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
experiments/cifar10/train.py
Outdated
@@ -45,6 +47,12 @@ | |||
if FLAGS.n_folds > 0: | |||
X_train, Y_train = select_fold(X_train, Y_train) | |||
|
|||
# Make sure dims and current module name is included in the run log | |||
# Note: this is currently kind of hackey, should clean up... | |||
FLAGS.__flags['train_module'] = re.sub(r'\/', '.', |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Breaks if any experiment subdir contains "tanda" right?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also shouldnt this be somewhere in train_scripts.py
?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Fixed
@henryre ready for re-review! Can be tested when training new TANs |
Ok will look soon! |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Let's run a basic test here to make sure these changes work
experiments/cifar10/train.py
Outdated
@@ -3,7 +3,10 @@ | |||
from __future__ import print_function | |||
from __future__ import unicode_literals | |||
|
|||
from dataset import load_cifar10_data | |||
import sys |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove these
experiments/cifar10/train.py
Outdated
import sys | ||
import re | ||
|
||
from .dataset import load_cifar10_data |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Revert to from dataset
experiments/train_scripts.py
Outdated
@@ -8,6 +8,8 @@ | |||
import re | |||
import tensorflow as tf | |||
import tensorflow.contrib.slim as slim | |||
import sys | |||
import cloud |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Needs to be added to package requirement
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Also why this instead of e.g. dill?
keras/utils.py
Outdated
@@ -4,6 +4,7 @@ | |||
from __future__ import unicode_literals | |||
|
|||
import json | |||
import pickle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: from six import cPickle
pretrained/cifar10/logs/run_log.json
Outdated
@@ -0,0 +1,61 @@ | |||
{ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove
experiments/train_scripts.py
Outdated
@@ -8,6 +8,8 @@ | |||
import re | |||
import tensorflow as tf | |||
import tensorflow.contrib.slim as slim | |||
import sys |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Remove
@henryre Addressed all inline comments |
experiments/cifar10/train.py
Outdated
@@ -3,6 +3,8 @@ | |||
from __future__ import print_function | |||
from __future__ import unicode_literals | |||
|
|||
import re |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Still need to remove
keras/utils.py
Outdated
@@ -4,6 +4,7 @@ | |||
from __future__ import unicode_literals | |||
|
|||
import json | |||
from six import pickle |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
It's cPickle
not pickle
python-package-requirement.txt
Outdated
@@ -6,3 +6,4 @@ scikit-image>=0.13 | |||
scipy>=0.18 | |||
six | |||
tensorflow>=1.2 | |||
cloud |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Nit: Alpha order
@henryre changes made |
Still needs to be tested