import numpy as np
from  scipy.stats import norm, rankdata
from scipy.special import logsumexp
from tqdm import tqdm
import os
from glob import glob

### directory containing all your EOS (and possibly other stuff)
old_eos_dir='[YOUR_DIR]'

### directory into which the ordered and renamed EoS are stored. 
##new_eos_dir = 'yourotherdir'

def eos_check(eos_file):
    try:
        data= np.loadtxt(eos_file).T
        data=np.atleast_2d(data)
    except ValueError:
        return ': does not contain float-only columns of equal length, therefore omitted.'

    if len(data)<3:
        return ':  contains less than three columns, therefore omitted'
    
    else:
        radius, mass, lambdat =data[:3,:]
        
        if mass_wrong(mass):
            if mass_wrong(radius):
                if mass_wrong(lambdat):
                    return ': Masses need to be provided in solar masses, \
                     but there is no plausible TOV-mass among the first three columns,\
                     therefore omitted.'
                    
                else:
                    temp=mass
                    mass=lambdat
                    lambdat=temp
                if radius_wrong(radius):
                    if radius_wrong(lambdat):
                        return ': Radii need to be provided in km, \
                    but there is no plausible TOV-radius among the first three columns,\
                     therefore omitted.'
                    elif lambdat_wrong(radius):
                        return ': There is no plausible dimensionless tidal deformability \
                            at TOV-limitamong the first three columns, therefore omitted.'
                    else:
                        temp=radius
                        radius=lambdat
                        lambdat=temp
            else:
                temp=mass
                mass=radius
                radius=temp
                if radius_wrong(radius):
                    if radius_wrong(lambdat):
                        return ': Radii need to be provided in km, \
                    but there is no plausible TOV-radius among the first three columns,\
                     therefore omitted.'
                    elif lambdat_wrong(radius):
                        return ': There is no plausible dimensionless tidal deformability \
                            at TOV-limitamong the first three columns, therefore omitted.'
                    else:
                        temp=radius
                        radius=lambdat
                        lambdat=temp
                elif lambdat_wrong(lambdat):
                    return ': There is no plausible dimensionless tidal deformability \
                        at TOV-limitamong the first three columns, therefore omitted.'
        elif radius_wrong(radius):
            if radius_wrong(lambdat):
                return ': Radii need to be provided in km, \
            but there is no plausible TOV-radius among the first three columns,\
                therefore omitted.'
            elif lambdat_wrong(radius):
                return ': There is no plausible dimensionless tidal deformability \
                    at TOV-limitamong the first three columns, therefore omitted.'
            else:
                temp=radius
                radius=lambdat
                lambdat=temp
        elif lambdat_wrong(lambdat):
            return ': There is no plausible dimensionless tidal deformability \
                at TOV-limitamong the first three columns, therefore omitted.'

    sorted_idx=mass.argsort()

    radius, mass, lambdat=radius[sorted_idx], mass[sorted_idx], lambdat[sorted_idx]
    return (radius, mass, lambdat)

def mass_wrong(mass):
    mass=np.sort(mass)
    if 1>mass[-1] or 3<mass[-1]:
        return True
    else:
        return False

def radius_wrong(radius):
    radius=np.sort(radius)
    if 5>radius[0] or 20<radius[0]:
        return True
    else:
        return False

def lambdat_wrong(lambdat):
    lambdat=np.sort(lambdat)
    if lambdat[0]>300:
        return True
    else:
        return False


def check_and_collect_eos(eos_file, new_eos_dir, check=True):
    '''Verify whether an EoS-file can be handled by our bilby-patches and try to fix it.
    
    Parameters
    ==========
    eos_file: str
        The input-file

    Returns
    =======
    out_file: str
        A suitable EoS-file

    weight_param: float / None
        The parameters used to obtain EoS weights. None if check failed
    '''
    
    if check:
        check_result= eos_check(eos_file)
        ###Bad check will indicate error
        if type(check_result)==str:
            print (eos_file, check_result)
            return eos_file, None
        
        else:
            radius, mass, lambdat = check_result

    
    else:
        radius, mass, lambdat = np.loadtxt(eos_file, unpack=True, usecols=[0, 1, 2])       
    
    weight_param=get_weight_param(radius, mass, lambdat) 
        ###save checked EoS data to temporary file, will be renamed in sorting step
    save_p=os.path.join(new_eos_dir, 'mod'+eos_file)
    np.savetxt(save_p, np.c_[radius, mass, lambdat], delimiter='\t', fmt='%.6g' ) 
    return eos_file, weight_param

