import argparse
import sys
import os

from smnist_base import NextFrameSmnist, sminst_input_func, smnist_model_fn, get_feature_columns
from smnist_params import smnist_hparams

import tensorflow as tf
import numpy as np

parser = argparse.ArgumentParser(description= "Command to train smnist")
parser.add_argument("--num_iterations", default = 180000,type=int, help ="The number of training steps")
parser.add_argument("--model_dir", default = "model_dir/", help = "The director to store summaries and checkpoints")
parser.add_argument("--hparams", default = "", type = str, help = "Hparams to override")
#parsing the arguments
args = parser.parse_args()
num_iterations = args.num_iterations
model_dir = args.model_dir

#getting model params
hparams = smnist_hparams()
hparams.parse(args.hparams)
hparams.add_hparam("model_dir",args.model_dir)
feature_columns = get_feature_columns(hparams)

tf.logging.set_verbosity(tf.logging.INFO)
strategy = tf.contrib.distribute.MirroredStrategy()
run_config = tf.estimator.RunConfig(model_dir = model_dir, train_distribute = strategy,  save_checkpoints_steps = 4000, save_summary_steps = 200)
estimator = tf.estimator.Estimator(smnist_model_fn, model_dir=model_dir,\
    config = run_config, params = {"hparams" : hparams})

train_spec = tf.estimator.TrainSpec(input_fn = lambda : sminst_input_func(hparams, tf.estimator.ModeKeys.TRAIN), max_steps = num_iterations)
eval_spec  = tf.estimator.EvalSpec(input_fn = lambda : sminst_input_func(hparams, tf.estimator.ModeKeys.TRAIN))

tf.estimator.train_and_evaluate(estimator, train_spec, eval_spec)