# -*- coding: utf-8 -*-

import math
import numpy
import numpy.linalg
import numpy.random



class Gaussian:
  """A basic multivariate Gaussian class. Has extensive caching to avoid duplicate calculation."""
  def __init__(self, dims):
    """dims is the number of dimensions. Initialises with mu at the origin and the identity matrix for the precision/covariance."""
    self.mean = numpy.zeros(dims, dtype=numpy.float32)
    self.precision = numpy.identity(dims, dtype=numpy.float32)
    self.covariance = None
    self.norm = None
    self.cholesky = None

  def setMean(self, mean):
    nm = numpy.array(mean, dtype=numpy.float32)
    assert(nm.shape==self.mean.shape)
    self.mean = nm

  def setPrecision(self, precision):
    np = numpy.array(precision, dtype=numpy.float32)
    assert(self.precision.shape==(self.mean.shape[0],self.mean.shape[0]))
    self.precision = np
    self.covariance = None
    self.norm = None
    self.cholesky = None

  def setCovariance(self, covariance):
    nc = numpy.array(covariance, dtype=numpy.float32)
    assert(self.covariance.shape==(self.mean.shape[0],self.mean.shape[0]))
    self.covariance = nc
    self.precision = None
    self.norm = None
    self.cholesky = None

  def getMean(self):
    return self.mean

  def getPrecision(self):
    if self.precision==None:
      self.precision = numpy.linalg.inv(self.covariance)
    return self.precision

  def getCovariance(self):
    if self.covariance==None:
      self.covariance = numpy.linalg.inv(self.precision)
    return self.covariance


  def getNorm(self):
    """Returns the normalising constant of the distribution. Typically for internal use only."""
    if self.norm==None:
      self.norm = math.pow(math.pi,-0.5*self.mean.shape[0]) * math.sqrt(numpy.linalg.det(self.getPrecision()))
    return self.norm

  def prob(self, x):
    """Given a vector x evaluates the density function at that point."""
    offset = x - self.mean
    val = numpy.dot(offset,numpy.dot(self.getPrecision(),offset))
    return self.getNorm() * math.exp(-0.5 * val)


  def sample(self):
    """Draws and returns a sample from the distribution."""
    if self.cholesky==None:
      self.cholesky = numpy.linalg.cholesky(self.getCovariance())
    z = numpy.random.normal(size=self.mean.shape)
    return self.mean + numpy.dot(self.cholesky,z)
