{-# LANGUAGE ImportQualifiedPost #-}
{-# LANGUAGE Trustworthy #-}

module TestDistribution
  ( passed1,
    passed2,
    passed3,
  )
where

import Control.Monad (replicateM)
import Control.Monad.Bayes.Class (mvNormal)
import Control.Monad.Bayes.Sampler.Strict
import Data.Matrix (fromList)
import Data.Vector qualified as V

-- Test the sampled covariance is approximately the same as the
-- specified covariance.
passed1 :: IO Bool
passed1 = sampleIOfixed $ do
  let mu = (V.fromList [0.0, 0.0])
      sigma11 = 2.0
      sigma12 = 1.0
      bigSigma = (fromList 2 2 [sigma11, sigma12, sigma12, sigma11])
      nSamples = 200000
      nSamples' = fromIntegral nSamples
  ss <- replicateM nSamples $ (mvNormal mu bigSigma)
  let xbar = (/ nSamples') $ sum $ fmap (V.! 0) ss
      ybar = (/ nSamples') $ sum $ fmap (V.! 1) ss
  let term1 = (/ nSamples') $ sum $ zipWith (*) (fmap (V.! 0) ss) (fmap (V.! 1) ss)
  let term2 = xbar * ybar
  return $ abs (sigma12 - (term1 - term2)) < 2e-2

-- Test the sampled means are approximately the same as the specified
-- means.
passed2 :: IO Bool
passed2 = sampleIOfixed $ do
  let mu = (V.fromList [0.0, 0.0])
      sigma11 = 2.0
      sigma12 = 1.0
      bigSigma = (fromList 2 2 [sigma11, sigma12, sigma12, sigma11])
      nSamples = 100000
      nSamples' = fromIntegral nSamples
  ss <- replicateM nSamples $ (mvNormal mu bigSigma)
  let xbar = (/ nSamples') $ sum $ fmap (V.! 0) ss
      ybar = (/ nSamples') $ sum $ fmap (V.! 1) ss
  return $ abs xbar < 1e-2 && abs ybar < 1e-2

-- Test the sampled variances are approximately the same as the
-- specified variances.
passed3 :: IO Bool
passed3 = sampleIOfixed $ do
  let mu = (V.fromList [0.0, 0.0])
      sigma11 = 2.0
      sigma12 = 1.0
      bigSigma = (fromList 2 2 [sigma11, sigma12, sigma12, sigma11])
      nSamples = 200000
      nSamples' = fromIntegral nSamples
  ss <- replicateM nSamples $ (mvNormal mu bigSigma)
  let xbar = (/ nSamples') $ sum $ fmap (V.! 0) ss
      ybar = (/ nSamples') $ sum $ fmap (V.! 1) ss
  let xbar2 = (/ nSamples') $ sum $ fmap (\x -> x * x) $ fmap (V.! 0) ss
      ybar2 = (/ nSamples') $ sum $ fmap (\x -> x * x) $ fmap (V.! 1) ss
  let xvar = xbar2 - xbar * xbar
  let yvar = ybar2 - ybar * ybar
  return $ abs (xvar - sigma11) < 1e-2 && abs (yvar - sigma11) < 2e-2