{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE RankNTypes #-}
module Control.Monad.Bayes.Weighted
  ( Weighted,
    withWeight,
    weighted,
    extractWeight,
    unweighted,
    applyWeight,
    hoist,
    runWeighted,
  )
where
import Control.Monad.Bayes.Class
  ( MonadDistribution,
    MonadFactor (..),
    MonadMeasure,
    factor,
  )
import Control.Monad.State (MonadIO, MonadTrans, StateT (..), lift, mapStateT, modify)
import Numeric.Log (Log)
newtype Weighted m a = Weighted (StateT (Log Double) m a)
  
  deriving newtype ((forall a b. (a -> b) -> Weighted m a -> Weighted m b)
-> (forall a b. a -> Weighted m b -> Weighted m a)
-> Functor (Weighted m)
forall a b. a -> Weighted m b -> Weighted m a
forall a b. (a -> b) -> Weighted m a -> Weighted m b
forall (m :: * -> *) a b.
Functor m =>
a -> Weighted m b -> Weighted m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> Weighted m a -> Weighted m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> Weighted m a -> Weighted m b
fmap :: forall a b. (a -> b) -> Weighted m a -> Weighted m b
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> Weighted m b -> Weighted m a
<$ :: forall a b. a -> Weighted m b -> Weighted m a
Functor, Functor (Weighted m)
Functor (Weighted m)
-> (forall a. a -> Weighted m a)
-> (forall a b.
    Weighted m (a -> b) -> Weighted m a -> Weighted m b)
-> (forall a b c.
    (a -> b -> c) -> Weighted m a -> Weighted m b -> Weighted m c)
-> (forall a b. Weighted m a -> Weighted m b -> Weighted m b)
-> (forall a b. Weighted m a -> Weighted m b -> Weighted m a)
-> Applicative (Weighted m)
forall a. a -> Weighted m a
forall a b. Weighted m a -> Weighted m b -> Weighted m a
forall a b. Weighted m a -> Weighted m b -> Weighted m b
forall a b. Weighted m (a -> b) -> Weighted m a -> Weighted m b
forall a b c.
(a -> b -> c) -> Weighted m a -> Weighted m b -> Weighted m c
forall {m :: * -> *}. Monad m => Functor (Weighted m)
forall (m :: * -> *) a. Monad m => a -> Weighted m a
forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> Weighted m b -> Weighted m a
forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> Weighted m b -> Weighted m b
forall (m :: * -> *) a b.
Monad m =>
Weighted m (a -> b) -> Weighted m a -> Weighted m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> Weighted m a -> Weighted m b -> Weighted m c
forall (f :: * -> *).
Functor f
-> (forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall (m :: * -> *) a. Monad m => a -> Weighted m a
pure :: forall a. a -> Weighted m a
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
Weighted m (a -> b) -> Weighted m a -> Weighted m b
<*> :: forall a b. Weighted m (a -> b) -> Weighted m a -> Weighted m b
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c) -> Weighted m a -> Weighted m b -> Weighted m c
liftA2 :: forall a b c.
(a -> b -> c) -> Weighted m a -> Weighted m b -> Weighted m c
$c*> :: forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> Weighted m b -> Weighted m b
*> :: forall a b. Weighted m a -> Weighted m b -> Weighted m b
$c<* :: forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> Weighted m b -> Weighted m a
<* :: forall a b. Weighted m a -> Weighted m b -> Weighted m a
Applicative, Applicative (Weighted m)
Applicative (Weighted m)
-> (forall a b.
    Weighted m a -> (a -> Weighted m b) -> Weighted m b)
-> (forall a b. Weighted m a -> Weighted m b -> Weighted m b)
-> (forall a. a -> Weighted m a)
-> Monad (Weighted m)
forall a. a -> Weighted m a
forall a b. Weighted m a -> Weighted m b -> Weighted m b
forall a b. Weighted m a -> (a -> Weighted m b) -> Weighted m b
forall (m :: * -> *). Monad m => Applicative (Weighted m)
forall (m :: * -> *) a. Monad m => a -> Weighted m a
forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> Weighted m b -> Weighted m b
forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> (a -> Weighted m b) -> Weighted m b
forall (m :: * -> *).
Applicative m
-> (forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> (a -> Weighted m b) -> Weighted m b
>>= :: forall a b. Weighted m a -> (a -> Weighted m b) -> Weighted m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
Weighted m a -> Weighted m b -> Weighted m b
>> :: forall a b. Weighted m a -> Weighted m b -> Weighted m b
$creturn :: forall (m :: * -> *) a. Monad m => a -> Weighted m a
return :: forall a. a -> Weighted m a
Monad, Monad (Weighted m)
Monad (Weighted m)
-> (forall a. IO a -> Weighted m a) -> MonadIO (Weighted m)
forall a. IO a -> Weighted m a
forall (m :: * -> *).
Monad m -> (forall a. IO a -> m a) -> MonadIO m
forall {m :: * -> *}. MonadIO m => Monad (Weighted m)
forall (m :: * -> *) a. MonadIO m => IO a -> Weighted m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> Weighted m a
liftIO :: forall a. IO a -> Weighted m a
MonadIO, (forall (m :: * -> *) a. Monad m => m a -> Weighted m a)
-> MonadTrans Weighted
forall (m :: * -> *) a. Monad m => m a -> Weighted m a
forall (t :: (* -> *) -> * -> *).
(forall (m :: * -> *) a. Monad m => m a -> t m a) -> MonadTrans t
$clift :: forall (m :: * -> *) a. Monad m => m a -> Weighted m a
lift :: forall (m :: * -> *) a. Monad m => m a -> Weighted m a
MonadTrans, Monad (Weighted m)
Weighted m Double
Monad (Weighted m)
-> Weighted m Double
-> (Double -> Double -> Weighted m Double)
-> (Double -> Double -> Weighted m Double)
-> (Double -> Double -> Weighted m Double)
-> (Double -> Double -> Weighted m Double)
-> (Double -> Weighted m Bool)
-> (forall (v :: * -> *).
    Vector v Double =>
    v Double -> Weighted m Int)
-> (forall (v :: * -> *).
    (Vector v (Log Double), Vector v Double) =>
    v (Log Double) -> Weighted m Int)
-> (forall a. [a] -> Weighted m a)
-> (Double -> Weighted m Int)
-> (Double -> Weighted m Int)
-> (forall (v :: * -> *).
    Vector v Double =>
    v Double -> Weighted m (v Double))
-> MonadDistribution (Weighted m)
Double -> Weighted m Bool
Double -> Weighted m Int
Double -> Double -> Weighted m Double
forall a. [a] -> Weighted m a
forall (m :: * -> *).
Monad m
-> m Double
-> (Double -> Double -> m Double)
-> (Double -> Double -> m Double)
-> (Double -> Double -> m Double)
-> (Double -> Double -> m Double)
-> (Double -> m Bool)
-> (forall (v :: * -> *). Vector v Double => v Double -> m Int)
-> (forall (v :: * -> *).
    (Vector v (Log Double), Vector v Double) =>
    v (Log Double) -> m Int)
-> (forall a. [a] -> m a)
-> (Double -> m Int)
-> (Double -> m Int)
-> (forall (v :: * -> *).
    Vector v Double =>
    v Double -> m (v Double))
-> MonadDistribution m
forall (v :: * -> *).
Vector v Double =>
v Double -> Weighted m (v Double)
forall (v :: * -> *). Vector v Double => v Double -> Weighted m Int
forall (v :: * -> *).
(Vector v (Log Double), Vector v Double) =>
v (Log Double) -> Weighted m Int
forall {m :: * -> *}. MonadDistribution m => Monad (Weighted m)
forall (m :: * -> *). MonadDistribution m => Weighted m Double
forall (m :: * -> *).
MonadDistribution m =>
Double -> Weighted m Bool
forall (m :: * -> *).
MonadDistribution m =>
Double -> Weighted m Int
forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> Weighted m Double
forall (m :: * -> *) a. MonadDistribution m => [a] -> Weighted m a
forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v Double) =>
v Double -> Weighted m (v Double)
forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v Double) =>
v Double -> Weighted m Int
forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v (Log Double), Vector v Double) =>
v (Log Double) -> Weighted m Int
$crandom :: forall (m :: * -> *). MonadDistribution m => Weighted m Double
random :: Weighted m Double
$cuniform :: forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> Weighted m Double
uniform :: Double -> Double -> Weighted m Double
$cnormal :: forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> Weighted m Double
normal :: Double -> Double -> Weighted m Double
$cgamma :: forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> Weighted m Double
gamma :: Double -> Double -> Weighted m Double
$cbeta :: forall (m :: * -> *).
MonadDistribution m =>
Double -> Double -> Weighted m Double
beta :: Double -> Double -> Weighted m Double
$cbernoulli :: forall (m :: * -> *).
MonadDistribution m =>
Double -> Weighted m Bool
bernoulli :: Double -> Weighted m Bool
$ccategorical :: forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v Double) =>
v Double -> Weighted m Int
categorical :: forall (v :: * -> *). Vector v Double => v Double -> Weighted m Int
$clogCategorical :: forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v (Log Double), Vector v Double) =>
v (Log Double) -> Weighted m Int
logCategorical :: forall (v :: * -> *).
(Vector v (Log Double), Vector v Double) =>
v (Log Double) -> Weighted m Int
$cuniformD :: forall (m :: * -> *) a. MonadDistribution m => [a] -> Weighted m a
uniformD :: forall a. [a] -> Weighted m a
$cgeometric :: forall (m :: * -> *).
MonadDistribution m =>
Double -> Weighted m Int
geometric :: Double -> Weighted m Int
$cpoisson :: forall (m :: * -> *).
MonadDistribution m =>
Double -> Weighted m Int
poisson :: Double -> Weighted m Int
$cdirichlet :: forall (m :: * -> *) (v :: * -> *).
(MonadDistribution m, Vector v Double) =>
v Double -> Weighted m (v Double)
dirichlet :: forall (v :: * -> *).
Vector v Double =>
v Double -> Weighted m (v Double)
MonadDistribution)
instance Monad m => MonadFactor (Weighted m) where
  score :: Log Double -> Weighted m ()
