module Language.Synthesis.MCMC (mhList) where
import           Control.Monad.Random            (Rand, RandomGen, getRandom,
                                                  getSplit, runRand)
import           Data.Functor                    ((<$>))
import           Language.Synthesis.Distribution (Distr)
import qualified Language.Synthesis.Distribution as Distr
mhList :: RandomGen g =>
          a                         
          -> (a -> (b, Double))       
          -> (a -> Distr a)           
          -> Rand g [(a, b, Double)] 
mhList startValue density jump = go (startValue, startAux, startDensity) <$> getSplit
  where (startAux, startDensity) = density startValue
        go orig g = let (next, g') = runRand (mhNext orig) g in orig : go next g'
        mhNext (orig, origAux, origDensity) = do
            next <- Distr.sample $ jump orig
            let origToNext = Distr.logProbability (jump orig) next
                nextToOrig = Distr.logProbability (jump next) orig
                (nextAux, nextDensity) = density next
                score = nextDensity  origDensity + nextToOrig  origToNext
            acceptance <- getRandom
            return $ if score >= log acceptance
                        then (next, nextAux, nextDensity)
                        else (orig, origAux, origDensity)