Skip to content

kento1109/RCGAN

Repository files navigation

RCGAN

ラベル情報を考慮した時系列データをGANにより生成する。

学習

2019/01/16 修正(対象ファイルを引数で指定)

python experiment.py -inputs "inputs/sin_wave.npz"

デフォルト(引数を省略した)の場合、 inputs/sin_wave.npz のデータを入力とする。
*デフォルトを変更したい場合、experiment.pyの FILE_NAME を修正する。

入力の npz の形式は以下のような保存形式を想定

ndarr_x = np.array([0.0, 0.1], [0.2, 0.3], [0.4, 0.5], [0.6, 0.7]])  # original data
ndarr_y = np.array([0, 1, 2, 3])  # label

np.savez('test.npz', x=ndarr_x, y=ndarr_y)

rotMNIST形式の変換

SyntheticMedDataを直接、GANには読み込めないので、プログラムにより上記のような入力形式に変換する。

変換プログラムの実行方法

cd inputs
python make_rotMNIST.py

rotMNIST.zip (rotMNISTから300サンプルを抽出した圧縮ファイル)を解凍することで、動作確認できる。

2019/01/16 追加
rotMNIST.zip (rotMNISTから300サンプルを抽出した圧縮ファイル)の場合、実験時にエラーが発生する。
※rotMNISTはサイズが大きいので、少ないサンプル数しかgithub上に置けないが、その場合はサンプル数が少なすぎてエラーが発生する。
rotMNIST を生成したい場合、inputs/build_rotMNIST.pyを実行することで同様のサンプルが得られる。

実行時に必要なパラメータ

  • FILNAME :保存先
  • SAMPLES :読み込むサンプル数(サンプル数が大きくメモリに載らない場合、データセットを複数作成する。)
  • MAX_SEQ :データの系列の長さ(GANではサンプルの系列長を同じ長さに揃える必要がある。)
  • INPUT_DIM :データの入力次元数

※欠損値(NULL)があると、GANは勾配計算が正しくできないので、欠損値は事前に0埋めなどをしておく必要がある。

rotMNISTの生成結果

  • オリジナル画像

alt tag

  • RCGANにより生成された画像

alt tag

データ生成

学習したモデルを利用してサンプルデータを生成

 python generate_sample.py -n 500

-n は生成するサンプルデータの数
※保存先や学習したモデルの指定などは、generate_sample.py のパラメータを修正する。

評価(TSTR)

Train on synthetic, test on real 生成したデータで学習を行い、実データで評価を行う。

 python tstr.py

※学習に使用する入力データなどは、tstr.py のパラメータを修正する。

実験環境

keras (2.2.4)
tensorflow (1.8.0)

参考論文

参考リポジトリ

その他参考サイト

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published