Skip to content

Commit

Permalink
Open sourcing PixelDA code
Browse files Browse the repository at this point in the history
  • Loading branch information
dmrd committed Jul 21, 2017
1 parent 2a5f2a9 commit aed6922
Show file tree
Hide file tree
Showing 20 changed files with 3,517 additions and 20 deletions.
84 changes: 69 additions & 15 deletions domain_adaptation/README.md
Original file line number Diff line number Diff line change
@@ -1,13 +1,23 @@
# Domain Separation Networks
## Introduction
This is the code used for two domain adaptation papers.

The `domain_separation` directory contains code for the "Domain Separation
Networks" paper by Bousmalis K., Trigeorgis G., et al. which was presented at
NIPS 2016. The paper can be found here: https://arxiv.org/abs/1608.06019.

## Introduction
This code is the code used for the "Domain Separation Networks" paper
by Bousmalis K., Trigeorgis G., et al. which was presented at NIPS 2016. The
paper can be found here: https://arxiv.org/abs/1608.06019.
The `pixel_domain_adaptation` directory contains the code used for the
"Unsupervised Pixel-Level Domain Adaptation with Generative Adversarial
Networks" paper by Bousmalis K., et al. (presented at CVPR 2017). The paper can
be found here: https://arxiv.org/abs/1612.05424. PixelDA aims to perform domain
adaptation by transfering the visual style of the target domain (which has few
or no labels) to a source domain (which has many labels). This is accomplished
using a Generative Adversarial Network (GAN).

