Skip to content

Latest commit

 

History

History
119 lines (84 loc) · 3.93 KB

README.md

File metadata and controls

119 lines (84 loc) · 3.93 KB

RCGAN

ラベル情報を考慮した時系列データをGANにより生成する。

学習

2019/01/16 修正(対象ファイルを引数で指定)

python experiment.py -inputs "inputs/sin_wave.npz"

デフォルト(引数を省略した)の場合、 inputs/sin_wave.npz のデータを入力とする。
*デフォルトを変更したい場合、experiment.pyの FILE_NAME を修正する。

入力の npz の形式は以下のような保存形式を想定

ndarr_x = np.array([0.0, 0.1], [0.2, 0.3], [0.4, 0.5], [0.6, 0.7]])  # original data
ndarr_y = np.array([0, 1, 2, 3])  # label

np.savez('test.npz', x=ndarr_x, y=ndarr_y)

rotMNIST形式の変換

SyntheticMedDataを直接、GANには読み込めないので、プログラムにより上記のような入力形式に変換する。

変換プログラムの実行方法

cd inputs
python make_rotMNIST.py

rotMNIST.zip (rotMNISTから300サンプルを抽出した圧縮ファイル)を解凍することで、動作確認できる。

2019/01/16 追加
rotMNIST.zip (rotMNISTから300サンプルを抽出した圧縮ファイル)の場合、実験時にエラーが発生する。
※rotMNISTはサイズが大きいので、少ないサンプル数しかgithub上に置けないが、その場合はサンプル数が少なすぎてエラーが発生する。
rotMNIST を生成したい場合、inputs/build_rotMNIST.pyを実行することで同様のサンプルが得られる。

実行時に必要なパラメータ

  • FILNAME :保存先
  • SAMPLES :読み込むサンプル数(サンプル数が大きくメモリに載らない場合、データセットを複数作成する。)
  • MAX_SEQ :データの系列の長さ(GANではサンプルの系列長を同じ長さに揃える必要がある。)
  • INPUT_DIM :データの入力次元数

※欠損値(NULL)があると、GANは勾配計算が正しくできないので、欠損値は事前に0埋めなどをしておく必要がある。

rotMNISTの生成結果

  • オリジナル画像

alt tag

  • RCGANにより生成された画像

alt tag

データ生成

学習したモデルを利用してサンプルデータを生成

 python generate_sample.py -n 500

-n は生成するサンプルデータの数
※保存先や学習したモデルの指定などは、generate_sample.py のパラメータを修正する。

評価(TSTR)

Train on synthetic, test on real 生成したデータで学習を行い、実データで評価を行う。

 python tstr.py

※学習に使用する入力データなどは、tstr.py のパラメータを修正する。

実験環境

keras (2.2.4)
tensorflow (1.8.0)

参考論文

参考リポジトリ

その他参考サイト