{-# LANGUAGE CPP
           , GADTs
           , KindSignatures
           , DataKinds
           , Rank2Types
           , ScopedTypeVariables
           , MultiParamTypeClasses
           , FlexibleContexts
           , FlexibleInstances
           #-}

{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
----------------------------------------------------------------
--                                                    2016.04.24
-- |
-- Module      :  Language.Hakaru.Evaluation.ExpectMonad
-- Copyright   :  Copyright (c) 2016 the Hakaru team
-- License     :  BSD3
-- Maintainer  :  wren@community.haskell.org
-- Stability   :  experimental
-- Portability :  GHC-only
--
--
----------------------------------------------------------------
module Language.Hakaru.Evaluation.ExpectMonad
    ( pureEvaluate
    
    -- * The expectation-evaluation monad
    -- ** List-based version
    , ListContext(..), ExpectAns, Expect(..), runExpect
    , residualizeExpectListContext
    -- ** TODO: IntMap-based version
    
    -- * ...
    , emit
    , emit_
    ) where

import           Prelude              hiding (id, (.))
import           Control.Category     (Category(..))
#if __GLASGOW_HASKELL__ < 710
import           Data.Functor         ((<$>))
import           Control.Applicative  (Applicative(..))
#endif
import qualified Data.Foldable        as F

import Language.Hakaru.Syntax.IClasses (Some2(..))
import Language.Hakaru.Syntax.ABT      (ABT(..), caseVarSyn, subst, maxNextFreeOrBind)
import Language.Hakaru.Syntax.Variable (memberVarSet)
import Language.Hakaru.Syntax.AST      hiding (Expect)
import Language.Hakaru.Syntax.Transform (TransformCtx(..))
import Language.Hakaru.Evaluation.Types
import Language.Hakaru.Evaluation.Lazy (evaluate)
import Language.Hakaru.Evaluation.PEvalMonad (ListContext(..))


-- The rest of these are just for the emit code, which isn't currently exported.
import Data.Text                       (Text)
import Language.Hakaru.Syntax.Variable (Variable())
import Language.Hakaru.Types.DataKind
import Language.Hakaru.Types.Sing      (Sing)
#ifdef __TRACE_DISINTEGRATE__
import Debug.Trace                     (trace)
#endif

----------------------------------------------------------------
----------------------------------------------------------------
type ExpectAns abt = ListContext abt 'ExpectP -> abt '[] 'HProb

newtype Expect abt x =
    Expect { Expect abt x -> (x -> ExpectAns abt) -> ExpectAns abt
unExpect :: (x -> ExpectAns abt) -> ExpectAns abt }

residualizeExpectListContext
    :: forall abt
    .  (ABT Term abt)
    => abt '[] 'HProb
    -> ListContext abt 'ExpectP
    -> abt '[] 'HProb
residualizeExpectListContext :: abt '[] 'HProb -> ListContext abt 'ExpectP -> abt '[] 'HProb
residualizeExpectListContext abt '[] 'HProb
e0 =
    (abt '[] 'HProb
 -> Statement abt Location 'ExpectP -> abt '[] 'HProb)
-> abt '[] 'HProb
-> [Statement abt Location 'ExpectP]
-> abt '[] 'HProb
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl abt '[] 'HProb -> Statement abt Location 'ExpectP -> abt '[] 'HProb
step abt '[] 'HProb
e0 ([Statement abt Location 'ExpectP] -> abt '[] 'HProb)
-> (ListContext abt 'ExpectP -> [Statement abt Location 'ExpectP])
-> ListContext abt 'ExpectP
-> abt '[] 'HProb
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. ListContext abt 'ExpectP -> [Statement abt Location 'ExpectP]
forall (abt :: [Hakaru] -> Hakaru -> *) (p :: Purity).
ListContext abt p -> [Statement abt Location p]
statements
    where
    -- TODO: make paremetric in the purity, so we can combine 'residualizeListContext' with this function.
    step :: abt '[] 'HProb -> Statement abt Location 'ExpectP -> abt '[] 'HProb
    step :: abt '[] 'HProb -> Statement abt Location 'ExpectP -> abt '[] 'HProb
step abt '[] 'HProb
e Statement abt Location 'ExpectP
s =
        case Statement abt Location 'ExpectP
s of
        SLet (Location Variable a
x) Lazy abt a
body [Index (abt '[])]
_
            -- BUG: this trick for dropping unused let-bindings doesn't seem to work anymore... (cf., 'Tests.Expect.test4')
            | Bool -> Bool
not (Variable a
x Variable a -> VarSet (KindOf 'HProb) -> Bool
forall k (a :: k) (kproxy :: KProxy k).
(Show1 Sing, JmEq1 Sing) =>
Variable a -> VarSet kproxy -> Bool
`memberVarSet` abt '[] 'HProb -> VarSet (KindOf 'HProb)
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (a :: k).
ABT syn abt =>
abt xs a -> VarSet (KindOf a)
freeVars abt '[] 'HProb
e) -> abt '[] 'HProb
e
            -- TODO: if used exactly once in @e@, then inline.
            | Bool
otherwise ->
                case Lazy abt a -> Maybe (Variable a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Lazy abt a -> Maybe (Variable a)
getLazyVariable Lazy abt a
body of
                Just Variable a
y  -> Variable a -> abt '[] a -> abt '[] 'HProb -> abt '[] 'HProb
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k) (xs :: [k]) (b :: k).
(JmEq1 Sing, Show1 Sing, Traversable21 syn, ABT syn abt) =>
Variable a -> abt '[] a -> abt xs b -> abt xs b
subst Variable a
x (Variable a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var Variable a
y) abt '[] 'HProb
e
                Maybe (Variable a)
Nothing ->
                    case Lazy abt a -> Maybe (Literal a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Lazy abt a -> Maybe (Literal a)
getLazyLiteral Lazy abt a
body of
                    Just Literal a
v  -> Variable a -> abt '[] a -> abt '[] 'HProb -> abt '[] 'HProb
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k) (xs :: [k]) (b :: k).
(JmEq1 Sing, Show1 Sing, Traversable21 syn, ABT syn abt) =>
Variable a -> abt '[] a -> abt xs b -> abt xs b
subst Variable a
x (Term abt a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn (Term abt a -> abt '[] a) -> Term abt a -> abt '[] a
forall a b. (a -> b) -> a -> b
$ Literal a -> Term abt a
forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *).
Literal a -> Term abt a
Literal_ Literal a
v) abt '[] 'HProb
e
                    Maybe (Literal a)
Nothing ->
                        Term abt 'HProb -> abt '[] 'HProb
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn (SCon '[LC a, '( '[a], 'HProb)] 'HProb
forall (a :: Hakaru) (b :: Hakaru). SCon '[LC a, '( '[a], b)] b
Let_ SCon '[LC a, '( '[a], 'HProb)] 'HProb
-> SArgs abt '[LC a, '( '[a], 'HProb)] -> Term abt 'HProb
forall (args :: [([Hakaru], Hakaru)]) (a :: Hakaru)
       (abt :: [Hakaru] -> Hakaru -> *).
SCon args a -> SArgs abt args -> Term abt a
:$ Lazy abt a -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Lazy abt a -> abt '[] a
fromLazy Lazy abt a
body abt '[] a
-> SArgs abt '[ '( '[a], 'HProb)]
-> SArgs abt '[LC a, '( '[a], 'HProb)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* Variable a -> abt '[] 'HProb -> abt '[a] 'HProb
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k) (xs :: [k]) (b :: k).
ABT syn abt =>
Variable a -> abt xs b -> abt (a : xs) b
bind Variable a
x abt '[] 'HProb
e abt '[a] 'HProb -> SArgs abt '[] -> SArgs abt '[ '( '[a], 'HProb)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* SArgs abt '[]
forall (abt :: [Hakaru] -> Hakaru -> *). SArgs abt '[]
End)
        SStuff0    abt '[] 'HProb -> abt '[] 'HProb
f [Index (abt '[])]
_ -> abt '[] 'HProb -> abt '[] 'HProb
f abt '[] 'HProb
e
        SStuff1 Location a
_x abt '[] 'HProb -> abt '[] 'HProb
f [Index (abt '[])]
_ -> abt '[] 'HProb -> abt '[] 'HProb
f abt '[] 'HProb
e


pureEvaluate :: (ABT Term abt) => TermEvaluator abt (Expect abt)
pureEvaluate :: TermEvaluator abt (Expect abt)
pureEvaluate = MeasureEvaluator abt (Expect abt) -> TermEvaluator abt (Expect abt)
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *)
       (p :: Purity).
(ABT Term abt, EvaluationMonad abt m p) =>
MeasureEvaluator abt m -> TermEvaluator abt m
evaluate (String -> abt '[] ('HMeasure a) -> Expect abt (Whnf abt a)
forall a. String -> a
brokenInvariant String
"perform")

brokenInvariant :: String -> a
brokenInvariant :: String -> a
brokenInvariant String
loc = String -> a
forall a. HasCallStack => String -> a
error (String
loc String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": Expect's invariant broken")


-- | Run a computation in the 'Expect' monad, residualizing out all
-- the statements in the final evaluation context. The second
-- argument should include all the terms altered by the 'Eval'
-- expression; this is necessary to ensure proper hygiene; for
-- example(s):
--
-- > runExpect (pureEvaluate e) [Some2 e]
--
-- We use 'Some2' on the inputs because it doesn't matter what their
-- type or locally-bound variables are, so we want to allow @f@ to
-- contain terms with different indices.
runExpect
    :: forall abt f a
    .  (ABT Term abt, F.Foldable f)
    => Expect abt (abt '[] a)
    -> TransformCtx
    -> abt '[a] 'HProb
    -> f (Some2 abt)
    -> abt '[] 'HProb
runExpect :: Expect abt (abt '[] a)
-> TransformCtx
-> abt '[a] 'HProb
-> f (Some2 abt)
-> abt '[] 'HProb
runExpect (Expect (abt '[] a -> ExpectAns abt) -> ExpectAns abt
m) TransformCtx
ctx abt '[a] 'HProb
f f (Some2 abt)
es =
    (abt '[] a -> ExpectAns abt) -> ExpectAns abt
m abt '[] a -> ExpectAns abt
c0 ListContext abt 'ExpectP
h0
    where
    i0 :: Nat
i0   = [Nat] -> Nat
forall (t :: * -> *) a. (Foldable t, Ord a) => t a -> a
maximum [abt '[a] 'HProb -> Nat
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (a :: k).
ABT syn abt =>
abt xs a -> Nat
nextFreeOrBind abt '[a] 'HProb
f, f (Some2 abt) -> Nat
forall k2 (syn :: ([k2] -> k2 -> *) -> k2 -> *)
       (abt :: [k2] -> k2 -> *) (f :: * -> *).
(ABT syn abt, Foldable f) =>
f (Some2 abt) -> Nat
maxNextFreeOrBind f (Some2 abt)
es, TransformCtx -> Nat
nextFreeVar TransformCtx
ctx]
    h0 :: ListContext abt 'ExpectP
h0   = Nat
-> [Statement abt Location 'ExpectP] -> ListContext abt 'ExpectP
forall (abt :: [Hakaru] -> Hakaru -> *) (p :: Purity).
Nat -> [Statement abt Location p] -> ListContext abt p
ListContext Nat
i0 []
    c0 :: abt '[] a -> ExpectAns abt
c0 abt '[] a
e =
        abt '[] 'HProb -> ExpectAns abt
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
abt '[] 'HProb -> ListContext abt 'ExpectP -> abt '[] 'HProb
residualizeExpectListContext (abt '[] 'HProb -> ExpectAns abt)
-> abt '[] 'HProb -> ExpectAns abt
forall a b. (a -> b) -> a -> b
$
        abt '[] a
-> (Variable a -> abt '[] 'HProb)
-> (Term abt a -> abt '[] 'HProb)
-> abt '[] 'HProb
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k) r.
ABT syn abt =>
abt '[] a -> (Variable a -> r) -> (syn abt a -> r) -> r
caseVarSyn abt '[] a
e
            (\Variable a
x -> abt '[a] 'HProb
-> (Variable a -> abt '[] 'HProb -> abt '[] 'HProb)
-> abt '[] 'HProb
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (x :: k) (xs :: [k]) (a :: k) r.
ABT syn abt =>
abt (x : xs) a -> (Variable x -> abt xs a -> r) -> r
caseBind abt '[a] 'HProb
f ((Variable a -> abt '[] 'HProb -> abt '[] 'HProb)
 -> abt '[] 'HProb)
-> (Variable a -> abt '[] 'HProb -> abt '[] 'HProb)
-> abt '[] 'HProb
forall a b. (a -> b) -> a -> b
$ \Variable a
y abt '[] 'HProb
f' -> Variable a -> abt '[] a -> abt '[] 'HProb -> abt '[] 'HProb
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k) (xs :: [k]) (b :: k).
(JmEq1 Sing, Show1 Sing, Traversable21 syn, ABT syn abt) =>
Variable a -> abt '[] a -> abt xs b -> abt xs b
subst Variable a
y (Variable a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var Variable a
x) abt '[] 'HProb
f')
            (\Term abt a
_ -> Term abt 'HProb -> abt '[] 'HProb
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn (SCon '[LC a, '( '[a], 'HProb)] 'HProb
forall (a :: Hakaru) (b :: Hakaru). SCon '[LC a, '( '[a], b)] b
Let_ SCon '[LC a, '( '[a], 'HProb)] 'HProb
-> SArgs abt '[LC a, '( '[a], 'HProb)] -> Term abt 'HProb
forall (args :: [([Hakaru], Hakaru)]) (a :: Hakaru)
       (abt :: [Hakaru] -> Hakaru -> *).
SCon args a -> SArgs abt args -> Term abt a
:$ abt '[] a
e abt '[] a
-> SArgs abt '[ '( '[a], 'HProb)]
-> SArgs abt '[LC a, '( '[a], 'HProb)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* abt '[a] 'HProb
f abt '[a] 'HProb -> SArgs abt '[] -> SArgs abt '[ '( '[a], 'HProb)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* SArgs abt '[]
forall (abt :: [Hakaru] -> Hakaru -> *). SArgs abt '[]
End))
        -- TODO: make this smarter still, to drop the let-binding entirely if it's not used in @f@.


----------------------------------------------------------------
instance Functor (Expect abt) where
    fmap :: (a -> b) -> Expect abt a -> Expect abt b
fmap a -> b
f (Expect (a -> ExpectAns abt) -> ExpectAns abt
m) = ((b -> ExpectAns abt) -> ExpectAns abt) -> Expect abt b
forall (abt :: [Hakaru] -> Hakaru -> *) x.
((x -> ExpectAns abt) -> ExpectAns abt) -> Expect abt x
Expect (((b -> ExpectAns abt) -> ExpectAns abt) -> Expect abt b)
-> ((b -> ExpectAns abt) -> ExpectAns abt) -> Expect abt b
forall a b. (a -> b) -> a -> b
$ \b -> ExpectAns abt
c -> (a -> ExpectAns abt) -> ExpectAns abt
m (b -> ExpectAns abt
c (b -> ExpectAns abt) -> (a -> b) -> a -> ExpectAns abt
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. a -> b
f)

instance Applicative (Expect abt) where
    pure :: a -> Expect abt a
pure a
x                  = ((a -> ExpectAns abt) -> ExpectAns abt) -> Expect abt a
forall (abt :: [Hakaru] -> Hakaru -> *) x.
((x -> ExpectAns abt) -> ExpectAns abt) -> Expect abt x
Expect (((a -> ExpectAns abt) -> ExpectAns abt) -> Expect abt a)
-> ((a -> ExpectAns abt) -> ExpectAns abt) -> Expect abt a
forall a b. (a -> b) -> a -> b
$ \a -> ExpectAns abt
c -> a -> ExpectAns abt
c a
x
    Expect ((a -> b) -> ExpectAns abt) -> ExpectAns abt
mf <*> :: Expect abt (a -> b) -> Expect abt a -> Expect abt b
<*> Expect (a -> ExpectAns abt) -> ExpectAns abt
mx = ((b -> ExpectAns abt) -> ExpectAns abt) -> Expect abt b
forall (abt :: [Hakaru] -> Hakaru -> *) x.
((x -> ExpectAns abt) -> ExpectAns abt) -> Expect abt x
Expect (((b -> ExpectAns abt) -> ExpectAns abt) -> Expect abt b)
-> ((b -> ExpectAns abt) -> ExpectAns abt) -> Expect abt b
forall a b. (a -> b) -> a -> b
$ \b -> ExpectAns abt
c -> ((a -> b) -> ExpectAns abt) -> ExpectAns abt
mf (((a -> b) -> ExpectAns abt) -> ExpectAns abt)
-> ((a -> b) -> ExpectAns abt) -> ExpectAns abt
forall a b. (a -> b) -> a -> b
$ \a -> b
f -> (a -> ExpectAns abt) -> ExpectAns abt
mx ((a -> ExpectAns abt) -> ExpectAns abt)
-> (a -> ExpectAns abt) -> ExpectAns abt
forall a b. (a -> b) -> a -> b
$ \a
x -> b -> ExpectAns abt
c (a -> b
f a
x)

instance Monad (Expect abt) where
    return :: a -> Expect abt a
return         = a -> Expect abt a
forall (f :: * -> *) a. Applicative f => a -> f a
pure
    Expect (a -> ExpectAns abt) -> ExpectAns abt
m >>= :: Expect abt a -> (a -> Expect abt b) -> Expect abt b
>>= a -> Expect abt b
k = ((b -> ExpectAns abt) -> ExpectAns abt) -> Expect abt b
forall (abt :: [Hakaru] -> Hakaru -> *) x.
((x -> ExpectAns abt) -> ExpectAns abt) -> Expect abt x
Expect (((b -> ExpectAns abt) -> ExpectAns abt) -> Expect abt b)
-> ((b -> ExpectAns abt) -> ExpectAns abt) -> Expect abt b
forall a b. (a -> b) -> a -> b
$ \b -> ExpectAns abt
c -> (a -> ExpectAns abt) -> ExpectAns abt
m ((a -> ExpectAns abt) -> ExpectAns abt)
-> (a -> ExpectAns abt) -> ExpectAns abt
forall a b. (a -> b) -> a -> b
$ \a
x -> Expect abt b -> (b -> ExpectAns abt) -> ExpectAns abt
forall (abt :: [Hakaru] -> Hakaru -> *) x.
Expect abt x -> (x -> ExpectAns abt) -> ExpectAns abt
unExpect (a -> Expect abt b
k a
x) b -> ExpectAns abt
c

instance (ABT Term abt) => EvaluationMonad abt (Expect abt) 'ExpectP where
    freshNat :: Expect abt Nat
freshNat =
        ((Nat -> ExpectAns abt) -> ExpectAns abt) -> Expect abt Nat
forall (abt :: [Hakaru] -> Hakaru -> *) x.
((x -> ExpectAns abt) -> ExpectAns abt) -> Expect abt x
Expect (((Nat -> ExpectAns abt) -> ExpectAns abt) -> Expect abt Nat)
-> ((Nat -> ExpectAns abt) -> ExpectAns abt) -> Expect abt Nat
forall a b. (a -> b) -> a -> b
$ \Nat -> ExpectAns abt
c (ListContext Nat
i [Statement abt Location 'ExpectP]
ss) ->
            Nat -> ExpectAns abt
c Nat
i (Nat
-> [Statement abt Location 'ExpectP] -> ListContext abt 'ExpectP
forall (abt :: [Hakaru] -> Hakaru -> *) (p :: Purity).
Nat -> [Statement abt Location p] -> ListContext abt p
ListContext (Nat
iNat -> Nat -> Nat
forall a. Num a => a -> a -> a
+Nat
1) [Statement abt Location 'ExpectP]
ss)

    unsafePush :: Statement abt Location 'ExpectP -> Expect abt ()
unsafePush Statement abt Location 'ExpectP
s =
        ((() -> ExpectAns abt) -> ExpectAns abt) -> Expect abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) x.
((x -> ExpectAns abt) -> ExpectAns abt) -> Expect abt x
Expect (((() -> ExpectAns abt) -> ExpectAns abt) -> Expect abt ())
-> ((() -> ExpectAns abt) -> ExpectAns abt) -> Expect abt ()
forall a b. (a -> b) -> a -> b
$ \() -> ExpectAns abt
c (ListContext Nat
i [Statement abt Location 'ExpectP]
ss) ->
            () -> ExpectAns abt
c () (Nat
-> [Statement abt Location 'ExpectP] -> ListContext abt 'ExpectP
forall (abt :: [Hakaru] -> Hakaru -> *) (p :: Purity).
Nat -> [Statement abt Location p] -> ListContext abt p
ListContext Nat
i (Statement abt Location 'ExpectP
sStatement abt Location 'ExpectP
-> [Statement abt Location 'ExpectP]
-> [Statement abt Location 'ExpectP]
forall a. a -> [a] -> [a]
:[Statement abt Location 'ExpectP]
ss))

    -- N.B., the use of 'reverse' is necessary so that the order
    -- of pushing matches that of 'pushes'
    unsafePushes :: [Statement abt Location 'ExpectP] -> Expect abt ()
unsafePushes [Statement abt Location 'ExpectP]
ss =
        ((() -> ExpectAns abt) -> ExpectAns abt) -> Expect abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) x.
((x -> ExpectAns abt) -> ExpectAns abt) -> Expect abt x
Expect (((() -> ExpectAns abt) -> ExpectAns abt) -> Expect abt ())
-> ((() -> ExpectAns abt) -> ExpectAns abt) -> Expect abt ()
forall a b. (a -> b) -> a -> b
$ \() -> ExpectAns abt
c (ListContext Nat
i [Statement abt Location 'ExpectP]
ss') ->
            () -> ExpectAns abt
c () (Nat
-> [Statement abt Location 'ExpectP] -> ListContext abt 'ExpectP
forall (abt :: [Hakaru] -> Hakaru -> *) (p :: Purity).
Nat -> [Statement abt Location p] -> ListContext abt p
ListContext Nat
i ([Statement abt Location 'ExpectP]
-> [Statement abt Location 'ExpectP]
forall a. [a] -> [a]
reverse [Statement abt Location 'ExpectP]
ss [Statement abt Location 'ExpectP]
-> [Statement abt Location 'ExpectP]
-> [Statement abt Location 'ExpectP]
forall a. [a] -> [a] -> [a]
++ [Statement abt Location 'ExpectP]
ss'))

    select :: Location a
-> (Statement abt Location 'ExpectP -> Maybe (Expect abt r))
-> Expect abt (Maybe r)
select Location a
x Statement abt Location 'ExpectP -> Maybe (Expect abt r)
p = [Statement abt Location 'ExpectP] -> Expect abt (Maybe r)
loop []
        where
        -- TODO: use a DList to avoid reversing inside 'unsafePushes'
        loop :: [Statement abt Location 'ExpectP] -> Expect abt (Maybe r)
loop [Statement abt Location 'ExpectP]
ss = do
            Maybe (Statement abt Location 'ExpectP)
ms <- Expect abt (Maybe (Statement abt Location 'ExpectP))
forall (abt :: [Hakaru] -> Hakaru -> *).
Expect abt (Maybe (Statement abt Location 'ExpectP))
unsafePop
            case Maybe (Statement abt Location 'ExpectP)
ms of
                Maybe (Statement abt Location 'ExpectP)
Nothing -> do
                    [Statement abt Location 'ExpectP] -> Expect abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *)
       (p :: Purity).
EvaluationMonad abt m p =>
[Statement abt Location p] -> m ()
unsafePushes [Statement abt Location 'ExpectP]
ss
                    Maybe r -> Expect abt (Maybe r)
forall (m :: * -> *) a. Monad m => a -> m a
return Maybe r
forall a. Maybe a
Nothing
                Just Statement abt Location 'ExpectP
s  ->
                    -- Alas, @p@ will have to recheck 'isBoundBy'
                    -- in order to grab the 'Refl' proof we erased;
                    -- but there's nothing to be done for it.
                    case Location a
x Location a -> Statement abt Location 'ExpectP -> Maybe ()
forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *)
       (p :: Purity).
Location a -> Statement abt Location p -> Maybe ()
`isBoundBy` Statement abt Location 'ExpectP
s Maybe () -> Maybe (Expect abt r) -> Maybe (Expect abt r)
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Statement abt Location 'ExpectP -> Maybe (Expect abt r)
p Statement abt Location 'ExpectP
s of
                    Maybe (Expect abt r)
Nothing -> [Statement abt Location 'ExpectP] -> Expect abt (Maybe r)
loop (Statement abt Location 'ExpectP
sStatement abt Location 'ExpectP
-> [Statement abt Location 'ExpectP]
-> [Statement abt Location 'ExpectP]
forall a. a -> [a] -> [a]
:[Statement abt Location 'ExpectP]
ss)
                    Just Expect abt r
mr -> do
                        r
r <- Expect abt r
mr
                        [Statement abt Location 'ExpectP] -> Expect abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *)
       (p :: Purity).
EvaluationMonad abt m p =>
[Statement abt Location p] -> m ()
unsafePushes [Statement abt Location 'ExpectP]
ss
                        Maybe r -> Expect abt (Maybe r)
forall (m :: * -> *) a. Monad m => a -> m a
return (r -> Maybe r
forall a. a -> Maybe a
Just r
r)

-- TODO: make paremetric in the purity
-- | Not exported because we only need it for defining 'select' on 'Expect'.
unsafePop :: Expect abt (Maybe (Statement abt Location 'ExpectP))
unsafePop :: Expect abt (Maybe (Statement abt Location 'ExpectP))
unsafePop =
    ((Maybe (Statement abt Location 'ExpectP) -> ExpectAns abt)
 -> ExpectAns abt)
-> Expect abt (Maybe (Statement abt Location 'ExpectP))
forall (abt :: [Hakaru] -> Hakaru -> *) x.
((x -> ExpectAns abt) -> ExpectAns abt) -> Expect abt x
Expect (((Maybe (Statement abt Location 'ExpectP) -> ExpectAns abt)
  -> ExpectAns abt)
 -> Expect abt (Maybe (Statement abt Location 'ExpectP)))
-> ((Maybe (Statement abt Location 'ExpectP) -> ExpectAns abt)
    -> ExpectAns abt)
-> Expect abt (Maybe (Statement abt Location 'ExpectP))
forall a b. (a -> b) -> a -> b
$ \Maybe (Statement abt Location 'ExpectP) -> ExpectAns abt
c h :: ListContext abt 'ExpectP
h@(ListContext Nat
i [Statement abt Location 'ExpectP]
ss) ->
        case [Statement abt Location 'ExpectP]
ss of
        []    -> Maybe (Statement abt Location 'ExpectP) -> ExpectAns abt
c Maybe (Statement abt Location 'ExpectP)
forall a. Maybe a
Nothing  ListContext abt 'ExpectP
h
        Statement abt Location 'ExpectP
s:[Statement abt Location 'ExpectP]
ss' -> Maybe (Statement abt Location 'ExpectP) -> ExpectAns abt
c (Statement abt Location 'ExpectP
-> Maybe (Statement abt Location 'ExpectP)
forall a. a -> Maybe a
Just Statement abt Location 'ExpectP
s) (Nat
-> [Statement abt Location 'ExpectP] -> ListContext abt 'ExpectP
forall (abt :: [Hakaru] -> Hakaru -> *) (p :: Purity).
Nat -> [Statement abt Location p] -> ListContext abt p
ListContext Nat
i [Statement abt Location 'ExpectP]
ss')

----------------------------------------------------------------
emit
    :: (ABT Term abt)
    => Text
    -> Sing a
    -> (abt '[a] 'HProb -> abt '[] 'HProb)
    -> Expect abt (Variable a)
emit :: Text
-> Sing a
-> (abt '[a] 'HProb -> abt '[] 'HProb)
-> Expect abt (Variable a)
emit Text
hint Sing a
typ abt '[a] 'HProb -> abt '[] 'HProb
f = do
    Variable a
x <- Text -> Sing a -> Expect abt (Variable a)
forall (abt :: [Hakaru] -> Hakaru -> *) (m :: * -> *) (p :: Purity)
       (a :: Hakaru).
EvaluationMonad abt m p =>
Text -> Sing a -> m (Variable a)
freshVar Text
hint Sing a
typ
    ((Variable a -> ExpectAns abt) -> ExpectAns abt)
-> Expect abt (Variable a)
forall (abt :: [Hakaru] -> Hakaru -> *) x.
((x -> ExpectAns abt) -> ExpectAns abt) -> Expect abt x
Expect (((Variable a -> ExpectAns abt) -> ExpectAns abt)
 -> Expect abt (Variable a))
-> ((Variable a -> ExpectAns abt) -> ExpectAns abt)
-> Expect abt (Variable a)
forall a b. (a -> b) -> a -> b
$ \Variable a -> ExpectAns abt
c ListContext abt 'ExpectP
h -> (abt '[a] 'HProb -> abt '[] 'HProb
f (abt '[a] 'HProb -> abt '[] 'HProb)
-> (abt '[] 'HProb -> abt '[a] 'HProb)
-> abt '[] 'HProb
-> abt '[] 'HProb
forall k (cat :: k -> k -> *) (b :: k) (c :: k) (a :: k).
Category cat =>
cat b c -> cat a b -> cat a c
. Variable a -> abt '[] 'HProb -> abt '[a] 'HProb
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k) (xs :: [k]) (b :: k).
ABT syn abt =>
Variable a -> abt xs b -> abt (a : xs) b
bind Variable a
x) (abt '[] 'HProb -> abt '[] 'HProb)
-> abt '[] 'HProb -> abt '[] 'HProb
forall a b. (a -> b) -> a -> b
$ Variable a -> ExpectAns abt
c Variable a
x ListContext abt 'ExpectP
h

emit_
    :: (ABT Term abt)
    => (abt '[] 'HProb -> abt '[] 'HProb)
    -> Expect abt ()
emit_ :: (abt '[] 'HProb -> abt '[] 'HProb) -> Expect abt ()
emit_ abt '[] 'HProb -> abt '[] 'HProb
f = ((() -> ExpectAns abt) -> ExpectAns abt) -> Expect abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) x.
((x -> ExpectAns abt) -> ExpectAns abt) -> Expect abt x
Expect (((() -> ExpectAns abt) -> ExpectAns abt) -> Expect abt ())
-> ((() -> ExpectAns abt) -> ExpectAns abt) -> Expect abt ()
forall a b. (a -> b) -> a -> b
$ \() -> ExpectAns abt
c ListContext abt 'ExpectP
h -> abt '[] 'HProb -> abt '[] 'HProb
f (abt '[] 'HProb -> abt '[] 'HProb)
-> abt '[] 'HProb -> abt '[] 'HProb
forall a b. (a -> b) -> a -> b
$ () -> ExpectAns abt
c () ListContext abt 'ExpectP
h

----------------------------------------------------------------
----------------------------------------------------------- fin.