## Contact
This code was open-sourced by [Konstantinos Bousmalis](https://github.com/bousmalis) (konstantinos@google.com).
The domain separation code was open-sourced
by [Konstantinos Bousmalis](https://github.com/bousmalis)
(konstantinos@google.com), while the pixel level domain adaptation code was
open-sourced by [David Dohan](https://github.com/dmrd) (ddohan@google.com).

## Installation
You will need to have the following installed on your machine before trying out the DSN code.
Expand All @@ -16,26 +26,70 @@ You will need to have the following installed on your machine before trying out
* Bazel: https://bazel.build/

## Important Note
Although we are making the code available, you are only able to use the MNIST
provider for now. We will soon provide a script to download and convert MNIST-M
as well. Check back here in a few weeks or wait for a relevant announcement from
[@bousmalis](https://twitter.com/bousmalis).
We are working to open source the pose estimation dataset. For now, the MNIST to
MNIST-M dataset is available. Check back here in a few weeks or wait for a
relevant announcement from [@bousmalis](https://twitter.com/bousmalis).

## Running the code for adapting MNIST to MNIST-M
In order to run the MNIST to MNIST-M experiments with DANNs and/or DANNs with
domain separation (DSNs) you will need to set the directory you used to download
MNIST and MNIST-M:
## Initial setup
In order to run the MNIST to MNIST-M experiments, you will need to set the
data directory:

```
$ export DSN_DATA_DIR=/your/dir
```

Add models and models/slim to your `$PYTHONPATH`:
Add models and models/slim to your `$PYTHONPATH` (assumes $PWD is /models):

```
$ export PYTHONPATH=$PYTHONPATH:$PWD:$PWD/slim
```

## Getting the datasets

You can fetch the MNIST data by running

```
$ bazel run slim:download_and_convert_data -- --dataset_dir $DSN_DATA_DIR --dataset_name=mnist
```

The MNIST-M dataset is available online [here](http://bit.ly/2nrlUAJ). Once it is downloaded and extracted into your data directory, create TFRecord files by running:
```
$ bazel run domain_adaptation/datasets:download_and_convert_mnist_m -- --dataset_dir $DSN_DATA_DIR
```



# Running PixelDA from MNIST to MNIST-M
You can run PixelDA as follows (using Tensorboard to examine the results):

```
$ bazel run domain_adaptation/pixel_domain_adaptation:pixelda_train -- --dataset_dir $DSN_DATA_DIR --source_dataset mnist --target_dataset mnist_m
```

And evaluation as:
```
$ bazel run domain_adaptation/pixel_domain_adaptation:pixelda_eval -- --dataset_dir $DSN_DATA_DIR --source_dataset mnist --target_dataset mnist_m --target_split_name test
```

The MNIST-M results in the paper were run with the following hparams flag:
```
--hparams arch=resnet,domain_loss_weight=0.135603587834,num_training_examples=16000000,style_transfer_loss_weight=0.0113173311334,task_loss_in_g_weight=0.0100959947002,task_tower=mnist,task_tower_in_g_step=true
```

### A note on terminology/language of the code:

The components of the network can be grouped into two parts
which correspond to elements which are jointly optimized: The generator
component and the discriminator component.

The generator component takes either an image or noise vector and produces an
output image.

The discriminator component takes the generated images and the target images
and attempts to discriminate between them.

## Running DSN code for adapting MNIST to MNIST-M

Then you need to build the binaries with Bazel:

```
Expand Down
10 changes: 10 additions & 0 deletions domain_adaptation/datasets/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -26,10 +26,20 @@ py_library(
],
)

py_binary(
name = "download_and_convert_mnist_m",
srcs = ["download_and_convert_mnist_m.py"],
deps = [

"//slim:dataset_utils",
],
)

py_binary(
name = "mnist_m",
srcs = ["mnist_m.py"],
deps = [

"//slim:dataset_utils",
],
)
5 changes: 3 additions & 2 deletions domain_adaptation/datasets/dataset_factory.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
Expand All @@ -11,13 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================

"""A factory-pattern class which returns image/label pairs."""

from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

# Dependency imports
import tensorflow as tf

from slim.datasets import mnist
Expand Down
237 changes: 237 additions & 0 deletions domain_adaptation/datasets/download_and_convert_mnist_m.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,237 @@
# Copyright 2017 Google Inc.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

r"""Downloads and converts MNIST-M data to TFRecords of TF-Example protos.
This module downloads the MNIST-M data, uncompresses it, reads the files
that make up the MNIST-M data and creates two TFRecord datasets: one for train
and one for test. Each TFRecord dataset is comprised of a set of TF-Example
protocol buffers, each of which contain a single image and label.
The script should take about a minute to run.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function

import os
import random
import sys

# Dependency imports
import numpy as np
from six.moves import urllib
import tensorflow as tf

from slim.datasets import dataset_utils

tf.app.flags.DEFINE_string(
'dataset_dir', None,
'The directory where the output TFRecords and temporary files are saved.')

FLAGS = tf.app.flags.FLAGS

_IMAGE_SIZE = 32
_NUM_CHANNELS = 3

# The number of images in the training set.
_NUM_TRAIN_SAMPLES = 59001

# The number of images to be kept from the training set for the validation set.
_NUM_VALIDATION = 1000

# The number of images in the test set.
_NUM_TEST_SAMPLES = 9001

# Seed for repeatability.
_RANDOM_SEED = 0

# The names of the classes.
_CLASS_NAMES = [
'zero',
'one',
'two',
'three',
'four',
'five',
'size',
'seven',
'eight',
'nine',
]


class ImageReader(object):
"""Helper class that provides TensorFlow image coding utilities."""

def __init__(self):
# Initializes function that decodes RGB PNG data.
self._decode_png_data = tf.placeholder(dtype=tf.string)
self._decode_png = tf.image.decode_png(self._decode_png_data, channels=3)

def read_image_dims(self, sess, image_data):
image = self.decode_png(sess, image_data)
return image.shape[0], image.shape[1]

def decode_png(self, sess, image_data):
image = sess.run(
self._decode_png, feed_dict={self._decode_png_data: image_data})
assert len(image.shape) == 3
assert image.shape[2] == 3
return image


def _convert_dataset(split_name, filenames, filename_to_class_id, dataset_dir):
"""Converts the given filenames to a TFRecord dataset.
Args:
split_name: The name of the dataset, either 'train' or 'valid'.
filenames: A list of absolute paths to png images.
filename_to_class_id: A dictionary from filenames (strings) to class ids
(integers).
dataset_dir: The directory where the converted datasets are stored.
"""
print('Converting the {} split.'.format(split_name))
# Train and validation splits are both in the train directory.
if split_name in ['train', 'valid']:
png_directory = os.path.join(dataset_dir, 'mnist_m', 'mnist_m_train')
elif split_name == 'test':
png_directory = os.path.join(dataset_dir, 'mnist_m', 'mnist_m_test')

with tf.Graph().as_default():
image_reader = ImageReader()

with tf.Session('') as sess:
output_filename = _get_output_filename(dataset_dir, split_name)

with tf.python_io.TFRecordWriter(output_filename) as tfrecord_writer:
for filename in filenames:
# Read the filename:
image_data = tf.gfile.FastGFile(
os.path.join(png_directory, filename), 'r').read()
height, width = image_reader.read_image_dims(sess, image_data)

class_id = filename_to_class_id[filename]
example = dataset_utils.image_to_tfexample(image_data, 'png', height,
width, class_id)
tfrecord_writer.write(example.SerializeToString())

sys.stdout.write('\n')
sys.stdout.flush()


def _extract_labels(label_filename):
"""Extract the labels into a dict of filenames to int labels.
Args:
labels_filename: The filename of the MNIST-M labels.
Returns:
A dictionary of filenames to int labels.
"""
print('Extracting labels from: ', label_filename)
label_file = tf.gfile.FastGFile(label_filename, 'r').readlines()
label_lines = [line.rstrip('\n').split() for line in label_file]
labels = {}
for line in label_lines:
assert len(line) == 2
labels[line[0]] = int(line[1])
return labels


def _get_output_filename(dataset_dir, split_name):
"""Creates the output filename.
Args:
dataset_dir: The directory where the temporary files are stored.
split_name: The name of the train/test split.
Returns:
An absolute file path.
"""
return '%s/mnist_m_%s.tfrecord' % (dataset_dir, split_name)


def _get_filenames(dataset_dir):
"""Returns a list of filenames and inferred class names.
Args:
dataset_dir: A directory containing a set PNG encoded MNIST-M images.
Returns:
A list of image file paths, relative to `dataset_dir`.
"""
photo_filenames = []
for filename in os.listdir(dataset_dir):
photo_filenames.append(filename)
return photo_filenames


def run(dataset_dir):
"""Runs the download and conversion operation.
Args:
dataset_dir: The dataset directory where the dataset is stored.
"""
if not tf.gfile.Exists(dataset_dir):
tf.gfile.MakeDirs(dataset_dir)

train_filename = _get_output_filename(dataset_dir, 'train')
testing_filename = _get_output_filename(dataset_dir, 'test')

if tf.gfile.Exists(train_filename) and tf.gfile.Exists(testing_filename):
print('Dataset files already exist. Exiting without re-creating them.')
return

# TODO(konstantinos): Add download and cleanup functionality

train_validation_filenames = _get_filenames(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_train'))
test_filenames = _get_filenames(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_test'))

# Divide into train and validation:
random.seed(_RANDOM_SEED)
random.shuffle(train_validation_filenames)
train_filenames = train_validation_filenames[_NUM_VALIDATION:]
validation_filenames = train_validation_filenames[:_NUM_VALIDATION]

train_validation_filenames_to_class_ids = _extract_labels(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_train_labels.txt'))
test_filenames_to_class_ids = _extract_labels(
os.path.join(dataset_dir, 'mnist_m', 'mnist_m_test_labels.txt'))

# Convert the train, validation, and test sets.
_convert_dataset('train', train_filenames,
train_validation_filenames_to_class_ids, dataset_dir)
_convert_dataset('valid', validation_filenames,
train_validation_filenames_to_class_ids, dataset_dir)
_convert_dataset('test', test_filenames, test_filenames_to_class_ids,
dataset_dir)

# Finally, write the labels file:
labels_to_class_names = dict(zip(range(len(_CLASS_NAMES)), _CLASS_NAMES))
dataset_utils.write_label_file(labels_to_class_names, dataset_dir)

print('\nFinished converting the MNIST-M dataset!')


def main(_):
assert FLAGS.dataset_dir
run(FLAGS.dataset_dir)


if __name__ == '__main__':
tf.app.run()
Loading

0 comments on commit aed6922

Please sign in to comment.