def get_weight_param(radius, mass, lambdat):
    '''Generic function to prepare weights. Currently only uses TOV-mass.'''
    return mass[-1] 
    

def get_logweight(weight_params):
    '''Function to ascribe weights to EoS.
    --- Feel free to edit as you need ---
    Currently only evaluates agreement of TOV-limit with known high-mass pulsars.

    Parameters
    ----------
    weight_params: float
        Highest mass that is supported by EoS

    Returns
    -------
    logweight: float
        ascribed weight in logscale    
    '''
    tovmass= np.array(weight_params)
    PSR_J0740_6220_mass = 2.08
    PSR_J0740_6220_std = 0.07

    PSR_J0348_4032_mass = 2.01
    PSR_J0348_4032_std = 0.04

    PSR_1614_2230_mass = 1.908
    PSR_1614_2230_std = 0.016

    PSR_J0740_6220_logcdf = norm.logcdf(tovmass, loc=PSR_J0740_6220_mass, scale=PSR_J0740_6220_std)
    PSR_J0348_4032_logcdf = norm.logcdf(tovmass, loc=PSR_J0348_4032_mass, scale=PSR_J0348_4032_std)
    PSR_1614_2230_logcdf = norm.logcdf(tovmass, loc=PSR_1614_2230_mass, scale=PSR_1614_2230_std)

    return PSR_J0740_6220_logcdf + PSR_J0348_4032_logcdf + PSR_1614_2230_logcdf 

def sort_and_rename_eos(eos_list, weight_array, eos_dir, sort_by_param=False):
    '''Rename EoS so bilby can handle them.
    See https://enlil.gw.physik.uni-potsdam.de/dokuwiki/doku.php?id=pbilby_gen for more explanations.

    Parameters
    ----------
    eos_list: list or 1-D array
        Contains names of EoS to be sorted
    
    weight_array: List or 1-D array
        Must be same size as eos_list. Contains the weights by which we must order
    
    eos_dir: str
        directory where EoS are stored

    sort_by_param: Bool or 1-D array, optional
        Default is False as weights must not change in monotony. If, however, \
        weights are repetitive, it can be advantageous to sort by some other parameter.

    Returns
    -------
    Nothing is returned, but generates bilby-readable Eos. Moreover, it generates a file 'eos_weights.dat' containing\
        the sorted weights and 'new_eos_names.txt' to assist the identification of filenames.
    '''
    if type(sort_by_param)==bool:
        ranks=rankdata(weight_array, method='ordinal').astype(int)
    else:
        ranks=rankdata(sort_by_param, method='ordinal').astype(int)
    weight_array=np.sort(weight_array)
    n_eos= len(ranks)

    print('Sort and rename {} identified EoS: '.format(n_eos))
    for i in tqdm(range(n_eos)):
        os.system('mv {0}/mod{1} {0}/{2}.dat'.format(eos_dir, eos_list[i], ranks[i]))
        
    ranks=np.char.add(ranks.astype(str), '.dat')
    np.savetxt('{}/new_eos_names.txt'.format(eos_dir), np.c_[eos_list, ranks], fmt='%s',delimiter='\t', header= 'old_filename \t new_filename')
    np.savetxt('{}/eos_weights.dat'.format(eos_dir), np.c_[weight_array])



def main(old_eos_dir, new_eos_dir='sorted_eos', check=True, zip=True):
    
    os.chdir(old_eos_dir)
    ##Make sure 
    if not os.path.isdir(new_eos_dir):
        os.mkdir(new_eos_dir)
    good_eos=[]
    weight_params=[]

    ###Iterate through directory
    print('Check {} for EoS-files'.format(old_eos_dir))
    for eos_file in tqdm(glob('*')):
        ##Skip if no file
        if not os.path.isfile(eos_file):
            continue
        
        ###Check if file can be treated as eos
        eos_file, weight_par=check_and_collect_eos(eos_file,new_eos_dir, check)
        if weight_par==None:
            continue
        
        else:
            good_eos.append(eos_file)
            weight_params.append(weight_par)
    

    logweights= get_logweight(weight_params)
    logweights -= logsumexp(logweights)
    weight = np.exp(logweights)

    sort_and_rename_eos(good_eos, weight, new_eos_dir, sort_by_param=weight_params) 
    if zip:
        os.system('zip -r {0}.zip {0}'.format(new_eos_dir))

main(old_eos_dir, zip=True)




