"""Analyse bayesienne de melanges de lois"""
import sys
import os
import scipy, numpy
cwd = os.getcwd()
sys.path.insert(0, cwd+os.sep+".."+os.sep)


def z_posteriori(x, theta, prob_cond):
    """
    Calcul de la loi a posteriori des etats
    sachant les observations et les parametres d'emission
    
    :Parameters:
        `x (list)` - liste des donnees
        `theta (list)` - liste des parametres d'emission
        `prob_cond (lambda x, theta:)` - fonction de densite conditionnelle
            lambda x: prob_cond(x, theta) = p(x|theta)

    :Returns:
        Retourne un tableau de probas conditionnelles : 
        Ligne i = observations x[i]
        Colonne k = etats k            
    """
    K = len(theta)  # Nombre d'etats
    n = len(x)  # Taille d'echantillon
    post = scipy.zeros((n, K))
    # Matrice des probas conditionnelles : 
    # Lignes = observations x
    # Colonnes = etats
    for k in range(K):
        post[:, k] = prob_cond(x, theta[k])
    # Normalisation des lignes par leur somme
    vect = post.sum(1)
    post = post / vect[:, None]
    return(post)


def gibbs(nit, x, lambda0, prob_cond, x_sim, rtheta_post, hyper_param):
    """
    Calcul de la loi a posteriori des parametres d'emission
    sachant les observations
    
    :Parameters:
        `nit (int)` - nombre d'iterations
        `x (list)` - donnees (liste ou format type scipy.array)
        `lambda0 (list)` - liste des parametres d'emission initiaux
        `prob_cond (lambda x, theta:)` - fonction de densite conditionnelle
            lambda x: prob_cond(x, theta) = p(x|theta)
        `x_sim (lambda n, theta:)` - fonction de simulation d'un echantillon
            de taille n suivant la loi p(x|theta)
        `rthetat_post (lambda x, hyper_param:)` - fonction de simulation
            suivant la loi a posteriori des parametres (sachant x et les hyper-
            parametres
        `hyper_param (list)` - liste des hyper-parametres

    :Returns:
        Retourne un quintuplet  (z, xpred, tri_prop, lambdap, tri) avec:
        `z` (scipy.array) - valeurs des etats simulees
            (lignes: iterations, colonnes: individus)
        `xpred` (scipy.array) - simulations suivant la loi predictive a
            posteriori (meme format que x)
        `tri_prop` (scipy.array) - frequence des etats au cours des iterations
            (lignes: iterations, colonnes: etats)
        `lambdap` (liste) - parametres simules au cours des iterations
        `tri`  (scipy.array) - permutations des etats realisees au cours
            des iterations pour conserver l'ordre des parametres
            (lignes: iterations, colonnes: etats)

    :Remark:
        - Les etats sont numerotes de 0 a len(lambda0)-1
        - x peut etre une liste de liste dans le cas multidimensionnel, dans 
        ce cas il doit etre formatte comme un scipy.array: les individus sont
        associes aux elements de la liste englobante, les dimensions a ceux
        des listes englobees.

    :SeeAlso:
        z_posteriori
        
    """
    dx = scipy.array(x)
    K = len(lambda0)  # nombre d'etats
    n = len(x)  # taille d'echantillon
    xpred = []  # tirages suivant la loi predictive
    propnaive = scipy.array([]) # 
    tri = []  # permutations des index des clusters permettant d'avoir
            # des parametres tries par ordre lexicographique (permutation 
            # des labels
    tri_prop = []  # liste des proportions estimees
    lambdac = [[]] * K  # parametres a l'iteration courante
    lambdap = []  # liste des parametres simules
    z = []  # liste des clusters simules
    sx = scipy.array([0] * K) # nombre de points dans chaque etat
    z0 = scipy.array([-1.] * n) # cluster de chaque point
    # Tirage des etats initiaux independamment suivant
    # une loi uniforme (pas de classe vide)
    while (sx.prod() == 0):
        z0 = scipy.random.choice(K,n,p=scipy.array([1.]*K) / K)
        sx = scipy.bincount(z0)
    # Gibbs sampling pour melanges
    for it in range(nit):
        if (it > 0) and (it % 500) == 0:
            print("Iteration courante : ", it+1)
        z0 = scipy.array([-1.] * n)
        zpost = z_posteriori(dx, lambda0, prob_cond)
        # On ne veut pas de syk = 0
        sx = scipy.array([0] * K) 
        while (sx.prod() == 0):
            # simulation de p(z|y,theta)
            z0 = list(map(lambda prob: scipy.random.choice(K,1,p=prob)[0], 
                          zpost))
            z0 = scipy.array(z0)
            sx = scipy.bincount(z0)
        # simulation de p(theta|x,z)        
        for k in range(K):
            lambdac[k] = rtheta_post(dx[z0==k], hyper_param)
        lambdaca = scipy.array(lambdac)
        # estimation naive des proportions du melange 
        # (pas de loi a priori de type Dirichlet ici)
        tpropnaive = sx / sum(sx)
        # Tri de lambda0 par ordre lexicographique
        # On triera a la fin seulement
        s = scipy.lexsort(lambdaca[:,::-1].transpose())
        if it == 0:
            tri = s
            tri_prop = list(tpropnaive)
        else:
            tri = scipy.vstack((tri, s))
            tri_prop =  scipy.vstack((tri_prop, tpropnaive))
        # simulation de xpred (loi predictive):
        # simulation du z de xpred
        zpred = scipy.random.choice(K, 1, p=tpropnaive)
        assert(len(zpred)==1)
        zpred = zpred[0]
        if it == 0:
            xpred = x_sim(1, lambdac[zpred])
            z = z0
        else:
            xpred = scipy.vstack((xpred, x_sim(1, lambdac[zpred])))
            z = scipy.vstack((z, z0))
        lambdap += [list(lambdac)]
    # Reordonne les etiquettes de classe (label-switching)
    for it in range(nit):
        permm = scipy.zeros((K, K), dtype=numpy.int)  # matrice de permutation des parametres
        for j in range(K):
            permm[tri[it, j], j] = 1
        perml = lambda s: tri[it, s]
        z[it, :] = list(map(perml, z[it,:]))
        # for j in range(n):
        #    z[it, j] = tri[it, z[it, j]]
        tri_prop[it] = permm.dot(tri_prop[it])
        lambdac = list(lambdap[it])
        for j in range(K):
            lambdap[it][tri[it, j]] = lambdac[j]
    return(z, xpred, tri_prop, lambdap, tri)

