# -*- coding: utf-8 -*-

import math
import numpy
import numpy.linalg
import numpy.random

from wishart import Wishart
from gaussian import Gaussian
from student_t import StudentT



class GaussianPrior:
  """The conjugate prior for the multivariate Gaussian distribution. Maintains the 4 values and supports various operations of interest - initialisation of prior, Bayesian update, drawing a Gaussian and calculating the probability of a data point comming from a Gaussian drawn from the distribution. Not a particularly efficient implimentation, and it has no numerical protection against large data sets."""
  def __init__(self, dims):
    """Initialises with everything zeroed out, such that a prior must added before anything interesting is done."""
    self.invShape = numpy.zeros((dims,dims), dtype=numpy.float32) # The inverse of lambda in the equations.
    self.shape = None # Cached value - inverse is considered primary.
    self.mu = numpy.zeros(dims, dtype=numpy.float32)
    self.n = 0.0
    self.k = 0.0

  def addPrior(self, mean, covariance, weight = None):
    """Adds a prior to the structure, as an estimate of the mean and covariance matrix with a weight which can be interpreted as how many samples that estimate is worth. Note the use of 'add' - you can call this after adding actual samples, or repeatedly. If weight itunate ths omitted it defaults to the number of dimensions, as the total weight in the system must match or excede this value before draws etc can be done."""
    if weight==None: weight = float(self.mu.shape[0])
    delta = mean - self.mu
    
    self.invShape += weight * covariance # *weight converts to a scatter matrix.
    self.invShape += ((self.k*weight)/(self.k+weight)) * numpy.outer(delta,delta)
    self.shape = None
    self.mu += (weight/(self.k+weight)) * delta
    self.n += weight
    self.k += weight

  def addSample(self, sample):
    """Updates the prior given a single sample drawn from the Gaussian being estimated."""
    sample = numpy.asarray(sample, dtype=numpy.float32)
    delta = sample - self.mu

    self.invShape += (self.k/(self.k+1.0)) * numpy.outer(delta,delta)
    self.shape = None
    self.mu += delta / (self.k+1.0)
    self.n += 1.0
    self.k += 1.0

  def addSamples(self, samples):
    """Updates the prior given multiple samples drawn from the Gaussian being estimated. Expects a data matrix ([sample,position in sample]), or an object that numpy.asarray will interpret as such. Note that if you have only a few samples it might be faster to repeatedly call addsample, as this is designed to be efficient for hundreds+ of samples."""
    samples = numpy.asarray(samples, dtype=numpy.float32)

    # Calculate the mean and scatter matrices...
    d = self.mu.shape[0]
    num = samples.shape[0]
    
    mean = numpy.zeros(d,dtype=numpy.float32)
    scatter = numpy.zeros((d,d),dtype=numpy.float32)

    for i in xrange(num):
      mean += (samples[i,:]-mean) / float(i+1)
    for i in xrange(num):
      delta = samples[i,:]-mean
      scatter += numpy.outer(delta,delta)

    # Update parameters...
    delta = mean-self.mu
    
    self.invShape += scatter
    self.invShape += ((self.k*num)/(self.k+num)) * numpy.outer(delta,delta)
    self.shape = None
    self.mu += (num/(self.k+num)) * delta
    self.n += num
    self.k += num


  def getN(self):
    return self.n

  def getK(self):
    return self.k

  def getMu(self):
    return self.mu

  def getLambda(self):
    if self.shape==None:
      self.shape = numpy.linalg.inv(self.invShape)
    return self.shape

  def getInverseLambda(self):
    return self.invShape


  def safe(self):
    """Returns true if it is possible to sample the prior, work out the probability of samples or work out the probability of samples being drawn from a collapsed sample - basically a test that there is enough information."""
    return self.n>self.mu.shape[0] and self.k>0.0


  def prob(self, gauss):
    """Returns the probability of drawing the provided Gaussian from this prior."""
    d = self.mu.shape[0]
    wishart = Wishart(d)
    gaussian = Gaussian(d)
    
    wishart.setDof(self.n)
    wishart.setScale(self.getLambda())
    gaussian.setMean(self.mu)
    gaussian.setPrecision(self.k*gauss.getPrecision())

    return wishart.prob(gauss.getPrecision()) * gaussian.prob(gauss.getMean())

  def intProb(self):
    """Returns a multivariate student-t distribution object that gives the probability of drawing a sample from a Gaussian drawn from this prior, with the Gaussian integrated out. You may then call the prob method of this object on each sample obtained."""
    d = self.mu.shape[0]
    st = StudentT(d)

    dof = self.n-d+1.0
    st.setDOF(dof)
    st.setLoc(self.mu)
    mult = self.k*dof / (self.k+1.0)
    st.setInvScale(mult * self.getLambda())

    return st

  def sample(self):
    """Returns a Gaussian, drawn from this prior."""
    d = self.mu.shape[0]
    wishart = Wishart(d)
    gaussian = Gaussian(d)
    ret = Gaussian(d)

    wishart.setDof(self.n)
    wishart.setScale(self.getLambda())
    ret.setPrecision(wishart.sample())

    gaussian.setPrecision(self.k*ret.getPrecision())
    gaussian.setMean(self.mu)
    ret.setMean(gaussian.sample())

    return ret
