-
Notifications
You must be signed in to change notification settings - Fork 22
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #49 from captain-pool/distillation_issue_48
Knowledge Distillation. Initial Commit
- Loading branch information
Showing
10 changed files
with
142 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
# Config File for Distillation | ||
|
||
student_network: "name1" | ||
student_networks_config: | ||
name1: | ||
param1: 1 | ||
param2: "two" | ||
paramN: "N" | ||
name2: | ||
param1: 1 | ||
param2: "two" | ||
paramN: "N" | ||
name3: | ||
param1: 1 | ||
param2: "two" | ||
paramN: "N" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
from libs import models | ||
from libs.models.abstract import Registry |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
from tensorflow.python import keras | ||
class Registry(type): | ||
models = {} | ||
def __init__(cls, name, bases, attrs): | ||
if name.lower() != "models": | ||
Registry.models[cls.__name__.lower()] = cls | ||
|
||
# Abstract class for Auto Registration of Kernels | ||
class Models(keras.models.Model, metaclass=Registry): | ||
def __init__(self, *args, **kwargs): | ||
super(Models, self).__init__() | ||
self.init(*args, **kwargs) | ||
def init(self, *args, **kwargs): | ||
pass |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import os | ||
import sys | ||
# Fetching Generator from ESRGAN | ||
sys.path.insert(0, "../E2_ESRGAN") | ||
from lib.model import RRDBNet |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
import os | ||
import yaml | ||
|
||
def singleton(cls): | ||
instances = {} | ||
def getinstance(*args, **kwargs): | ||
distill_config = kwargs.get("student", "") | ||
key = cls.__name__ | ||
if distill_config: | ||
key = "%s_student" % (cls.__name__) | ||
if key not in instances: | ||
instances[key] = cls(*args, **kwargs) | ||
return instances[key] | ||
return getinstance | ||
|
||
@singleton | ||
class Settings(object): | ||
def __init__(self, filename="config.yaml", student=False): | ||
self.__path = os.path.abspath(filename) | ||
|
||
@property | ||
def path(self): | ||
return os.path.dirname(self.__path) | ||
|
||
def __getitem__(self, index): | ||
with open(self.__path, "r") as file_: | ||
return yaml.load(file_.read(), Loader=yaml.FullLoader)[index] | ||
|
||
def get(self, index, default=None): | ||
with open(self.__path, "r") as file_: | ||
return yaml.load(file_.read(), Loader=yaml.FullLoader).get(index, default) | ||
|
||
|
||
class Stats(object): | ||
def __init__(self, filename="stats.yaml"): | ||
if os.path.exists(filename): | ||
with open(filename, "r") as file_: | ||
self.__data = yaml.load(file_.read(), Loader=yaml.FullLoader) | ||
else: | ||
self.__data = {} | ||
self.file = filename | ||
|
||
def get(self, index, default=None): | ||
self.__data.get(index, default) | ||
|
||
def __getitem__(self, index): | ||
return self.__data[index] | ||
|
||
def __setitem__(self, index, data): | ||
self.__data[index] = data | ||
with open(self.file, "w") as file_: | ||
yaml.dump(self.__data, file_, default_flow_style=False) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
import tensorflow as tf |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,5 @@ | ||
import os | ||
import sys | ||
# Fetching Ra Loss from project ESRGAN | ||
sys.path.insert(0, os.path.abspath("..")) | ||
from E2_ESRGAN.lib.utils import RelativisticAverageLoss |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,45 @@ | ||
from absl import logging | ||
import argparse | ||
from libs.models import teacher | ||
from libs import model | ||
from libs import utils | ||
from libs import settings | ||
import tensorflow as tf | ||
""" | ||
Compressing GANs using Knowledge Distillation. | ||
Teacher GAN: ESRGAN (https://github.com/captain-pool/E2_ESRGAN) | ||
Citation: | ||
@article{DBLP:journals/corr/abs-1902-00159, | ||
author = {Angeline Aguinaldo and | ||
Ping{-}Yeh Chiang and | ||
Alexander Gain and | ||
Ameya Patil and | ||
Kolten Pearson and | ||
Soheil Feizi}, | ||
title = {Compressing GANs using Knowledge Distillation}, | ||
journal = {CoRR}, | ||
volume = {abs/1902.00159}, | ||
year = {2019}, | ||
url = {http://arxiv.org/abs/1902.00159}, | ||
archivePrefix = {arXiv}, | ||
eprint = {1902.00159}, | ||
timestamp = {Tue, 21 May 2019 18:03:39 +0200}, | ||
biburl = {https://dblp.org/rec/bib/journals/corr/abs-1902-00159}, | ||
bibsource = {dblp computer science bibliography, https://dblp.org} | ||
} | ||
""" | ||
def main(**kwargs): | ||
student_settings = settings.Settings("../E2_ESRGAN/config.yaml", student=True) | ||
teacher_settings = settings.Settings("config/config.yaml", student=False) | ||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument("--logdir", default=None, help="Path to log directory") | ||
parser.add_argument("--modeldir", default=None, help="directory to store checkpoints and SavedModel") | ||
parser.add_argument("--verbose", "-v", action="count", default=0, help="Increases Verbosity. Repeat to increase more") | ||
FLAGS, unparsed = parser.parse_known_args() | ||
log_levels = [logging.WARNING, logging.INFO, logging.DEBUG] | ||
log_level = log_levels[min(FLAGS.verbose, len(log_levels)-1)] | ||
logging.set_verbosity(log_level) | ||
main(**vars(FLAGS)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1 @@ | ||
#! /bin/bash |