#!/usr/bin/env python3

""" This script implements a command line interface for RMES-based MPT fitting.

| Copyright 2017 Cognitive Computation Lab
| University of Freiburg
| Nicolas Riesterer <riestern@tf.uni-freiburg.de>

"""

import argparse
import time

import numpy as np
import pandas as pd

from mpt import cli
from mpt import likelihood as cl
from mpt import optimize as co
from mpt import parser as cp


def main():
    """ Script entry point. Parses the command line arguments and fits the
    model in accordance to the preferences specified by the user.

    """

    # Parse the command line arguments
    args = cli.parse_commandlineargs()

    # Load the data
    data_raw = pd.read_csv(args['data_filepath'], sep=args['sep'])
    n_datasets = len(data_raw)
    dat = data_raw.values.sum(axis=0)

    # Load the model
    subtrees, params = cp.parse(args['model_filepath'])
    cat_formulae = [f for tree in subtrees for f in tree]

    # Fit the model
    start = time.time()
    res, _ = co.fit_classical(
        fun=co.optim_rmse,
        cat_formulae=cat_formulae,
        param_names=params,
        data=dat,
        n_optim=args['n_optim'])

    # Compute the correct criteria (without ignoring factorials)
    ass = dict(zip(params, res.x))
    llik = cl.log_likelihood(cat_formulae, ass, dat, ignore_factorials=False)
    aic = -2 * llik + 2 * len(params)
    bic = -2 * llik + np.log(dat.sum()) * len(params)
    rmse = res.fun

    # Print the result
    print('Fitting done ({:.2f}s, n_optim={})'.format(
        time.time() - start, args['n_optim']))
    print('  Objective: RMSE on agg. data')
    print('  Model:', args['model_filepath'])
    print('  n_params:', len(params))
    print('  Dataset:', args['data_filepath'])
    print('  n_datasets:', n_datasets)
    print()
    print('  RMSE (on agg.):', rmse)
    print('  LogLik:', llik)
    print('  AIC:', aic)
    print('  BIC:', bic)
    print()
    print('Final assignment:')
    for param in sorted(params):
        print('  {} = {}'.format(param, ass[param]))
    print()
    print('Data Vector:')
    print(' ', dat)

if __name__ == '__main__':
    main()
