| Copyright | (c) 2016 FP Complete Corporation | 
|---|---|
| License | MIT (see LICENSE) | 
| Maintainer | dominic@steinitz.org | 
| Safe Haskell | None | 
| Language | Haskell2010 | 
Data.Random.Distribution.Static.MultivariateNormal
Description
Sample from the multivariate normal distribution with a given vector-valued \(\mu\) and covariance matrix \(\Sigma\). For example, the chart below shows samples from the bivariate normal distribution. The dimension of the mean \(n\) is statically checked to be compatible with the dimension of the covariance matrix \(n \times n\).
Example code to generate the chart:
{-# LANGUAGE DataKinds #-}
import qualified Graphics.Rendering.Chart as C
import Graphics.Rendering.Chart.Backend.Diagrams
import Data.Random.Distribution.Static.MultivariateNormal
import qualified Data.Random as R
import Data.Random.Source.PureMT
import Control.Monad.State
import Numeric.LinearAlgebra.Static
nSamples :: Int
nSamples = 10000
sigma1, sigma2, rho :: Double
sigma1 = 3.0
sigma2 = 1.0
rho = 0.5
singleSample :: R.RVarT (State PureMT) (R 2)
singleSample = R.sample $ Normal (vector [0.0, 0.0])
               (sym $ matrix [ sigma1, rho * sigma1 * sigma2
                             , rho * sigma1 * sigma2, sigma2])
multiSamples :: [R 2]
multiSamples = evalState (replicateM nSamples $ R.sample singleSample) (pureMT 3)
pts = map f multiSamples
  where
    f z = (x, y)
      where
        (x, t) = headTail z
        (y, _) = headTail t
chartPoint pointVals n = C.toRenderable layout
  where
    fitted = C.plot_points_values .~ pointVals
              $ C.plot_points_style  . C.point_color .~ opaque red
              $ C.plot_points_title .~ "Sample"
              $ def
    layout = C.layout_title .~ "Sampling Bivariate Normal (" ++ (show n) ++ " samples)"
           $ C.layout_y_axis . C.laxis_generate .~ C.scaledAxis def (-3,3)
           $ C.layout_x_axis . C.laxis_generate .~ C.scaledAxis def (-3,3)
           $ C.layout_plots .~ [C.toPlot fitted]
           $ def
diagMS = do
  denv <- defaultEnv C.vectorAlignmentFns 600 500
  return $ fst $ runBackend denv (C.render (chartPoint pts nSamples) (500, 500))