score Log Double
w = StateT (Log Double) m () -> Weighted m ()
forall (m :: * -> *) a. StateT (Log Double) m a -> Weighted m a
Weighted ((Log Double -> Log Double) -> StateT (Log Double) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
w))
instance MonadDistribution m => MonadMeasure (Weighted m)
weighted, runWeighted :: Weighted m a -> m (a, Log Double)
weighted :: forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
weighted (Weighted StateT (Log Double) m a
m) = StateT (Log Double) m a -> Log Double -> m (a, Log Double)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT (Log Double) m a
m Log Double
1
runWeighted :: forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
runWeighted = Weighted m a -> m (a, Log Double)
forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
weighted
unweighted :: Functor m => Weighted m a -> m a
unweighted :: forall (m :: * -> *) a. Functor m => Weighted m a -> m a
unweighted = ((a, Log Double) -> a) -> m (a, Log Double) -> m a
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Log Double) -> a
forall a b. (a, b) -> a
fst (m (a, Log Double) -> m a)
-> (Weighted m a -> m (a, Log Double)) -> Weighted m a -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Weighted m a -> m (a, Log Double)
forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
weighted
extractWeight :: Functor m => Weighted m a -> m (Log Double)
 = ((a, Log Double) -> Log Double)
-> m (a, Log Double) -> m (Log Double)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (a, Log Double) -> Log Double
forall a b. (a, b) -> b
snd (m (a, Log Double) -> m (Log Double))
-> (Weighted m a -> m (a, Log Double))
-> Weighted m a
-> m (Log Double)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Weighted m a -> m (a, Log Double)
forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
weighted
withWeight :: (Monad m) => m (a, Log Double) -> Weighted m a
withWeight :: forall (m :: * -> *) a.
Monad m =>
m (a, Log Double) -> Weighted m a
withWeight m (a, Log Double)
m = StateT (Log Double) m a -> Weighted m a
forall (m :: * -> *) a. StateT (Log Double) m a -> Weighted m a
Weighted (StateT (Log Double) m a -> Weighted m a)
-> StateT (Log Double) m a -> Weighted m a
forall a b. (a -> b) -> a -> b
$ do
  (a
x, Log Double
w) <- m (a, Log Double) -> StateT (Log Double) m (a, Log Double)
forall (m :: * -> *) a. Monad m => m a -> StateT (Log Double) m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift m (a, Log Double)
m
  (Log Double -> Log Double) -> StateT (Log Double) m ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (Log Double -> Log Double -> Log Double
forall a. Num a => a -> a -> a
* Log Double
w)
  a -> StateT (Log Double) m a
