#!/usr/bin/env python
# encoding: utf-8
"""
"""

import os, re, time, sqlite3
import matplotlib.pyplot as plt
#from matplotlib.font_manager import FontProperties
from scipy import stats


BOLD = './db/_bold.dat'
PRED = './db/_pred.dat'

def get_levels(c, db, column):
    c.execute('SELECT DISTINCT "%s" FROM "%s"' % (column, db))
    crows = c.fetchall()
    levels = [ x[0] for x in crows ]
    return levels
    
    
def convert_to_trialtime (premise, premisetime):
    
    if premise == 1:
        trialtime = premisetime
    elif premise == 2:
        trialtime = premisetime + 4000
    elif premise == 3:
        trialtime = premisetime + 8000
    elif premise == 4:
        trialtime = premisetime + 12000
    elif premise == 5:
        trialtime = premisetime + 16000
    else:
        trialtime = 99999
    
    return trialtime

def convert_to_trialtime2 (trialtime):
    
    if trialtime <= 500:
        trialtime2 = 0
    elif trialtime > 500 and trialtime <= 1500:
        trialtime2 = 1000
    elif trialtime > 1500 and trialtime <= 2500:
        trialtime2 = 2000
    elif trialtime > 2500 and trialtime <= 3500:
        trialtime2 = 3000
    elif trialtime > 3500 and trialtime <= 4500:
        trialtime2 = 4000
    elif trialtime > 4500 and trialtime <= 5500:
        trialtime2 = 5000
    elif trialtime > 5500 and trialtime <= 6500:
        trialtime2 = 6000
    elif trialtime > 6500 and trialtime <= 7500:
        trialtime2 = 7000
    elif trialtime > 7500 and trialtime <= 8500:
        trialtime2 = 8000
    elif trialtime > 8500 and trialtime <= 9500:
        trialtime2 = 9000
    elif trialtime > 9500 and trialtime <= 10500:
        trialtime2 = 10000
    elif trialtime > 10500 and trialtime <= 11500:
        trialtime2 = 11000
    elif trialtime > 11500 and trialtime <= 12500:
        trialtime2 = 12000
    elif trialtime > 12500 and trialtime <= 13500:
        trialtime2 = 13000
    elif trialtime > 13500 and trialtime <= 14500:
        trialtime2 = 14000
    elif trialtime > 14500 and trialtime <= 15500:
        trialtime2 = 15000
    elif trialtime > 15500 and trialtime <= 16500:
        trialtime2 = 16000
    elif trialtime > 16500 and trialtime <= 17500:
        trialtime2 = 17000
    elif trialtime > 17500 and trialtime <= 18500:
        trialtime2 = 18000
    elif trialtime > 18500 and trialtime <= 19500:
        trialtime2 = 19000
    elif trialtime > 19500 and trialtime <= 20500:
        trialtime2 = 20000
    else:
        trialtime2 = 99999
        
    return trialtime2
    
    
def db_dump (c, filename, variables):
    
    f = open(filename, 'w')
    f.write('\t'.join([v for v in variables]) + '\n')
    for row in c:
        f.write('\t'.join([str(r) for r in row]) + '\n')
    f.close()
    print '"' + filename + '"' + ' written.'
    
    
def insert_data (filename, db, c):
    f = open(filename, 'r')
    ### Insert a row of data
    for lid, line in enumerate(f.readlines()[1:]):
        entries = str(lid) + '", "' + re.sub(r'\t+', '","', line).lstrip('","').strip()
        
        c.execute('INSERT INTO "%s" VALUES ("%s")' % (db, entries))
    
    ### close file
    f.close()
    
def insert_bold_data(bc, bconn):
    # Create table
    bc.execute('''CREATE TABLE IF NOT EXISTS bolddb
        (id INTEGER PRIMARY KEY,
        roi TEXT,
        premisetime INTEGER,
        intensity REAL,
        stderr REAL,
        tasktype TEXT,
        premise INTEGER,
        trialtime INTEGER,
        trialtime2 INTEGER,
        module TEXT,
        region TEXT,
        hemisphere TEXT,
        roidef TEXT,
        contrast TEXT,
        mapping TEXT)''')
        
    insert_data(BOLD, "bolddb", bc)
    
    ### Save (commit) the changes
    bconn.commit()
    
    bc.execute('SELECT premise, premisetime FROM bolddb')
    parameters = []
    for premise, premisetime in bc:
        #print premise, premisetime
        trialtime = convert_to_trialtime(premise, premisetime)
        parameters.append((trialtime, premise, premisetime))
    bc.executemany('UPDATE bolddb SET trialtime=? WHERE premise=? AND premisetime=?', parameters)
    bconn.commit()
    
    bc.execute('SELECT trialtime FROM bolddb')
    parameters = []
    for trialtime in bc:
        #print trialtime[0]
        trialtime2 = convert_to_trialtime2(trialtime[0])
        parameters.append((trialtime2, trialtime[0]))
    bc.executemany('UPDATE bolddb SET trialtime2=? WHERE trialtime=?', parameters)
    bconn.commit()
    
