#!/usr/bin/env python
# -*- coding: utf-8 -*-
from __future__ import absolute_import, print_function

import absl.app
import optuna
import sys
import tensorflow.compat.v1 as tfv1

from deepspeech_training.evaluate import evaluate
from deepspeech_training.train import create_model
from deepspeech_training.util.config import Config, initialize_globals
from deepspeech_training.util.flags import create_flags, FLAGS
from deepspeech_training.util.logging import log_error
from deepspeech_training.util.evaluate_tools import wer_cer_batch
from ds_ctcdecoder import Scorer


def character_based():
    is_character_based = False
    if FLAGS.scorer_path:
        scorer = Scorer(FLAGS.lm_alpha, FLAGS.lm_beta, FLAGS.scorer_path, Config.alphabet)
        is_character_based = scorer.is_utf8_mode()
    return is_character_based

def objective(trial):
    FLAGS.lm_alpha = trial.suggest_uniform('lm_alpha', 0, FLAGS.lm_alpha_max)
    FLAGS.lm_beta = trial.suggest_uniform('lm_beta', 0, FLAGS.lm_beta_max)

    is_character_based = trial.study.user_attrs['is_character_based']

    samples = []
    for step, test_file in enumerate(FLAGS.test_files.split(',')):
        tfv1.reset_default_graph()

        current_samples = evaluate([test_file], create_model)
        samples += current_samples

        # Report intermediate objective value.
        wer, cer = wer_cer_batch(current_samples)
        trial.report(cer if is_character_based else wer, step)

        # Handle pruning based on the intermediate value.
        if trial.should_prune():
            raise optuna.exceptions.TrialPruned()

    wer, cer = wer_cer_batch(samples)
    return cer if is_character_based else wer

def main(_):
    initialize_globals()

    if not FLAGS.test_files:
        log_error('You need to specify what files to use for evaluation via '
                  'the --test_files flag.')
        sys.exit(1)

    is_character_based = character_based()

    study = optuna.create_study()
    study.set_user_attr("is_character_based", is_character_based)
    study.optimize(objective, n_jobs=1, n_trials=FLAGS.n_trials)
    print('Best params: lm_alpha={} and lm_beta={} with WER={}'.format(study.best_params['lm_alpha'],
                                                                       study.best_params['lm_beta'],
                                                                       study.best_value))


if __name__ == '__main__':
    create_flags()
    absl.app.run(main)