# -*- coding: utf-8 -*-

import math
import numpy
import numpy.linalg
import scipy.special



class StudentT:
  """A feature incomplete multivariate student-t distribution object - at this time just supports calculating the probability of a sample."""
  def __init__(self, dims):
    """dims is the number of dimensions - initalises it to default values with the degrees of freedom set to 1, the location as the zero vector and the identity matrix for the scale."""
    self.dof = 1.0
    self.loc = numpy.zeros(dims, dtype=numpy.float32)
    self.scale = numpy.identity(dims, dtype=numpy.float32)
    self.invScale = None
    self.norm = None

  def setDOF(self, dof):
    self.dof = dof
    self.norm = None

  def setLoc(self, loc):
    l = numpy.array(loc, dtype=numpy.float32)
    assert(l.shape==self.loc.shape)
    self.loc = l

  def setScale(self, scale):
    s = numpy.array(scale, dtype=numpy.float32)
    assert(s.shape==(self.loc.shape[0],self.loc.shape[0]))
    self.scale = s
    self.invScale = None
    self.norm = None

  def setInvScale(self, invScale):
    i = numpy.array(invScale, dtype=numpy.float32)
    assert(i.shape==(self.loc.shape[0],self.loc.shape[0]))
    self.scale = None
    self.invScale = i
    self.norm = None

  def getDOF(self):
    return self.dof

  def getLoc(self):
    return self.loc

  def getScale(self):
    if self.scale==None:
      self.scale = numpy.linalg.inv(self.invScale)
    return self.scale

  def getInvScale(self):
    if self.invScale==None:
      self.invScale = numpy.linalg.inv(self.scale)
    return self.invScale


  def getLogNorm(self):
    """Returns the normalising constant of the distribution. Typically for internal use only."""
    if self.norm==None:
      d = self.loc.shape[0]
      self.norm = scipy.special.gammaln(0.5*(self.dof+d))
      self.norm -= scipy.special.gammaln(0.5*self.dof)
      self.norm -= math.log(self.dof*math.pi)*(0.5*d)
      self.norm += 0.5*math.log(numpy.linalg.det(self.getInvScale()))
    return self.norm

  def prob(self, x):
    """Given a vector x evaluates the density function at that point."""
    d = self.loc.shape[0]
    delta = x - self.loc
    
    val = numpy.dot(delta,numpy.dot(self.getInvScale(),delta))
    val = 1.0 + val/self.dof
    return math.exp(self.getLogNorm() + math.log(val)*(-0.5*(self.dof+d)))


  def __str__(self):
    return '{dof:%f,location:%s,scale:%s}'%(self.getDOF(),str(self.getLoc()),str(self.getScale()))