def insert_pred_data(pc, pconn):
    # Create table
    
    pc.execute('''CREATE TABLE IF NOT EXISTS preddb
        (id INTEGER PRIMARY KEY,
        cancelled_id TEXT,
        run INT,
        trial TEXT,
        time REAL,
        goal REAL,
        imaginal REAL,
        manual REAL,
        production REAL,
        retrieval REAL,
        visual REAL,
        visuallocation REAL,
        mintime REAL,
        trialtime INTEGER,
        trialtime2 INTEGER,
        model TEXT)''')
        
    insert_data(PRED, "preddb", pc)
    
    ### Save (commit) the changes
    pconn.commit()
    
    pc.execute('SELECT trialtime FROM preddb')
    parameters = []
    for trialtime in pc:
        trialtime2 = convert_to_trialtime2(trialtime[0])
        parameters.append((trialtime2, trialtime[0]))
    pc.executemany('UPDATE preddb SET trialtime2=? WHERE trialtime=?', parameters)
    pconn.commit()


def plot_bold_pred (bc, pc):
    
    ## 11.6 x 8.2 = DIN A4 (Landscape)
    #fig = plt.figure(figsize=(11.6, 8.2), dpi=150)
    
    ## 8.2 x 11.6 = DIN A4 (Portrait)
    #fig = plt.figure(figsize=(8.2, 11.6), dpi=150)
    
    fig = plt.figure(figsize=(8.2, 2.32), dpi=300)

    
    # adjust and show figure
    plt.subplots_adjust(left=.05, bottom=.275, right=.95, top=.925, wspace=.5, hspace=.5)
    
    pos = 1

    for correlation in [
            ('lipfc', 'right', 'actr','NA'),
            ###('lipfc', 'left', 'actr','NA'),        ###
            ###('pfc','right', 'fang', 'between'),    ###
            ###('pfc','left', 'fang', 'between'),     ###
            ('ppc', 'right', 'fang','between'),
            ###('ppc', 'right', 'actr','NA'),          ###
            ###('ppc', 'left', 'actr','NA'),           ###
            ('caudate', 'right', 'actr','NA'),
            ###('caudate', 'left', 'actr','NA'),       ###
            ###('caudate', 'right', 'fang','between'), ###
            ###('caudate', 'left', 'fang','between'),  ###
            ('apfc','right', 'fang', 'between'),
            ###('acc','right', 'fang', 'between'),     ###
            ###('acc','right', 'actr', 'NA'),          ###
            ###('acc','left', 'actr', 'NA'),           ###
            ###('fusiform', 'right', 'actr', 'NA'),    ###
            ###('fusiform', 'left', 'actr', 'NA'),      ###
            ('motor', 'right', 'actr', 'NA'),         ###
            ('motor', 'left', 'actr', 'NA')         ###
            ]:
        region, hemisphere, roidef, contrast = correlation
        
        if region in ['acc', 'apfc']:
            module = 'goal'
        elif region in ['pfc'] and hemisphere == 'right':
            module = 'retrieval'
        elif region in ['pfc'] and hemisphere == 'left':
            module = 'retrieval'
        elif region in ['lipfc']:
            module = 'retrieval'
        elif region in ['caudate']:
            module = 'production'
        elif region in ['fusiform', 'otc']:
            module = 'visual'
        elif region in ['motor']:
            module = 'manual'
        elif region in['ppc']:
            module = 'imaginal'
        
        pc.execute('''SELECT AVG("%s"), model FROM preddb
            WHERE trialtime2 < 12000
            GROUP BY trialtime2
            ORDER BY trialtime2''' % module)
        
        bc.execute('''SELECT AVG(intensity), AVG(STDERR), trialtime2, module, region, mapping
            FROM bolddb
            WHERE tasktype="R"
                AND module="%s"
                AND region="%s"
                AND hemisphere="%s"
                AND roidef="%s"
                AND contrast="%s"
                AND premise <= 3
            GROUP BY roi, trialtime2, tasktype
            ORDER BY trialtime2''' % (module, region, hemisphere, roidef, contrast))
        
        bcrows = bc.fetchall()
        pcrows = pc.fetchall()
        
        if len(bcrows) > 0:
            
            mapping = bcrows[0][-1]
            model = pcrows[0][-1]
            
            #########################################################
            """
            ax1 = fig.add_subplot(1, 4, pos)
            print pos,
            pos = pos + 1
            """
            #########################################################
            
            mean_intensities   = [ x[0] for x in bcrows ]
            mean_stderr        = [ x[1] for x in bcrows ]
            mean_predictions   = [ x[0] for x in pcrows ]
            
            zi   = map(None, stats.zscore(mean_intensities))
            zse  = map(None, stats.zscore(mean_stderr))
            zp   = map(None, stats.zscore(mean_predictions))
            
            r, p = stats.pearsonr(mean_intensities, mean_predictions)
            
            print module, region, hemisphere, contrast, roidef, ": ", "r=%.2f, p=%.3f" % (r, p)
            
            if r < 0:
                rr = "-" + str(round(r, 2))[2:]
            else:
                rr = str(round(r, 2))[1:]
                
            if p < 0.001:
                result = 'r=' + rr + ', p<.001'
            elif p < 0.01:
                result = 'r=' + rr + ', p<.01'
            elif p < 0.05:
                result = 'r=' + rr + ', p<.05'
            else:
                result = 'r=' + rr + ', p=' + str(round(p, 2))[1:]

            x = range(0, 12000, 1000)
            
            if roidef == 'actr':
                roidef = 'ACT-R'
            elif roidef == 'fang':
                roidef = 'Fangmeier'
            
            if region == 'lipfc':
                region = 'LIPFC'
            elif region == 'ppc':
                region = 'PPC'
            elif region == 'apfc':
                region = 'APFC'
            elif region == 'caudate':
                region = 'Caudate'
            elif region == 'fusiform':
                region = 'Fusiform'



            
            if module == 'imaginal':
                module = 'Imaginal'
            elif module == 'retrieval':
                module = 'Retrieval'
            elif module == 'goal':
                module = 'Goal'
            elif module == 'production':
                module = 'Production'
            elif module == 'visual':
                module = 'Visual'
            
            #########################################################
            """
            fs = 6 # fontsize
            ax1.plot(x, zi, label= 'Human', marker='.', ls='-', color='k')
            ax1.errorbar(x, zi, yerr=zse, label= 'Human', marker='.', fmt=None, ecolor='#b3b3b3', capsize=3)
            ax1.plot(x, zp, label= 'Model', marker='.', ls='-', color='r')
            plt.title(module + '/' + region + ' ' + '(' +hemisphere + '), ' + roidef + '', fontsize=fs)
            plt.xticks(fontsize=fs, rotation=90)
            plt.yticks(fontsize=fs)
            plt.xlabel('Time (ms)', fontsize=fs)
            plt.ylabel('BOLD', fontsize=fs)
            plt.ylim(-4, 4),
            plt.xlim(0, 11000),
            plt.text(
            0.5, -0.4,
            result,
            ha='center',
            va='bottom',
            transform = ax1.transAxes,
            backgroundcolor='w',
            color='k',
            fontsize=fs+3,
            style='italic')
    #plt.gcf().text( 0.5, 0.95, model, horizontalalignment='center', fontproperties=FontProperties(size=6))
    
    #outfilename = time.strftime("%Y-%m-%d-%H%M%S")
    outfilename = time.strftime(model + "-%Y-%m-%d-%H%M%S")
    plt.savefig('./pdf/' + outfilename + '.pdf', format='pdf')
    """
    #########################################################



