{-# LANGUAGE CPP
           , GADTs
           , EmptyCase
           , KindSignatures
           , DataKinds
           , PolyKinds
           , TypeOperators
           , ScopedTypeVariables
           , Rank2Types
           , MultiParamTypeClasses
           , TypeSynonymInstances
           , FlexibleInstances
           , FlexibleContexts
           , UndecidableInstances
           #-}

{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
----------------------------------------------------------------
--                                                    2016.06.29
-- |
-- Module      :  Language.Hakaru.Disintegrate
-- Copyright   :  Copyright (c) 2016 the Hakaru team
-- License     :  BSD3
-- Maintainer  :  wren@community.haskell.org
-- Stability   :  experimental
-- Portability :  GHC-only
--
-- Disintegration via lazy partial evaluation.
--
-- N.B., the forward direction of disintegration is /not/ just
-- partial evaluation! In the version discussed in the paper we
-- must also ensure that no heap-bound variables occur in the result,
-- which requires using HNFs rather than WHNFs. That condition is
-- sound, but a bit too strict; we could generalize this to handle
-- cases where there may be heap-bound variables remaining in neutral
-- terms, provided we (a) don't end up trying to go both forward
-- and backward on the same variable, and (more importantly) (b)
-- end up with the proper Jacobian. The paper version is rigged to
-- ensure that whenever we recurse into two subexpressions (e.g.,
-- the arguments to addition) one of them has a Jacobian of zero,
-- thus when going from @x*y@ to @dx*y + x*dy@ one of the terms
-- cancels out.
--
-- /Developer's Note:/ To help keep the code clean, we use the
-- worker\/wrapper transform. However, due to complexities in
-- typechecking GADTs, this often confuses GHC if you don't give
-- just the right type signature on definitions. This confusion
-- shows up whenever you get error messages about an \"ambiguous\"
-- choice of 'ABT' instance, or certain types of \"couldn't match
-- @a@ with @a1@\" error messages. To eliminate these issues we use
-- @-XScopedTypeVariables@. In particular, the @abt@ type variable
-- must be bound by the wrapper (i.e., the top-level definition),
-- and the workers should just refer to that same type variable
-- rather than quantifying over abother @abt@ type. In addition,
-- whatever other type variables there are (e.g., the @xs@ and @a@
-- of an @abt xs a@ argument) should be polymorphic in the workers
-- and should /not/ reuse the other analogous type variables bound
-- by the wrapper.
--
-- /Developer's Note:/ In general, we'd like to emit weights and
-- guards \"as early as possible\"; however, determining when that
-- actually is can be tricky. If we emit them way-too-early then
-- we'll get hygiene errors because bindings for the variables they
-- use have not yet been emitted. We can fix these hygiene erors
-- by calling 'atomize', to ensure that all the necessary bindings
-- have already been emitted. But that may still emit things too
-- early, because emitting th variable-binding statements now means
-- that we can't go forward\/backward on them later on; which may
-- cause us to bot unnecessarily. One way to avoid this bot issue
-- is to emit guards\/weights later than necessary, by actually
-- pushing them onto the context (and then emitting them whenever
-- we residualize the context). This guarantees we don't emit too
-- early; but the tradeoff is we may end up generating duplicate
-- code by emitting too late. One possible (currently unimplemented)
-- solution to that code duplication issue is to let these statements
-- be emitted too late, but then have a post-processing step to
-- lift guards\/weights up as high as they can go. To avoid problems
-- about testing programs\/expressions for equality, we can use a
-- hash-consing trick so we keep track of the identity of guard\/weight
-- statements, then we can simply compare those identities and only
-- after the lifting do we replace the identity hash with the actual
-- statement.
----------------------------------------------------------------
module Language.Hakaru.Disintegrate
    ( lam_
    -- * the Hakaru API
    , disintegrateWithVar
    , disintegrate, disintegrateInCtx
    , densityWithVar
    , density, densityInCtx
    , observe, observeInCtx
    , determine
    
    -- * Implementation details
    , perform
    , atomize
    , constrainValue
    , constrainOutcome
    ) where

#if __GLASGOW_HASKELL__ < 710
import           Data.Functor         ((<$>))
import           Data.Foldable        (Foldable, foldMap)
import           Data.Traversable     (Traversable)
import           Control.Applicative  (Applicative(..))
#endif
import           Control.Applicative  (Alternative(..))
import           Control.Monad        ((<=<), guard)
import           Data.Functor.Compose (Compose(..))
import qualified Data.Traversable     as T
import           Data.List.NonEmpty   (NonEmpty(..))
import qualified Data.List.NonEmpty   as L
import qualified Data.Text            as Text
import qualified Data.IntMap          as IM
import           Data.Sequence        (Seq)
import qualified Data.Sequence        as S
import           Data.Proxy           (KProxy(..))
import           Data.Maybe           (fromMaybe, fromJust)

import Language.Hakaru.Syntax.IClasses
import Data.Number.Natural
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing
import qualified Language.Hakaru.Types.Coercion as C
import Language.Hakaru.Types.HClasses
import Language.Hakaru.Syntax.TypeOf
import Language.Hakaru.Syntax.AST
import Language.Hakaru.Syntax.Datum
import Language.Hakaru.Syntax.DatumCase (DatumEvaluator, MatchResult(..), matchBranches)
import Language.Hakaru.Syntax.ABT
import Language.Hakaru.Syntax.Transform (TransformCtx(..), minimalCtx)
import Language.Hakaru.Evaluation.Types
import Language.Hakaru.Evaluation.Lazy
import Language.Hakaru.Evaluation.DisintegrationMonad
import qualified Language.Hakaru.Syntax.Prelude as P
import qualified Language.Hakaru.Expect         as E

#ifdef __TRACE_DISINTEGRATE__
import qualified Text.PrettyPrint     as PP
import Language.Hakaru.Pretty.Haskell
import Debug.Trace                    (trace, traceM)
#endif


----------------------------------------------------------------

lam_ :: (ABT Term abt) => Variable a -> abt '[] b -> abt '[] (a ':-> b)
lam_ :: Variable a -> abt '[] b -> abt '[] (a ':-> b)
lam_ Variable a
x abt '[] b
e = Term abt (a ':-> b) -> abt '[] (a ':-> b)
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn (SCon '[ '( '[a], b)] (a ':-> b)
forall (a :: Hakaru) (b :: Hakaru). SCon '[ '( '[a], b)] (a ':-> b)
Lam_ SCon '[ '( '[a], b)] (a ':-> b)
-> SArgs abt '[ '( '[a], b)] -> Term abt (a ':-> b)
forall (args :: [([Hakaru], Hakaru)]) (a :: Hakaru)
       (abt :: [Hakaru] -> Hakaru -> *).
SCon args a -> SArgs abt args -> Term abt a
:$ Variable a -> abt '[] b -> abt '[a] b
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k) (xs :: [k]) (b :: k).
ABT syn abt =>
Variable a -> abt xs b -> abt (a : xs) b
bind Variable a
x abt '[] b
e abt '[a] b -> SArgs abt '[] -> SArgs abt '[ '( '[a], b)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* SArgs abt '[]
forall (abt :: [Hakaru] -> Hakaru -> *). SArgs abt '[]
End)


-- | Disintegrate a measure over pairs with respect to the lebesgue
-- measure on the first component. That is, for each measure kernel
-- @n <- disintegrate m@ we have that @m == bindx lebesgue n@. The
-- first two arguments give the hint and type of the lambda-bound
-- variable in the result. If you want to automatically fill those
-- in, then see 'disintegrate'.
--
-- N.B., the resulting functions from @a@ to @'HMeasure b@ are
-- indeed measurable, thus it is safe\/appropriate to use Hakaru's
-- @(':->)@ rather than Haskell's @(->)@.
--
-- BUG: Actually, disintegration is with respect to the /Borel/
-- measure on the first component of the pair! Alas, we don't really
-- have a clean way of describing this since we've no primitive
-- 'MeasureOp' for Borel measures.
--
-- /Developer's Note:/ This function fills the role that the old
-- @runDisintegrate@ did (as opposed to the old function called
-- @disintegrate@). [Once people are familiar enough with the new
-- code base and no longer recall what the old code base was doing,
-- this note should be deleted.]
disintegrateWithVar
    :: (ABT Term abt)
    => TransformCtx
    -> Text.Text
    -> Sing a
    -> abt '[] ('HMeasure (HPair a b))
    -> [abt '[] (a ':-> 'HMeasure b)]
disintegrateWithVar :: TransformCtx
-> Text
-> Sing a
-> abt '[] ('HMeasure (HPair a b))
-> [abt '[] (a ':-> 'HMeasure b)]
disintegrateWithVar TransformCtx
ctx Text
hint Sing a
typ abt '[] ('HMeasure (HPair a b))
m =
    let x :: Variable a
x = Text -> Nat -> Sing a -> Variable a
forall k (a :: k). Text -> Nat -> Sing a -> Variable a
Variable Text
hint (abt '[] ('HMeasure (HPair a b)) -> Nat
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (a :: k).
ABT syn abt =>
abt xs a -> Nat
nextFreeOrBind abt '[] ('HMeasure (HPair a b))
m) Sing a
typ
    in (abt '[] ('HMeasure b) -> abt '[] (a ':-> 'HMeasure b))
-> [abt '[] ('HMeasure b)] -> [abt '[] (a ':-> 'HMeasure b)]
forall a b. (a -> b) -> [a] -> [b]
map (Variable a -> abt '[] ('HMeasure b) -> abt '[] (a ':-> 'HMeasure b)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
ABT Term abt =>
Variable a -> abt '[] b -> abt '[] (a ':-> b)
lam_ Variable a
x) ([abt '[] ('HMeasure b)] -> [abt '[] (a ':-> 'HMeasure b)])
-> (Dis abt (abt '[] b) -> [abt '[] ('HMeasure b)])
-> Dis abt (abt '[] b)
-> [abt '[] (a ':-> 'HMeasure b)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Dis abt (abt '[] b) -> [Some2 abt] -> [abt '[] ('HMeasure b)])
-> [Some2 abt] -> Dis abt (abt '[] b) -> [abt '[] ('HMeasure b)]
forall a b c. (a -> b -> c) -> b -> a -> c
flip (TransformCtx
-> Dis abt (abt '[] b) -> [Some2 abt] -> [abt '[] ('HMeasure b)]
forall (abt :: [Hakaru] -> Hakaru -> *) (f :: * -> *)
       (a :: Hakaru).
(ABT Term abt, Foldable f) =>
TransformCtx
-> Dis abt (abt '[] a) -> f (Some2 abt) -> [abt '[] ('HMeasure a)]
runDisInCtx TransformCtx
ctx) [abt '[] ('HMeasure (HPair a b)) -> Some2 abt
forall k1 k2 (a :: k1 -> k2 -> *) (i :: k1) (j :: k2).
a i j -> Some2 a
Some2 abt '[] ('HMeasure (HPair a b))
m, abt '[] a -> Some2 abt
forall k1 k2 (a :: k1 -> k2 -> *) (i :: k1) (j :: k2).
a i j -> Some2 a
Some2 (Variable a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var Variable a
x)] (Dis abt (abt '[] b) -> [abt '[] (a ':-> 'HMeasure b)])
-> Dis abt (abt '[] b) -> [abt '[] (a ':-> 'HMeasure b)]
forall a b. (a -> b) -> a -> b
$ do
        Whnf abt (HPair a b)
ab <- abt '[] ('HMeasure (HPair a b)) -> Dis abt (Whnf abt (HPair a b))
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
MeasureEvaluator abt (Dis abt)
perform abt '[] ('HMeasure (HPair a b))
m
#ifdef __TRACE_DISINTEGRATE__
        ss <- getStatements
        trace ("-- disintegrate: finished perform\n"
            ++ show (pretty_Statements ss PP.$+$ PP.sep(prettyPrec_ 11 ab))
            ++ "\n") $ return ()
#endif
        (abt '[] a
a,abt '[] b
b) <- Whnf abt (HPair a b) -> Dis abt (abt '[] a, abt '[] b)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
ABT Term abt =>
Whnf abt (HPair a b) -> Dis abt (abt '[] a, abt '[] b)
emitUnpair Whnf abt (HPair a b)
ab
#ifdef __TRACE_DISINTEGRATE__
        trace ("-- disintegrate: finished emitUnpair: "
            ++ show (pretty a, pretty b)) $ return ()
#endif
        abt '[] a -> abt '[] a -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue (Variable a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var Variable a
x) abt '[] a
a
#ifdef __TRACE_DISINTEGRATE__
        ss <- getStatements
        extras <- getExtras
        traceM ("-- disintegrate: finished constrainValue\n"
                ++ show (pretty_Statements ss) ++ "\n"
                ++ show (prettyExtras extras)
               )
#endif
        abt '[] b -> Dis abt (abt '[] b)
forall (m :: * -> *) a. Monad m => a -> m a
return abt '[] b
b


-- | A variant of 'disintegrateWithVar' which automatically computes
-- the type via 'typeOf'.
disintegrateInCtx
    :: (ABT Term abt)
    => TransformCtx
    -> abt '[] ('HMeasure (HPair a b))
    -> [abt '[] (a ':-> 'HMeasure b)]
disintegrateInCtx :: TransformCtx
-> abt '[] ('HMeasure (HPair a b))
-> [abt '[] (a ':-> 'HMeasure b)]
disintegrateInCtx TransformCtx
ctx abt '[] ('HMeasure (HPair a b))
m =
    TransformCtx
-> Text
-> Sing a
-> abt '[] ('HMeasure (HPair a b))
-> [abt '[] (a ':-> 'HMeasure b)]
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
ABT Term abt =>
TransformCtx
-> Text
-> Sing a
-> abt '[] ('HMeasure (HPair a b))
-> [abt '[] (a ':-> 'HMeasure b)]
disintegrateWithVar
        TransformCtx
ctx
        Text
Text.empty
        ((Sing a, Sing b) -> Sing a
forall a b. (a, b) -> a
fst ((Sing a, Sing b) -> Sing a)
-> (Sing ('HMeasure (HPair a b)) -> (Sing a, Sing b))
-> Sing ('HMeasure (HPair a b))
-> Sing a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sing (HPair a b) -> (Sing a, Sing b)
forall (a :: Hakaru) (b :: Hakaru).
Sing (HPair a b) -> (Sing a, Sing b)
sUnPair (Sing (HPair a b) -> (Sing a, Sing b))
-> (Sing ('HMeasure (HPair a b)) -> Sing (HPair a b))
-> Sing ('HMeasure (HPair a b))
-> (Sing a, Sing b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sing ('HMeasure (HPair a b)) -> Sing (HPair a b)
forall (a :: Hakaru). Sing ('HMeasure a) -> Sing a
sUnMeasure (Sing ('HMeasure (HPair a b)) -> Sing a)
-> Sing ('HMeasure (HPair a b)) -> Sing a
forall a b. (a -> b) -> a -> b
$ abt '[] ('HMeasure (HPair a b)) -> Sing ('HMeasure (HPair a b))
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Sing a
typeOf abt '[] ('HMeasure (HPair a b))
m) -- TODO: change the exception thrown form 'typeOf' so that we know it comes from here
        abt '[] ('HMeasure (HPair a b))
m

-- | A variant of 'disintegrateInCtx' which takes the context to be the minimal
-- one. Calling this function is only really valid on top-level programs, or
-- subprograms in which the enclosing program doesn't bind any variables.
disintegrate
    :: (ABT Term abt)
    => abt '[] ('HMeasure (HPair a b))
    -> [abt '[] (a ':-> 'HMeasure b)]
disintegrate :: abt '[] ('HMeasure (HPair a b)) -> [abt '[] (a ':-> 'HMeasure b)]
disintegrate = TransformCtx
-> abt '[] ('HMeasure (HPair a b))
-> [abt '[] (a ':-> 'HMeasure b)]
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
ABT Term abt =>
TransformCtx
-> abt '[] ('HMeasure (HPair a b))
-> [abt '[] (a ':-> 'HMeasure b)]
disintegrateInCtx TransformCtx
minimalCtx

-- | Return the density function for a given measure. The first two
-- arguments give the hint and type of the lambda-bound variable
-- in the result. If you want to automatically fill those in, then
-- see 'density'.
--
-- TODO: is the resulting function guaranteed to be measurable? If
-- so, update this documentation to reflect that fact; if not, then
-- we should make it into a Haskell function instead.
densityWithVar
    :: (ABT Term abt)
    => TransformCtx
    -> Text.Text
    -> Sing a
    -> abt '[] ('HMeasure a)
    -> [abt '[] (a ':-> 'HProb)]
densityWithVar :: TransformCtx
-> Text
-> Sing a
-> abt '[] ('HMeasure a)
-> [abt '[] (a ':-> 'HProb)]
densityWithVar TransformCtx
ctx Text
hint Sing a
typ abt '[] ('HMeasure a)
m =
    let x :: Variable a
x = Text -> Nat -> Sing a -> Variable a
forall k (a :: k). Text -> Nat -> Sing a -> Variable a
Variable Text
hint (abt '[] ('HMeasure a) -> Nat
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (a :: k).
ABT syn abt =>
abt xs a -> Nat
nextFree abt '[] ('HMeasure a)
m Nat -> Nat -> Nat
forall a. Ord a => a -> a -> a
`max` abt '[] ('HMeasure a) -> Nat
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (a :: k).
ABT syn abt =>
abt xs a -> Nat
nextBind abt '[] ('HMeasure a)
m) Sing a
typ
    in (Variable a -> abt '[] 'HProb -> abt '[] (a ':-> 'HProb)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
ABT Term abt =>
Variable a -> abt '[] b -> abt '[] (a ':-> b)
lam_ Variable a
x (abt '[] 'HProb -> abt '[] (a ':-> 'HProb))
-> (abt '[] ('HMeasure a) -> abt '[] 'HProb)
-> abt '[] ('HMeasure a)
-> abt '[] (a ':-> 'HProb)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. abt '[] ('HMeasure a) -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] ('HMeasure a) -> abt '[] 'HProb
E.total) (abt '[] ('HMeasure a) -> abt '[] (a ':-> 'HProb))
-> [abt '[] ('HMeasure a)] -> [abt '[] (a ':-> 'HProb)]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TransformCtx
-> abt '[] ('HMeasure a) -> abt '[] a -> [abt '[] ('HMeasure a)]
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
TransformCtx
-> abt '[] ('HMeasure a) -> abt '[] a -> [abt '[] ('HMeasure a)]
observeInCtx TransformCtx
ctx abt '[] ('HMeasure a)
m (Variable a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var Variable a
x)


-- | A variant of 'densityWithVar' which automatically computes the
-- type via 'typeOf'.
densityInCtx
    :: (ABT Term abt)
    => TransformCtx
    -> abt '[] ('HMeasure a)
    -> [abt '[] (a ':-> 'HProb)]
densityInCtx :: TransformCtx -> abt '[] ('HMeasure a) -> [abt '[] (a ':-> 'HProb)]
densityInCtx TransformCtx
ctx abt '[] ('HMeasure a)
m =
    TransformCtx
-> Text
-> Sing a
-> abt '[] ('HMeasure a)
-> [abt '[] (a ':-> 'HProb)]
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
TransformCtx
-> Text
-> Sing a
-> abt '[] ('HMeasure a)
-> [abt '[] (a ':-> 'HProb)]
densityWithVar
        TransformCtx
ctx
        Text
Text.empty
        (Sing ('HMeasure a) -> Sing a
forall (a :: Hakaru). Sing ('HMeasure a) -> Sing a
sUnMeasure (Sing ('HMeasure a) -> Sing a) -> Sing ('HMeasure a) -> Sing a
forall a b. (a -> b) -> a -> b
$ abt '[] ('HMeasure a) -> Sing ('HMeasure a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Sing a
typeOf abt '[] ('HMeasure a)
m)
        abt '[] ('HMeasure a)
m

density
    :: (ABT Term abt)
    => abt '[] ('HMeasure a)
    -> [abt '[] (a ':-> 'HProb)]
density :: abt '[] ('HMeasure a) -> [abt '[] (a ':-> 'HProb)]
density = TransformCtx -> abt '[] ('HMeasure a) -> [abt '[] (a ':-> 'HProb)]
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
TransformCtx -> abt '[] ('HMeasure a) -> [abt '[] (a ':-> 'HProb)]
densityInCtx TransformCtx
minimalCtx

-- | Constrain a measure such that it must return the observed
-- value. In other words, the resulting measure returns the observed
-- value with weight according to its density in the original
-- measure, and gives all other values weight zero.
observeInCtx
    :: (ABT Term abt)
    => TransformCtx
    -> abt '[] ('HMeasure a)
    -> abt '[] a
    -> [abt '[] ('HMeasure a)]
observeInCtx :: TransformCtx
-> abt '[] ('HMeasure a) -> abt '[] a -> [abt '[] ('HMeasure a)]
observeInCtx TransformCtx
ctx abt '[] ('HMeasure a)
m abt '[] a
x =
    TransformCtx
-> Dis abt (abt '[] a) -> [Some2 abt] -> [abt '[] ('HMeasure a)]
forall (abt :: [Hakaru] -> Hakaru -> *) (f :: * -> *)
       (a :: Hakaru).
(ABT Term abt, Foldable f) =>
TransformCtx
-> Dis abt (abt '[] a) -> f (Some2 abt) -> [abt '[] ('HMeasure a)]
runDisInCtx TransformCtx
ctx (abt '[] a -> abt '[] ('HMeasure a) -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] ('HMeasure a) -> Dis abt ()
constrainOutcome abt '[] a
x abt '[] ('HMeasure a)
m Dis abt () -> Dis abt (abt '[] a) -> Dis abt (abt '[] a)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> abt '[] a -> Dis abt (abt '[] a)
forall (m :: * -> *) a. Monad m => a -> m a
return abt '[] a
x) [abt '[] ('HMeasure a) -> Some2 abt
forall k1 k2 (a :: k1 -> k2 -> *) (i :: k1) (j :: k2).
a i j -> Some2 a
Some2 abt '[] ('HMeasure a)
m, abt '[] a -> Some2 abt
forall k1 k2 (a :: k1 -> k2 -> *) (i :: k1) (j :: k2).
a i j -> Some2 a
Some2 abt '[] a
x]

observe
    :: (ABT Term abt)
    => abt '[] ('HMeasure a)
    -> abt '[] a
    -> [abt '[] ('HMeasure a)]
observe :: abt '[] ('HMeasure a) -> abt '[] a -> [abt '[] ('HMeasure a)]
observe = TransformCtx
-> abt '[] ('HMeasure a) -> abt '[] a -> [abt '[] ('HMeasure a)]
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
TransformCtx
-> abt '[] ('HMeasure a) -> abt '[] a -> [abt '[] ('HMeasure a)]
observeInCtx TransformCtx
minimalCtx

-- | Arbitrarily choose one of the possible alternatives. In the
-- future, this function should be replaced by a better one that
-- takes some sort of strategy for deciding which alternative to
-- choose.
determine :: (ABT Term abt) => [abt '[] a] -> Maybe (abt '[] a)
determine :: [abt '[] a] -> Maybe (abt '[] a)
determine []    = Maybe (abt '[] a)
forall a. Maybe a
Nothing
determine (abt '[] a
m:[abt '[] a]
_) = abt '[] a -> Maybe (abt '[] a)
forall a. a -> Maybe a
Just abt '[] a
m


----------------------------------------------------------------
----------------------------------------------------------------

-- N.B., forward disintegration is not identical to partial evaluation,
-- as noted at the top of the file. For correctness we need to
-- ensure the result is emissible (i.e., has no heap-bound variables).
-- More specifically, we need to ensure emissibility in the places
-- where we call 'emitMBind'
evaluate_ :: (ABT Term abt) => TermEvaluator abt (Dis abt)
evaluate_ :: TermEvaluator abt (Dis abt)
evaluate_ = MeasureEvaluator abt (Dis abt) -> TermEvaluator abt (Dis abt)
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *)
       (p :: Purity).
(ABT Term abt, EvaluationMonad abt m p) =>
MeasureEvaluator abt m -> TermEvaluator abt m
evaluate MeasureEvaluator abt (Dis abt)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
MeasureEvaluator abt (Dis abt)
perform

evaluateDatum :: (ABT Term abt) => DatumEvaluator (abt '[]) (Dis abt)
evaluateDatum :: DatumEvaluator (abt '[]) (Dis abt)
evaluateDatum abt '[] (HData' t)
e = Whnf abt (HData' t) -> Maybe (Datum (abt '[]) (HData' t))
forall (abt :: [Hakaru] -> Hakaru -> *) (t :: HakaruCon).
ABT Term abt =>
Whnf abt (HData' t) -> Maybe (Datum (abt '[]) (HData' t))
viewWhnfDatum (Whnf abt (HData' t) -> Maybe (Datum (abt '[]) (HData' t)))
-> Dis abt (Whnf abt (HData' t))
-> Dis abt (Maybe (Datum (abt '[]) (HData' t)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> abt '[] (HData' t) -> Dis abt (Whnf abt (HData' t))
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
TermEvaluator abt (Dis abt)
evaluate_ abt '[] (HData' t)
e

-- | Simulate performing 'HMeasure' actions by simply emitting code
-- for those actions, returning the bound variable.
--
-- This is the function called @(|>>)@ in the disintegration paper.
perform :: forall abt. (ABT Term abt) => MeasureEvaluator abt (Dis abt)
perform :: MeasureEvaluator abt (Dis abt)
perform = \abt '[] ('HMeasure a)
e0 ->
#ifdef __TRACE_DISINTEGRATE__
    getStatements >>= \ss ->
    getExtras >>= \extras ->
    getIndices >>= \inds ->
    trace ("\n-- perform --\n"
        ++ "at " ++ show (ppInds inds) ++ "\n"
        ++ show (prettyExtras extras) ++ "\n"
        ++ show (pretty_Statements_withTerm ss e0)
        ++ "\n") $
#endif
    abt '[] ('HMeasure a)
-> (Variable ('HMeasure a) -> Dis abt (Whnf abt a))
-> (Term abt ('HMeasure a) -> Dis abt (Whnf abt a))
-> Dis abt (Whnf abt a)
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k) r.
ABT syn abt =>
abt '[] a -> (Variable a -> r) -> (syn abt a -> r) -> r
caseVarSyn abt '[] ('HMeasure a)
e0 Variable ('HMeasure a) -> Dis abt (Whnf abt a)
forall (a :: Hakaru).
Variable ('HMeasure a) -> Dis abt (Whnf abt a)
performVar Term abt ('HMeasure a) -> Dis abt (Whnf abt a)
forall (a :: Hakaru).
Term abt ('HMeasure a) -> Dis abt (Whnf abt a)
performTerm
    where
    performTerm :: forall a. Term abt ('HMeasure a) -> Dis abt (Whnf abt a)
    performTerm :: Term abt ('HMeasure a) -> Dis abt (Whnf abt a)
performTerm (SCon args ('HMeasure a)
Dirac :$ abt vars a
e1 :* SArgs abt args
End)       = abt '[] a -> Dis abt (Whnf abt a)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
TermEvaluator abt (Dis abt)
evaluate_ abt vars a
abt '[] a
e1
    performTerm (MeasureOp_ MeasureOp typs a
o :$ SArgs abt args
es)       = MeasureOp typs a -> SArgs abt args -> Dis abt (Whnf abt a)
forall (typs :: [Hakaru]) (args :: [([Hakaru], Hakaru)])
       (a :: Hakaru).
(typs ~ UnLCs args, args ~ LCs typs) =>
MeasureOp typs a -> SArgs abt args -> Dis abt (Whnf abt a)
performMeasureOp MeasureOp typs a
o SArgs abt args
es
    performTerm (SCon args ('HMeasure a)
MBind :$ abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) =
        abt '[a] a
-> (Variable a -> abt '[] a -> Dis abt (Whnf abt a))
-> Dis abt (Whnf abt a)
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (x :: k) (xs :: [k]) (a :: k) r.
ABT syn abt =>
abt (x : xs) a -> (Variable x -> abt xs a -> r) -> r
caseBind abt vars a
abt '[a] a
e2 ((Variable a -> abt '[] a -> Dis abt (Whnf abt a))
 -> Dis abt (Whnf abt a))
-> (Variable a -> abt '[] a -> Dis abt (Whnf abt a))
-> Dis abt (Whnf abt a)
forall a b. (a -> b) -> a -> b
$ \Variable a
x abt '[] a
e2' -> do
            [Index (abt '[])]
inds <- Dis abt [Index (abt '[])]
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *)
       (p :: Purity).
EvaluationMonad abt m p =>
m [Index (abt '[])]
getIndices
            Statement abt Variable 'Impure -> abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *) (p :: Purity)
       (xs :: [Hakaru]) (a :: Hakaru).
(ABT Term abt, EvaluationMonad abt m p) =>
Statement abt Variable p -> abt xs a -> m (abt xs a)
push (Variable a
-> Lazy abt ('HMeasure a)
-> [Index (abt '[])]
-> Statement abt Variable 'Impure
forall (abt :: [Hakaru] -> Hakaru -> *) (v :: Hakaru -> *)
       (a :: Hakaru).
v a
-> Lazy abt ('HMeasure a)
-> [Index (abt '[])]
-> Statement abt v 'Impure
SBind Variable a
x (abt '[] a -> Lazy abt a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
abt '[] a -> Lazy abt a
Thunk abt vars a
abt '[] a
e1) [Index (abt '[])]
inds) abt '[] a
e2' Dis abt (abt '[] a)
-> (abt '[] a -> Dis abt (Whnf abt a)) -> Dis abt (Whnf abt a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= abt '[] a -> Dis abt (Whnf abt a)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
MeasureEvaluator abt (Dis abt)
perform

    performTerm (SCon args ('HMeasure a)
Plate :$ abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) =  do
      abt '[] ('HArray a)
x1 <- abt '[] 'HNat
-> abt '[ 'HNat] ('HMeasure a) -> Dis abt (abt '[] ('HArray a))
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] 'HNat
-> abt '[ 'HNat] ('HMeasure a) -> Dis abt (abt '[] ('HArray a))
pushPlate abt vars a
abt '[] 'HNat
e1 abt vars a
abt '[ 'HNat] ('HMeasure a)
e2
      Whnf abt ('HArray a) -> Dis abt (Whnf abt ('HArray a))
forall (m :: * -> *) a. Monad m => a -> m a
return (Whnf abt ('HArray a) -> Dis abt (Whnf abt ('HArray a)))
-> Whnf abt ('HArray a) -> Dis abt (Whnf abt ('HArray a))
forall a b. (a -> b) -> a -> b
$ Maybe (Whnf abt ('HArray a)) -> Whnf abt ('HArray a)
forall a. HasCallStack => Maybe a -> a
fromJust (abt '[] ('HArray a) -> Maybe (Whnf abt ('HArray a))
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Maybe (Whnf abt a)
toWhnf abt '[] ('HArray a)
x1)

    performTerm (Superpose_ NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a))
pes) = do
        [Index (abt '[])]
inds <- Dis abt [Index (abt '[])]
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *)
       (p :: Purity).
EvaluationMonad abt m p =>
m [Index (abt '[])]
getIndices
        if Bool -> Bool
not ([Index (abt '[])] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Index (abt '[])]
inds) Bool -> Bool -> Bool
&& NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a)) -> Int
forall a. NonEmpty a -> Int
L.length NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a))
pes Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1 then Dis abt (Whnf abt a)
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
Dis abt a
bot else
          (forall (r :: Hakaru).
 NonEmpty (abt '[] ('HMeasure r)) -> abt '[] ('HMeasure r))
-> NonEmpty (Dis abt (Whnf abt a)) -> Dis abt (Whnf abt a)
forall (abt :: [Hakaru] -> Hakaru -> *) (t :: * -> *) a.
(ABT Term abt, Traversable t) =>
(forall (r :: Hakaru).
 t (abt '[] ('HMeasure r)) -> abt '[] ('HMeasure r))
-> t (Dis abt a) -> Dis abt a
emitFork_ (NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure r))
-> abt '[] ('HMeasure r)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a))
-> abt '[] ('HMeasure a)
P.superpose (NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure r))
 -> abt '[] ('HMeasure r))
-> (NonEmpty (abt '[] ('HMeasure r))
    -> NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure r)))
-> NonEmpty (abt '[] ('HMeasure r))
-> abt '[] ('HMeasure r)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (abt '[] ('HMeasure r) -> (abt '[] 'HProb, abt '[] ('HMeasure r)))
-> NonEmpty (abt '[] ('HMeasure r))
-> NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure r))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((,) abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a
P.one))
                    (((abt '[] 'HProb, abt '[] ('HMeasure a)) -> Dis abt (Whnf abt a))
-> NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a))
-> NonEmpty (Dis abt (Whnf abt a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(abt '[] 'HProb
p,abt '[] ('HMeasure a)
e) -> Statement abt Variable 'Impure
-> abt '[] ('HMeasure a) -> Dis abt (abt '[] ('HMeasure a))
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *) (p :: Purity)
       (xs :: [Hakaru]) (a :: Hakaru).
(ABT Term abt, EvaluationMonad abt m p) =>
Statement abt Variable p -> abt xs a -> m (abt xs a)
push (Lazy abt 'HProb
-> [Index (abt '[])] -> Statement abt Variable 'Impure
forall (abt :: [Hakaru] -> Hakaru -> *) (v :: Hakaru -> *).
Lazy abt 'HProb -> [Index (abt '[])] -> Statement abt v 'Impure
SWeight (abt '[] 'HProb -> Lazy abt 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
abt '[] a -> Lazy abt a
Thunk abt '[] 'HProb
p) [Index (abt '[])]
inds) abt '[] ('HMeasure a)
e Dis abt (abt '[] ('HMeasure a))
-> (abt '[] ('HMeasure a) -> Dis abt (Whnf abt a))
-> Dis abt (Whnf abt a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= abt '[] ('HMeasure a) -> Dis abt (Whnf abt a)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
MeasureEvaluator abt (Dis abt)
perform)
                          NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a))
pes)

    -- Avoid falling through to the @performWhnf <=< evaluate_@ case
    performTerm (SCon args ('HMeasure a)
Let_ :$ abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) =
        abt '[a] a
-> (Variable a -> abt '[] a -> Dis abt (Whnf abt a))
-> Dis abt (Whnf abt a)
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (x :: k) (xs :: [k]) (a :: k) r.
ABT syn abt =>
abt (x : xs) a -> (Variable x -> abt xs a -> r) -> r
caseBind abt vars a
abt '[a] a
e2 ((Variable a -> abt '[] a -> Dis abt (Whnf abt a))
 -> Dis abt (Whnf abt a))
-> (Variable a -> abt '[] a -> Dis abt (Whnf abt a))
-> Dis abt (Whnf abt a)
forall a b. (a -> b) -> a -> b
$ \Variable a
x abt '[] a
e2' -> do
            [Index (abt '[])]
inds <- Dis abt [Index (abt '[])]
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *)
       (p :: Purity).
EvaluationMonad abt m p =>
m [Index (abt '[])]
getIndices
            Statement abt Variable 'Impure -> abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *) (p :: Purity)
       (xs :: [Hakaru]) (a :: Hakaru).
(ABT Term abt, EvaluationMonad abt m p) =>
Statement abt Variable p -> abt xs a -> m (abt xs a)
push (Variable a
-> Lazy abt a
-> [Index (abt '[])]
-> Statement abt Variable 'Impure
forall (abt :: [Hakaru] -> Hakaru -> *) (p :: Purity)
       (v :: Hakaru -> *) (a :: Hakaru).
v a -> Lazy abt a -> [Index (abt '[])] -> Statement abt v p
SLet Variable a
x (abt '[] a -> Lazy abt a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
abt '[] a -> Lazy abt a
Thunk abt vars a
abt '[] a
e1) [Index (abt '[])]
inds) abt '[] a
e2' Dis abt (abt '[] a)
-> (abt '[] a -> Dis abt (Whnf abt a)) -> Dis abt (Whnf abt a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= abt '[] a -> Dis abt (Whnf abt a)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
MeasureEvaluator abt (Dis abt)
perform

    -- TODO: we could optimize this by calling some @evaluateTerm@
    -- directly, rather than calling 'syn' to rebuild @e0@ from
    -- @t0@ and then calling 'evaluate_' (which will just use
    -- 'caseVarSyn' to get the @t0@ back out from the @e0@).
    performTerm Term abt ('HMeasure a)
t0 = do
        Whnf abt ('HMeasure a)
w <- abt '[] ('HMeasure a) -> Dis abt (Whnf abt ('HMeasure a))
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
TermEvaluator abt (Dis abt)
evaluate_ (Term abt ('HMeasure a) -> abt '[] ('HMeasure a)
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn Term abt ('HMeasure a)
t0)
#ifdef __TRACE_DISINTEGRATE__
        trace ("-- perform: finished evaluate, with:\n" ++ show (PP.sep(prettyPrec_ 11 w))) $ return ()
#endif
        Whnf abt ('HMeasure a) -> Dis abt (Whnf abt a)
forall (a :: Hakaru).
Whnf abt ('HMeasure a) -> Dis abt (Whnf abt a)
performWhnf Whnf abt ('HMeasure a)
w


    performVar :: forall a. Variable ('HMeasure a) -> Dis abt (Whnf abt a)
    performVar :: Variable ('HMeasure a) -> Dis abt (Whnf abt a)
performVar = Whnf abt ('HMeasure a) -> Dis abt (Whnf abt a)
forall (a :: Hakaru).
Whnf abt ('HMeasure a) -> Dis abt (Whnf abt a)
performWhnf (Whnf abt ('HMeasure a) -> Dis abt (Whnf abt a))
-> (Variable ('HMeasure a) -> Dis abt (Whnf abt ('HMeasure a)))
-> Variable ('HMeasure a)
-> Dis abt (Whnf abt a)
forall (m :: * -> *) b c a.
Monad m =>
(b -> m c) -> (a -> m b) -> a -> m c
<=< MeasureEvaluator abt (Dis abt)
-> TermEvaluator abt (Dis abt) -> VariableEvaluator abt (Dis abt)
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *)
       (p :: Purity).
EvaluationMonad abt m p =>
MeasureEvaluator abt m
-> TermEvaluator abt m -> VariableEvaluator abt m
evaluateVar MeasureEvaluator abt (Dis abt)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
MeasureEvaluator abt (Dis abt)
perform TermEvaluator abt (Dis abt)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
TermEvaluator abt (Dis abt)
evaluate_

    performWhnf
        :: forall a. Whnf abt ('HMeasure a) -> Dis abt (Whnf abt a)
    performWhnf :: Whnf abt ('HMeasure a) -> Dis abt (Whnf abt a)
performWhnf (Head_   Head abt ('HMeasure a)
v) = abt '[] ('HMeasure a) -> Dis abt (Whnf abt a)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
MeasureEvaluator abt (Dis abt)
perform (abt '[] ('HMeasure a) -> Dis abt (Whnf abt a))
-> abt '[] ('HMeasure a) -> Dis abt (Whnf abt a)
forall a b. (a -> b) -> a -> b
$ Head abt ('HMeasure a) -> abt '[] ('HMeasure a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Head abt a -> abt '[] a
fromHead Head abt ('HMeasure a)
v
    performWhnf (Neutral abt '[] ('HMeasure a)
e) = (abt '[] a -> Whnf abt a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
abt '[] a -> Whnf abt a
Neutral (abt '[] a -> Whnf abt a)
-> (Variable a -> abt '[] a) -> Variable a -> Whnf abt a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Variable a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var) (Variable a -> Whnf abt a)
-> Dis abt (Variable a) -> Dis abt (Whnf abt a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                              (abt '[] ('HMeasure a) -> Dis abt (Variable a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] ('HMeasure a) -> Dis abt (Variable a)
emitMBind (abt '[] ('HMeasure a) -> Dis abt (Variable a))
-> (Whnf abt ('HMeasure a) -> abt '[] ('HMeasure a))
-> Whnf abt ('HMeasure a)
-> Dis abt (Variable a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Whnf abt ('HMeasure a) -> abt '[] ('HMeasure a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Whnf abt a -> abt '[] a
fromWhnf (Whnf abt ('HMeasure a) -> Dis abt (Variable a))
-> Dis abt (Whnf abt ('HMeasure a)) -> Dis abt (Variable a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< abt '[] ('HMeasure a) -> Dis abt (Whnf abt ('HMeasure a))
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
TermEvaluator abt (Dis abt)
atomize abt '[] ('HMeasure a)
e)


    -- TODO: right now we do the simplest thing. However, for better
    -- coverage and cleaner produced code we'll need to handle each
    -- of the ops separately. (For example, see how 'Uniform' is
    -- handled in the old code; how it has two options for what to
    -- do.)
    performMeasureOp
        :: forall typs args a
        .  (typs ~ UnLCs args, args ~ LCs typs)
        => MeasureOp typs a
        -> SArgs abt args
        -> Dis abt (Whnf abt a)
    performMeasureOp :: MeasureOp typs a -> SArgs abt args -> Dis abt (Whnf abt a)
performMeasureOp = \MeasureOp typs a
o SArgs abt args
es -> MeasureOp typs a -> SArgs abt args -> Dis abt (Whnf abt a)
nice MeasureOp typs a
o SArgs abt args
es Dis abt (Whnf abt a)
-> Dis abt (Whnf abt a) -> Dis abt (Whnf abt a)
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> MeasureOp typs a -> SArgs abt args -> Dis abt (Whnf abt a)
complete MeasureOp typs a
o SArgs abt args
es
        where
        -- Try to generate nice pretty output.
        nice
            :: MeasureOp typs a
            -> SArgs abt args
            -> Dis abt (Whnf abt a)
        nice :: MeasureOp typs a -> SArgs abt args -> Dis abt (Whnf abt a)
nice MeasureOp typs a
o SArgs abt args
es = do
            SArgs abt args
es' <- (forall (h :: [Hakaru]) (i :: Hakaru).
 abt h i -> Dis abt (abt h i))
-> SArgs abt args -> Dis abt (SArgs abt args)
forall k1 k2 k3 (t :: (k1 -> k2 -> *) -> k3 -> *) (f :: * -> *)
       (a :: k1 -> k2 -> *) (b :: k1 -> k2 -> *) (j :: k3).
(Traversable21 t, Applicative f) =>
(forall (h :: k1) (i :: k2). a h i -> f (b h i))
-> t a j -> f (t b j)
traverse21 forall (h :: [Hakaru]) (i :: Hakaru). abt h i -> Dis abt (abt h i)
forall (abt :: [Hakaru] -> Hakaru -> *) (xs :: [Hakaru])
       (a :: Hakaru).
ABT Term abt =>
abt xs a -> Dis abt (abt xs a)
atomizeCore SArgs abt args
es
            abt '[] a
x   <- abt '[] ('HMeasure a) -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] ('HMeasure a) -> Dis abt (abt '[] a)
emitMBind2 (abt '[] ('HMeasure a) -> Dis abt (abt '[] a))
-> abt '[] ('HMeasure a) -> Dis abt (abt '[] a)
forall a b. (a -> b) -> a -> b
$ Term abt ('HMeasure a) -> abt '[] ('HMeasure a)
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn (MeasureOp typs a -> SCon args ('HMeasure a)
forall (typs :: [Hakaru]) (args :: [([Hakaru], Hakaru)])
       (a :: Hakaru).
(typs ~ UnLCs args, args ~ LCs typs) =>
MeasureOp typs a -> SCon args ('HMeasure a)
MeasureOp_ MeasureOp typs a
o SCon args ('HMeasure a) -> SArgs abt args -> Term abt ('HMeasure a)
forall (args :: [([Hakaru], Hakaru)]) (a :: Hakaru)
       (abt :: [Hakaru] -> Hakaru -> *).
SCon args a -> SArgs abt args -> Term abt a
:$ SArgs abt args
es')
            -- return (Neutral $ var x)
            Whnf abt a -> Dis abt (Whnf abt a)
forall (m :: * -> *) a. Monad m => a -> m a
return (abt '[] a -> Whnf abt a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
abt '[] a -> Whnf abt a
Neutral abt '[] a
x)

        -- Try to be as complete as possible (i.e., 'bot' as little as
        -- possible), no matter how ugly the output code gets.
        complete
            :: MeasureOp typs a
            -> SArgs abt args
            -> Dis abt (Whnf abt a)
        complete :: MeasureOp typs a -> SArgs abt args -> Dis abt (Whnf abt a)
complete MeasureOp typs a
Normal = \(abt vars a
mu :* abt vars a
sd :* SArgs abt args
End) -> do
            abt '[] 'HReal
x <- Variable 'HReal -> abt '[] 'HReal
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var (Variable 'HReal -> abt '[] 'HReal)
-> Dis abt (Variable 'HReal) -> Dis abt (abt '[] 'HReal)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> abt '[] ('HMeasure 'HReal) -> Dis abt (Variable 'HReal)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] ('HMeasure a) -> Dis abt (Variable a)
emitMBind abt '[] ('HMeasure 'HReal)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] ('HMeasure 'HReal)
P.lebesgue
            abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
pushWeight (abt '[] 'HReal
-> abt '[] 'HProb -> abt '[] 'HReal -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal
-> abt '[] 'HProb -> abt '[] 'HReal -> abt '[] 'HProb
P.densityNormal abt vars a
abt '[] 'HReal
mu abt vars a
abt '[] 'HProb
sd abt '[] 'HReal
x)
            Whnf abt 'HReal -> Dis abt (Whnf abt 'HReal)
forall (m :: * -> *) a. Monad m => a -> m a
return (abt '[] 'HReal -> Whnf abt 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
abt '[] a -> Whnf abt a
Neutral abt '[] 'HReal
x)
        complete MeasureOp typs a
Uniform = \(abt vars a
lo :* abt vars a
hi :* SArgs abt args
End) -> do
            abt '[] 'HReal
x <- Variable 'HReal -> abt '[] 'HReal
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var (Variable 'HReal -> abt '[] 'HReal)
-> Dis abt (Variable 'HReal) -> Dis abt (abt '[] 'HReal)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> abt '[] ('HMeasure 'HReal) -> Dis abt (Variable 'HReal)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] ('HMeasure a) -> Dis abt (Variable a)
emitMBind abt '[] ('HMeasure 'HReal)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] ('HMeasure 'HReal)
P.lebesgue
            abt '[] HBool -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> Dis abt ()
pushGuard (abt vars a
abt '[] a
lo abt '[] a -> abt '[] a -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HOrd_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.< abt '[] a
abt '[] 'HReal
x abt '[] HBool -> abt '[] HBool -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> abt '[] HBool -> abt '[] HBool
P.&& abt '[] 'HReal
x abt '[] 'HReal -> abt '[] 'HReal -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HOrd_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.< abt vars a
abt '[] 'HReal
hi)
            abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
pushWeight (abt '[] 'HReal
-> abt '[] 'HReal -> abt '[] 'HReal -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal
-> abt '[] 'HReal -> abt '[] 'HReal -> abt '[] 'HProb
P.densityUniform abt vars a
abt '[] 'HReal
lo abt vars a
abt '[] 'HReal
hi abt '[] 'HReal
x)
            Whnf abt 'HReal -> Dis abt (Whnf abt 'HReal)
forall (m :: * -> *) a. Monad m => a -> m a
return (abt '[] 'HReal -> Whnf abt 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
abt '[] a -> Whnf abt a
Neutral abt '[] 'HReal
x)
        complete MeasureOp typs a
_ = \SArgs abt args
_ -> Dis abt (Whnf abt a)
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
Dis abt a
bot
                               

-- | The goal of this function is to ensure the correctness criterion
-- that given any term to be emitted, the resulting term is
-- semantically equivalent but contains no heap-bound variables.
-- That correctness criterion is necessary to ensure hygiene\/scoping.
--
-- This particular implementation calls 'evaluate' recursively,
-- giving us something similar to full-beta reduction. However,
-- that is considered an implementation detail rather than part of
-- the specification of what the function should do. Also, it's a
-- gross hack and prolly a big part of why we keep running into
-- infinite looping issues.
--
-- This name is taken from the old finally tagless code, where
-- \"atomic\" terms are (among other things) emissible; i.e., contain
-- no heap-bound variables.
--
-- BUG: this function infinitely loops in certain circumstances
-- (namely when dealing with neutral terms)
atomize :: (ABT Term abt) => TermEvaluator abt (Dis abt)
atomize :: TermEvaluator abt (Dis abt)
atomize abt '[] a
e =
#ifdef __TRACE_DISINTEGRATE__
    trace ("\n-- atomize --\n" ++ show (pretty e)) $
#endif
    do Whnf abt a
whnf <- abt '[] a -> Dis abt (Whnf abt a)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
TermEvaluator abt (Dis abt)
evaluate_ abt '[] a
e
       case Whnf abt a
whnf of
         Head_ Head abt a
v   -> Head abt a -> Whnf abt a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
Head abt a -> Whnf abt a
Head_ (Head abt a -> Whnf abt a)
-> Dis abt (Head abt a) -> Dis abt (Whnf abt a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (h :: [Hakaru]) (i :: Hakaru).
 abt h i -> Dis abt (abt h i))
-> Head abt a -> Dis abt (Head abt a)
forall k1 k2 k3 (t :: (k1 -> k2 -> *) -> k3 -> *) (f :: * -> *)
       (a :: k1 -> k2 -> *) (b :: k1 -> k2 -> *) (j :: k3).
(Traversable21 t, Applicative f) =>
(forall (h :: k1) (i :: k2). a h i -> f (b h i))
-> t a j -> f (t b j)
traverse21 forall (h :: [Hakaru]) (i :: Hakaru). abt h i -> Dis abt (abt h i)
forall (abt :: [Hakaru] -> Hakaru -> *) (xs :: [Hakaru])
       (a :: Hakaru).
ABT Term abt =>
abt xs a -> Dis abt (abt xs a)
atomizeCore Head abt a
v
         Neutral abt '[] a
e -> abt '[] a -> Whnf abt a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
abt '[] a -> Whnf abt a
Neutral (abt '[] a -> Whnf abt a)
-> (View (Term abt) '[] a -> abt '[] a)
-> View (Term abt) '[] a
-> Whnf abt a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. View (Term abt) '[] a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (a :: k).
ABT syn abt =>
View (syn abt) xs a -> abt xs a
unviewABT (View (Term abt) '[] a -> Whnf abt a)
-> Dis abt (View (Term abt) '[] a) -> Dis abt (Whnf abt a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$>
                      (forall (i :: Hakaru). Term abt i -> Dis abt (Term abt i))
-> View (Term abt) '[] a -> Dis abt (View (Term abt) '[] a)
forall k1 k2 k3 (t :: (k1 -> *) -> k2 -> k3 -> *) (f :: * -> *)
       (a :: k1 -> *) (b :: k1 -> *) (j :: k2) (l :: k3).
(Traversable12 t, Applicative f) =>
(forall (i :: k1). a i -> f (b i)) -> t a j l -> f (t b j l)
traverse12 ((forall (h :: [Hakaru]) (i :: Hakaru).
 abt h i -> Dis abt (abt h i))
-> Term abt i -> Dis abt (Term abt i)
forall k1 k2 k3 (t :: (k1 -> k2 -> *) -> k3 -> *) (f :: * -> *)
       (a :: k1 -> k2 -> *) (b :: k1 -> k2 -> *) (j :: k3).
(Traversable21 t, Applicative f) =>
(forall (h :: k1) (i :: k2). a h i -> f (b h i))
-> t a j -> f (t b j)
traverse21 forall (h :: [Hakaru]) (i :: Hakaru). abt h i -> Dis abt (abt h i)
forall (abt :: [Hakaru] -> Hakaru -> *) (xs :: [Hakaru])
       (a :: Hakaru).
ABT Term abt =>
abt xs a -> Dis abt (abt xs a)
atomizeCore) (abt '[] a -> View (Term abt) '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (a :: k).
ABT syn abt =>
abt xs a -> View (syn abt) xs a
viewABT abt '[] a
e)             


-- | A variant of 'atomize' which is polymorphic in the locally
-- bound variables @xs@ (whereas 'atomize' requires @xs ~ '[]@).
-- We factored this out because we often want this more polymorphic
-- variant when using our indexed @TraversableMN@ classes.
atomizeCore :: (ABT Term abt) => abt xs a -> Dis abt (abt xs a)
atomizeCore :: abt xs a -> Dis abt (abt xs a)
atomizeCore abt xs a
e =
    -- HACK: this check for 'disjointVarSet' is sufficient to catch
    -- the particular infinite loops we were encountering, but it
    -- will not catch all of them. If the call to 'evaluate_' in
    -- 'atomize' returns a neutral term which contains heap-bound
    -- variables, then we'll still loop forever since we don't
    -- traverse\/fmap over the top-level term constructor of neutral
    -- terms.    
 do VarSet 'KProxy
xs <- Dis abt (VarSet 'KProxy)
forall (abt :: [Hakaru] -> Hakaru -> *). Dis abt (VarSet 'KProxy)
getHeapVars
    VarSet 'KProxy
vs <- abt xs a -> Dis abt (VarSet 'KProxy)
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *) (p :: Purity)
       (xs :: [Hakaru]) (a :: Hakaru).
EvaluationMonad abt m p =>
abt xs a -> m (VarSet 'KProxy)
extFreeVars abt xs a
e
    if VarSet 'KProxy -> VarSet 'KProxy -> Bool
forall k k (kproxy :: KProxy k) (kproxy :: KProxy k).
VarSet kproxy -> VarSet kproxy -> Bool
disjointVarSet VarSet 'KProxy
xs VarSet 'KProxy
vs
        then abt xs a -> Dis abt (abt xs a)
forall (m :: * -> *) a. Monad m => a -> m a
return abt xs a
e
        else
            let (List1 Variable xs
ys, abt '[] a
e') = abt xs a -> (List1 Variable xs, abt '[] a)
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (a :: k).
ABT syn abt =>
abt xs a -> (List1 Variable xs, abt '[] a)
caseBinds abt xs a
e
            in
#ifdef __TRACE_DISINTEGRATE__
               trace ("\n-- atomizeCore --\n" ++ show (pretty e')) $
#endif
               (List1 Variable xs -> abt '[] a -> abt xs a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (b :: k).
ABT syn abt =>
List1 Variable xs -> abt '[] b -> abt xs b
binds_ List1 Variable xs
ys (abt '[] a -> abt xs a)
-> (Whnf abt a -> abt '[] a) -> Whnf abt a -> abt xs a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Whnf abt a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Whnf abt a -> abt '[] a
fromWhnf) (Whnf abt a -> abt xs a)
-> Dis abt (Whnf abt a) -> Dis abt (abt xs a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> abt '[] a -> Dis abt (Whnf abt a)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
TermEvaluator abt (Dis abt)
atomize abt '[] a
e'
    where
    -- TODO: does @IM.null . IM.intersection@ fuse correctly?
    disjointVarSet :: VarSet kproxy -> VarSet kproxy -> Bool
disjointVarSet VarSet kproxy
xs VarSet kproxy
ys =
        IntMap (SomeVariable kproxy) -> Bool
forall a. IntMap a -> Bool
IM.null (IntMap (SomeVariable kproxy)
-> IntMap (SomeVariable kproxy) -> IntMap (SomeVariable kproxy)
forall a b. IntMap a -> IntMap b -> IntMap a
IM.intersection (VarSet kproxy -> IntMap (SomeVariable kproxy)
forall k (kproxy :: KProxy k).
VarSet kproxy -> IntMap (SomeVariable kproxy)
unVarSet VarSet kproxy
xs) (VarSet kproxy -> IntMap (SomeVariable kproxy)
forall k (kproxy :: KProxy k).
VarSet kproxy -> IntMap (SomeVariable kproxy)
unVarSet VarSet kproxy
ys))

-- HACK: if we really want to go through with this approach, then
-- we should memoize the set of heap-bound variables in the
-- 'ListContext' itself rather than recomputing it every time!
getHeapVars :: Dis abt (VarSet ('KProxy :: KProxy Hakaru))
getHeapVars :: Dis abt (VarSet 'KProxy)
getHeapVars =
    (forall (a :: Hakaru).
 [Index (abt '[])] -> (VarSet 'KProxy -> Ans abt a) -> Ans abt a)
-> Dis abt (VarSet 'KProxy)
forall (abt :: [Hakaru] -> Hakaru -> *) x.
(forall (a :: Hakaru).
 [Index (abt '[])] -> (x -> Ans abt a) -> Ans abt a)
-> Dis abt x
Dis ((forall (a :: Hakaru).
  [Index (abt '[])] -> (VarSet 'KProxy -> Ans abt a) -> Ans abt a)
 -> Dis abt (VarSet 'KProxy))
-> (forall (a :: Hakaru).
    [Index (abt '[])] -> (VarSet 'KProxy -> Ans abt a) -> Ans abt a)
-> Dis abt (VarSet 'KProxy)
forall a b. (a -> b) -> a -> b
$ \[Index (abt '[])]
_ VarSet 'KProxy -> Ans abt a
c ListContext abt 'Impure
h -> VarSet 'KProxy -> Ans abt a
c ((Statement abt Location 'Impure -> VarSet 'KProxy)
-> [Statement abt Location 'Impure] -> VarSet 'KProxy
forall (t :: * -> *) m a.
(Foldable t, Monoid m) =>
(a -> m) -> t a -> m
foldMap Statement abt Location 'Impure -> VarSet 'KProxy
forall (abt :: [Hakaru] -> Hakaru -> *) (p :: Purity).
Statement abt Location p -> VarSet 'KProxy
statementVars (ListContext abt 'Impure -> [Statement abt Location 'Impure]
forall (abt :: [Hakaru] -> Hakaru -> *) (p :: Purity).
ListContext abt p -> [Statement abt Location p]
statements ListContext abt 'Impure
h)) ListContext abt 'Impure
h

----------------------------------------------------------------
-- | Given an emissible term @v0@ (the first argument) and another
-- term @e0@ (the second argument), compute the constraints such
-- that @e0@ must evaluate to @v0@. This is the function called
-- @(<|)@ in the disintegration paper, though notably we swap the
-- argument order so that the \"value\" is the first argument.
--
-- N.B., this function assumes (and does not verify) that the first
-- argument is emissible. So callers (including recursive calls)
-- must guarantee this invariant, by calling 'atomize' as necessary.
--
-- TODO: capture the emissibility requirement on the first argument
-- in the types, to help avoid accidentally passing the arguments
-- in the wrong order!
constrainValue :: (ABT Term abt) => abt '[] a -> abt '[] a -> Dis abt ()
constrainValue :: abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] a
v0 abt '[] a
e0 =
#ifdef __TRACE_DISINTEGRATE__
    getStatements >>= \ss ->
    getExtras >>= \extras ->
    getIndices >>= \inds ->
    trace ("\n-- constrainValue: " ++ show (pretty v0) ++ "\n"           
        ++ show (pretty_Statements_withTerm ss e0) ++ "\n"
        ++ "at " ++ show (ppInds inds) ++ "\n"
        ++ show (prettyExtras extras) ++ "\n"
          ) $
#endif
    abt '[] a
-> (Variable a -> Dis abt ())
-> (Term abt a -> Dis abt ())
-> Dis abt ()
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k) r.
ABT syn abt =>
abt '[] a -> (Variable a -> r) -> (syn abt a -> r) -> r
caseVarSyn abt '[] a
e0 (abt '[] a -> Variable a -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Variable a -> Dis abt ()
constrainVariable abt '[] a
v0) ((Term abt a -> Dis abt ()) -> Dis abt ())
-> (Term abt a -> Dis abt ()) -> Dis abt ()
forall a b. (a -> b) -> a -> b
$ \Term abt a
t ->
        case Term abt a
t of
        -- There's a bunch of stuff we don't even bother trying to handle
        Empty_   Sing ('HArray a)
_               -> [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error [Char]
"TODO: disintegrate empty arrays"
        Array_   abt '[] 'HNat
n abt '[ 'HNat] a
e             ->
            abt '[ 'HNat] a
-> (Variable 'HNat -> abt '[] a -> Dis abt ()) -> Dis abt ()
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (x :: k) (xs :: [k]) (a :: k) r.
ABT syn abt =>
abt (x : xs) a -> (Variable x -> abt xs a -> r) -> r
caseBind abt '[ 'HNat] a
e ((Variable 'HNat -> abt '[] a -> Dis abt ()) -> Dis abt ())
-> (Variable 'HNat -> abt '[] a -> Dis abt ()) -> Dis abt ()
forall a b. (a -> b) -> a -> b
$ \Variable 'HNat
x abt '[] a
body -> do Index (abt '[])
j <- abt '[] 'HNat -> Dis abt (Index (abt '[]))
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *)
       (p :: Purity).
EvaluationMonad abt m p =>
abt '[] 'HNat -> m (Index (abt '[]))
freshInd abt '[] 'HNat
n
                                       let x' :: Variable 'HNat
x'    = Index (abt '[]) -> Variable 'HNat
forall (ast :: Hakaru -> *). Index ast -> Variable 'HNat
indVar Index (abt '[])
j
                                       abt '[] a
body' <- Variable 'HNat -> abt '[] 'HNat -> abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (xs :: [Hakaru]) (b :: Hakaru) (m :: * -> *) (p :: Purity).
EvaluationMonad abt m p =>
Variable a -> abt '[] a -> abt xs b -> m (abt xs b)
extSubst Variable 'HNat
x (Variable 'HNat -> abt '[] 'HNat
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var Variable 'HNat
x') abt '[] a
body
                                       [Index (abt '[])]
inds  <- Dis abt [Index (abt '[])]
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *)
       (p :: Purity).
EvaluationMonad abt m p =>
m [Index (abt '[])]
getIndices
                                       [Index (abt '[])] -> Dis abt () -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) a.
[Index (abt '[])] -> Dis abt a -> Dis abt a
withIndices (Index (abt '[]) -> [Index (abt '[])] -> [Index (abt '[])]
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Index (abt '[]) -> [Index (abt '[])] -> [Index (abt '[])]
extendIndices Index (abt '[])
j [Index (abt '[])]
inds) (Dis abt () -> Dis abt ()) -> Dis abt () -> Dis abt ()
forall a b. (a -> b) -> a -> b
$
                                                   abt '[] a -> abt '[] a -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue (abt '[] a
abt '[] ('HArray a)
v0 abt '[] ('HArray a) -> abt '[] 'HNat -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] ('HArray a) -> abt '[] 'HNat -> abt '[] a
P.! (Variable 'HNat -> abt '[] 'HNat
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var Variable 'HNat
x')) abt '[] a
body'
                                                   -- TODO use meta-index
        ArrayLiteral_ [abt '[] a]
_          -> [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error [Char]
"TODO: disintegrate literal arrays"
        ArrayOp_ ArrayOp typs a
o :$ SArgs abt args
args       -> abt '[] a -> ArrayOp typs a -> SArgs abt args -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (typs :: [Hakaru])
       (args :: [([Hakaru], Hakaru)]) (a :: Hakaru).
(ABT Term abt, typs ~ UnLCs args, args ~ LCs typs) =>
abt '[] a -> ArrayOp typs a -> SArgs abt args -> Dis abt ()
constrainValueArrayOp abt '[] a
v0 ArrayOp typs a
o SArgs abt args
args
        SCon args a
Lam_  :$ abt vars a
_  :* SArgs abt args
End       -> [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error [Char]
"TODO: disintegrate lambdas"
        SCon args a
App_  :$ abt vars a
_  :* abt vars a
_ :* SArgs abt args
End  -> [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error [Char]
"TODO: disintegrate lambdas"
        SCon args a
Integrate :$ abt vars a
_ :* abt vars a
_ :* abt vars a
_ :* SArgs abt args
End ->
            [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error [Char]
"TODO: disintegrate integration"
        Summate HDiscrete a
_ HSemiring a
_ :$ abt vars a
_ :* abt vars a
_ :* abt vars a
_ :* SArgs abt args
End ->
            [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error [Char]
"TODO: disintegrate integration"


        -- N.B., the semantically correct definition is:
        --
        -- > Literal_ v
        -- >     | "dirac v has a density wrt the ambient measure" -> ...
        -- >     | otherwise -> bot
        --
        -- For the case where the ambient measure is Lebesgue, dirac
        -- doesn't have a density, so we return 'bot'. However, we
        -- will need to generalize this when we start handling other
        -- ambient measures.
        Literal_ Literal a
v               -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
Dis abt a
bot -- unsolvable. (kinda; see note)
        Datum_   Datum (abt '[]) (HData' t)
d               -> abt '[] a -> Datum (abt '[]) a -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Datum (abt '[]) a -> Dis abt ()
constrainDatum abt '[] a
v0 Datum (abt '[]) a
Datum (abt '[]) (HData' t)
d
        SCon args a
Dirac :$ abt vars a
_ :* SArgs abt args
End        -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
Dis abt a
bot -- giving up.
        SCon args a
MBind :$ abt vars a
_ :* abt vars a
_ :* SArgs abt args
End   -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
Dis abt a
bot -- giving up.
        MeasureOp_ MeasureOp typs a
o :$ SArgs abt args
es       -> abt '[] ('HMeasure a)
-> MeasureOp typs a -> SArgs abt args -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (typs :: [Hakaru])
       (args :: [([Hakaru], Hakaru)]) (a :: Hakaru).
(ABT Term abt, typs ~ UnLCs args, args ~ LCs typs) =>
abt '[] ('HMeasure a)
-> MeasureOp typs a -> SArgs abt args -> Dis abt ()
constrainValueMeasureOp abt '[] a
abt '[] ('HMeasure a)
v0 MeasureOp typs a
o SArgs abt args
es
        Superpose_ NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a))
pes           -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
Dis abt a
bot -- giving up.
        Reject_ Sing ('HMeasure a)
_                -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
Dis abt a
bot -- giving up.
        SCon args a
Let_ :$ abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End ->
            abt '[a] a -> (Variable a -> abt '[] a -> Dis abt ()) -> Dis abt ()
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (x :: k) (xs :: [k]) (a :: k) r.
ABT syn abt =>
abt (x : xs) a -> (Variable x -> abt xs a -> r) -> r
caseBind abt vars a
abt '[a] a
e2 ((Variable a -> abt '[] a -> Dis abt ()) -> Dis abt ())
-> (Variable a -> abt '[] a -> Dis abt ()) -> Dis abt ()
forall a b. (a -> b) -> a -> b
$ \Variable a
x abt '[] a
e2' ->
                Statement abt Variable 'Impure -> abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *) (p :: Purity)
       (xs :: [Hakaru]) (a :: Hakaru).
(ABT Term abt, EvaluationMonad abt m p) =>
Statement abt Variable p -> abt xs a -> m (abt xs a)
push (Variable a
-> Lazy abt a
-> [Index (abt '[])]
-> Statement abt Variable 'Impure
forall (abt :: [Hakaru] -> Hakaru -> *) (p :: Purity)
       (v :: Hakaru -> *) (a :: Hakaru).
v a -> Lazy abt a -> [Index (abt '[])] -> Statement abt v p
SLet Variable a
x (abt '[] a -> Lazy abt a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
abt '[] a -> Lazy abt a
Thunk abt vars a
abt '[] a
e1) []) abt '[] a
e2' Dis abt (abt '[] a) -> (abt '[] a -> Dis abt ()) -> Dis abt ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= abt '[] a -> abt '[] a -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] a
v0

        CoerceTo_   Coercion a a
c :$ abt vars a
e1 :* SArgs abt args
End ->
            -- TODO: we need to insert some kind of guard that says
            -- @v0@ is in the range of @coerceTo c@, or equivalently
            -- that @unsafeFrom c v0@ will always succeed. We need
            -- to emit this guard (for correctness of the generated
            -- program) because if @v0@ isn't in the range of the
            -- coercion, then there's no possible way the program
            -- @e1@ could in fact be observed at @v0@. The only
            -- question is how to perform that check; for the
            -- 'Signed' coercions it's easy enough, but for the
            -- 'Continuous' coercions it's not really clear.
            abt '[] a -> abt '[] a -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue (Coercion a a -> abt '[] a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
ABT Term abt =>
Coercion a b -> abt '[] b -> abt '[] a
P.unsafeFrom_ Coercion a a
c abt '[] a
v0) abt vars a
abt '[] a
e1
        UnsafeFrom_ Coercion a b
c :$ abt vars a
e1 :* SArgs abt args
End ->
            -- TODO: to avoid returning garbage, we'd need to place
            -- some constraint on @e1@ so that if the original
            -- program would've crashed due to a bad unsafe-coercion,
            -- then we won't return a disintegrated program (since
            -- it too should always crash). Avoiding this check is
            -- sound (i.e., if the input program is well-formed
            -- then the output program is a well-formed disintegration),
            -- it just overgeneralizes.
            abt '[] b -> abt '[] b -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue  (Coercion a b -> abt '[] a -> abt '[] b
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
ABT Term abt =>
Coercion a b -> abt '[] a -> abt '[] b
P.coerceTo_ Coercion a b
c abt '[] a
v0) abt vars a
abt '[] b
e1
        NaryOp_     NaryOp a
o    Seq (abt '[] a)
es        -> abt '[] a -> NaryOp a -> Seq (abt '[] a) -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> NaryOp a -> Seq (abt '[] a) -> Dis abt ()
constrainNaryOp abt '[] a
v0 NaryOp a
o Seq (abt '[] a)
es
        PrimOp_     PrimOp typs a
o :$ SArgs abt args
es        -> abt '[] a -> PrimOp typs a -> SArgs abt args -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (typs :: [Hakaru])
       (args :: [([Hakaru], Hakaru)]) (a :: Hakaru).
(ABT Term abt, typs ~ UnLCs args, args ~ LCs typs) =>
abt '[] a -> PrimOp typs a -> SArgs abt args -> Dis abt ()
constrainPrimOp abt '[] a
v0 PrimOp typs a
o SArgs abt args
es

        Transform_ Transform args a
t :$ SArgs abt args
_            -> [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> Dis abt ()) -> [Char] -> Dis abt ()
forall a b. (a -> b) -> a -> b
$
          [[Char]] -> [Char]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat[[Char]
"constrainValue{", Transform args a -> [Char]
forall a. Show a => a -> [Char]
show Transform args a
t, [Char]
"}"
                ,[Char]
": cannot yet disintegrate transforms; expand them first"]

        Case_ abt '[] a
e [Branch a abt a]
bs ->
            -- First we try going forward on the scrutinee, to make
            -- pretty resulting programs; but if that doesn't work
            -- out, we fall back to just constraining the branches.
            do  Maybe (MatchResult (abt '[]) abt a)
match <- DatumEvaluator (abt '[]) (Dis abt)
-> abt '[] a
-> [Branch a abt a]
-> Dis abt (Maybe (MatchResult (abt '[]) abt a))
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *)
       (ast :: Hakaru -> *) (a :: Hakaru) (b :: Hakaru).
(ABT Term abt, Monad m) =>
DatumEvaluator ast m
-> ast a -> [Branch a abt b] -> m (Maybe (MatchResult ast abt b))
matchBranches DatumEvaluator (abt '[]) (Dis abt)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
DatumEvaluator (abt '[]) (Dis abt)
evaluateDatum abt '[] a
e [Branch a abt a]
bs
                case Maybe (MatchResult (abt '[]) abt a)
match of
                    Maybe (MatchResult (abt '[]) abt a)
Nothing ->
                        -- If desired, we could return the Hakaru program
                        -- that always crashes, instead of throwing a
                        -- Haskell error.
                        [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error [Char]
"constrainValue{Case_}: nothing matched!"
                    Just MatchResult (abt '[]) abt a
GotStuck ->
                        abt '[] a -> abt '[] a -> [Branch a abt a] -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] b -> [Branch b abt a] -> Dis abt ()
constrainBranches abt '[] a
v0 abt '[] a
e [Branch a abt a]
bs
                    Just (Matched Assocs (abt '[])
rho abt '[] a
body) ->
                        [Statement abt Variable 'Impure]
-> abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *) (p :: Purity)
       (xs :: [Hakaru]) (a :: Hakaru).
(ABT Term abt, EvaluationMonad abt m p) =>
[Statement abt Variable p] -> abt xs a -> m (abt xs a)
pushes (Assocs (abt '[]) -> [Statement abt Variable 'Impure]
forall (abt :: [Hakaru] -> Hakaru -> *) (p :: Purity).
Assocs (abt '[]) -> [Statement abt Variable p]
toVarStatements Assocs (abt '[])
rho) abt '[] a
body Dis abt (abt '[] a) -> (abt '[] a -> Dis abt ()) -> Dis abt ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= abt '[] a -> abt '[] a -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] a
v0
            Dis abt () -> Dis abt () -> Dis abt ()
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> abt '[] a -> abt '[] a -> [Branch a abt a] -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] b -> [Branch b abt a] -> Dis abt ()
constrainBranches abt '[] a
v0 abt '[] a
e [Branch a abt a]
bs

        SCon args a
_ :$ SArgs abt args
_ -> [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error [Char]
"constrainValue: the impossible happened"


-- | The default way of doing 'constrainValue' on a 'Case_' expression:
-- by constraining each branch. To do this we rely on the fact that
-- we're in a 'HMeasure' context (i.e., the continuation produces
-- programs of 'HMeasure' type). For each branch we first assert the
-- branch's pattern holds (via 'SGuard') and then call 'constrainValue'
-- on the body of the branch; and the final program is the superposition
-- of all these branches.
--
-- TODO: how can we avoid duplicating the scrutinee expression?
-- Would pushing a 'SLet' statement before the superpose be sufficient
-- to achieve maximal sharing?
constrainBranches
    :: (ABT Term abt)
    => abt '[] a
    -> abt '[] b
    -> [Branch b abt a]
    -> Dis abt ()
constrainBranches :: abt '[] a -> abt '[] b -> [Branch b abt a] -> Dis abt ()
constrainBranches abt '[] a
v0 abt '[] b
e = [Dis abt ()] -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
[Dis abt a] -> Dis abt a
choose ([Dis abt ()] -> Dis abt ())
-> ([Branch b abt a] -> [Dis abt ()])
-> [Branch b abt a]
-> Dis abt ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Branch b abt a -> Dis abt ()) -> [Branch b abt a] -> [Dis abt ()]
forall a b. (a -> b) -> [a] -> [b]
map Branch b abt a -> Dis abt ()
constrainBranch
    where
    constrainBranch :: Branch b abt a -> Dis abt ()
constrainBranch (Branch Pattern xs b
pat abt xs a
body) =
        let (List1 Variable xs
vars,abt '[] a
body') = abt xs a -> (List1 Variable xs, abt '[] a)
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (a :: k).
ABT syn abt =>
abt xs a -> (List1 Variable xs, abt '[] a)
caseBinds abt xs a
body
        in Statement abt Variable 'Impure -> abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *) (p :: Purity)
       (xs :: [Hakaru]) (a :: Hakaru).
(ABT Term abt, EvaluationMonad abt m p) =>
Statement abt Variable p -> abt xs a -> m (abt xs a)
push (List1 Variable xs
-> Pattern xs b
-> Lazy abt b
-> [Index (abt '[])]
-> Statement abt Variable 'Impure
forall (abt :: [Hakaru] -> Hakaru -> *) (v :: Hakaru -> *)
       (xs :: [Hakaru]) (a :: Hakaru).
List1 v xs
-> Pattern xs a
-> Lazy abt a
-> [Index (abt '[])]
-> Statement abt v 'Impure
SGuard List1 Variable xs
vars Pattern xs b
pat (abt '[] b -> Lazy abt b
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
abt '[] a -> Lazy abt a
Thunk abt '[] b
e) []) abt '[] a
body'
               Dis abt (abt '[] a) -> (abt '[] a -> Dis abt ()) -> Dis abt ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= abt '[] a -> abt '[] a -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] a
v0


constrainDatum
    :: (ABT Term abt) => abt '[] a -> Datum (abt '[]) a -> Dis abt ()
constrainDatum :: abt '[] a -> Datum (abt '[]) a -> Dis abt ()
constrainDatum abt '[] a
v0 Datum (abt '[]) a
d =
    case Datum (abt '[]) a -> PatternOfDatum (abt '[]) a
forall (ast :: Hakaru -> *) (a :: Hakaru).
Datum ast a -> PatternOfDatum ast a
patternOfDatum Datum (abt '[]) a
d of
    PatternOfDatum Pattern xs a
pat List1 (abt '[]) xs
es -> do
        List1 Variable xs
xs <- List1 Hint xs -> Dis abt (List1 Variable xs)
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *) (p :: Purity)
       (xs :: [Hakaru]).
EvaluationMonad abt m p =>
List1 Hint xs -> m (List1 Variable xs)
freshVars (List1 Hint xs -> Dis abt (List1 Variable xs))
-> List1 Hint xs -> Dis abt (List1 Variable xs)
forall a b. (a -> b) -> a -> b
$ (forall (i :: Hakaru). abt '[] i -> Hint i)
-> List1 (abt '[]) xs -> List1 Hint xs
forall k1 k2 (f :: (k1 -> *) -> k2 -> *) (a :: k1 -> *)
       (b :: k1 -> *) (j :: k2).
Functor11 f =>
(forall (i :: k1). a i -> b i) -> f a j -> f b j
fmap11 (Text -> Sing i -> Hint i
forall (a :: Hakaru). Text -> Sing a -> Hint a
Hint Text
Text.empty (Sing i -> Hint i) -> (abt '[] i -> Sing i) -> abt '[] i -> Hint i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. abt '[] i -> Sing i
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Sing a
typeOf) List1 (abt '[]) xs
es
        (forall (r :: Hakaru).
 abt '[] ('HMeasure r) -> abt '[] ('HMeasure r))
-> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
(forall (r :: Hakaru).
 abt '[] ('HMeasure r) -> abt '[] ('HMeasure r))
-> Dis abt ()
emit_ ((forall (r :: Hakaru).
  abt '[] ('HMeasure r) -> abt '[] ('HMeasure r))
 -> Dis abt ())
-> (forall (r :: Hakaru).
    abt '[] ('HMeasure r) -> abt '[] ('HMeasure r))
-> Dis abt ()
forall a b. (a -> b) -> a -> b
$ \abt '[] ('HMeasure r)
body ->
            Term abt ('HMeasure r) -> abt '[] ('HMeasure r)
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn (Term abt ('HMeasure r) -> abt '[] ('HMeasure r))
-> Term abt ('HMeasure r) -> abt '[] ('HMeasure r)
forall a b. (a -> b) -> a -> b
$ abt '[] a -> [Branch a abt ('HMeasure r)] -> Term abt ('HMeasure r)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
abt '[] a -> [Branch a abt b] -> Term abt b
Case_ abt '[] a
v0
                [ Pattern xs a -> abt xs ('HMeasure r) -> Branch a abt ('HMeasure r)
forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *) (b :: Hakaru)
       (xs :: [Hakaru]).
Pattern xs a -> abt xs b -> Branch a abt b
Branch Pattern xs a
pat (List1 Variable xs -> abt '[] ('HMeasure r) -> abt xs ('HMeasure r)
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (b :: k).
ABT syn abt =>
List1 Variable xs -> abt '[] b -> abt xs b
binds_ List1 Variable xs
xs abt '[] ('HMeasure r)
body)
                , Pattern '[] a
-> abt '[] ('HMeasure r) -> Branch a abt ('HMeasure r)
forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *) (b :: Hakaru)
       (xs :: [Hakaru]).
Pattern xs a -> abt xs b -> Branch a abt b
Branch Pattern '[] a
forall (a :: Hakaru). Pattern '[] a
PWild (Sing ('HMeasure r) -> abt '[] ('HMeasure r)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Sing ('HMeasure a) -> abt '[] ('HMeasure a)
P.reject (Sing ('HMeasure r) -> abt '[] ('HMeasure r))
-> Sing ('HMeasure r) -> abt '[] ('HMeasure r)
forall a b. (a -> b) -> a -> b
$ (abt '[] ('HMeasure r) -> Sing ('HMeasure r)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Sing a
typeOf abt '[] ('HMeasure r)
body))
                ]
        List1 Variable xs -> List1 (abt '[]) xs -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (xs :: [Hakaru]).
ABT Term abt =>
List1 Variable xs -> List1 (abt '[]) xs -> Dis abt ()
constrainValues List1 Variable xs
xs List1 (abt '[]) xs
es

constrainValues
    :: (ABT Term abt)
    => List1 Variable  xs
    -> List1 (abt '[]) xs
    -> Dis abt ()
constrainValues :: List1 Variable xs -> List1 (abt '[]) xs -> Dis abt ()
constrainValues (Cons1 Variable x
x List1 Variable xs
xs) (Cons1 abt '[] x
e List1 (abt '[]) xs
es) =
    abt '[] x -> abt '[] x -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue (Variable x -> abt '[] x
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var Variable x
x) abt '[] x
abt '[] x
e Dis abt () -> Dis abt () -> Dis abt ()
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> List1 Variable xs -> List1 (abt '[]) xs -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (xs :: [Hakaru]).
ABT Term abt =>
List1 Variable xs -> List1 (abt '[]) xs -> Dis abt ()
constrainValues List1 Variable xs
xs List1 (abt '[]) xs
List1 (abt '[]) xs
es
constrainValues List1 Variable xs
Nil1 List1 (abt '[]) xs
Nil1 = () -> Dis abt ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()
constrainValues List1 Variable xs
_ List1 (abt '[]) xs
_ = [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error [Char]
"constrainValues: the impossible happened"


data PatternOfDatum (ast :: Hakaru -> *) (a :: Hakaru) =
    forall xs. PatternOfDatum
        !(Pattern xs a)
        !(List1 ast xs)

-- | Given a datum, return the pattern which will match it along
-- with the subexpressions which would be bound to patter-variables.
patternOfDatum :: Datum ast a -> PatternOfDatum ast a
patternOfDatum :: Datum ast a -> PatternOfDatum ast a
patternOfDatum =
    \(Datum Text
hint Sing (HData' t)
_typ DatumCode (Code t) ast (HData' t)
d) ->
        DatumCode (Code t) ast (HData' t)
-> (forall (bs :: [Hakaru]).
    PDatumCode (Code t) bs (HData' t)
    -> List1 ast bs -> PatternOfDatum ast a)
-> PatternOfDatum ast a
forall (xss :: [[HakaruFun]]) (ast :: Hakaru -> *) (a :: Hakaru) r.
DatumCode xss ast a
-> (forall (bs :: [Hakaru]).
    PDatumCode xss bs a -> List1 ast bs -> r)
-> r
podCode DatumCode (Code t) ast (HData' t)
d ((forall (bs :: [Hakaru]).
  PDatumCode (Code t) bs (HData' t)
  -> List1 ast bs -> PatternOfDatum ast a)
 -> PatternOfDatum ast a)
-> (forall (bs :: [Hakaru]).
    PDatumCode (Code t) bs (HData' t)
    -> List1 ast bs -> PatternOfDatum ast a)
-> PatternOfDatum ast a
forall a b. (a -> b) -> a -> b
$ \PDatumCode (Code t) bs (HData' t)
p List1 ast bs
es ->
        Pattern bs (HData' t)
-> List1 ast bs -> PatternOfDatum ast (HData' t)
forall (ast :: Hakaru -> *) (a :: Hakaru) (xs :: [Hakaru]).
Pattern xs a -> List1 ast xs -> PatternOfDatum ast a
PatternOfDatum (Text -> PDatumCode (Code t) bs (HData' t) -> Pattern bs (HData' t)
forall (t :: HakaruCon) (vars :: [Hakaru]).
Text
-> PDatumCode (Code t) vars (HData' t) -> Pattern vars (HData' t)
PDatum Text
hint PDatumCode (Code t) bs (HData' t)
p) List1 ast bs
es
    where
    podCode
        :: DatumCode xss ast a
        -> (forall bs. PDatumCode xss bs a -> List1 ast bs -> r)
        -> r
    podCode :: DatumCode xss ast a
-> (forall (bs :: [Hakaru]).
    PDatumCode xss bs a -> List1 ast bs -> r)
-> r
podCode (Inr DatumCode xss ast a
d) forall (bs :: [Hakaru]). PDatumCode xss bs a -> List1 ast bs -> r
k = DatumCode xss ast a
-> (forall (bs :: [Hakaru]).
    PDatumCode xss bs a -> List1 ast bs -> r)
-> r
forall (xss :: [[HakaruFun]]) (ast :: Hakaru -> *) (a :: Hakaru) r.
DatumCode xss ast a
-> (forall (bs :: [Hakaru]).
    PDatumCode xss bs a -> List1 ast bs -> r)
-> r
podCode   DatumCode xss ast a
d ((forall (bs :: [Hakaru]).
  PDatumCode xss bs a -> List1 ast bs -> r)
 -> r)
-> (forall (bs :: [Hakaru]).
    PDatumCode xss bs a -> List1 ast bs -> r)
-> r
forall a b. (a -> b) -> a -> b
$ \ PDatumCode xss bs a
p List1 ast bs
es -> PDatumCode xss bs a -> List1 ast bs -> r
forall (bs :: [Hakaru]). PDatumCode xss bs a -> List1 ast bs -> r
k (PDatumCode xss bs a -> PDatumCode (xs : xss) bs a
forall (xss :: [[HakaruFun]]) (vars :: [Hakaru]) (a :: Hakaru)
       (xs :: [HakaruFun]).
PDatumCode xss vars a -> PDatumCode (xs : xss) vars a
PInr PDatumCode xss bs a
p) List1 ast bs
es
    podCode (Inl DatumStruct xs ast a
d) forall (bs :: [Hakaru]). PDatumCode xss bs a -> List1 ast bs -> r
k = DatumStruct xs ast a
-> (forall (bs :: [Hakaru]).
    PDatumStruct xs bs a -> List1 ast bs -> r)
-> r
forall (xs :: [HakaruFun]) (ast :: Hakaru -> *) (a :: Hakaru) r.
DatumStruct xs ast a
-> (forall (bs :: [Hakaru]).
    PDatumStruct xs bs a -> List1 ast bs -> r)
-> r
podStruct DatumStruct xs ast a
d ((forall (bs :: [Hakaru]).
  PDatumStruct xs bs a -> List1 ast bs -> r)
 -> r)
-> (forall (bs :: [Hakaru]).
    PDatumStruct xs bs a -> List1 ast bs -> r)
-> r
forall a b. (a -> b) -> a -> b
$ \ PDatumStruct xs bs a
p List1 ast bs
es -> PDatumCode xss bs a -> List1 ast bs -> r
forall (bs :: [Hakaru]). PDatumCode xss bs a -> List1 ast bs -> r
k (PDatumStruct xs bs a -> PDatumCode (xs : xss) bs a
forall (xs :: [HakaruFun]) (vars :: [Hakaru]) (a :: Hakaru)
       (xss :: [[HakaruFun]]).
PDatumStruct xs vars a -> PDatumCode (xs : xss) vars a
PInl PDatumStruct xs bs a
p) List1 ast bs
es

    podStruct
        :: DatumStruct xs ast a
        -> (forall bs. PDatumStruct xs bs a -> List1 ast bs -> r)
        -> r
    podStruct :: DatumStruct xs ast a
-> (forall (bs :: [Hakaru]).
    PDatumStruct xs bs a -> List1 ast bs -> r)
-> r
podStruct (Et DatumFun x ast a
d1 DatumStruct xs ast a
d2) forall (bs :: [Hakaru]). PDatumStruct xs bs a -> List1 ast bs -> r
k =
        DatumFun x ast a
-> (forall (bs :: [Hakaru]). PDatumFun x bs a -> List1 ast bs -> r)
-> r
forall (x :: HakaruFun) (ast :: Hakaru -> *) (a :: Hakaru) r.
DatumFun x ast a
-> (forall (bs :: [Hakaru]). PDatumFun x bs a -> List1 ast bs -> r)
-> r
podFun    DatumFun x ast a
d1 ((forall (bs :: [Hakaru]). PDatumFun x bs a -> List1 ast bs -> r)
 -> r)
-> (forall (bs :: [Hakaru]). PDatumFun x bs a -> List1 ast bs -> r)
-> r
forall a b. (a -> b) -> a -> b
$ \PDatumFun x bs a
p1 List1 ast bs
es1 ->
        DatumStruct xs ast a
-> (forall (bs :: [Hakaru]).
    PDatumStruct xs bs a -> List1 ast bs -> r)
-> r
forall (xs :: [HakaruFun]) (ast :: Hakaru -> *) (a :: Hakaru) r.
DatumStruct xs ast a
-> (forall (bs :: [Hakaru]).
    PDatumStruct xs bs a -> List1 ast bs -> r)
-> r
podStruct DatumStruct xs ast a
d2 ((forall (bs :: [Hakaru]).
  PDatumStruct xs bs a -> List1 ast bs -> r)
 -> r)
-> (forall (bs :: [Hakaru]).
    PDatumStruct xs bs a -> List1 ast bs -> r)
-> r
forall a b. (a -> b) -> a -> b
$ \PDatumStruct xs bs a
p2 List1 ast bs
es2 ->
        PDatumStruct xs (bs ++ bs) a -> List1 ast (bs ++ bs) -> r
forall (bs :: [Hakaru]). PDatumStruct xs bs a -> List1 ast bs -> r
k (PDatumFun x bs a
-> PDatumStruct xs bs a -> PDatumStruct (x : xs) (bs ++ bs) a
forall (x :: HakaruFun) (vars1 :: [Hakaru]) (a :: Hakaru)
       (xs :: [HakaruFun]) (vars2 :: [Hakaru]).
PDatumFun x vars1 a
-> PDatumStruct xs vars2 a
-> PDatumStruct (x : xs) (vars1 ++ vars2) a
PEt PDatumFun x bs a
p1 PDatumStruct xs bs a
p2) (List1 ast bs
es1 List1 ast bs -> List1 ast bs -> List1 ast (bs ++ bs)
forall k (a :: k -> *) (xs :: [k]) (ys :: [k]).
List1 a xs -> List1 a ys -> List1 a (xs ++ ys)
`append1` List1 ast bs
es2)
    podStruct DatumStruct xs ast a
Done forall (bs :: [Hakaru]). PDatumStruct xs bs a -> List1 ast bs -> r
k = PDatumStruct xs '[] a -> List1 ast '[] -> r
forall (bs :: [Hakaru]). PDatumStruct xs bs a -> List1 ast bs -> r
k PDatumStruct xs '[] a
forall (a :: Hakaru). PDatumStruct '[] '[] a
PDone List1 ast '[]
forall k (a :: k -> *). List1 a '[]
Nil1

    podFun
        :: DatumFun x ast a
        -> (forall bs. PDatumFun x bs a -> List1 ast bs -> r)
        -> r
    podFun :: DatumFun x ast a
-> (forall (bs :: [Hakaru]). PDatumFun x bs a -> List1 ast bs -> r)
-> r
podFun (Konst ast b
e) forall (bs :: [Hakaru]). PDatumFun x bs a -> List1 ast bs -> r
k = PDatumFun x '[b] a -> List1 ast '[b] -> r
forall (bs :: [Hakaru]). PDatumFun x bs a -> List1 ast bs -> r
k (Pattern '[b] b -> PDatumFun ('K b) '[b] a
forall (vars :: [Hakaru]) (b :: Hakaru) (a :: Hakaru).
Pattern vars b -> PDatumFun ('K b) vars a
PKonst Pattern '[b] b
forall (a :: Hakaru). Pattern '[a] a
PVar) (ast b -> List1 ast '[] -> List1 ast '[b]
forall a (a :: a -> *) (x :: a) (xs :: [a]).
a x -> List1 a xs -> List1 a (x : xs)
Cons1 ast b
e List1 ast '[]
forall k (a :: k -> *). List1 a '[]
Nil1)
    podFun (Ident ast a
e) forall (bs :: [Hakaru]). PDatumFun x bs a -> List1 ast bs -> r
k = PDatumFun x '[a] a -> List1 ast '[a] -> r
forall (bs :: [Hakaru]). PDatumFun x bs a -> List1 ast bs -> r
k (Pattern '[a] a -> PDatumFun 'I '[a] a
forall (vars :: [Hakaru]) (a :: Hakaru).
Pattern vars a -> PDatumFun 'I vars a
PIdent Pattern '[a] a
forall (a :: Hakaru). Pattern '[a] a
PVar) (ast a -> List1 ast '[] -> List1 ast '[a]
forall a (a :: a -> *) (x :: a) (xs :: [a]).
a x -> List1 a xs -> List1 a (x : xs)
Cons1 ast a
e List1 ast '[]
forall k (a :: k -> *). List1 a '[]
Nil1)


----------------------------------------------------------------
-- | N.B., as with 'constrainValue', we assume that the first
-- argument is emissible. So it is the caller's responsibility to
-- ensure this (by calling 'atomize' as appropriate).
--
-- TODO: capture the emissibility requirement on the first argument
-- in the types.
constrainVariable
    :: (ABT Term abt) => abt '[] a -> Variable a -> Dis abt ()
constrainVariable :: abt '[] a -> Variable a -> Dis abt ()
constrainVariable abt '[] a
v0 Variable a
x =
    do Assocs (Extra (abt '[]))
extras <- Dis abt (Assocs (Extra (abt '[])))
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Dis abt (Assocs (Extra (abt '[])))
getExtras
    -- If we get 'Nothing', then it turns out @x@ is a free variable.
    -- If @x@ is a free variable, then it's a neutral term; and we
    -- return 'bot' for neutral terms
       Dis abt ()
-> (Extra (abt '[]) a -> Dis abt ())
-> Maybe (Extra (abt '[]) a)
-> Dis abt ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
Dis abt a
bot Extra (abt '[]) a -> Dis abt ()
lookForLoc (Variable a -> Assocs (Extra (abt '[])) -> Maybe (Extra (abt '[]) a)
forall k (a :: k) (ast :: k -> *).
(Show1 Sing, JmEq1 Sing) =>
Variable a -> Assocs ast -> Maybe (ast a)
lookupAssoc Variable a
x Assocs (Extra (abt '[]))
extras)
    where lookForLoc :: Extra (abt '[]) a -> Dis abt ()
lookForLoc (Loc      Location a
l [abt '[] 'HNat]
jxs) =
              -- If we get 'Nothing', then it turns out @l@ is a free
              -- location. This is an error because of the
              -- invariant:
              --   if there exists an 'Assoc x (Loc l _)' inside @extras@
              --   then there must be a statement on the 'ListContext' that binds @l@
              (Dis abt () -> (() -> Dis abt ()) -> Maybe () -> Dis abt ()
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (Location a -> Dis abt ()
forall (a :: Hakaru) b. Location a -> b
freeLocError Location a
l) () -> Dis abt ()
forall (m :: * -> *) a. Monad m => a -> m a
return (Maybe () -> Dis abt ()) -> Dis abt (Maybe ()) -> Dis abt ()
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<<) (Dis abt (Maybe ()) -> Dis abt ())
-> ((Statement abt Location 'Impure -> Maybe (Dis abt ()))
    -> Dis abt (Maybe ()))
-> (Statement abt Location 'Impure -> Maybe (Dis abt ()))
-> Dis abt ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Location a
-> (Statement abt Location 'Impure -> Maybe (Dis abt ()))
-> Dis abt (Maybe ())
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *) (p :: Purity)
       (a :: Hakaru) r.
EvaluationMonad abt m p =>
Location a
-> (Statement abt Location p -> Maybe (m r)) -> m (Maybe r)
select Location a
l ((Statement abt Location 'Impure -> Maybe (Dis abt ()))
 -> Dis abt ())
-> (Statement abt Location 'Impure -> Maybe (Dis abt ()))
-> Dis abt ()
forall a b. (a -> b) -> a -> b
$ \Statement abt Location 'Impure
s ->
                  case Statement abt Location 'Impure
s of
                    SBind Location a
l' Lazy abt ('HMeasure a)
e [Index (abt '[])]
ixs -> do
                           TypeEq a a
Refl <- Location a -> Location a -> Maybe (TypeEq a a)
forall k (a :: k) (b :: k).
(Show1 Sing, JmEq1 Sing) =>
Location a -> Location b -> Maybe (TypeEq a b)
locEq Location a
l Location a
l'
                           Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard ([Index (abt '[])] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index (abt '[])]
ixs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [abt '[] 'HNat] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [abt '[] 'HNat]
jxs) -- will error otherwise
                           Dis abt () -> Maybe (Dis abt ())
forall a. a -> Maybe a
Just (Dis abt () -> Maybe (Dis abt ()))
-> Dis abt () -> Maybe (Dis abt ())
forall a b. (a -> b) -> a -> b
$ do
                             [Index (abt '[])]
inds <- Dis abt [Index (abt '[])]
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *)
       (p :: Purity).
EvaluationMonad abt m p =>
m [Index (abt '[])]
getIndices
                             Bool -> Dis abt ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard ([abt '[] 'HNat]
jxs [abt '[] 'HNat] -> [Index (abt '[])] -> Bool
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
[abt '[] 'HNat] -> [Index (abt '[])] -> Bool
`permutes` [Index (abt '[])]
inds) -- will bot otherwise
                             abt '[] ('HMeasure a)
e' <- [Index (abt '[])]
-> [Index (abt '[])]
-> abt '[] ('HMeasure a)
-> Dis abt (abt '[] ('HMeasure a))
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
[Index (abt '[])]
-> [Index (abt '[])] -> abt '[] a -> Dis abt (abt '[] a)
apply [Index (abt '[])]
ixs [Index (abt '[])]
inds (Lazy abt ('HMeasure a) -> abt '[] ('HMeasure a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Lazy abt a -> abt '[] a
fromLazy Lazy abt ('HMeasure a)
e)
                             abt '[] a -> abt '[] ('HMeasure a) -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] ('HMeasure a) -> Dis abt ()
constrainOutcome abt '[] a
v0 abt '[] ('HMeasure a)
abt '[] ('HMeasure a)
e'
                             Statement abt Location 'Impure -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *)
       (p :: Purity).
EvaluationMonad abt m p =>
Statement abt Location p -> m ()
unsafePush (Location a
-> Lazy abt a
-> [Index (abt '[])]
-> Statement abt Location 'Impure
forall (abt :: [Hakaru] -> Hakaru -> *) (p :: Purity)
       (v :: Hakaru -> *) (a :: Hakaru).
v a -> Lazy abt a -> [Index (abt '[])] -> Statement abt v p
SLet Location a
l (Whnf abt a -> Lazy abt a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
Whnf abt a -> Lazy abt a
Whnf_ (abt '[] a -> Whnf abt a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
abt '[] a -> Whnf abt a
Neutral abt '[] a
v0)) [Index (abt '[])]
inds)
                    SLet  Location a
l' Lazy abt a
e [Index (abt '[])]
ixs -> do
                           TypeEq a a
Refl <- Location a -> Location a -> Maybe (TypeEq a a)
forall k (a :: k) (b :: k).
(Show1 Sing, JmEq1 Sing) =>
Location a -> Location b -> Maybe (TypeEq a b)
locEq Location a
l Location a
l'
                           Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard ([Index (abt '[])] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Index (abt '[])]
ixs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [abt '[] 'HNat] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [abt '[] 'HNat]
jxs) -- will error otherwise
                           Dis abt () -> Maybe (Dis abt ())
forall a. a -> Maybe a
Just (Dis abt () -> Maybe (Dis abt ()))
-> Dis abt () -> Maybe (Dis abt ())
forall a b. (a -> b) -> a -> b
$ do
                             [Index (abt '[])]
inds <- Dis abt [Index (abt '[])]
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *)
       (p :: Purity).
EvaluationMonad abt m p =>
m [Index (abt '[])]
getIndices
                             Bool -> Dis abt ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard ([abt '[] 'HNat]
jxs [abt '[] 'HNat] -> [Index (abt '[])] -> Bool
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
[abt '[] 'HNat] -> [Index (abt '[])] -> Bool
`permutes` [Index (abt '[])]
inds) -- will bot otherwise
                             abt '[] a
e' <- [Index (abt '[])]
-> [Index (abt '[])] -> abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
[Index (abt '[])]
-> [Index (abt '[])] -> abt '[] a -> Dis abt (abt '[] a)
apply [Index (abt '[])]
ixs [Index (abt '[])]
inds (Lazy abt a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Lazy abt a -> abt '[] a
fromLazy Lazy abt a
e)
                             abt '[] a -> abt '[] a -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] a
v0 abt '[] a
abt '[] a
e'
                             Statement abt Location 'Impure -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *)
       (p :: Purity).
EvaluationMonad abt m p =>
Statement abt Location p -> m ()
unsafePush (Location a
-> Lazy abt a
-> [Index (abt '[])]
-> Statement abt Location 'Impure
forall (abt :: [Hakaru] -> Hakaru -> *) (p :: Purity)
       (v :: Hakaru -> *) (a :: Hakaru).
v a -> Lazy abt a -> [Index (abt '[])] -> Statement abt v p
SLet Location a
l (Whnf abt a -> Lazy abt a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
Whnf abt a -> Lazy abt a
Whnf_ (abt '[] a -> Whnf abt a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
abt '[] a -> Whnf abt a
Neutral abt '[] a
v0)) [Index (abt '[])]
inds)
                    SWeight Lazy abt 'HProb
_ [Index (abt '[])]
_ -> Maybe (Dis abt ())
forall a. Maybe a
Nothing
                    SGuard List1 Location xs
ls' Pattern xs a
pat Lazy abt a
scrutinee [Index (abt '[])]
i -> [Char] -> Maybe (Dis abt ())
forall a. HasCallStack => [Char] -> a
error [Char]
"TODO: constrainVariable{SGuard}"

----------------------------------------------------------------
-- | N.B., as with 'constrainValue', we assume that the first
-- argument is emissible. So it is the caller's responsibility to
-- ensure this (by calling 'atomize' as appropriate).
--
-- TODO: capture the emissibility requirement on the first argument
-- in the types.
constrainValueArrayOp
    :: forall abt typs args a
    .  (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs)
    => abt '[] a
    -> ArrayOp typs a
    -> SArgs abt args
    -> Dis abt ()
constrainValueArrayOp :: abt '[] a -> ArrayOp typs a -> SArgs abt args -> Dis abt ()
constrainValueArrayOp abt '[] a
v0 = ArrayOp typs a -> SArgs abt args -> Dis abt ()
go
    where
      go :: ArrayOp typs a -> SArgs abt args -> Dis abt ()
      go :: ArrayOp typs a -> SArgs abt args -> Dis abt ()
go (Index  Sing a
_) (abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) = do
        Whnf abt a
w1 <- abt '[] a -> Dis abt (Whnf abt a)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
TermEvaluator abt (Dis abt)
evaluate_ abt vars a
abt '[] a
e1
        case Whnf abt a
w1 of
          Neutral abt '[] a
e1' -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
Dis abt a
bot
          Head_ (WArray abt '[] 'HNat
_ abt '[ 'HNat] a
b) -> abt '[ 'HNat] a
-> (Variable 'HNat -> abt '[] a -> Dis abt ()) -> Dis abt ()
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (x :: k) (xs :: [k]) (a :: k) r.
ABT syn abt =>
abt (x : xs) a -> (Variable x -> abt xs a -> r) -> r
caseBind abt '[ 'HNat] a
b ((Variable 'HNat -> abt '[] a -> Dis abt ()) -> Dis abt ())
-> (Variable 'HNat -> abt '[] a -> Dis abt ()) -> Dis abt ()
forall a b. (a -> b) -> a -> b
$ \Variable 'HNat
x abt '[] a
body ->
                                Variable 'HNat -> abt '[] 'HNat -> abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (xs :: [Hakaru]) (b :: Hakaru) (m :: * -> *) (p :: Purity).
EvaluationMonad abt m p =>
Variable a -> abt '[] a -> abt xs b -> m (abt xs b)
extSubst Variable 'HNat
x abt vars a
abt '[] 'HNat
e2 abt '[] a
body Dis abt (abt '[] a) -> (abt '[] a -> Dis abt ()) -> Dis abt ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= abt '[] a -> abt '[] a -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] a
v0
          Head_ (WEmpty Sing ('HArray a)
_) -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
Dis abt a
bot -- TODO: check this
          Head_ a :: Head abt a
a@(WArrayLiteral [abt '[] a]
_) -> abt '[] a -> abt '[] 'HNat -> Head abt ('HArray a) -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] 'HNat -> Head abt ('HArray a) -> Dis abt ()
constrainValueIdxArrLit abt '[] a
v0 abt vars a
abt '[] 'HNat
e2 Head abt a
Head abt ('HArray a)
a
          Whnf abt a
_ -> [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error [Char]
"constrainValue {ArrayOp Index}: uknown whnf of array type"
      go (Size   Sing a
_) SArgs abt args
_ = [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error [Char]
"TODO: disintegrate {ArrayOp Size}"
      go (Reduce Sing a
_) SArgs abt args
_ = [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error [Char]
"TODO: disintegrate {ArrayOp Reduce}"
      go ArrayOp typs a
_ SArgs abt args
_ = [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error [Char]
"constrainValueArrayOp: unknown arrayOp"


-- | Special case for [true,false] ! i, and [false, true] ! i
-- This helps us disintegrate bern                                                    
constrainValueIdxArrLit
     :: forall abt a
     .  (ABT Term abt)
     => abt '[] a
     -> abt '[] 'HNat
     -> Head abt ('HArray a)
     -> Dis abt ()
constrainValueIdxArrLit :: abt '[] a -> abt '[] 'HNat -> Head abt ('HArray a) -> Dis abt ()
constrainValueIdxArrLit abt '[] a
v0 abt '[] 'HNat
e2 = Head abt ('HArray a) -> Dis abt ()
go
    where
      go :: Head abt ('HArray a) -> Dis abt ()
      go :: Head abt ('HArray a) -> Dis abt ()
go (WArrayLiteral [abt '[] a
a1,abt '[] a
a2]) =
          case (Sing a -> Sing HBool -> Maybe (TypeEq a HBool)
forall k (a :: k -> *) (i :: k) (j :: k).
JmEq1 a =>
a i -> a j -> Maybe (TypeEq i j)
jmEq1 (abt '[] a -> Sing a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Sing a
typeOf abt '[] a
v0) Sing HBool
sBool) of
            Just TypeEq a HBool
Refl ->
                let constrainInd :: abt '[] 'HNat -> Dis abt ()
constrainInd = (abt '[] 'HNat -> abt '[] 'HNat -> Dis abt ())
-> abt '[] 'HNat -> abt '[] 'HNat -> Dis abt ()
forall a b c. (a -> b -> c) -> b -> a -> c
flip abt '[] 'HNat -> abt '[] 'HNat -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] 'HNat
e2
                in case (abt '[] a -> Maybe (Datum (abt '[]) HBool)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Maybe (Datum (abt '[]) HBool)
isLitBool abt '[] a
a1, abt '[] a -> Maybe (Datum (abt '[]) HBool)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Maybe (Datum (abt '[]) HBool)
isLitBool abt '[] a
a2) of
                     (Just Datum (abt '[]) HBool
b1, Just Datum (abt '[]) HBool
b2)
                         | Datum (abt '[]) HBool -> Bool
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Datum (abt '[]) HBool -> Bool
isLitTrue Datum (abt '[]) HBool
b1 Bool -> Bool -> Bool
&& Datum (abt '[]) HBool -> Bool
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Datum (abt '[]) HBool -> Bool
isLitFalse Datum (abt '[]) HBool
b2 ->
                             abt '[] 'HNat -> Dis abt ()
constrainInd (abt '[] 'HNat -> Dis abt ()) -> abt '[] 'HNat -> Dis abt ()
forall a b. (a -> b) -> a -> b
$ abt '[] HBool -> abt '[] 'HNat -> abt '[] 'HNat -> abt '[] 'HNat
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] HBool -> abt '[] a -> abt '[] a -> abt '[] a
P.if_ abt '[] a
abt '[] HBool
v0 (Natural -> abt '[] 'HNat
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Natural -> abt '[] 'HNat
P.nat_ Natural
0) (Natural -> abt '[] 'HNat
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Natural -> abt '[] 'HNat
P.nat_ Natural
1) 
                         | Datum (abt '[]) HBool -> Bool
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Datum (abt '[]) HBool -> Bool
isLitTrue Datum (abt '[]) HBool
b2 Bool -> Bool -> Bool
&& Datum (abt '[]) HBool -> Bool
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Datum (abt '[]) HBool -> Bool
isLitFalse Datum (abt '[]) HBool
b1 ->
                             abt '[] 'HNat -> Dis abt ()
constrainInd (abt '[] 'HNat -> Dis abt ()) -> abt '[] 'HNat -> Dis abt ()
forall a b. (a -> b) -> a -> b
$ abt '[] HBool -> abt '[] 'HNat -> abt '[] 'HNat -> abt '[] 'HNat
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] HBool -> abt '[] a -> abt '[] a -> abt '[] a
P.if_ abt '[] a
abt '[] HBool
v0 (Natural -> abt '[] 'HNat
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Natural -> abt '[] 'HNat
P.nat_ Natural
1) (Natural -> abt '[] 'HNat
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Natural -> abt '[] 'HNat
P.nat_ Natural
0)
                         | Bool
otherwise -> [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error [Char]
"constrainValue: cannot invert (Index [b,b] i)"
                     (Maybe (Datum (abt '[]) HBool), Maybe (Datum (abt '[]) HBool))
_ -> [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error [Char]
"TODO: constrainValue (Index [b1,b2] i)"
            Maybe (TypeEq a HBool)
Nothing -> [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error [Char]
"TODO: constrainValue (Index [a1,a2] i)"
      go (WArrayLiteral [abt '[] a]
_) = Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
Dis abt a
bot
      go Head abt ('HArray a)
_ = [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error [Char]
"constrainValueIdxArrLit: unknown ArrayLiteral form"

-- | Helpers for disintegrating bern             
isLitBool :: (ABT Term abt) => abt '[] a -> Maybe (Datum (abt '[]) HBool)
isLitBool :: abt '[] a -> Maybe (Datum (abt '[]) HBool)
isLitBool abt '[] a
e = abt '[] a
-> (Variable a -> Maybe (Datum (abt '[]) HBool))
-> (Term abt a -> Maybe (Datum (abt '[]) HBool))
-> Maybe (Datum (abt '[]) HBool)
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k) r.
ABT syn abt =>
abt '[] a -> (Variable a -> r) -> (syn abt a -> r) -> r
caseVarSyn abt '[] a
e (Maybe (Datum (abt '[]) HBool)
-> Variable a -> Maybe (Datum (abt '[]) HBool)
forall a b. a -> b -> a
const Maybe (Datum (abt '[]) HBool)
forall a. Maybe a
Nothing) ((Term abt a -> Maybe (Datum (abt '[]) HBool))
 -> Maybe (Datum (abt '[]) HBool))
-> (Term abt a -> Maybe (Datum (abt '[]) HBool))
-> Maybe (Datum (abt '[]) HBool)
forall a b. (a -> b) -> a -> b
$ \Term abt a
b ->
                  case Term abt a
b of
                    Datum_ d :: Datum (abt '[]) (HData' t)
d@(Datum Text
_ Sing (HData' t)
typ DatumCode (Code t) (abt '[]) (HData' t)
_) -> case (Sing (HData' t) -> Sing HBool -> Maybe (TypeEq (HData' t) HBool)
forall k (a :: k -> *) (i :: k) (j :: k).
JmEq1 a =>
a i -> a j -> Maybe (TypeEq i j)
jmEq1 Sing (HData' t)
Sing (HData' t)
typ Sing HBool
sBool) of
                                                  Just TypeEq (HData' t) HBool
Refl -> Datum (abt '[]) HBool -> Maybe (Datum (abt '[]) HBool)
forall a. a -> Maybe a
Just Datum (abt '[]) (HData' t)
Datum (abt '[]) HBool
d
                                                  Maybe (TypeEq (HData' t) HBool)
Nothing   -> Maybe (Datum (abt '[]) HBool)
forall a. Maybe a
Nothing
                    Term abt a
_ -> Maybe (Datum (abt '[]) HBool)
forall a. Maybe a
Nothing

isLitTrue :: (ABT Term abt) => Datum (abt '[]) HBool -> Bool
isLitTrue :: Datum (abt '[]) HBool -> Bool
isLitTrue (Datum Text
tdTrue Sing (HData' t)
sBool (Inl DatumStruct xs (abt '[]) (HData' t)
Done)) = Bool
True
isLitTrue Datum (abt '[]) HBool
_                               = Bool
False

isLitFalse :: (ABT Term abt) => Datum (abt '[]) HBool -> Bool
isLitFalse :: Datum (abt '[]) HBool -> Bool
isLitFalse (Datum Text
tdFalse Sing (HData' t)
sBool (Inr (Inl DatumStruct xs (abt '[]) (HData' t)
Done))) = Bool
True
isLitFalse Datum (abt '[]) HBool
_                                      = Bool
False             

----------------------------------------------------------------
-- | N.B., as with 'constrainValue', we assume that the first
-- argument is emissible. So it is the caller's responsibility to
-- ensure this (by calling 'atomize' as appropriate).
--
-- TODO: capture the emissibility requirement on the first argument
-- in the types.
constrainValueMeasureOp
    :: forall abt typs args a
    .  (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs)
    => abt '[] ('HMeasure a)
    -> MeasureOp typs a
    -> SArgs abt args
    -> Dis abt ()
constrainValueMeasureOp :: abt '[] ('HMeasure a)
-> MeasureOp typs a -> SArgs abt args -> Dis abt ()
constrainValueMeasureOp abt '[] ('HMeasure a)
v0 = MeasureOp typs a -> SArgs abt args -> Dis abt ()
go
    where
    -- TODO: for Lebesgue and Counting we use @bot@ because that's
    -- what the old finally-tagless code seems to have been doing.
    -- But is that right, or should they really be @return ()@?
    go :: MeasureOp typs a -> SArgs abt args -> Dis abt ()
    go :: MeasureOp typs a -> SArgs abt args -> Dis abt ()
go MeasureOp typs a
Lebesgue    = \(abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) ->
        abt '[] ('HMeasure a) -> abt '[] ('HMeasure a) -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] ('HMeasure a)
v0 (abt '[] 'HReal -> abt '[] 'HReal -> abt '[] ('HMeasure 'HReal)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HReal -> abt '[] ('HMeasure 'HReal)
P.lebesgue' abt vars a
abt '[] 'HReal
e1 abt vars a
abt '[] 'HReal
e2)
    go MeasureOp typs a
Counting    = \SArgs abt args
End               -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
Dis abt a
bot -- TODO: see note above
    go MeasureOp typs a
Categorical = \(abt vars a
e1 :* SArgs abt args
End)       ->
        abt '[] ('HMeasure a) -> abt '[] ('HMeasure a) -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] ('HMeasure a)
v0 (abt '[] ('HArray 'HProb) -> abt '[] ('HMeasure 'HNat)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] ('HArray 'HProb) -> abt '[] ('HMeasure 'HNat)
P.categorical abt vars a
abt '[] ('HArray 'HProb)
e1)
    go MeasureOp typs a
Uniform     = \(abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) ->
        abt '[] ('HMeasure a) -> abt '[] ('HMeasure a) -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] ('HMeasure a)
v0 (abt '[] 'HReal -> abt '[] 'HReal -> abt '[] ('HMeasure 'HReal)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HReal -> abt '[] ('HMeasure 'HReal)
P.uniform' abt vars a
abt '[] 'HReal
e1 abt vars a
abt '[] 'HReal
e2)
    go MeasureOp typs a
Normal      = \(abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) ->
        abt '[] ('HMeasure a) -> abt '[] ('HMeasure a) -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] ('HMeasure a)
v0 (abt '[] 'HReal -> abt '[] 'HProb -> abt '[] ('HMeasure 'HReal)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HProb -> abt '[] ('HMeasure 'HReal)
P.normal'  abt vars a
abt '[] 'HReal
e1 abt vars a
abt '[] 'HProb
e2)
    go MeasureOp typs a
Poisson     = \(abt vars a
e1 :* SArgs abt args
End)       ->
        abt '[] ('HMeasure a) -> abt '[] ('HMeasure a) -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] ('HMeasure a)
v0 (abt '[] 'HProb -> abt '[] ('HMeasure 'HNat)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> abt '[] ('HMeasure 'HNat)
P.poisson' abt vars a
abt '[] 'HProb
e1)
    go MeasureOp typs a
Gamma       = \(abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) ->
        abt '[] ('HMeasure a) -> abt '[] ('HMeasure a) -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] ('HMeasure a)
v0 (abt '[] 'HProb -> abt '[] 'HProb -> abt '[] ('HMeasure 'HProb)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> abt '[] 'HProb -> abt '[] ('HMeasure 'HProb)
P.gamma'   abt vars a
abt '[] 'HProb
e1 abt vars a
abt '[] 'HProb
e2)
    go MeasureOp typs a
Beta        = \(abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) ->
        abt '[] ('HMeasure a) -> abt '[] ('HMeasure a) -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] ('HMeasure a)
v0 (abt '[] 'HProb -> abt '[] 'HProb -> abt '[] ('HMeasure 'HProb)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> abt '[] 'HProb -> abt '[] ('HMeasure 'HProb)
P.beta'    abt vars a
abt '[] 'HProb
e1 abt vars a
abt '[] 'HProb
e2)

----------------------------------------------------------------
-- | N.B., We assume that the first argument, @v0@, is already
-- atomized. So, this must be ensured before recursing, but we can
-- assume it's already been done by the IH.
--
-- N.B., we also rely on the fact that our 'HSemiring' instances
-- are actually all /commutative/ semirings. If that ever becomes
-- not the case, then we'll need to fix things here.
--
-- As written, this will do a lot of redundant work in atomizing
-- the subterms other than the one we choose to go backward on.
-- Because evaluation has side-effects on the heap and is heap
-- dependent, it seems like there may not be a way around that
-- issue. (I.e., we could use dynamic programming to efficiently
-- build up the 'M' computations, but not to evaluate them.) Of
-- course, really we shouldn't be relying on the structure of the
-- program here; really we should be looking at the heap-bound
-- variables in the term: choosing each @x@ to go backward on, treat
-- the term as a function of @x@, atomize that function (hence going
-- forward on the rest of the variables), and then invert it and
-- get the Jacobian.
--
-- TODO: find some way to capture in the type that the first argument
-- must be emissible.
constrainNaryOp
    :: (ABT Term abt)
    => abt '[] a
    -> NaryOp a
    -> Seq (abt '[] a)
    -> Dis abt ()
constrainNaryOp :: abt '[] a -> NaryOp a -> Seq (abt '[] a) -> Dis abt ()
constrainNaryOp abt '[] a
v0 NaryOp a
o =
    case NaryOp a
o of
    Sum HSemiring a
theSemi ->
        (Seq (abt '[] a) -> abt '[] a -> Seq (abt '[] a) -> Dis abt ())
-> Seq (abt '[] a) -> Dis abt ()
forall (m :: * -> *) a b.
Alternative m =>
(Seq a -> a -> Seq a -> m b) -> Seq a -> m b
lubSeq ((Seq (abt '[] a) -> abt '[] a -> Seq (abt '[] a) -> Dis abt ())
 -> Seq (abt '[] a) -> Dis abt ())
-> (Seq (abt '[] a) -> abt '[] a -> Seq (abt '[] a) -> Dis abt ())
-> Seq (abt '[] a)
-> Dis abt ()
forall a b. (a -> b) -> a -> b
$ \Seq (abt '[] a)
es1 abt '[] a
e Seq (abt '[] a)
es2 -> do
            Whnf abt a
u <- abt '[] a -> Dis abt (Whnf abt a)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
TermEvaluator abt (Dis abt)
atomize (abt '[] a -> Dis abt (Whnf abt a))
-> abt '[] a -> Dis abt (Whnf abt a)
forall a b. (a -> b) -> a -> b
$ Term abt a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn (NaryOp a -> Seq (abt '[] a) -> Term abt a
forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *).
NaryOp a -> Seq (abt '[] a) -> Term abt a
NaryOp_ (HSemiring a -> NaryOp a
forall (a :: Hakaru). HSemiring a -> NaryOp a
Sum HSemiring a
theSemi) (Seq (abt '[] a)
es1 Seq (abt '[] a) -> Seq (abt '[] a) -> Seq (abt '[] a)
forall a. Seq a -> Seq a -> Seq a
S.>< Seq (abt '[] a)
es2))
            Whnf abt a
v <- abt '[] a -> Dis abt (Whnf abt a)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
TermEvaluator abt (Dis abt)
evaluate_ (abt '[] a -> Dis abt (Whnf abt a))
-> abt '[] a -> Dis abt (Whnf abt a)
forall a b. (a -> b) -> a -> b
$ HSemiring a -> abt '[] a -> abt '[] a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
HSemiring a -> abt '[] a -> abt '[] a -> abt '[] a
P.unsafeMinus_ HSemiring a
theSemi abt '[] a
v0 (Whnf abt a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Whnf abt a -> abt '[] a
fromWhnf Whnf abt a
u)
            abt '[] a -> abt '[] a -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue (Whnf abt a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Whnf abt a -> abt '[] a
fromWhnf Whnf abt a
v) abt '[] a
e
    Prod HSemiring a
theSemi ->
        (Seq (abt '[] a) -> abt '[] a -> Seq (abt '[] a) -> Dis abt ())
-> Seq (abt '[] a) -> Dis abt ()
forall (m :: * -> *) a b.
Alternative m =>
(Seq a -> a -> Seq a -> m b) -> Seq a -> m b
lubSeq ((Seq (abt '[] a) -> abt '[] a -> Seq (abt '[] a) -> Dis abt ())
 -> Seq (abt '[] a) -> Dis abt ())
-> (Seq (abt '[] a) -> abt '[] a -> Seq (abt '[] a) -> Dis abt ())
-> Seq (abt '[] a)
-> Dis abt ()
forall a b. (a -> b) -> a -> b
$ \Seq (abt '[] a)
es1 abt '[] a
e Seq (abt '[] a)
es2 -> do
            Whnf abt a
u <- abt '[] a -> Dis abt (Whnf abt a)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
TermEvaluator abt (Dis abt)
atomize (abt '[] a -> Dis abt (Whnf abt a))
-> abt '[] a -> Dis abt (Whnf abt a)
forall a b. (a -> b) -> a -> b
$ Term abt a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn (NaryOp a -> Seq (abt '[] a) -> Term abt a
forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *).
NaryOp a -> Seq (abt '[] a) -> Term abt a
NaryOp_ (HSemiring a -> NaryOp a
forall (a :: Hakaru). HSemiring a -> NaryOp a
Prod HSemiring a
theSemi) (Seq (abt '[] a)
es1 Seq (abt '[] a) -> Seq (abt '[] a) -> Seq (abt '[] a)
forall a. Seq a -> Seq a -> Seq a
S.>< Seq (abt '[] a)
es2))
            let u' :: abt '[] a
u' = Whnf abt a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Whnf abt a -> abt '[] a
fromWhnf Whnf abt a
u -- TODO: emitLet?
            abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
emitWeight (abt '[] 'HProb -> Dis abt ()) -> abt '[] 'HProb -> Dis abt ()
forall a b. (a -> b) -> a -> b
$ abt '[] 'HProb -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HFractional_ a) =>
abt '[] a -> abt '[] a
P.recip (HSemiring a -> abt '[] a -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
HSemiring a -> abt '[] a -> abt '[] 'HProb
toProb_abs HSemiring a
theSemi abt '[] a
u')
            Whnf abt a
v <- abt '[] a -> Dis abt (Whnf abt a)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
TermEvaluator abt (Dis abt)
evaluate_ (abt '[] a -> Dis abt (Whnf abt a))
-> abt '[] a -> Dis abt (Whnf abt a)
forall a b. (a -> b) -> a -> b
$ HSemiring a -> abt '[] a -> abt '[] a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
HSemiring a -> abt '[] a -> abt '[] a -> abt '[] a
P.unsafeDiv_ HSemiring a
theSemi abt '[] a
v0 abt '[] a
u'
            abt '[] a -> abt '[] a -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue (Whnf abt a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Whnf abt a -> abt '[] a
fromWhnf Whnf abt a
v) abt '[] a
e
    Max HOrd a
theOrd ->
        (Seq (abt '[] a) -> abt '[] a -> Seq (abt '[] a) -> Dis abt ())
-> Seq (abt '[] a) -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) a b.
ABT Term abt =>
(Seq a -> a -> Seq a -> Dis abt b) -> Seq a -> Dis abt b
chooseSeq ((Seq (abt '[] a) -> abt '[] a -> Seq (abt '[] a) -> Dis abt ())
 -> Seq (abt '[] a) -> Dis abt ())
-> (Seq (abt '[] a) -> abt '[] a -> Seq (abt '[] a) -> Dis abt ())
-> Seq (abt '[] a)
-> Dis abt ()
forall a b. (a -> b) -> a -> b
$ \Seq (abt '[] a)
es1 abt '[] a
e Seq (abt '[] a)
es2 -> do
            Whnf abt a
u <- abt '[] a -> Dis abt (Whnf abt a)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
TermEvaluator abt (Dis abt)
atomize (abt '[] a -> Dis abt (Whnf abt a))
-> abt '[] a -> Dis abt (Whnf abt a)
forall a b. (a -> b) -> a -> b
$ Term abt a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn (NaryOp a -> Seq (abt '[] a) -> Term abt a
forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *).
NaryOp a -> Seq (abt '[] a) -> Term abt a
NaryOp_ (HOrd a -> NaryOp a
forall (a :: Hakaru). HOrd a -> NaryOp a
Max HOrd a
theOrd) (Seq (abt '[] a)
es1 Seq (abt '[] a) -> Seq (abt '[] a) -> Seq (abt '[] a)
forall a. Seq a -> Seq a -> Seq a
S.>< Seq (abt '[] a)
es2))
            abt '[] HBool -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> Dis abt ()
emitGuard (abt '[] HBool -> Dis abt ()) -> abt '[] HBool -> Dis abt ()
forall a b. (a -> b) -> a -> b
$ PrimOp '[a, a] HBool -> abt '[] a -> abt '[] a -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru) (b :: Hakaru)
       (c :: Hakaru).
ABT Term abt =>
PrimOp '[a, b] c -> abt '[] a -> abt '[] b -> abt '[] c
P.primOp2_ (HOrd a -> PrimOp '[a, a] HBool
forall (a :: Hakaru). HOrd a -> PrimOp '[a, a] HBool
Less HOrd a
theOrd) (Whnf abt a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Whnf abt a -> abt '[] a
fromWhnf Whnf abt a
u) abt '[] a
v0
            abt '[] a -> abt '[] a -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] a
v0 abt '[] a
e
    NaryOp a
_ -> [Char] -> Seq (abt '[] a) -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error ([Char] -> Seq (abt '[] a) -> Dis abt ())
-> [Char] -> Seq (abt '[] a) -> Dis abt ()
forall a b. (a -> b) -> a -> b
$ [Char]
"TODO: constrainNaryOp{" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ NaryOp a -> [Char]
forall a. Show a => a -> [Char]
show NaryOp a
o [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
"}"


-- TODO: if this function (or the component @toProb@ and @semiringAbs@
-- parts) turn out to be useful elsewhere, then we should move it
-- to the Prelude.
toProb_abs :: (ABT Term abt) => HSemiring a -> abt '[] a -> abt '[] 'HProb
toProb_abs :: HSemiring a -> abt '[] a -> abt '[] 'HProb
toProb_abs HSemiring a
HSemiring_Nat  = abt '[] a -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HNat -> abt '[] 'HProb
P.nat2prob
toProb_abs HSemiring a
HSemiring_Int  = abt '[] 'HNat -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HNat -> abt '[] 'HProb
P.nat2prob (abt '[] 'HNat -> abt '[] 'HProb)
-> (abt '[] a -> abt '[] 'HNat) -> abt '[] a -> abt '[] 'HProb
forall b c a. (b -> c) -> (a -> b) -> a -> c
. abt '[] a -> abt '[] 'HNat
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HRing_ a) =>
abt '[] a -> abt '[] (NonNegative a)
P.abs_
toProb_abs HSemiring a
HSemiring_Prob = abt '[] a -> abt '[] 'HProb
forall a. a -> a
id
toProb_abs HSemiring a
HSemiring_Real = abt '[] a -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HRing_ a) =>
abt '[] a -> abt '[] (NonNegative a)
P.abs_


-- TODO: is there any way to optimise the zippering over the Seq, a la
-- 'S.inits' or 'S.tails'?
-- TODO: really we want a dynamic programming
-- approach to avoid unnecessary repetition of calling @evaluateNaryOp
-- evaluate_@ on the two subsequences...
lubSeq :: (Alternative m) => (Seq a -> a -> Seq a -> m b) -> Seq a -> m b
lubSeq :: (Seq a -> a -> Seq a -> m b) -> Seq a -> m b
lubSeq Seq a -> a -> Seq a -> m b
f = Seq a -> Seq a -> m b
go Seq a
forall a. Seq a
S.empty
    where
    go :: Seq a -> Seq a -> m b
go Seq a
xs Seq a
ys =
        case Seq a -> ViewL a
forall a. Seq a -> ViewL a
S.viewl Seq a
ys of
        ViewL a
S.EmptyL   -> m b
forall (f :: * -> *) a. Alternative f => f a
empty
        a
y S.:< Seq a
ys' -> Seq a -> a -> Seq a -> m b
f Seq a
xs a
y Seq a
ys' m b -> m b -> m b
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> Seq a -> Seq a -> m b
go (Seq a
xs Seq a -> a -> Seq a
forall a. Seq a -> a -> Seq a
S.|> a
y) Seq a
ys'

chooseSeq :: (ABT Term abt)
          => (Seq a -> a -> Seq a -> Dis abt b)
          -> Seq a
          -> Dis abt b
chooseSeq :: (Seq a -> a -> Seq a -> Dis abt b) -> Seq a -> Dis abt b
chooseSeq Seq a -> a -> Seq a -> Dis abt b
f = [Dis abt b] -> Dis abt b
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
[Dis abt a] -> Dis abt a
choose  ([Dis abt b] -> Dis abt b)
-> (Seq a -> [Dis abt b]) -> Seq a -> Dis abt b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Seq a -> Seq a -> [Dis abt b]
go Seq a
forall a. Seq a
S.empty
    where
    go :: Seq a -> Seq a -> [Dis abt b]
go Seq a
xs Seq a
ys =
        case Seq a -> ViewL a
forall a. Seq a -> ViewL a
S.viewl Seq a
ys of
        ViewL a
S.EmptyL   -> []
        a
y S.:< Seq a
ys' -> Seq a -> a -> Seq a -> Dis abt b
f Seq a
xs a
y Seq a
ys' Dis abt b -> [Dis abt b] -> [Dis abt b]
forall a. a -> [a] -> [a]
: Seq a -> Seq a -> [Dis abt b]
go (Seq a
xs Seq a -> a -> Seq a
forall a. Seq a -> a -> Seq a
S.|> a
y) Seq a
ys'


----------------------------------------------------------------
-- HACK: for a lot of these, we can't use the prelude functions
-- because Haskell can't figure out our polymorphism, so we have
-- to define our own versions for manually passing dictionaries
-- around.
--
-- | N.B., We assume that the first argument, @v0@, is already
-- atomized. So, this must be ensured before recursing, but we can
-- assume it's already been done by the IH.
--
-- TODO: find some way to capture in the type that the first argument
-- must be emissible.
constrainPrimOp
    :: forall abt typs args a
    .  (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs)
    => abt '[] a
    -> PrimOp typs a
    -> SArgs abt args
    -> Dis abt ()
constrainPrimOp :: abt '[] a -> PrimOp typs a -> SArgs abt args -> Dis abt ()
constrainPrimOp abt '[] a
v0 = PrimOp typs a -> SArgs abt args -> Dis abt ()
go
    where
    error_TODO :: [Char] -> a
error_TODO [Char]
op = [Char] -> a
forall a. HasCallStack => [Char] -> a
error ([Char] -> a) -> [Char] -> a
forall a b. (a -> b) -> a -> b
$ [Char]
"TODO: constrainPrimOp{" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
op [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++[Char]
"}"

    go :: PrimOp typs a -> SArgs abt args -> Dis abt ()
    go :: PrimOp typs a -> SArgs abt args -> Dis abt ()
go PrimOp typs a
Not  = \(abt vars a
e1 :* SArgs abt args
End)       -> [Char] -> Dis abt ()
forall a. [Char] -> a
error_TODO [Char]
"Not"
    go PrimOp typs a
Impl = \(abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) -> [Char] -> Dis abt ()
forall a. [Char] -> a
error_TODO [Char]
"Impl"
    go PrimOp typs a
Diff = \(abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) -> [Char] -> Dis abt ()
forall a. [Char] -> a
error_TODO [Char]
"Diff"
    go PrimOp typs a
Nand = \(abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) -> [Char] -> Dis abt ()
forall a. [Char] -> a
error_TODO [Char]
"Nand"
    go PrimOp typs a
Nor  = \(abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) -> [Char] -> Dis abt ()
forall a. [Char] -> a
error_TODO [Char]
"Nor"

    go PrimOp typs a
Pi = \SArgs abt args
End -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
Dis abt a
bot -- because @dirac pi@ has no density wrt lebesgue

    go PrimOp typs a
Sin = \(abt vars a
e1 :* SArgs abt args
End) -> do
        abt '[] a
x0 <- abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' abt '[] a
v0
        abt '[] 'HInt
n  <- Variable 'HInt -> abt '[] 'HInt
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var (Variable 'HInt -> abt '[] 'HInt)
-> Dis abt (Variable 'HInt) -> Dis abt (abt '[] 'HInt)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> abt '[] ('HMeasure 'HInt) -> Dis abt (Variable 'HInt)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] ('HMeasure a) -> Dis abt (Variable a)
emitMBind abt '[] ('HMeasure 'HInt)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] ('HMeasure 'HInt)
P.counting
        let tau_n :: abt '[] 'HReal
tau_n = Rational -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Rational -> abt '[] 'HReal
P.real_ Rational
2 abt '[] 'HReal -> abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
P.* abt '[] 'HInt -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HInt -> abt '[] 'HReal
P.fromInt abt '[] 'HInt
n abt '[] 'HReal -> abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
P.* abt '[] 'HReal
forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *).
(RealProb a, ABT Term abt) =>
abt '[] a
P.pi -- TODO: emitLet?
        abt '[] HBool -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> Dis abt ()
emitGuard (abt '[] a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HRing_ a) =>
abt '[] a -> abt '[] a
P.negate abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a
P.one abt '[] a -> abt '[] a -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HOrd_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.< abt '[] a
x0 abt '[] HBool -> abt '[] HBool -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> abt '[] HBool -> abt '[] HBool
P.&& abt '[] a
x0 abt '[] a -> abt '[] a -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HOrd_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.< abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a
P.one)
        abt '[] 'HReal
v  <- Variable 'HReal -> abt '[] 'HReal
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var (Variable 'HReal -> abt '[] 'HReal)
-> Dis abt (Variable 'HReal) -> Dis abt (abt '[] 'HReal)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [abt '[] ('HMeasure 'HReal)] -> Dis abt (Variable 'HReal)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
[abt '[] ('HMeasure a)] -> Dis abt (Variable a)
emitSuperpose
            [ abt '[] 'HReal -> abt '[] ('HMeasure 'HReal)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] ('HMeasure a)
P.dirac (abt '[] 'HReal
tau_n abt '[] 'HReal -> abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
P.+ abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HReal
P.asin abt '[] a
abt '[] 'HReal
x0)
            , abt '[] 'HReal -> abt '[] ('HMeasure 'HReal)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] ('HMeasure a)
P.dirac (abt '[] 'HReal
tau_n abt '[] 'HReal -> abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
P.+ abt '[] 'HReal
forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *).
(RealProb a, ABT Term abt) =>
abt '[] a
P.pi abt '[] 'HReal -> abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HRing_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
P.- abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HReal
P.asin abt '[] a
abt '[] 'HReal
x0)
            ]
        abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
emitWeight
            (abt '[] 'HProb -> Dis abt ())
-> (abt '[] 'HReal -> abt '[] 'HProb)
-> abt '[] 'HReal
-> Dis abt ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. abt '[] 'HProb -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HFractional_ a) =>
abt '[] a -> abt '[] a
P.recip
            (abt '[] 'HProb -> abt '[] 'HProb)
-> (abt '[] 'HReal -> abt '[] 'HProb)
-> abt '[] 'HReal
-> abt '[] 'HProb
forall b c a. (b -> c) -> (a -> b) -> a -> c
. abt '[] 'HProb -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HRadical_ a) =>
abt '[] a -> abt '[] a
P.sqrt
            (abt '[] 'HProb -> abt '[] 'HProb)
-> (abt '[] 'HReal -> abt '[] 'HProb)
-> abt '[] 'HReal
-> abt '[] 'HProb
forall b c a. (b -> c) -> (a -> b) -> a -> c
. abt '[] 'HReal -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HProb
P.unsafeProb
            (abt '[] 'HReal -> Dis abt ()) -> abt '[] 'HReal -> Dis abt ()
forall a b. (a -> b) -> a -> b
$ (abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a
P.one abt '[] a -> abt '[] a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HRing_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
P.- abt '[] a
x0 abt '[] a -> abt '[] 'HNat -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] 'HNat -> abt '[] a
P.^ Natural -> abt '[] 'HNat
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Natural -> abt '[] 'HNat
P.nat_ Natural
2)
        abt '[] 'HReal -> abt '[] 'HReal -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] 'HReal
v abt vars a
abt '[] 'HReal
e1

    go PrimOp typs a
Cos = \(abt vars a
e1 :* SArgs abt args
End) -> do
        abt '[] a
x0 <- abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' abt '[] a
v0
        abt '[] 'HInt
n  <- Variable 'HInt -> abt '[] 'HInt
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var (Variable 'HInt -> abt '[] 'HInt)
-> Dis abt (Variable 'HInt) -> Dis abt (abt '[] 'HInt)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> abt '[] ('HMeasure 'HInt) -> Dis abt (Variable 'HInt)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] ('HMeasure a) -> Dis abt (Variable a)
emitMBind abt '[] ('HMeasure 'HInt)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] ('HMeasure 'HInt)
P.counting
        let tau_n :: abt '[] 'HReal
tau_n = Rational -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Rational -> abt '[] 'HReal
P.real_ Rational
2 abt '[] 'HReal -> abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
P.* abt '[] 'HInt -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HInt -> abt '[] 'HReal
P.fromInt abt '[] 'HInt
n abt '[] 'HReal -> abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
P.* abt '[] 'HReal
forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *).
(RealProb a, ABT Term abt) =>
abt '[] a
P.pi
        abt '[] HBool -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> Dis abt ()
emitGuard (abt '[] a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HRing_ a) =>
abt '[] a -> abt '[] a
P.negate abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a
P.one abt '[] a -> abt '[] a -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HOrd_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.< abt '[] a
x0 abt '[] HBool -> abt '[] HBool -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> abt '[] HBool -> abt '[] HBool
P.&& abt '[] a
x0 abt '[] a -> abt '[] a -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HOrd_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.< abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a
P.one)
        abt '[] 'HReal
r  <- abt '[] 'HReal -> Dis abt (abt '[] 'HReal)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' (abt '[] 'HReal
tau_n abt '[] 'HReal -> abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
P.+ abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HReal
P.acos abt '[] a
abt '[] 'HReal
x0)
        abt '[] 'HReal
v  <- Variable 'HReal -> abt '[] 'HReal
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var (Variable 'HReal -> abt '[] 'HReal)
-> Dis abt (Variable 'HReal) -> Dis abt (abt '[] 'HReal)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [abt '[] ('HMeasure 'HReal)] -> Dis abt (Variable 'HReal)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
[abt '[] ('HMeasure a)] -> Dis abt (Variable a)
emitSuperpose [abt '[] 'HReal -> abt '[] ('HMeasure 'HReal)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] ('HMeasure a)
P.dirac abt '[] 'HReal
r, abt '[] 'HReal -> abt '[] ('HMeasure 'HReal)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] ('HMeasure a)
P.dirac (abt '[] 'HReal
r abt '[] 'HReal -> abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
P.+ abt '[] 'HReal
forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *).
(RealProb a, ABT Term abt) =>
abt '[] a
P.pi)]
        abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
emitWeight
            (abt '[] 'HProb -> Dis abt ())
-> (abt '[] 'HReal -> abt '[] 'HProb)
-> abt '[] 'HReal
-> Dis abt ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. abt '[] 'HProb -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HFractional_ a) =>
abt '[] a -> abt '[] a
P.recip
            (abt '[] 'HProb -> abt '[] 'HProb)
-> (abt '[] 'HReal -> abt '[] 'HProb)
-> abt '[] 'HReal
-> abt '[] 'HProb
forall b c a. (b -> c) -> (a -> b) -> a -> c
. abt '[] 'HProb -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HRadical_ a) =>
abt '[] a -> abt '[] a
P.sqrt
            (abt '[] 'HProb -> abt '[] 'HProb)
-> (abt '[] 'HReal -> abt '[] 'HProb)
-> abt '[] 'HReal
-> abt '[] 'HProb
forall b c a. (b -> c) -> (a -> b) -> a -> c
. abt '[] 'HReal -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HProb
P.unsafeProb
            (abt '[] 'HReal -> Dis abt ()) -> abt '[] 'HReal -> Dis abt ()
forall a b. (a -> b) -> a -> b
$ (abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a
P.one abt '[] a -> abt '[] a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HRing_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
P.- abt '[] a
x0 abt '[] a -> abt '[] 'HNat -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] 'HNat -> abt '[] a
P.^ Natural -> abt '[] 'HNat
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Natural -> abt '[] 'HNat
P.nat_ Natural
2)
        abt '[] 'HReal -> abt '[] 'HReal -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] 'HReal
v abt vars a
abt '[] 'HReal
e1

    go PrimOp typs a
Tan = \(abt vars a
e1 :* SArgs abt args
End) -> do
        abt '[] a
x0 <- abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' abt '[] a
v0
        abt '[] 'HInt
n  <- Variable 'HInt -> abt '[] 'HInt
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var (Variable 'HInt -> abt '[] 'HInt)
-> Dis abt (Variable 'HInt) -> Dis abt (abt '[] 'HInt)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> abt '[] ('HMeasure 'HInt) -> Dis abt (Variable 'HInt)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] ('HMeasure a) -> Dis abt (Variable a)
emitMBind abt '[] ('HMeasure 'HInt)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] ('HMeasure 'HInt)
P.counting
        abt '[] 'HReal
r  <- abt '[] 'HReal -> Dis abt (abt '[] 'HReal)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' (abt '[] 'HInt -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HInt -> abt '[] 'HReal
P.fromInt abt '[] 'HInt
n abt '[] 'HReal -> abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
P.* abt '[] 'HReal
forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *).
(RealProb a, ABT Term abt) =>
abt '[] a
P.pi abt '[] 'HReal -> abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
P.+ abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HReal
P.atan abt '[] a
abt '[] 'HReal
x0)
        abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
emitWeight (abt '[] 'HProb -> Dis abt ()) -> abt '[] 'HProb -> Dis abt ()
forall a b. (a -> b) -> a -> b
$ abt '[] 'HProb -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HFractional_ a) =>
abt '[] a -> abt '[] a
P.recip (abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a
P.one abt '[] 'HProb -> abt '[] 'HProb -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
P.+ abt '[] a -> abt '[] (NonNegative a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HRing_ a) =>
abt '[] a -> abt '[] (NonNegative a)
P.square abt '[] a
x0)
        abt '[] 'HReal -> abt '[] 'HReal -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] 'HReal
r abt vars a
abt '[] 'HReal
e1

    go PrimOp typs a
Asin = \(abt vars a
e1 :* SArgs abt args
End) -> do
        abt '[] a
x0 <- abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' abt '[] a
v0
        abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
emitWeight (abt '[] 'HProb -> Dis abt ()) -> abt '[] 'HProb -> Dis abt ()
forall a b. (a -> b) -> a -> b
$ abt '[] 'HReal -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HProb
P.unsafeProb (abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HReal
P.cos abt '[] a
abt '[] 'HReal
x0)
        -- TODO: bounds check for -pi/2 <= v0 < pi/2
        abt '[] 'HReal -> abt '[] 'HReal -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue (abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HReal
P.sin abt '[] a
abt '[] 'HReal
x0) abt vars a
abt '[] 'HReal
e1

    go PrimOp typs a
Acos = \(abt vars a
e1 :* SArgs abt args
End) -> do
        abt '[] a
x0 <- abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' abt '[] a
v0
        abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
emitWeight (abt '[] 'HProb -> Dis abt ()) -> abt '[] 'HProb -> Dis abt ()
forall a b. (a -> b) -> a -> b
$ abt '[] 'HReal -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HProb
P.unsafeProb (abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HReal
P.sin abt '[] a
abt '[] 'HReal
x0)
        abt '[] 'HReal -> abt '[] 'HReal -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue (abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HReal
P.cos abt '[] a
abt '[] 'HReal
x0) abt vars a
abt '[] 'HReal
e1

    go PrimOp typs a
Atan = \(abt vars a
e1 :* SArgs abt args
End) -> do
        abt '[] a
x0 <- abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' abt '[] a
v0
        abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
emitWeight (abt '[] 'HProb -> Dis abt ()) -> abt '[] 'HProb -> Dis abt ()
forall a b. (a -> b) -> a -> b
$ abt '[] 'HProb -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HFractional_ a) =>
abt '[] a -> abt '[] a
P.recip (abt '[] 'HReal -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HProb
P.unsafeProb (abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HReal
P.cos abt '[] a
abt '[] 'HReal
x0 abt '[] 'HReal -> abt '[] 'HNat -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] 'HNat -> abt '[] a
P.^ Natural -> abt '[] 'HNat
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Natural -> abt '[] 'HNat
P.nat_ Natural
2))
        abt '[] 'HReal -> abt '[] 'HReal -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue (abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HReal
P.tan abt '[] a
abt '[] 'HReal
x0) abt vars a
abt '[] 'HReal
e1

    go PrimOp typs a
Sinh      = \(abt vars a
e1 :* SArgs abt args
End)       -> [Char] -> Dis abt ()
forall a. [Char] -> a
error_TODO [Char]
"Sinh"
    go PrimOp typs a
Cosh      = \(abt vars a
e1 :* SArgs abt args
End)       -> [Char] -> Dis abt ()
forall a. [Char] -> a
error_TODO [Char]
"Cosh"
    go PrimOp typs a
Tanh      = \(abt vars a
e1 :* SArgs abt args
End)       -> [Char] -> Dis abt ()
forall a. [Char] -> a
error_TODO [Char]
"Tanh"
    go PrimOp typs a
Asinh     = \(abt vars a
e1 :* SArgs abt args
End)       -> [Char] -> Dis abt ()
forall a. [Char] -> a
error_TODO [Char]
"Asinh"
    go PrimOp typs a
Acosh     = \(abt vars a
e1 :* SArgs abt args
End)       -> [Char] -> Dis abt ()
forall a. [Char] -> a
error_TODO [Char]
"Acosh"
    go PrimOp typs a
Atanh     = \(abt vars a
e1 :* SArgs abt args
End)       -> [Char] -> Dis abt ()
forall a. [Char] -> a
error_TODO [Char]
"Atanh"
    go PrimOp typs a
Choose    = \(abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) -> [Char] -> Dis abt ()
forall a. [Char] -> a
error_TODO [Char]
"Choose"
    go PrimOp typs a
Floor     = \(abt vars a
e1 :* SArgs abt args
End)       -> [Char] -> Dis abt ()
forall a. [Char] -> a
error_TODO [Char]
"Floor"
    go PrimOp typs a
RealPow   = \(abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) ->
        -- TODO: There's a discrepancy between @(**)@ and @pow_@ in
        -- the old code...
        do
            -- TODO: if @v1@ is 0 or 1 then bot. Maybe the @log v1@ in
            -- @w@ takes care of the 0 case?
            abt '[] a
u <- abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' abt '[] a
v0
            -- either this from @(**)@:
            --   emitGuard  $ P.zero P.< u
            --   w <- atomize $ P.recip (P.abs (v0 P.* P.log v1))
            --   emitWeight $ P.unsafeProb (fromWhnf w)
            --   constrainValue (P.logBase v1 v0) e2
            -- or this from @pow_@:
            let w :: abt '[] a
w = abt '[] a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HFractional_ a) =>
abt '[] a -> abt '[] a
P.recip (abt '[] a
u abt '[] a -> abt '[] a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
P.* abt '[] 'HReal -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HProb
P.unsafeProb (abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HRing_ a) =>
abt '[] a -> abt '[] a
P.abs (abt '[] 'HProb -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> abt '[] 'HReal
P.log abt vars a
abt '[] 'HProb
e1)))
            abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
emitWeight abt '[] a
abt '[] 'HProb
w
            abt '[] 'HReal -> abt '[] 'HReal -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue (abt '[] 'HProb -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> abt '[] 'HReal
P.log abt '[] a
abt '[] 'HProb
u abt '[] 'HReal -> abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HFractional_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
P./ abt '[] 'HProb -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> abt '[] 'HReal
P.log abt vars a
abt '[] 'HProb
e1) abt vars a
abt '[] 'HReal
e2
            -- end.
        Dis abt () -> Dis abt () -> Dis abt ()
forall (f :: * -> *) a. Alternative f => f a -> f a -> f a
<|> do
            -- TODO: if @v2@ is 0 then bot. Maybe the weight @w@ takes
            -- care of this case?
            abt '[] a
u <- abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' abt '[] a
v0
            let ex :: abt '[] 'HProb
ex = abt '[] a
abt '[] 'HProb
v0 abt '[] 'HProb -> abt '[] a -> abt '[] 'HProb
forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *).
(RealProb a, ABT Term abt) =>
abt '[] 'HProb -> abt '[] a -> abt '[] 'HProb
P.** abt '[] a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HFractional_ a) =>
abt '[] a -> abt '[] a
P.recip abt vars a
abt '[] a
e2
            -- either this from @(**)@:
            --   emitGuard $ P.zero P.< u
            --   w <- atomize $ abs (ex / (v2 * v0))
            -- or this from @pow_@:
            let w :: abt '[] 'HReal
w = abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HRing_ a) =>
abt '[] a -> abt '[] a
P.abs (abt '[] 'HProb -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> abt '[] 'HReal
P.fromProb abt '[] 'HProb
ex abt '[] 'HReal -> abt '[] 'HReal -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HFractional_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
P./ (abt vars a
abt '[] a
e2 abt '[] a -> abt '[] a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
P.* abt '[] 'HProb -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> abt '[] 'HReal
P.fromProb abt '[] a
abt '[] 'HProb
u))
            -- end.
            abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
emitWeight (abt '[] 'HProb -> Dis abt ()) -> abt '[] 'HProb -> Dis abt ()
forall a b. (a -> b) -> a -> b
$ abt '[] 'HReal -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal -> abt '[] 'HProb
P.unsafeProb abt '[] 'HReal
w
            abt '[] 'HProb -> abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] 'HProb
ex abt vars a
abt '[] 'HProb
e1
    go PrimOp typs a
Exp = \(abt vars a
e1 :* SArgs abt args
End) -> do
        abt '[] a
x0 <- abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' abt '[] a
v0
        -- TODO: do we still want\/need the @emitGuard (0 < x0)@ which
        -- is now equivalent to @emitGuard (0 /= x0)@ thanks to the
        -- types?
        abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
emitWeight (abt '[] a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HFractional_ a) =>
abt '[] a -> abt '[] a
P.recip abt '[] a
x0)
        abt '[] 'HReal -> abt '[] 'HReal -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue (abt '[] 'HProb -> abt '[] 'HReal
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> abt '[] 'HReal
P.log abt '[] a
abt '[] 'HProb
x0) abt vars a
abt '[] 'HReal
e1

    go PrimOp typs a
Log = \(abt vars a
e1 :* SArgs abt args
End) -> do
        abt '[] 'HProb
exp_x0 <- abt '[] 'HProb -> Dis abt (abt '[] 'HProb)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' (abt '[] a -> abt '[] 'HProb
forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *).
(RealProb a, ABT Term abt) =>
abt '[] a -> abt '[] 'HProb
P.exp abt '[] a
v0)
        abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
emitWeight abt '[] 'HProb
exp_x0
        abt '[] 'HProb -> abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] 'HProb
exp_x0 abt vars a
abt '[] 'HProb
e1

    go (Infinity HIntegrable a
_)     = \SArgs abt args
End               -> [Char] -> Dis abt ()
forall a. [Char] -> a
error_TODO [Char]
"Infinity" -- scalar0
    go PrimOp typs a
GammaFunc        = \(abt vars a
e1 :* SArgs abt args
End)       -> [Char] -> Dis abt ()
forall a. [Char] -> a
error_TODO [Char]
"GammaFunc" -- scalar1
    go PrimOp typs a
BetaFunc         = \(abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) -> [Char] -> Dis abt ()
forall a. [Char] -> a
error_TODO [Char]
"BetaFunc" -- scalar2
    go (Equal  HEq a
theOrd)  = \(abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) -> [Char] -> Dis abt ()
forall a. [Char] -> a
error_TODO [Char]
"Equal"
    go (Less   HOrd a
theOrd)  = \(abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) -> [Char] -> Dis abt ()
forall a. [Char] -> a
error_TODO [Char]
"Less"
    go (NatPow HSemiring a
theSemi) = \(abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) -> [Char] -> Dis abt ()
forall a. [Char] -> a
error_TODO [Char]
"NatPow"
    go (Negate HRing a
theRing) = \(abt vars a
e1 :* SArgs abt args
End) ->
        -- TODO: figure out how to merge this implementation of @rr1
        -- negate@ with the one in 'evaluatePrimOp' to DRY
        -- TODO: just
        -- emitLet the @v0@ and pass the neutral term to the recursive
        -- call?
        let negate_v0 :: abt '[] a
negate_v0 = Term abt a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn (PrimOp '[a] a -> SCon '[ '( '[], a)] a
forall (typs :: [Hakaru]) (args :: [([Hakaru], Hakaru)])
       (a :: Hakaru).
(typs ~ UnLCs args, args ~ LCs typs) =>
PrimOp typs a -> SCon args a
PrimOp_ (HRing a -> PrimOp '[a] a
forall (a :: Hakaru). HRing a -> PrimOp '[a] a
Negate HRing a
theRing) SCon '[ '( '[], a)] a -> SArgs abt '[ '( '[], a)] -> Term abt a
forall (args :: [([Hakaru], Hakaru)]) (a :: Hakaru)
       (abt :: [Hakaru] -> Hakaru -> *).
SCon args a -> SArgs abt args -> Term abt a
:$ abt '[] a
v0 abt '[] a -> SArgs abt '[] -> SArgs abt '[ '( '[], a)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* SArgs abt '[]
forall (abt :: [Hakaru] -> Hakaru -> *). SArgs abt '[]
End)
                -- case v0 of
                -- Neutral e ->
                --     Neutral $ syn (PrimOp_ (Negate theRing) :$ e :* End)
                -- Head_ v ->
                --     case theRing of
                --     HRing_Int  -> Head_ . reflect . negate $ reify v
                --     HRing_Real -> Head_ . reflect . negate $ reify v
        in abt '[] a -> abt '[] a -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] a
negate_v0 abt vars a
abt '[] a
e1

    go (Abs HRing a
theRing) = \(abt vars a
e1 :* SArgs abt args
End) -> do
        let theSemi :: HSemiring a
theSemi = HRing a -> HSemiring a
forall (a :: Hakaru). HRing a -> HSemiring a
hSemiring_HRing HRing a
theRing
            theOrd :: HOrd a
theOrd  =
                case HRing a
theRing of
                HRing a
HRing_Int  -> HOrd a
HOrd 'HInt
HOrd_Int
                HRing a
HRing_Real -> HOrd a
HOrd 'HReal
HOrd_Real
            theEq :: HEq a
theEq   = HOrd a -> HEq a
forall (a :: Hakaru). HOrd a -> HEq a
hEq_HOrd HOrd a
theOrd
            signed :: Coercion a a
signed  = PrimCoercion a a -> Coercion a a
forall (a :: Hakaru) (b :: Hakaru).
PrimCoercion a b -> Coercion a b
C.singletonCoercion (HRing a -> PrimCoercion (NonNegative a) a
forall (a :: Hakaru). HRing a -> PrimCoercion (NonNegative a) a
C.Signed HRing a
theRing)
            zero :: abt '[] a
zero    = HSemiring a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
HSemiring a -> abt '[] a
P.zero_ HSemiring a
theSemi
            lt :: abt '[] a -> abt '[] a -> abt '[] HBool
lt      = PrimOp '[a, a] HBool -> abt '[] a -> abt '[] a -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru) (b :: Hakaru)
       (c :: Hakaru).
ABT Term abt =>
PrimOp '[a, b] c -> abt '[] a -> abt '[] b -> abt '[] c
P.primOp2_ (PrimOp '[a, a] HBool -> abt '[] a -> abt '[] a -> abt '[] HBool)
-> PrimOp '[a, a] HBool -> abt '[] a -> abt '[] a -> abt '[] HBool
forall a b. (a -> b) -> a -> b
$ HOrd a -> PrimOp '[a, a] HBool
forall (a :: Hakaru). HOrd a -> PrimOp '[a, a] HBool
Less   HOrd a
theOrd
            eq :: abt '[] a -> abt '[] a -> abt '[] HBool
eq      = PrimOp '[a, a] HBool -> abt '[] a -> abt '[] a -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru) (b :: Hakaru)
       (c :: Hakaru).
ABT Term abt =>
PrimOp '[a, b] c -> abt '[] a -> abt '[] b -> abt '[] c
P.primOp2_ (PrimOp '[a, a] HBool -> abt '[] a -> abt '[] a -> abt '[] HBool)
-> PrimOp '[a, a] HBool -> abt '[] a -> abt '[] a -> abt '[] HBool
forall a b. (a -> b) -> a -> b
$ HEq a -> PrimOp '[a, a] HBool
forall (a :: Hakaru). HEq a -> PrimOp '[a, a] HBool
Equal  HEq a
theEq
            neg :: abt '[] a -> abt '[] a
neg     = PrimOp '[a] a -> abt '[] a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
ABT Term abt =>
PrimOp '[a] b -> abt '[] a -> abt '[] b
P.primOp1_ (PrimOp '[a] a -> abt '[] a -> abt '[] a)
-> PrimOp '[a] a -> abt '[] a -> abt '[] a
forall a b. (a -> b) -> a -> b
$ HRing a -> PrimOp '[a] a
forall (a :: Hakaru). HRing a -> PrimOp '[a] a
Negate HRing a
theRing

        abt '[] a
x0 <- abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' (Coercion a a -> abt '[] a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
ABT Term abt =>
Coercion a b -> abt '[] a -> abt '[] b
P.coerceTo_ Coercion a a
signed abt '[] a
v0)
        abt '[] a
v  <- Variable a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var (Variable a -> abt '[] a)
-> Dis abt (Variable a) -> Dis abt (abt '[] a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> abt '[] ('HMeasure a) -> Dis abt (Variable a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] ('HMeasure a) -> Dis abt (Variable a)
emitMBind
            (abt '[] HBool
-> abt '[] ('HMeasure a)
-> abt '[] ('HMeasure a)
-> abt '[] ('HMeasure a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] HBool -> abt '[] a -> abt '[] a -> abt '[] a
P.if_ (abt '[] a -> abt '[] a -> abt '[] HBool
lt abt '[] a
zero abt '[] a
x0)
                (abt '[] a -> abt '[] ('HMeasure a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] ('HMeasure a)
P.dirac abt '[] a
x0 abt '[] ('HMeasure a)
-> abt '[] ('HMeasure a) -> abt '[] ('HMeasure a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] ('HMeasure a)
-> abt '[] ('HMeasure a) -> abt '[] ('HMeasure a)
P.<|> abt '[] a -> abt '[] ('HMeasure a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] ('HMeasure a)
P.dirac (abt '[] a -> abt '[] a
neg abt '[] a
x0))
                (abt '[] HBool
-> abt '[] ('HMeasure a)
-> abt '[] ('HMeasure a)
-> abt '[] ('HMeasure a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] HBool -> abt '[] a -> abt '[] a -> abt '[] a
P.if_ (abt '[] a -> abt '[] a -> abt '[] HBool
eq abt '[] a
zero abt '[] a
x0)
                    (abt '[] a -> abt '[] ('HMeasure a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] ('HMeasure a)
P.dirac abt '[] a
zero)
                    (Sing ('HMeasure a) -> abt '[] ('HMeasure a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Sing ('HMeasure a) -> abt '[] ('HMeasure a)
P.reject (Sing ('HMeasure a) -> abt '[] ('HMeasure a))
-> (Sing a -> Sing ('HMeasure a))
-> Sing a
-> abt '[] ('HMeasure a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Sing a -> Sing ('HMeasure a)
forall (a :: Hakaru). Sing a -> Sing ('HMeasure a)
SMeasure (Sing a -> abt '[] ('HMeasure a))
-> Sing a -> abt '[] ('HMeasure a)
forall a b. (a -> b) -> a -> b
$ abt '[] a -> Sing a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Sing a
typeOf abt '[] a
zero)))
        abt '[] a -> abt '[] a -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] a
v abt vars a
abt '[] a
e1

    go (Signum HRing a
theRing) = \(abt vars a
e1 :* SArgs abt args
End) ->
        case HRing a
theRing of
        HRing a
HRing_Real -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
Dis abt a
bot
        HRing a
HRing_Int  -> do
            abt '[] 'HInt
x <- Variable 'HInt -> abt '[] 'HInt
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var (Variable 'HInt -> abt '[] 'HInt)
-> Dis abt (Variable 'HInt) -> Dis abt (abt '[] 'HInt)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> abt '[] ('HMeasure 'HInt) -> Dis abt (Variable 'HInt)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] ('HMeasure a) -> Dis abt (Variable a)
emitMBind abt '[] ('HMeasure 'HInt)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] ('HMeasure 'HInt)
P.counting
            abt '[] HBool -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> Dis abt ()
emitGuard (abt '[] HBool -> Dis abt ()) -> abt '[] HBool -> Dis abt ()
forall a b. (a -> b) -> a -> b
$ abt '[] 'HInt -> abt '[] 'HInt
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HRing_ a) =>
abt '[] a -> abt '[] a
P.signum abt '[] 'HInt
x abt '[] 'HInt -> abt '[] 'HInt -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HEq_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.== abt '[] a
abt '[] 'HInt
v0
            abt '[] 'HInt -> abt '[] 'HInt -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] 'HInt
x abt vars a
abt '[] 'HInt
e1

    go (Recip HFractional a
theFractional) = \(abt vars a
e1 :* SArgs abt args
End) -> do
        abt '[] a
x0 <- abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' abt '[] a
v0
        abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
emitWeight
            (abt '[] 'HProb -> Dis abt ())
-> (abt '[] a -> abt '[] 'HProb) -> abt '[] a -> Dis abt ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. abt '[] 'HProb -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HFractional_ a) =>
abt '[] a -> abt '[] a
P.recip
            (abt '[] 'HProb -> abt '[] 'HProb)
-> (abt '[] a -> abt '[] 'HProb) -> abt '[] a -> abt '[] 'HProb
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HFractional a -> abt '[] a -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
HFractional a -> abt '[] a -> abt '[] 'HProb
P.unsafeProbFraction_ HFractional a
theFractional
            -- TODO: define a dictionary-passing variant of 'P.square'
            -- instead, to include the coercion in there explicitly...
            (abt '[] a -> Dis abt ()) -> abt '[] a -> Dis abt ()
forall a b. (a -> b) -> a -> b
$ HSemiring a -> abt '[] a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
HSemiring a -> abt '[] a -> abt '[] a
square (HFractional a -> HSemiring a
forall (a :: Hakaru). HFractional a -> HSemiring a
hSemiring_HFractional HFractional a
theFractional) abt '[] a
x0
        abt '[] a -> abt '[] a -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue (PrimOp '[a] a -> abt '[] a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru)
       (b :: Hakaru).
ABT Term abt =>
PrimOp '[a] b -> abt '[] a -> abt '[] b
P.primOp1_ (HFractional a -> PrimOp '[a] a
forall (a :: Hakaru). HFractional a -> PrimOp '[a] a
Recip HFractional a
theFractional) abt '[] a
x0) abt vars a
abt '[] a
e1

    go (NatRoot HRadical a
theRadical) = \(abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) ->
        case HRadical a
theRadical of
        HRadical a
HRadical_Prob -> do
            abt '[] a
x0 <- abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' abt '[] a
v0
            abt '[] a
u2 <- Whnf abt a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Whnf abt a -> abt '[] a
fromWhnf (Whnf abt a -> abt '[] a)
-> Dis abt (Whnf abt a) -> Dis abt (abt '[] a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> abt '[] a -> Dis abt (Whnf abt a)
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
TermEvaluator abt (Dis abt)
atomize abt vars a
abt '[] a
e2
            abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
emitWeight (abt '[] 'HNat -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HNat -> abt '[] 'HProb
P.nat2prob abt '[] a
abt '[] 'HNat
u2 abt '[] 'HProb -> abt '[] 'HProb -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] a -> abt '[] a
P.* abt '[] a
abt '[] 'HProb
x0)
            abt '[] a -> abt '[] a -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue (abt '[] a
x0 abt '[] a -> abt '[] 'HNat -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a -> abt '[] 'HNat -> abt '[] a
P.^ abt '[] a
abt '[] 'HNat
u2) abt vars a
abt '[] a
e1

    go (Erf HContinuous a
theContinuous) = \(abt vars a
e1 :* SArgs abt args
End) ->
        [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error [Char]
"TODO: constrainPrimOp: need InvErf to disintegrate Erf"


-- HACK: can't use @(P.^)@ because Haskell can't figure out our polymorphism
square :: (ABT Term abt) => HSemiring a -> abt '[] a -> abt '[] a
square :: HSemiring a -> abt '[] a -> abt '[] a
square HSemiring a
theSemiring abt '[] a
e =
    Term abt a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn (PrimOp '[a, 'HNat] a -> SCon '[ '( '[], a), '( '[], 'HNat)] a
forall (typs :: [Hakaru]) (args :: [([Hakaru], Hakaru)])
       (a :: Hakaru).
(typs ~ UnLCs args, args ~ LCs typs) =>
PrimOp typs a -> SCon args a
PrimOp_ (HSemiring a -> PrimOp '[a, 'HNat] a
forall (a :: Hakaru). HSemiring a -> PrimOp '[a, 'HNat] a
NatPow HSemiring a
theSemiring) SCon '[ '( '[], a), '( '[], 'HNat)] a
-> SArgs abt '[ '( '[], a), '( '[], 'HNat)] -> Term abt a
forall (args :: [([Hakaru], Hakaru)]) (a :: Hakaru)
       (abt :: [Hakaru] -> Hakaru -> *).
SCon args a -> SArgs abt args -> Term abt a
:$ abt '[] a
e abt '[] a
-> SArgs abt '[ '( '[], 'HNat)]
-> SArgs abt '[ '( '[], a), '( '[], 'HNat)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* Natural -> abt '[] 'HNat
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Natural -> abt '[] 'HNat
P.nat_ Natural
2 abt '[] 'HNat -> SArgs abt '[] -> SArgs abt '[ '( '[], 'HNat)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* SArgs abt '[]
forall (abt :: [Hakaru] -> Hakaru -> *). SArgs abt '[]
End)


----------------------------------------------------------------
----------------------------------------------------------------
-- TODO: do we really want the first argument to be a term at all,
-- or do we want something more general like patters for capturing
-- measurable events?
--
-- | This is a helper function for 'constrainValue' to handle 'SBind'
-- statements (just as the 'perform' argument to 'evaluate' is a
-- helper for handling 'SBind' statements).
--
-- N.B., We assume that the first argument, @v0@, is already
-- atomized. So, this must be ensured before recursing, but we can
-- assume it's already been done by the IH. Technically, we con't
-- care whether the first argument is in normal form or not, just
-- so long as it doesn't contain any heap-bound variables.
--
-- This is the function called @(<<|)@ in the paper, though notably
-- we swap the argument order.
--
-- TODO: find some way to capture in the type that the first argument
-- must be emissible, to help avoid accidentally passing the arguments
-- in the wrong order!
--
-- TODO: under what circumstances is @constrainOutcome x m@ different
-- from @constrainValue x =<< perform m@? If they're always the
-- same, then we should just use that as the definition in order
-- to avoid repeating ourselves
constrainOutcome
    :: forall abt a
    .  (ABT Term abt)
    => abt '[] a
    -> abt '[] ('HMeasure a)
    -> Dis abt ()
constrainOutcome :: abt '[] a -> abt '[] ('HMeasure a) -> Dis abt ()
constrainOutcome abt '[] a
v0 abt '[] ('HMeasure a)
e0 =
#ifdef __TRACE_DISINTEGRATE__
    getExtras >>= \extras ->
    getIndices >>= \inds ->
    trace (
        let s = "-- constrainOutcome"
        in "\n" ++ s ++ ": "
            ++ show (pretty v0)
            ++ "\n" ++ replicate (length s) ' ' ++ ": "
            ++ show (pretty e0) ++ "\n"
            ++ "at " ++  show (ppInds inds) ++ "\n"
            ++ show (prettyExtras extras)
          ) $
#endif
    do  Whnf abt ('HMeasure a)
w0 <- abt '[] ('HMeasure a) -> Dis abt (Whnf abt ('HMeasure a))
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
TermEvaluator abt (Dis abt)
evaluate_ abt '[] ('HMeasure a)
e0
        case Whnf abt ('HMeasure a)
w0 of
            Neutral abt '[] ('HMeasure a)
_ -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
Dis abt a
bot
            Head_   Head abt ('HMeasure a)
v -> Head abt ('HMeasure a) -> Dis abt ()
go Head abt ('HMeasure a)
v
    where
    impossible :: a
impossible = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"constrainOutcome: the impossible happened"

    go :: Head abt ('HMeasure a) -> Dis abt ()
    go :: Head abt ('HMeasure a) -> Dis abt ()
go (WLiteral Literal ('HMeasure a)
_)          = Dis abt ()
forall a. a
impossible
    -- go (WDatum _)         = impossible
    -- go (WEmpty _)         = impossible
    -- go (WArray _ _)       = impossible
    -- go (WLam _)           = impossible
    -- go (WIntegrate _ _ _) = impossible
    -- go (WSummate   _ _ _) = impossible
    go (WCoerceTo   Coercion a ('HMeasure a)
_ Head abt a
_)     = Dis abt ()
forall a. a
impossible
    go (WUnsafeFrom Coercion ('HMeasure a) b
_ Head abt b
_)     = Dis abt ()
forall a. a
impossible
    go (WMeasureOp MeasureOp typs a
o SArgs abt args
es)     = abt '[] a -> MeasureOp typs a -> SArgs abt args -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (typs :: [Hakaru])
       (args :: [([Hakaru], Hakaru)]) (a :: Hakaru).
(ABT Term abt, typs ~ UnLCs args, args ~ LCs typs) =>
abt '[] a -> MeasureOp typs a -> SArgs abt args -> Dis abt ()
constrainOutcomeMeasureOp abt '[] a
v0 MeasureOp typs a
MeasureOp typs a
o SArgs abt args
es
    go (WDirac abt '[] a
e1)           = abt '[] a -> abt '[] a -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] a
v0 abt '[] a
abt '[] a
e1
    go (WMBind abt '[] ('HMeasure a)
e1 abt '[a] ('HMeasure b)
e2)        =
        abt '[a] ('HMeasure b)
-> (Variable a -> abt '[] ('HMeasure b) -> Dis abt ())
-> Dis abt ()
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (x :: k) (xs :: [k]) (a :: k) r.
ABT syn abt =>
abt (x : xs) a -> (Variable x -> abt xs a -> r) -> r
caseBind abt '[a] ('HMeasure b)
e2 ((Variable a -> abt '[] ('HMeasure b) -> Dis abt ()) -> Dis abt ())
-> (Variable a -> abt '[] ('HMeasure b) -> Dis abt ())
-> Dis abt ()
forall a b. (a -> b) -> a -> b
$ \Variable a
x abt '[] ('HMeasure b)
e2' -> do
            [Index (abt '[])]
i <- Dis abt [Index (abt '[])]
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *)
       (p :: Purity).
EvaluationMonad abt m p =>
m [Index (abt '[])]
getIndices
            Statement abt Variable 'Impure
-> abt '[] ('HMeasure b) -> Dis abt (abt '[] ('HMeasure b))
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *) (p :: Purity)
       (xs :: [Hakaru]) (a :: Hakaru).
(ABT Term abt, EvaluationMonad abt m p) =>
Statement abt Variable p -> abt xs a -> m (abt xs a)
push (Variable a
-> Lazy abt ('HMeasure a)
-> [Index (abt '[])]
-> Statement abt Variable 'Impure
forall (abt :: [Hakaru] -> Hakaru -> *) (v :: Hakaru -> *)
       (a :: Hakaru).
v a
-> Lazy abt ('HMeasure a)
-> [Index (abt '[])]
-> Statement abt v 'Impure
SBind Variable a
x (abt '[] ('HMeasure a) -> Lazy abt ('HMeasure a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
abt '[] a -> Lazy abt a
Thunk abt '[] ('HMeasure a)
e1) [Index (abt '[])]
i) abt '[] ('HMeasure b)
e2' Dis abt (abt '[] ('HMeasure b))
-> (abt '[] ('HMeasure b) -> Dis abt ()) -> Dis abt ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= abt '[] a -> abt '[] ('HMeasure a) -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] ('HMeasure a) -> Dis abt ()
constrainOutcome abt '[] a
v0
    go (WPlate abt '[] 'HNat
e1 abt '[ 'HNat] ('HMeasure a)
e2)        = do
        abt '[] ('HArray a)
x' <- abt '[] 'HNat
-> abt '[ 'HNat] ('HMeasure a) -> Dis abt (abt '[] ('HArray a))
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] 'HNat
-> abt '[ 'HNat] ('HMeasure a) -> Dis abt (abt '[] ('HArray a))
pushPlate abt '[] 'HNat
e1 abt '[ 'HNat] ('HMeasure a)
e2
        abt '[] a -> abt '[] a -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] a -> Dis abt ()
constrainValue abt '[] a
v0 abt '[] a
abt '[] ('HArray a)
x'

    go (WChain abt '[] 'HNat
e1 abt '[] s
e2 abt '[s] ('HMeasure (HPair a s))
e3)     = [Char] -> Dis abt ()
forall a. HasCallStack => [Char] -> a
error [Char]
"TODO: constrainOutcome{Chain}"
    go (WReject Sing ('HMeasure a)
typ)         = (forall (r :: Hakaru).
 abt '[] ('HMeasure r) -> abt '[] ('HMeasure r))
-> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
(forall (r :: Hakaru).
 abt '[] ('HMeasure r) -> abt '[] ('HMeasure r))
-> Dis abt ()
emit_ ((forall (r :: Hakaru).
  abt '[] ('HMeasure r) -> abt '[] ('HMeasure r))
 -> Dis abt ())
-> (forall (r :: Hakaru).
    abt '[] ('HMeasure r) -> abt '[] ('HMeasure r))
-> Dis abt ()
forall a b. (a -> b) -> a -> b
$ \abt '[] ('HMeasure r)
m -> Sing ('HMeasure r) -> abt '[] ('HMeasure r)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Sing ('HMeasure a) -> abt '[] ('HMeasure a)
P.reject (abt '[] ('HMeasure r) -> Sing ('HMeasure r)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Sing a
typeOf abt '[] ('HMeasure r)
m)
    go (WSuperpose NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a))
pes) = do
        [Index (abt '[])]
i <- Dis abt [Index (abt '[])]
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *)
       (p :: Purity).
EvaluationMonad abt m p =>
m [Index (abt '[])]
getIndices
        if Bool -> Bool
not ([Index (abt '[])] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Index (abt '[])]
i) Bool -> Bool -> Bool
&& NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a)) -> Int
forall a. NonEmpty a -> Int
L.length NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a))
pes Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1 then Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
Dis abt a
bot else
          (forall (r :: Hakaru).
 NonEmpty (abt '[] ('HMeasure r)) -> abt '[] ('HMeasure r))
-> NonEmpty (Dis abt ()) -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (t :: * -> *) a.
(ABT Term abt, Traversable t) =>
(forall (r :: Hakaru).
 t (abt '[] ('HMeasure r)) -> abt '[] ('HMeasure r))
-> t (Dis abt a) -> Dis abt a
emitFork_ (NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure r))
-> abt '[] ('HMeasure r)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a))
-> abt '[] ('HMeasure a)
P.superpose (NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure r))
 -> abt '[] ('HMeasure r))
-> (NonEmpty (abt '[] ('HMeasure r))
    -> NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure r)))
-> NonEmpty (abt '[] ('HMeasure r))
-> abt '[] ('HMeasure r)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (abt '[] ('HMeasure r) -> (abt '[] 'HProb, abt '[] ('HMeasure r)))
-> NonEmpty (abt '[] ('HMeasure r))
-> NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure r))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((,) abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HSemiring_ a) =>
abt '[] a
P.one))
                    (((abt '[] 'HProb, abt '[] ('HMeasure a)) -> Dis abt ())
-> NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a))
-> NonEmpty (Dis abt ())
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(abt '[] 'HProb
p,abt '[] ('HMeasure a)
e) -> Statement abt Variable 'Impure
-> abt '[] ('HMeasure a) -> Dis abt (abt '[] ('HMeasure a))
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *) (p :: Purity)
       (xs :: [Hakaru]) (a :: Hakaru).
(ABT Term abt, EvaluationMonad abt m p) =>
Statement abt Variable p -> abt xs a -> m (abt xs a)
push (Lazy abt 'HProb
-> [Index (abt '[])] -> Statement abt Variable 'Impure
forall (abt :: [Hakaru] -> Hakaru -> *) (v :: Hakaru -> *).
Lazy abt 'HProb -> [Index (abt '[])] -> Statement abt v 'Impure
SWeight (abt '[] 'HProb -> Lazy abt 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
abt '[] a -> Lazy abt a
Thunk abt '[] 'HProb
p) [Index (abt '[])]
i) abt '[] ('HMeasure a)
e Dis abt (abt '[] ('HMeasure a))
-> (abt '[] ('HMeasure a) -> Dis abt ()) -> Dis abt ()
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= abt '[] a -> abt '[] ('HMeasure a) -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> abt '[] ('HMeasure a) -> Dis abt ()
constrainOutcome abt '[] a
v0)
                          NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a))
NonEmpty (abt '[] 'HProb, abt '[] ('HMeasure a))
pes)

-- TODO: should this really be different from 'constrainValueMeasureOp'?
--
-- TODO: find some way to capture in the type that the first argument
-- must be emissible.
constrainOutcomeMeasureOp
    :: (ABT Term abt, typs ~ UnLCs args, args ~ LCs typs)
    => abt '[] a
    -> MeasureOp typs a
    -> SArgs abt args
    -> Dis abt ()
constrainOutcomeMeasureOp :: abt '[] a -> MeasureOp typs a -> SArgs abt args -> Dis abt ()
constrainOutcomeMeasureOp abt '[] a
v0 = MeasureOp typs a -> SArgs abt args -> Dis abt ()
go
    where
    go :: MeasureOp typs a -> SArgs abt args -> Dis abt ()
go MeasureOp typs a
Lebesgue = \(abt vars a
lo :* abt vars a
hi :* SArgs abt args
End) -> do
        -- TODO: optimize the cases where lo is -∞ or hi is ∞
        abt '[] a
v0' <- abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' abt '[] a
v0
        abt '[] HBool -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> Dis abt ()
pushGuard (abt vars a
abt '[] a
lo abt '[] a -> abt '[] a -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HOrd_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.<= abt '[] a
abt '[] a
v0' abt '[] HBool -> abt '[] HBool -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> abt '[] HBool -> abt '[] HBool
P.&& abt '[] a
v0' abt '[] a -> abt '[] a -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HOrd_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.<= abt vars a
abt '[] a
hi)

    -- TODO: I think, based on Hakaru v0.2.0
    go MeasureOp typs a
Counting = \SArgs abt '[]
End -> () -> Dis abt ()
forall (m :: * -> *) a. Monad m => a -> m a
return ()

    go MeasureOp typs a
Categorical = \(abt vars a
e1 :* SArgs abt args
End) -> do
        -- TODO: check that v0' is < then length of e1
        abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
pushWeight (abt '[] ('HArray 'HProb) -> abt '[] 'HNat -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] ('HArray 'HProb) -> abt '[] 'HNat -> abt '[] 'HProb
P.densityCategorical abt vars a
abt '[] ('HArray 'HProb)
e1 abt '[] a
abt '[] 'HNat
v0)

    -- Per the paper
    go MeasureOp typs a
Uniform = \(abt vars a
lo :* abt vars a
hi :* SArgs abt args
End) -> do
        abt '[] a
v0' <- abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' abt '[] a
v0
        abt '[] HBool -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> Dis abt ()
pushGuard (abt vars a
abt '[] a
lo abt '[] a -> abt '[] a -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HOrd_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.<= abt '[] a
abt '[] a
v0' abt '[] HBool -> abt '[] HBool -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> abt '[] HBool -> abt '[] HBool
P.&& abt '[] a
v0' abt '[] a -> abt '[] a -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HOrd_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.<= abt vars a
abt '[] a
hi)
        abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
pushWeight (abt '[] 'HReal
-> abt '[] 'HReal -> abt '[] 'HReal -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal
-> abt '[] 'HReal -> abt '[] 'HReal -> abt '[] 'HProb
P.densityUniform abt vars a
abt '[] 'HReal
lo abt vars a
abt '[] 'HReal
hi abt '[] a
abt '[] 'HReal
v0')

    -- TODO: Add fallback handling of Normal that does not atomize mu,sd.
    -- This fallback is as if Normal were defined in terms of Lebesgue
    -- and a density Weight.  This fallback is present in Hakaru v0.2.0
    -- in order to disintegrate a program such as
    --  x <~ normal(0,1)
    --  y <~ normal(x,1)
    --  return ((x+(y+y),x)::pair(real,real))
    go MeasureOp typs a
Normal = \(abt vars a
mu :* abt vars a
sd :* SArgs abt args
End) -> do
        -- N.B., if\/when extending this to higher dimensions, the
        -- real equation is
        -- @recip (sqrt (2*pi*sd^2) ^ n) *
        --  exp (negate (norm_n (v0 - mu) ^ 2) /
        --  (2*sd^2))@
        -- for @Real^n@.
        abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
pushWeight (abt '[] 'HReal
-> abt '[] 'HProb -> abt '[] 'HReal -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HReal
-> abt '[] 'HProb -> abt '[] 'HReal -> abt '[] 'HProb
P.densityNormal abt vars a
abt '[] 'HReal
mu abt vars a
abt '[] 'HProb
sd abt '[] a
abt '[] 'HReal
v0)

    go MeasureOp typs a
Poisson = \(abt vars a
e1 :* SArgs abt args
End) -> do
        abt '[] a
v0' <- abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' abt '[] a
v0
        abt '[] HBool -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> Dis abt ()
pushGuard (Natural -> abt '[] 'HNat
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Natural -> abt '[] 'HNat
P.nat_ Natural
0 abt '[] 'HNat -> abt '[] 'HNat -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HOrd_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.<= abt '[] a
abt '[] 'HNat
v0' abt '[] HBool -> abt '[] HBool -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> abt '[] HBool -> abt '[] HBool
P.&& NonNegativeRational -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
NonNegativeRational -> abt '[] 'HProb
P.prob_ NonNegativeRational
0 abt '[] 'HProb -> abt '[] 'HProb -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HOrd_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.< abt vars a
abt '[] 'HProb
e1)
        abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
pushWeight (abt '[] 'HProb -> abt '[] 'HNat -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> abt '[] 'HNat -> abt '[] 'HProb
P.densityPoisson abt vars a
abt '[] 'HProb
e1 abt '[] a
abt '[] 'HNat
v0')

    go MeasureOp typs a
Gamma = \(abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) -> do
        abt '[] a
v0' <- abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' abt '[] a
v0
        abt '[] HBool -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> Dis abt ()
pushGuard (NonNegativeRational -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
NonNegativeRational -> abt '[] 'HProb
P.prob_ NonNegativeRational
0 abt '[] 'HProb -> abt '[] 'HProb -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HOrd_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.< abt '[] a
abt '[] 'HProb
v0' abt '[] HBool -> abt '[] HBool -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> abt '[] HBool -> abt '[] HBool
P.&&
                   NonNegativeRational -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
NonNegativeRational -> abt '[] 'HProb
P.prob_ NonNegativeRational
0 abt '[] 'HProb -> abt '[] 'HProb -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HOrd_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.< abt vars a
abt '[] 'HProb
e1  abt '[] HBool -> abt '[] HBool -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> abt '[] HBool -> abt '[] HBool
P.&&
                   NonNegativeRational -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
NonNegativeRational -> abt '[] 'HProb
P.prob_ NonNegativeRational
0 abt '[] 'HProb -> abt '[] 'HProb -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HOrd_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.< abt vars a
abt '[] 'HProb
e2)
        abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
pushWeight (abt '[] 'HProb
-> abt '[] 'HProb -> abt '[] 'HProb -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb
-> abt '[] 'HProb -> abt '[] 'HProb -> abt '[] 'HProb
P.densityGamma abt vars a
abt '[] 'HProb
e1 abt vars a
abt '[] 'HProb
e2 abt '[] a
abt '[] 'HProb
v0')
    go MeasureOp typs a
Beta = \(abt vars a
e1 :* abt vars a
e2 :* SArgs abt args
End) -> do
        abt '[] a
v0' <- abt '[] a -> Dis abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
abt '[] a -> Dis abt (abt '[] a)
emitLet' abt '[] a
v0
        abt '[] HBool -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> Dis abt ()
pushGuard (NonNegativeRational -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
NonNegativeRational -> abt '[] 'HProb
P.prob_ NonNegativeRational
0 abt '[] 'HProb -> abt '[] 'HProb -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HOrd_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.<= abt '[] a
abt '[] 'HProb
v0' abt '[] HBool -> abt '[] HBool -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> abt '[] HBool -> abt '[] HBool
P.&&
                   NonNegativeRational -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
NonNegativeRational -> abt '[] 'HProb
P.prob_ NonNegativeRational
1 abt '[] 'HProb -> abt '[] 'HProb -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HOrd_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.>= abt '[] a
abt '[] 'HProb
v0' abt '[] HBool -> abt '[] HBool -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> abt '[] HBool -> abt '[] HBool
P.&&
                   NonNegativeRational -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
NonNegativeRational -> abt '[] 'HProb
P.prob_ NonNegativeRational
0 abt '[] 'HProb -> abt '[] 'HProb -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HOrd_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.< abt vars a
abt '[] 'HProb
e1   abt '[] HBool -> abt '[] HBool -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] HBool -> abt '[] HBool -> abt '[] HBool
P.&&
                   NonNegativeRational -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
NonNegativeRational -> abt '[] 'HProb
P.prob_ NonNegativeRational
0 abt '[] 'HProb -> abt '[] 'HProb -> abt '[] HBool
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
(ABT Term abt, HOrd_ a) =>
abt '[] a -> abt '[] a -> abt '[] HBool
P.< abt vars a
abt '[] 'HProb
e2)
        abt '[] 'HProb -> Dis abt ()
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> Dis abt ()
pushWeight (abt '[] 'HProb
-> abt '[] 'HProb -> abt '[] 'HProb -> abt '[] 'HProb
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb
-> abt '[] 'HProb -> abt '[] 'HProb -> abt '[] 'HProb
P.densityBeta abt vars a
abt '[] 'HProb
e1 abt vars a
abt '[] 'HProb
e2 abt '[] a
abt '[] 'HProb
v0')

----------------------------------------------------------------
----------------------------------------------------------- fin.