forall a. a -> StateT (Log Double) m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
applyWeight :: MonadFactor m => Weighted m a -> m a
applyWeight :: forall (m :: * -> *) a. MonadFactor m => Weighted m a -> m a
applyWeight Weighted m a
m = do
  (a
x, Log Double
w) <- Weighted m a -> m (a, Log Double)
forall (m :: * -> *) a. Weighted m a -> m (a, Log Double)
weighted Weighted m a
m
  Log Double -> m ()
forall (m :: * -> *). MonadFactor m => Log Double -> m ()
factor Log Double
w
  a -> m a
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return a
x
hoist :: (forall x. m x -> n x) -> Weighted m a -> Weighted n a
hoist :: forall (m :: * -> *) (n :: * -> *) a.
(forall x. m x -> n x) -> Weighted m a -> Weighted n a
hoist forall x. m x -> n x
t (Weighted StateT (Log Double) m a
m) = StateT (Log Double) n a -> Weighted n a
forall (m :: * -> *) a. StateT (Log Double) m a -> Weighted m a
Weighted (StateT (Log Double) n a -> Weighted n a)
-> StateT (Log Double) n a -> Weighted n a
forall a b. (a -> b) -> a -> b
$ (m (a, Log Double) -> n (a, Log Double))
-> StateT (Log Double) m a -> StateT (Log Double) n a
forall (m :: * -> *) a s (n :: * -> *) b.
(m (a, s) -> n (b, s)) -> StateT s m a -> StateT s n b
mapStateT m (a, Log Double) -> n (a, Log Double)
forall x. m x -> n x
t StateT (Log Double) m a
m