def main():

    bdbexists = os.path.exists('./db/_bolddat.sqlite')
    pdbexists = os.path.exists('./db/_preddat.sqlite')
    
    bconn = sqlite3.connect('./db/_bolddat.sqlite')
    pconn = sqlite3.connect('./db/_preddat.sqlite')
    
    bc = bconn.cursor()
    pc = pconn.cursor()
    
    if not bdbexists:
        insert_bold_data(bc, bconn)
    if not pdbexists:
        insert_pred_data(pc, pconn)
    
    bc.execute('SELECT * FROM bolddb')
    pc.execute('SELECT * FROM preddb')
    
    
    db_dump(bc, './db/_bold.dat', ['ROI', 'PREMISETIME', 'INTENSITY', 'STDERR', 'TASKTYPE', 'PREMISE', 'TRIALTIME', 'TRIALTIME2', 'MODULE', 'REGION', 'HEMISPHERE', 'ROIDEF', 'CONTRAST','MAPPING'])
    db_dump(pc, './db/_pred.dat', ['ID', 'ID2', 'RUN', 'TRIAL', 'TIME', 'GOAL', 'IMAGINAL', 'MANUAL', 'PRODUCTION', 'RETRIEVAL', 'VISUAL', 'VISUALLOCATION', 'MINTIME', 'TRIALTIME', 'TRIALTIME2', 'MODEL'])

    plot_bold_pred(bc, pc)
    
    
    #close cursor
    bc.close()
    pc.close()

if __name__ == '__main__':
    main()

