{-# LANGUAGE CPP
           , BangPatterns
           , DataKinds
           , EmptyCase
           , ExistentialQuantification
           , FlexibleContexts
           , FlexibleInstances
           , GADTs
           , GeneralizedNewtypeDeriving
           , KindSignatures
           , MultiParamTypeClasses
           , OverloadedStrings
           , PolyKinds
           , ScopedTypeVariables
           , StandaloneDeriving
           , TupleSections
           , TypeFamilies
           , TypeOperators
           , UndecidableInstances
           #-}

{-# OPTIONS_GHC -Wall -fwarn-tabs #-}
----------------------------------------------------------------
--                                                    2017.02.01
-- |
-- Module      :  Language.Hakaru.Syntax.Hoist
-- Copyright   :  Copyright (c) 2016 the Hakaru team
-- License     :  BSD3
-- Maintainer  :
-- Stability   :  experimental
-- Portability :  GHC-only
--
-- Hoist expressions to the point where their data dependencies are met.
-- This pass duplicates *a lot* of work and relies on a the CSE and pruning
-- passes to cleanup the junk (most of which is trivial to do, but we don't know
-- what is junk until after CSE has occured).
--
-- NOTE: This pass assumes globally unique variable ids, as two subterms may
-- otherwise bind the same variable. Those variables would potentially shadow
-- eachother if hoisted upward to a common scope.
--
----------------------------------------------------------------
module Language.Hakaru.Syntax.Hoist (hoist) where

import           Control.Applicative             (liftA2)
import           Control.Monad.RWS               hiding ((<>))
import qualified Data.Foldable                   as F
import qualified Data.Graph                      as G
import qualified Data.IntMap.Strict              as IM
import qualified Data.List                       as L
import           Data.Maybe                      (mapMaybe)
import           Data.Number.Nat
import           Data.Proxy                      (KProxy (..))
import qualified Data.Vector                     as V

import           Language.Hakaru.Syntax.ABT
import           Language.Hakaru.Syntax.ANF      (isValue)
import           Language.Hakaru.Syntax.AST
import           Language.Hakaru.Syntax.AST.Eq   (alphaEq)
import           Language.Hakaru.Syntax.Gensym
import           Language.Hakaru.Syntax.IClasses
import           Language.Hakaru.Types.DataKind
import           Language.Hakaru.Types.Sing      (Sing)

#if __GLASGOW_HASKELL__ < 710
import           Control.Applicative
#endif

#if !(MIN_VERSION_base(4,11,0))
import Data.Semigroup
#endif

data Entry (abt :: Hakaru -> *)
  = forall (a :: Hakaru) . Entry
  { ()
varDependencies :: !(VarSet (KindOf a))
  , ()
expression      :: !(abt a)
  -- The type of the expression, to allow for easy comparison of types.
  -- The typeOf operator is technically O(n) in the size of the expresion
  -- and we may need to call it many times.
  , ()
sing            :: !(Sing a)
  , ()
bindings        :: ![Variable a]
  }

instance Show (Entry abt) where
  show :: Entry abt -> String
show (Entry VarSet (KindOf a)
d abt a
_ Sing a
_ [Variable a]
b) = String
"Entry (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ VarSet (KindOf a) -> String
forall a. Show a => a -> String
show VarSet (KindOf a)
d String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
") (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ [Variable a] -> String
forall a. Show a => a -> String
show [Variable a]
b String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

type HakaruProxy = ('KProxy :: KProxy Hakaru)
type LiveSet     = VarSet HakaruProxy
type HakaruVar   = SomeVariable HakaruProxy

-- The @HoistM@ monad makes use of three monadic layers to propagate information
-- both downwards to the leaves and upwards to the root node of the AST.
--
-- The Writer layer propagates the live expressions which may be hoisted (i.e.
-- all their data dependencies are currently filled) from each subexpression to
-- their parents.
--
-- The Reader layer propagates the currently bound variables which will be used
-- to decide when to introduce new bindings.
--
-- The State layer is just to provide a counter in order to gensym new
-- variables, since the process of adding new bindings is a little tricky.
-- What we want is to fully duplicate bindings without altering the original
-- variable identifiers. To do so, all original variable names are preserved and
-- new variables are added outside the range of existing variables.
newtype HoistM (abt :: [Hakaru] -> Hakaru -> *) a
  = HoistM { HoistM abt a -> RWS (VarSet (KindOf a)) (ExpressionSet abt) Nat a
runHoistM :: RWS LiveSet (ExpressionSet abt) Nat a }

deriving instance                   Functor (HoistM abt)
deriving instance (ABT Term abt) => Applicative (HoistM abt)
deriving instance (ABT Term abt) => Monad (HoistM abt)
deriving instance (ABT Term abt) => MonadState Nat (HoistM abt)
deriving instance (ABT Term abt) => MonadWriter (ExpressionSet abt) (HoistM abt)
deriving instance (ABT Term abt) => MonadReader LiveSet (HoistM abt)

newtype ExpressionSet (abt :: [Hakaru] -> Hakaru -> *)
  = ExpressionSet [Entry (abt '[])]

mergeEntry :: (ABT Term abt) => Entry (abt '[]) -> Entry (abt '[]) -> Entry (abt '[])
mergeEntry :: Entry (abt '[]) -> Entry (abt '[]) -> Entry (abt '[])
mergeEntry (Entry VarSet (KindOf a)
d abt '[] a
e Sing a
s1 [Variable a]
b1) (Entry VarSet (KindOf a)
_ abt '[] a
_ Sing a
s2 [Variable a]
b2) =
  case Sing a -> Sing a -> Maybe (TypeEq a a)
forall k (a :: k -> *) (i :: k) (j :: k).
JmEq1 a =>
a i -> a j -> Maybe (TypeEq i j)
jmEq1 Sing a
s1 Sing a
s2 of
    Just TypeEq a a
Refl -> VarSet (KindOf a)
-> abt '[] a -> Sing a -> [Variable a] -> Entry (abt '[])
forall (abt :: Hakaru -> *) (a :: Hakaru).
VarSet (KindOf a) -> abt a -> Sing a -> [Variable a] -> Entry abt
Entry VarSet (KindOf a)
d abt '[] a
e Sing a
s1 ([Variable a] -> Entry (abt '[]))
-> [Variable a] -> Entry (abt '[])
forall a b. (a -> b) -> a -> b
$ [Variable a] -> [Variable a]
forall a. Eq a => [a] -> [a]
L.nub ([Variable a]
b1 [Variable a] -> [Variable a] -> [Variable a]
forall a. [a] -> [a] -> [a]
++ [Variable a]
[Variable a]
b2)
    Maybe (TypeEq a a)
Nothing   -> String -> Entry (abt '[])
forall a. HasCallStack => String -> a
error String
"cannot union mismatched entries"

entryEqual :: (ABT Term abt) => Entry (abt '[]) -> Entry (abt '[]) -> Bool
entryEqual :: Entry (abt '[]) -> Entry (abt '[]) -> Bool
entryEqual Entry{varDependencies :: ()
varDependencies=VarSet (KindOf a)
d1,expression :: ()
expression=abt '[] a
e1,sing :: ()
sing=Sing a
s1}
           Entry{varDependencies :: ()
varDependencies=VarSet (KindOf a)
d2,expression :: ()
expression=abt '[] a
e2,sing :: ()
sing=Sing a
s2} =
  case (VarSet (KindOf a)
d1 VarSet (KindOf a) -> VarSet (KindOf a) -> Bool
forall a. Eq a => a -> a -> Bool
== VarSet (KindOf a)
d2, Sing a -> Sing a -> Maybe (TypeEq a a)
forall k (a :: k -> *) (i :: k) (j :: k).
JmEq1 a =>
a i -> a j -> Maybe (TypeEq i j)
jmEq1 Sing a
s1 Sing a
s2) of
    (Bool
True , Just TypeEq a a
Refl) -> abt '[] a -> abt '[] a -> Bool
forall (abt :: [Hakaru] -> Hakaru -> *) (d :: Hakaru).
ABT Term abt =>
abt '[] d -> abt '[] d -> Bool
alphaEq abt '[] a
e1 abt '[] a
abt '[] a
e2
    (Bool, Maybe (TypeEq a a))
_                  -> Bool
False

unionEntrySet
  :: forall abt
  .  (ABT Term abt)
  => ExpressionSet abt
  -> ExpressionSet abt
  -> ExpressionSet abt
unionEntrySet :: ExpressionSet abt -> ExpressionSet abt -> ExpressionSet abt
unionEntrySet (ExpressionSet [Entry (abt '[])]
xs) (ExpressionSet [Entry (abt '[])]
ys) =
  [Entry (abt '[])] -> ExpressionSet abt
forall (abt :: [Hakaru] -> Hakaru -> *).
[Entry (abt '[])] -> ExpressionSet abt
ExpressionSet ([Entry (abt '[])] -> ExpressionSet abt)
-> ([[Entry (abt '[])]] -> [Entry (abt '[])])
-> [[Entry (abt '[])]]
-> ExpressionSet abt
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([Entry (abt '[])] -> Maybe (Entry (abt '[])))
-> [[Entry (abt '[])]] -> [Entry (abt '[])]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe [Entry (abt '[])] -> Maybe (Entry (abt '[]))
uniquify ([[Entry (abt '[])]] -> ExpressionSet abt)
-> [[Entry (abt '[])]] -> ExpressionSet abt
forall a b. (a -> b) -> a -> b
$ (Entry (abt '[]) -> Entry (abt '[]) -> Bool)
-> [Entry (abt '[])] -> [[Entry (abt '[])]]
forall a. (a -> a -> Bool) -> [a] -> [[a]]
L.groupBy Entry (abt '[]) -> Entry (abt '[]) -> Bool
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Entry (abt '[]) -> Entry (abt '[]) -> Bool
entryEqual ([Entry (abt '[])]
xs [Entry (abt '[])] -> [Entry (abt '[])] -> [Entry (abt '[])]
forall a. [a] -> [a] -> [a]
++ [Entry (abt '[])]
ys)
  where
    uniquify :: [Entry (abt '[])] -> Maybe (Entry (abt '[]))
    uniquify :: [Entry (abt '[])] -> Maybe (Entry (abt '[]))
uniquify [] = Maybe (Entry (abt '[]))
forall a. Maybe a
Nothing
    uniquify [Entry (abt '[])]
zs = Entry (abt '[]) -> Maybe (Entry (abt '[]))
forall a. a -> Maybe a
Just (Entry (abt '[]) -> Maybe (Entry (abt '[])))
-> Entry (abt '[]) -> Maybe (Entry (abt '[]))
forall a b. (a -> b) -> a -> b
$ (Entry (abt '[]) -> Entry (abt '[]) -> Entry (abt '[]))
-> [Entry (abt '[])] -> Entry (abt '[])
forall a. (a -> a -> a) -> [a] -> a
L.foldl1' Entry (abt '[]) -> Entry (abt '[]) -> Entry (abt '[])
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Entry (abt '[]) -> Entry (abt '[]) -> Entry (abt '[])
mergeEntry [Entry (abt '[])]
zs

intersectEntrySet
  :: forall abt
  .  (ABT Term abt)
  => ExpressionSet abt
  -> ExpressionSet abt
  -> ExpressionSet abt
intersectEntrySet :: ExpressionSet abt -> ExpressionSet abt -> ExpressionSet abt
intersectEntrySet (ExpressionSet [Entry (abt '[])]
xs) (ExpressionSet [Entry (abt '[])]
ys) = [Entry (abt '[])] -> ExpressionSet abt
forall (abt :: [Hakaru] -> Hakaru -> *).
[Entry (abt '[])] -> ExpressionSet abt
ExpressionSet [Entry (abt '[])]
merged
  where
    merged :: [Entry (abt '[])]
    merged :: [Entry (abt '[])]
merged = ((Entry (abt '[]), Entry (abt '[])) -> Entry (abt '[]))
-> [(Entry (abt '[]), Entry (abt '[]))] -> [Entry (abt '[])]
forall a b. (a -> b) -> [a] -> [b]
map ((Entry (abt '[]) -> Entry (abt '[]) -> Entry (abt '[]))
-> (Entry (abt '[]), Entry (abt '[])) -> Entry (abt '[])
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Entry (abt '[]) -> Entry (abt '[]) -> Entry (abt '[])
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Entry (abt '[]) -> Entry (abt '[]) -> Entry (abt '[])
mergeEntry) ([(Entry (abt '[]), Entry (abt '[]))] -> [Entry (abt '[])])
-> ([(Entry (abt '[]), Entry (abt '[]))]
    -> [(Entry (abt '[]), Entry (abt '[]))])
-> [(Entry (abt '[]), Entry (abt '[]))]
-> [Entry (abt '[])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Entry (abt '[]), Entry (abt '[])) -> Bool)
-> [(Entry (abt '[]), Entry (abt '[]))]
-> [(Entry (abt '[]), Entry (abt '[]))]
forall a. (a -> Bool) -> [a] -> [a]
filter ((Entry (abt '[]) -> Entry (abt '[]) -> Bool)
-> (Entry (abt '[]), Entry (abt '[])) -> Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Entry (abt '[]) -> Entry (abt '[]) -> Bool
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Entry (abt '[]) -> Entry (abt '[]) -> Bool
entryEqual) ([(Entry (abt '[]), Entry (abt '[]))] -> [Entry (abt '[])])
-> [(Entry (abt '[]), Entry (abt '[]))] -> [Entry (abt '[])]
forall a b. (a -> b) -> a -> b
$ (Entry (abt '[])
 -> Entry (abt '[]) -> (Entry (abt '[]), Entry (abt '[])))
-> [Entry (abt '[])]
-> [Entry (abt '[])]
-> [(Entry (abt '[]), Entry (abt '[]))]
forall (f :: * -> *) a b c.
Applicative f =>
(a -> b -> c) -> f a -> f b -> f c
liftA2 (,) [Entry (abt '[])]
xs [Entry (abt '[])]
ys

-- The general case for generating the entry set for a term is to simply union
-- the sets for all the subterms, so we choose union as our monoidal operation
-- for the Writer monad.
instance (ABT Term abt) => Semigroup (ExpressionSet abt) where
  <> :: ExpressionSet abt -> ExpressionSet abt -> ExpressionSet abt
(<>) = ExpressionSet abt -> ExpressionSet abt -> ExpressionSet abt
forall (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
ExpressionSet abt -> ExpressionSet abt -> ExpressionSet abt
unionEntrySet

instance (ABT Term abt) => Monoid (ExpressionSet abt) where
  mempty :: ExpressionSet abt
mempty  = [Entry (abt '[])] -> ExpressionSet abt
forall (abt :: [Hakaru] -> Hakaru -> *).
[Entry (abt '[])] -> ExpressionSet abt
ExpressionSet []
#if !(MIN_VERSION_base(4,11,0))
  mappend = (<>)
#endif

-- Given a list of entries to introduce, order them so that their data
-- data dependencies are satisified.
topSortEntries
  :: forall abt
  .  [Entry (abt '[])]
  -> [Entry (abt '[])]
topSortEntries :: [Entry (abt '[])] -> [Entry (abt '[])]
topSortEntries [Entry (abt '[])]
entryList = (Int -> Entry (abt '[])) -> [Int] -> [Entry (abt '[])]
forall a b. (a -> b) -> [a] -> [b]
map (Vector (Entry (abt '[]))
entries Vector (Entry (abt '[])) -> Int -> Entry (abt '[])
forall a. Vector a -> Int -> a
V.!) ([Int] -> [Entry (abt '[])]) -> [Int] -> [Entry (abt '[])]
forall a b. (a -> b) -> a -> b
$ Graph -> [Int]
G.topSort Graph
graph
  where
    entries :: V.Vector (Entry (abt '[]))
    !entries :: Vector (Entry (abt '[]))
entries = [Entry (abt '[])] -> Vector (Entry (abt '[]))
forall a. [a] -> Vector a
V.fromList [Entry (abt '[])]
entryList

    -- The graph is represented as dependencies between entries, where an entry
    -- (a) depends on entry (b) if (b) introduces a variable which (a) depends
    -- on.
    getVIDs :: Entry (abt '[]) -> [Int]
    getVIDs :: Entry (abt '[]) -> [Int]
getVIDs Entry{bindings :: ()
bindings=[Variable a]
b} = (Variable a -> Int) -> [Variable a] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Nat -> Int
fromNat (Nat -> Int) -> (Variable a -> Nat) -> Variable a -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Variable a -> Nat
forall k (a :: k). Variable a -> Nat
varID) [Variable a]
b

    -- Associates all variables introduced by an entry to the entry itself.
    -- A given entry may introduce multiple bindings, since an entry stores all
    -- α-equivalent variable definitions.
    assocBindingsTo :: IM.IntMap Int -> Int -> Entry (abt '[]) -> IM.IntMap Int
    assocBindingsTo :: IntMap Int -> Int -> Entry (abt '[]) -> IntMap Int
assocBindingsTo IntMap Int
m Int
n = (IntMap Int -> Int -> IntMap Int)
-> IntMap Int -> [Int] -> IntMap Int
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' (\IntMap Int
acc Int
v -> Int -> Int -> IntMap Int -> IntMap Int
forall a. Int -> a -> IntMap a -> IntMap a
IM.insert Int
v Int
n IntMap Int
acc) IntMap Int
m ([Int] -> IntMap Int)
-> (Entry (abt '[]) -> [Int]) -> Entry (abt '[]) -> IntMap Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Entry (abt '[]) -> [Int]
getVIDs

    -- Mapping from variable IDs to their corresponding entries
    varMap :: IM.IntMap Int
    !varMap :: IntMap Int
varMap = (IntMap Int -> Int -> Entry (abt '[]) -> IntMap Int)
-> IntMap Int -> Vector (Entry (abt '[])) -> IntMap Int
forall a b. (a -> Int -> b -> a) -> a -> Vector b -> a
V.ifoldl' IntMap Int -> Int -> Entry (abt '[]) -> IntMap Int
assocBindingsTo IntMap Int
forall a. IntMap a
IM.empty Vector (Entry (abt '[]))
entries

    -- Create an edge from each dependency to the variable
    makeEdges :: Int -> Entry (abt '[]) -> [G.Edge]
    makeEdges :: Int -> Entry (abt '[]) -> [Edge]
makeEdges Int
idx Entry{varDependencies :: ()
varDependencies=VarSet (KindOf a)
d} = (Int -> Edge) -> [Int] -> [Edge]
forall a b. (a -> b) -> [a] -> [b]
map (, Int
idx)
                                           ([Int] -> [Edge]) -> ([Int] -> [Int]) -> [Int] -> [Edge]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Maybe Int) -> [Int] -> [Int]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe ((Int -> IntMap Int -> Maybe Int) -> IntMap Int -> Int -> Maybe Int
forall a b c. (a -> b -> c) -> b -> a -> c
flip Int -> IntMap Int -> Maybe Int
forall a. Int -> IntMap a -> Maybe a
IM.lookup IntMap Int
varMap)
                                           ([Int] -> [Edge]) -> [Int] -> [Edge]
forall a b. (a -> b) -> a -> b
$ VarSet (KindOf a) -> [Int]
forall k (a :: KProxy k). VarSet a -> [Int]
varSetKeys VarSet (KindOf a)
d

    -- Collect all the verticies to build the full graph
    vertices :: [G.Edge]
    !vertices :: [Edge]
vertices = ([Edge] -> [Edge] -> [Edge]) -> [Edge] -> Vector [Edge] -> [Edge]
forall a b. (a -> b -> b) -> b -> Vector a -> b
V.foldr [Edge] -> [Edge] -> [Edge]
forall a. [a] -> [a] -> [a]
(++) [] (Vector [Edge] -> [Edge]) -> Vector [Edge] -> [Edge]
forall a b. (a -> b) -> a -> b
$ (Int -> Entry (abt '[]) -> [Edge])
-> Vector (Entry (abt '[])) -> Vector [Edge]
forall a b. (Int -> a -> b) -> Vector a -> Vector b
V.imap Int -> Entry (abt '[]) -> [Edge]
makeEdges Vector (Entry (abt '[]))
entries

    -- The full graph structure to be topologically sorted
    graph :: G.Graph
    !graph :: Graph
graph = Edge -> [Edge] -> Graph
G.buildG (Int
0, Vector (Entry (abt '[])) -> Int
forall a. Vector a -> Int
V.length Vector (Entry (abt '[]))
entries Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) [Edge]
vertices

recordEntry
  :: (ABT Term abt)
  => Variable a
  -> abt '[] a
  -> HoistM abt ()
recordEntry :: Variable a -> abt '[] a -> HoistM abt ()
recordEntry Variable a
v abt '[] a
abt = ExpressionSet abt -> HoistM abt ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell (ExpressionSet abt -> HoistM abt ())
-> ExpressionSet abt -> HoistM abt ()
forall a b. (a -> b) -> a -> b
$ [Entry (abt '[])] -> ExpressionSet abt
forall (abt :: [Hakaru] -> Hakaru -> *).
[Entry (abt '[])] -> ExpressionSet abt
ExpressionSet [VarSet (KindOf a)
-> abt '[] a -> Sing a -> [Variable a] -> Entry (abt '[])
forall (abt :: Hakaru -> *) (a :: Hakaru).
VarSet (KindOf a) -> abt a -> Sing a -> [Variable a] -> Entry abt
Entry (abt '[] a -> VarSet (KindOf a)
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (a :: k).
ABT syn abt =>
abt xs a -> VarSet (KindOf a)
freeVars abt '[] a
abt) abt '[] a
abt (Variable a -> Sing a
forall k (a :: k). Variable a -> Sing a
varType Variable a
v) [Variable a
v]]

execHoistM :: Nat -> HoistM abt a -> a
execHoistM :: Nat -> HoistM abt a -> a
execHoistM Nat
counter HoistM abt a
act = a
a
  where
    hoisted :: RWS (VarSet (KindOf a)) (ExpressionSet abt) Nat a
hoisted   = HoistM abt a -> RWS (VarSet (KindOf a)) (ExpressionSet abt) Nat a
forall (abt :: [Hakaru] -> Hakaru -> *) a.
HoistM abt a -> RWS (VarSet (KindOf a)) (ExpressionSet abt) Nat a
runHoistM HoistM abt a
act
    (a
a, Nat
_, ExpressionSet abt
_) = RWS (VarSet (KindOf a)) (ExpressionSet abt) Nat a
-> VarSet (KindOf a) -> Nat -> (a, Nat, ExpressionSet abt)
forall r w s a. RWS r w s a -> r -> s -> (a, s, w)
runRWS RWS (VarSet (KindOf a)) (ExpressionSet abt) Nat a
hoisted VarSet (KindOf a)
forall k (kproxy :: KProxy k). VarSet kproxy
emptyVarSet Nat
counter

-- | An expression is considered "toplevel" if it can be hoisted outside all
-- binders. This means that the expression has no data dependencies.
toplevelEntry
  :: Entry abt
  -> Bool
toplevelEntry :: Entry abt -> Bool
toplevelEntry Entry{varDependencies :: ()
varDependencies=VarSet (KindOf a)
d} = VarSet (KindOf a) -> Int
forall k (a :: KProxy k). VarSet a -> Int
sizeVarSet VarSet (KindOf a)
d Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0

captureEntries
  :: (ABT Term abt)
  => HoistM abt a
  -> HoistM abt (a, ExpressionSet abt)
captureEntries :: HoistM abt a -> HoistM abt (a, ExpressionSet abt)
captureEntries = (ExpressionSet abt -> ExpressionSet abt)
-> HoistM abt (a, ExpressionSet abt)
-> HoistM abt (a, ExpressionSet abt)
forall w (m :: * -> *) a. MonadWriter w m => (w -> w) -> m a -> m a
censor (ExpressionSet abt -> ExpressionSet abt -> ExpressionSet abt
forall a b. a -> b -> a
const ExpressionSet abt
forall a. Monoid a => a
mempty) (HoistM abt (a, ExpressionSet abt)
 -> HoistM abt (a, ExpressionSet abt))
-> (HoistM abt a -> HoistM abt (a, ExpressionSet abt))
-> HoistM abt a
-> HoistM abt (a, ExpressionSet abt)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HoistM abt a -> HoistM abt (a, ExpressionSet abt)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen

hoist
  :: (ABT Term abt)
  => abt '[] a
  -> abt '[] a
hoist :: abt '[] a -> abt '[] a
hoist abt '[] a
abt = Nat -> HoistM abt (abt '[] a) -> abt '[] a
forall (abt :: [Hakaru] -> Hakaru -> *) a. Nat -> HoistM abt a -> a
execHoistM (abt '[] a -> Nat
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (a :: k).
ABT syn abt =>
abt xs a -> Nat
nextFreeOrBind abt '[] a
abt) (HoistM abt (abt '[] a) -> abt '[] a)
-> HoistM abt (abt '[] a) -> abt '[] a
forall a b. (a -> b) -> a -> b
$
  HoistM abt (abt '[] a) -> HoistM abt (abt '[] a, ExpressionSet abt)
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
HoistM abt a -> HoistM abt (a, ExpressionSet abt)
captureEntries (abt '[] a -> HoistM abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (xs :: [Hakaru])
       (a :: Hakaru).
ABT Term abt =>
abt xs a -> HoistM abt (abt xs a)
hoist' abt '[] a
abt) HoistM abt (abt '[] a, ExpressionSet abt)
-> ((abt '[] a, ExpressionSet abt) -> HoistM abt (abt '[] a))
-> HoistM abt (abt '[] a)
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (abt '[] a -> ExpressionSet abt -> HoistM abt (abt '[] a))
-> (abt '[] a, ExpressionSet abt) -> HoistM abt (abt '[] a)
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry (VarSet (KindOf a)
-> abt '[] a -> ExpressionSet abt -> HoistM abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
VarSet (KindOf a)
-> abt '[] a -> ExpressionSet abt -> HoistM abt (abt '[] a)
introduceToplevel VarSet (KindOf a)
forall k (kproxy :: KProxy k). VarSet kproxy
emptyVarSet)

partitionEntrySet
  :: (Entry (abt '[]) -> Bool)
  -> ExpressionSet abt
  -> (ExpressionSet abt, ExpressionSet abt)
partitionEntrySet :: (Entry (abt '[]) -> Bool)
-> ExpressionSet abt -> (ExpressionSet abt, ExpressionSet abt)
partitionEntrySet Entry (abt '[]) -> Bool
p (ExpressionSet [Entry (abt '[])]
xs) = ([Entry (abt '[])] -> ExpressionSet abt
forall (abt :: [Hakaru] -> Hakaru -> *).
[Entry (abt '[])] -> ExpressionSet abt
ExpressionSet [Entry (abt '[])]
true, [Entry (abt '[])] -> ExpressionSet abt
forall (abt :: [Hakaru] -> Hakaru -> *).
[Entry (abt '[])] -> ExpressionSet abt
ExpressionSet [Entry (abt '[])]
false)
  where
    ([Entry (abt '[])]
true, [Entry (abt '[])]
false) = (Entry (abt '[]) -> Bool)
-> [Entry (abt '[])] -> ([Entry (abt '[])], [Entry (abt '[])])
forall a. (a -> Bool) -> [a] -> ([a], [a])
L.partition Entry (abt '[]) -> Bool
p [Entry (abt '[])]
xs

introduceToplevel
  :: (ABT Term abt)
  => LiveSet
  -> abt '[] a
  -> ExpressionSet abt
  -> HoistM abt (abt '[] a)
introduceToplevel :: VarSet (KindOf a)
-> abt '[] a -> ExpressionSet abt -> HoistM abt (abt '[] a)
introduceToplevel VarSet (KindOf a)
avail abt '[] a
abt ExpressionSet abt
entries = do
  -- After transforming the given ast, we need to introduce all the toplevel
  -- bindings (i.e. bindings with no data dependencies), most of which should be
  -- eliminated by constant propagation.
  let (ExpressionSet [Entry (abt '[])]
toplevel, ExpressionSet abt
rest) = (Entry (abt '[]) -> Bool)
-> ExpressionSet abt -> (ExpressionSet abt, ExpressionSet abt)
forall (abt :: [Hakaru] -> Hakaru -> *).
(Entry (abt '[]) -> Bool)
-> ExpressionSet abt -> (ExpressionSet abt, ExpressionSet abt)
partitionEntrySet Entry (abt '[]) -> Bool
forall (abt :: Hakaru -> *). Entry abt -> Bool
toplevelEntry ExpressionSet abt
entries
      intro :: [HakaruVar]
intro = (Entry (abt '[]) -> [HakaruVar])
-> [Entry (abt '[])] -> [HakaruVar]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Entry (abt '[]) -> [HakaruVar]
forall (x :: Hakaru -> *). Entry x -> [HakaruVar]
getBoundVars [Entry (abt '[])]
toplevel [HakaruVar] -> [HakaruVar] -> [HakaruVar]
forall a. [a] -> [a] -> [a]
++ VarSet (KindOf a) -> [HakaruVar]
forall k (kproxy :: KProxy k).
VarSet kproxy -> [SomeVariable kproxy]
fromVarSet VarSet (KindOf a)
avail
  -- First we wrap the now AST in the all terms which depdend on top level
  -- definitions
  abt '[] a
wrapped <- [HakaruVar]
-> abt '[] a -> ExpressionSet abt -> HoistM abt (abt '[] a)
forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
[HakaruVar]
-> abt '[] a -> ExpressionSet abt -> HoistM abt (abt '[] a)
introduceBindings [HakaruVar]
intro abt '[] a
abt ExpressionSet abt
rest
  -- Then wrap the result in the toplevel definitions
  abt '[] a -> [Entry (abt '[])] -> HoistM abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (b :: Hakaru).
ABT Term abt =>
abt '[] b -> [Entry (abt '[])] -> HoistM abt (abt '[] b)
wrapExpr abt '[] a
wrapped [Entry (abt '[])]
toplevel

bindVar
  :: (ABT Term abt)
  => Variable (a :: Hakaru)
  -> HoistM abt b
  -> HoistM abt b
bindVar :: Variable a -> HoistM abt b -> HoistM abt b
bindVar = (VarSet (KindOf a) -> VarSet (KindOf a))
-> HoistM abt b -> HoistM abt b
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((VarSet (KindOf a) -> VarSet (KindOf a))
 -> HoistM abt b -> HoistM abt b)
-> (Variable a -> VarSet (KindOf a) -> VarSet (KindOf a))
-> Variable a
-> HoistM abt b
-> HoistM abt b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Variable a -> VarSet (KindOf a) -> VarSet (KindOf a)
forall k (a :: k).
Variable a -> VarSet (KindOf a) -> VarSet (KindOf a)
insertVarSet

isolateBinder
  :: (ABT Term abt)
  => Variable (a :: Hakaru)
  -> HoistM abt b
  -> HoistM abt (b, ExpressionSet abt)
isolateBinder :: Variable a -> HoistM abt b -> HoistM abt (b, ExpressionSet abt)
isolateBinder Variable a
v = HoistM abt b -> HoistM abt (b, ExpressionSet abt)
forall (abt :: [Hakaru] -> Hakaru -> *) a.
ABT Term abt =>
HoistM abt a -> HoistM abt (a, ExpressionSet abt)
captureEntries (HoistM abt b -> HoistM abt (b, ExpressionSet abt))
-> (HoistM abt b -> HoistM abt b)
-> HoistM abt b
-> HoistM abt (b, ExpressionSet abt)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Variable a -> HoistM abt b -> HoistM abt b
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru) b.
ABT Term abt =>
Variable a -> HoistM abt b -> HoistM abt b
bindVar Variable a
v

hoist'
  :: forall abt xs a . (ABT Term abt)
  => abt xs a
  -> HoistM abt (abt xs a)
hoist' :: abt xs a -> HoistM abt (abt xs a)
hoist' = abt xs a -> HoistM abt (abt xs a)
forall (ys :: [Hakaru]) (b :: Hakaru).
abt ys b -> HoistM abt (abt ys b)
start
  where
    insertMany :: [HakaruVar] -> LiveSet -> LiveSet
    insertMany :: [HakaruVar] -> VarSet (KindOf a) -> VarSet (KindOf a)
insertMany = (VarSet (KindOf a) -> [HakaruVar] -> VarSet (KindOf a))
-> [HakaruVar] -> VarSet (KindOf a) -> VarSet (KindOf a)
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((VarSet (KindOf a) -> [HakaruVar] -> VarSet (KindOf a))
 -> [HakaruVar] -> VarSet (KindOf a) -> VarSet (KindOf a))
-> (VarSet (KindOf a) -> [HakaruVar] -> VarSet (KindOf a))
-> [HakaruVar]
-> VarSet (KindOf a)
-> VarSet (KindOf a)
forall a b. (a -> b) -> a -> b
$ (VarSet (KindOf a) -> HakaruVar -> VarSet (KindOf a))
-> VarSet (KindOf a) -> [HakaruVar] -> VarSet (KindOf a)
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' (\ VarSet (KindOf a)
acc (SomeVariable Variable a
v) -> Variable a -> VarSet (KindOf a) -> VarSet (KindOf a)
forall k (a :: k).
Variable a -> VarSet (KindOf a) -> VarSet (KindOf a)
insertVarSet Variable a
v VarSet (KindOf a)
acc)

    start :: forall ys b . abt ys b -> HoistM abt (abt ys b)
    start :: abt ys b -> HoistM abt (abt ys b)
start = [HakaruVar] -> View (Term abt) ys b -> HoistM abt (abt ys b)
forall (ys :: [Hakaru]) (b :: Hakaru).
[HakaruVar] -> View (Term abt) ys b -> HoistM abt (abt ys b)
loop [] (View (Term abt) ys b -> HoistM abt (abt ys b))
-> (abt ys b -> View (Term abt) ys b)
-> abt ys b
-> HoistM abt (abt ys b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. abt ys b -> View (Term abt) ys b
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (a :: k).
ABT syn abt =>
abt xs a -> View (syn abt) xs a
viewABT

    isolateBinders :: [HakaruVar] -> HoistM abt c -> HoistM abt (c, ExpressionSet abt)
    isolateBinders :: [HakaruVar] -> HoistM abt c -> HoistM abt (c, ExpressionSet abt)
isolateBinders [HakaruVar]
xs = (ExpressionSet abt -> ExpressionSet abt)
-> HoistM abt (c, ExpressionSet abt)
-> HoistM abt (c, ExpressionSet abt)
forall w (m :: * -> *) a. MonadWriter w m => (w -> w) -> m a -> m a
censor (ExpressionSet abt -> ExpressionSet abt -> ExpressionSet abt
forall a b. a -> b -> a
const ExpressionSet abt
forall a. Monoid a => a
mempty) (HoistM abt (c, ExpressionSet abt)
 -> HoistM abt (c, ExpressionSet abt))
-> (HoistM abt c -> HoistM abt (c, ExpressionSet abt))
-> HoistM abt c
-> HoistM abt (c, ExpressionSet abt)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HoistM abt c -> HoistM abt (c, ExpressionSet abt)
forall w (m :: * -> *) a. MonadWriter w m => m a -> m (a, w)
listen (HoistM abt c -> HoistM abt (c, ExpressionSet abt))
-> (HoistM abt c -> HoistM abt c)
-> HoistM abt c
-> HoistM abt (c, ExpressionSet abt)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (VarSet (KindOf a) -> VarSet (KindOf a))
-> HoistM abt c -> HoistM abt c
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ([HakaruVar] -> VarSet (KindOf a) -> VarSet (KindOf a)
insertMany [HakaruVar]
xs)

    -- @loop@ takes 2 parameters.
    --
    -- 1. The list of variables bound so far
    -- 2. The current term we are recurring over
    --
    -- We add a value to the first every time we hit a @Bind@ term, and when
    -- a @Syn@ term is finally reached, we introduce any hoisted values whose
    -- data dependencies are satisified by these new variables.
    loop :: forall ys b
         .  [HakaruVar]
         -> View (Term abt) ys b
         -> HoistM abt (abt ys b)
    loop :: [HakaruVar] -> View (Term abt) ys b -> HoistM abt (abt ys b)
loop [HakaruVar]
_  (Var Variable b
v)    = abt '[] b -> HoistM abt (abt '[] b)
forall (m :: * -> *) a. Monad m => a -> m a
return (Variable b -> abt '[] b
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var Variable b
v)

    -- This case is not needed, but we can avoid performing the expensive work
    -- of calling introduceBindings in the case were we won't be performing any
    -- work.
    loop [] (Syn Term abt b
s)    = Term abt b -> HoistM abt (abt '[] b)
forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Term abt a -> HoistM abt (abt '[] a)
hoistTerm Term abt b
s
    loop [HakaruVar]
xs (Syn Term abt b
s)    = do
      (abt '[] b
term, ExpressionSet abt
entries) <- [HakaruVar]
-> HoistM abt (abt '[] b)
-> HoistM abt (abt '[] b, ExpressionSet abt)
forall c.
[HakaruVar] -> HoistM abt c -> HoistM abt (c, ExpressionSet abt)
isolateBinders [HakaruVar]
xs (Term abt b -> HoistM abt (abt '[] b)
forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
Term abt a -> HoistM abt (abt '[] a)
hoistTerm Term abt b
s)
      [HakaruVar]
-> abt '[] b -> ExpressionSet abt -> HoistM abt (abt '[] b)
forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *).
ABT Term abt =>
[HakaruVar]
-> abt '[] a -> ExpressionSet abt -> HoistM abt (abt '[] a)
introduceBindings [HakaruVar]
xs abt '[] b
term ExpressionSet abt
entries

    loop [HakaruVar]
xs (Bind Variable a
v View (Term abt) xs b
b) = Variable a -> abt xs b -> abt (a : xs) b
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k) (xs :: [k]) (b :: k).
ABT syn abt =>
Variable a -> abt xs b -> abt (a : xs) b
bind Variable a
v (abt xs b -> abt (a : xs) b)
-> HoistM abt (abt xs b) -> HoistM abt (abt (a : xs) b)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [HakaruVar] -> View (Term abt) xs b -> HoistM abt (abt xs b)
forall (ys :: [Hakaru]) (b :: Hakaru).
[HakaruVar] -> View (Term abt) ys b -> HoistM abt (abt ys b)
loop (Variable a -> HakaruVar
forall k (kproxy :: KProxy k) (a :: k).
Variable a -> SomeVariable kproxy
SomeVariable Variable a
v HakaruVar -> [HakaruVar] -> [HakaruVar]
forall a. a -> [a] -> [a]
: [HakaruVar]
xs) View (Term abt) xs b
b

getBoundVars :: Entry x -> [HakaruVar]
getBoundVars :: Entry x -> [HakaruVar]
getBoundVars Entry{bindings :: ()
bindings=[Variable a]
b} = (Variable a -> HakaruVar) -> [Variable a] -> [HakaruVar]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Variable a -> HakaruVar
forall k (kproxy :: KProxy k) (a :: k).
Variable a -> SomeVariable kproxy
SomeVariable [Variable a]
b

wrapExpr
  :: forall abt b . (ABT Term abt)
  => abt '[] b
  -> [Entry (abt '[])]
  -> HoistM abt (abt '[] b)
wrapExpr :: abt '[] b -> [Entry (abt '[])] -> HoistM abt (abt '[] b)
wrapExpr = (Entry (abt '[]) -> abt '[] b -> HoistM abt (abt '[] b))
-> abt '[] b -> [Entry (abt '[])] -> HoistM abt (abt '[] b)
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
(a -> b -> m b) -> b -> t a -> m b
F.foldrM Entry (abt '[]) -> abt '[] b -> HoistM abt (abt '[] b)
wrap
  where
    mklet :: abt '[] a -> Variable a -> abt '[] b -> abt '[] b
    mklet :: abt '[] a -> Variable a -> abt '[] b -> abt '[] b
mklet abt '[] a
e Variable a
v abt '[] b
b =
      case abt '[] b -> View (Term abt) '[] b
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (xs :: [k]) (a :: k).
ABT syn abt =>
abt xs a -> View (syn abt) xs a
viewABT abt '[] b
b of
        Var Variable b
v' | Just TypeEq a b
Refl <- Variable a -> Variable b -> Maybe (TypeEq a b)
forall k (a :: k) (b :: k).
(Show1 Sing, JmEq1 Sing) =>
Variable a -> Variable b -> Maybe (TypeEq a b)
varEq Variable a
v Variable b
v' -> abt '[] b
abt '[] a
e
        View (Term abt) '[] b
_      -> Term abt b -> abt '[] b
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn (SCon '[LC a, '( '[a], b)] b
forall (a :: Hakaru) (b :: Hakaru). SCon '[LC a, '( '[a], b)] b
Let_ SCon '[LC a, '( '[a], b)] b
-> SArgs abt '[LC a, '( '[a], b)] -> Term abt b
forall (args :: [([Hakaru], Hakaru)]) (a :: Hakaru)
       (abt :: [Hakaru] -> Hakaru -> *).
SCon args a -> SArgs abt args -> Term abt a
:$ abt '[] a
e abt '[] a
-> SArgs abt '[ '( '[a], b)] -> SArgs abt '[LC a, '( '[a], b)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* Variable a -> abt '[] b -> abt '[a] b
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k) (xs :: [k]) (b :: k).
ABT syn abt =>
Variable a -> abt xs b -> abt (a : xs) b
bind Variable a
v abt '[] b
b abt '[a] b -> SArgs abt '[] -> SArgs abt '[ '( '[a], b)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* SArgs abt '[]
forall (abt :: [Hakaru] -> Hakaru -> *). SArgs abt '[]
End)

    -- Binds the Entry's expression to a fresh variable and rebinds any other
    -- variable uses to the fresh variable.
    wrap :: Entry (abt '[]) -> abt '[] b ->  HoistM abt (abt '[] b)
    wrap :: Entry (abt '[]) -> abt '[] b -> HoistM abt (abt '[] b)
wrap Entry{expression :: ()
expression=abt '[] a
e,bindings :: ()
bindings=[]} abt '[] b
acc = do
      Variable a
tmp <- abt '[] a -> HoistM abt (Variable a)
forall (m :: * -> *) (abt :: [Hakaru] -> Hakaru -> *)
       (a :: Hakaru).
(Functor m, Gensym m, ABT Term abt) =>
abt '[] a -> m (Variable a)
varForExpr abt '[] a
e
      abt '[] b -> HoistM abt (abt '[] b)
forall (m :: * -> *) a. Monad m => a -> m a
return (abt '[] b -> HoistM abt (abt '[] b))
-> abt '[] b -> HoistM abt (abt '[] b)
forall a b. (a -> b) -> a -> b
$ abt '[] a -> Variable a -> abt '[] b -> abt '[] b
forall (a :: Hakaru).
abt '[] a -> Variable a -> abt '[] b -> abt '[] b
mklet abt '[] a
e Variable a
tmp abt '[] b
acc
    wrap Entry{expression :: ()
expression=abt '[] a
e,bindings :: ()
bindings=(Variable a
x:[Variable a]
xs)} abt '[] b
acc = do
      let rhs :: abt '[] a
rhs  = Variable a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var Variable a
x
          body :: abt '[] b
body = (Variable a -> abt '[] b -> abt '[] b)
-> abt '[] b -> [Variable a] -> abt '[] b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (abt '[] a -> Variable a -> abt '[] b -> abt '[] b
forall (a :: Hakaru).
abt '[] a -> Variable a -> abt '[] b -> abt '[] b
mklet abt '[] a
rhs) abt '[] b
acc [Variable a]
xs
      abt '[] b -> HoistM abt (abt '[] b)
forall (m :: * -> *) a. Monad m => a -> m a
return (abt '[] b -> HoistM abt (abt '[] b))
-> abt '[] b -> HoistM abt (abt '[] b)
forall a b. (a -> b) -> a -> b
$ abt '[] a -> Variable a -> abt '[] b -> abt '[] b
forall (a :: Hakaru).
abt '[] a -> Variable a -> abt '[] b -> abt '[] b
mklet abt '[] a
e Variable a
x abt '[] b
body

-- This will introduce all binders which must be introduced by binding the
-- @newVars@ set. As a side effect, the remaining entries are written into the
-- Writer layer of the stack.
introduceBindings
  :: forall (a :: Hakaru) abt
  .  (ABT Term abt)
  => [HakaruVar]
  -> abt '[] a
  -> ExpressionSet abt
  -> HoistM abt (abt '[] a)
introduceBindings :: [HakaruVar]
-> abt '[] a -> ExpressionSet abt -> HoistM abt (abt '[] a)
introduceBindings [HakaruVar]
newVars abt '[] a
body (ExpressionSet [Entry (abt '[])]
entries) = do
  ExpressionSet abt -> HoistM abt ()
forall w (m :: * -> *). MonadWriter w m => w -> m ()
tell ([Entry (abt '[])] -> ExpressionSet abt
forall (abt :: [Hakaru] -> Hakaru -> *).
[Entry (abt '[])] -> ExpressionSet abt
ExpressionSet [Entry (abt '[])]
leftOver)
  abt '[] a -> [Entry (abt '[])] -> HoistM abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (b :: Hakaru).
ABT Term abt =>
abt '[] b -> [Entry (abt '[])] -> HoistM abt (abt '[] b)
wrapExpr abt '[] a
body ([Entry (abt '[])] -> [Entry (abt '[])]
forall k (abt :: [k] -> Hakaru -> *).
[Entry (abt '[])] -> [Entry (abt '[])]
topSortEntries [Entry (abt '[])]
resultEntries)
  where
    resultEntries, leftOver :: [Entry (abt '[])]
    ([Entry (abt '[])]
resultEntries, [Entry (abt '[])]
leftOver) = [Entry (abt '[])]
-> [HakaruVar] -> ([Entry (abt '[])], [Entry (abt '[])])
loop [Entry (abt '[])]
entries [HakaruVar]
newVars

    introducedBy
      :: forall (b :: Hakaru)
      .  Variable b
      -> Entry (abt '[])
      -> Bool
    introducedBy :: Variable b -> Entry (abt '[]) -> Bool
introducedBy Variable b
v Entry{varDependencies :: ()
varDependencies=VarSet (KindOf a)
deps} = Variable b -> VarSet (KindOf a) -> Bool
forall k (a :: k) (kproxy :: KProxy k).
(Show1 Sing, JmEq1 Sing) =>
Variable a -> VarSet kproxy -> Bool
memberVarSet Variable b
v VarSet (KindOf a)
deps

    loop
      :: [Entry (abt '[])]
      -> [HakaruVar]
      -> ([Entry (abt '[])], [Entry (abt '[])])
    loop :: [Entry (abt '[])]
-> [HakaruVar] -> ([Entry (abt '[])], [Entry (abt '[])])
loop [Entry (abt '[])]
exprs []                    = ([], [Entry (abt '[])]
exprs)
    loop [Entry (abt '[])]
exprs (SomeVariable Variable a
v : [HakaruVar]
xs) = ([Entry (abt '[])]
introduced [Entry (abt '[])] -> [Entry (abt '[])] -> [Entry (abt '[])]
forall a. [a] -> [a] -> [a]
++ [Entry (abt '[])]
intro, [Entry (abt '[])]
acc)
      where
        ~([Entry (abt '[])]
intro, [Entry (abt '[])]
acc)      = [Entry (abt '[])]
-> [HakaruVar] -> ([Entry (abt '[])], [Entry (abt '[])])
loop [Entry (abt '[])]
rest ([HakaruVar]
xs [HakaruVar] -> [HakaruVar] -> [HakaruVar]
forall a. [a] -> [a] -> [a]
++ [HakaruVar]
vars)
        vars :: [HakaruVar]
vars               = (Entry (abt '[]) -> [HakaruVar])
-> [Entry (abt '[])] -> [HakaruVar]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Entry (abt '[]) -> [HakaruVar]
forall (x :: Hakaru -> *). Entry x -> [HakaruVar]
getBoundVars [Entry (abt '[])]
introduced
        ([Entry (abt '[])]
introduced, [Entry (abt '[])]
rest) = (Entry (abt '[]) -> Bool)
-> [Entry (abt '[])] -> ([Entry (abt '[])], [Entry (abt '[])])
forall a. (a -> Bool) -> [a] -> ([a], [a])
L.partition (Variable a -> Entry (abt '[]) -> Bool
forall (b :: Hakaru). Variable b -> Entry (abt '[]) -> Bool
introducedBy Variable a
v) [Entry (abt '[])]
exprs

-- Contrary to the other binding forms, let expressions are killed by the
-- hoisting pass. Their RHSs are floated upward in the AST and re-introduced
-- where their data dependencies are fulfilled. Thus, the result of hoisting
-- a let expression is just the hoisted body.
hoistTerm
  :: forall (a :: Hakaru) (abt :: [Hakaru] -> Hakaru -> *)
  .  (ABT Term abt)
  => Term abt a
  -> HoistM abt (abt '[] a)
hoistTerm :: Term abt a -> HoistM abt (abt '[] a)
hoistTerm (SCon args a
Let_ :$ abt vars a
rhs :* abt vars a
body :* SArgs abt args
End) =
  abt '[a] a
-> (Variable a -> abt '[] a -> HoistM abt (abt '[] a))
-> HoistM abt (abt '[] a)
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (x :: k) (xs :: [k]) (a :: k) r.
ABT syn abt =>
abt (x : xs) a -> (Variable x -> abt xs a -> r) -> r
caseBind abt vars a
abt '[a] a
body ((Variable a -> abt '[] a -> HoistM abt (abt '[] a))
 -> HoistM abt (abt '[] a))
-> (Variable a -> abt '[] a -> HoistM abt (abt '[] a))
-> HoistM abt (abt '[] a)
forall a b. (a -> b) -> a -> b
$ \ Variable a
v abt '[] a
body' -> do
    abt vars a
rhs' <- abt vars a -> HoistM abt (abt vars a)
forall (abt :: [Hakaru] -> Hakaru -> *) (xs :: [Hakaru])
       (a :: Hakaru).
ABT Term abt =>
abt xs a -> HoistM abt (abt xs a)
hoist' abt vars a
rhs
    Variable a -> abt '[] a -> HoistM abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Variable a -> abt '[] a -> HoistM abt ()
recordEntry Variable a
v abt vars a
abt '[] a
rhs'
    Variable a -> HoistM abt (abt '[] a) -> HoistM abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru) b.
ABT Term abt =>
Variable a -> HoistM abt b -> HoistM abt b
bindVar Variable a
v (abt '[] a -> HoistM abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (xs :: [Hakaru])
       (a :: Hakaru).
ABT Term abt =>
abt xs a -> HoistM abt (abt xs a)
hoist' abt '[] a
body')

hoistTerm (SCon args a
Lam_ :$ abt vars a
body :* SArgs abt args
End) =
  abt '[a] a
-> (Variable a -> abt '[] a -> HoistM abt (abt '[] (a ':-> a)))
-> HoistM abt (abt '[] (a ':-> a))
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (x :: k) (xs :: [k]) (a :: k) r.
ABT syn abt =>
abt (x : xs) a -> (Variable x -> abt xs a -> r) -> r
caseBind abt vars a
abt '[a] a
body ((Variable a -> abt '[] a -> HoistM abt (abt '[] (a ':-> a)))
 -> HoistM abt (abt '[] (a ':-> a)))
-> (Variable a -> abt '[] a -> HoistM abt (abt '[] (a ':-> a)))
-> HoistM abt (abt '[] (a ':-> a))
forall a b. (a -> b) -> a -> b
$ \ Variable a
v abt '[] a
body' -> do
    VarSet (KindOf a)
available         <- (VarSet (KindOf a) -> VarSet (KindOf a))
-> HoistM abt (VarSet (KindOf a)) -> HoistM abt (VarSet (KindOf a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Variable a -> VarSet (KindOf a) -> VarSet (KindOf a)
forall k (a :: k).
Variable a -> VarSet (KindOf a) -> VarSet (KindOf a)
insertVarSet Variable a
v) HoistM abt (VarSet (KindOf a))
forall r (m :: * -> *). MonadReader r m => m r
ask
    (abt '[] a
body'', ExpressionSet abt
entries) <- Variable a
-> HoistM abt (abt '[] a)
-> HoistM abt (abt '[] a, ExpressionSet abt)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru) b.
ABT Term abt =>
Variable a -> HoistM abt b -> HoistM abt (b, ExpressionSet abt)
isolateBinder Variable a
v (abt '[] a -> HoistM abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (xs :: [Hakaru])
       (a :: Hakaru).
ABT Term abt =>
abt xs a -> HoistM abt (abt xs a)
hoist' abt '[] a
body')
    abt '[] a
finalized         <- VarSet (KindOf a)
-> abt '[] a -> ExpressionSet abt -> HoistM abt (abt '[] a)
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
VarSet (KindOf a)
-> abt '[] a -> ExpressionSet abt -> HoistM abt (abt '[] a)
introduceToplevel VarSet (KindOf a)
available abt '[] a
body'' ExpressionSet abt
entries
    abt '[] (a ':-> a) -> HoistM abt (abt '[] (a ':-> a))
forall (m :: * -> *) a. Monad m => a -> m a
return (abt '[] (a ':-> a) -> HoistM abt (abt '[] (a ':-> a)))
-> abt '[] (a ':-> a) -> HoistM abt (abt '[] (a ':-> a))
forall a b. (a -> b) -> a -> b
$ Term abt (a ':-> a) -> abt '[] (a ':-> a)
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn (SCon '[ '( '[a], a)] (a ':-> a)
forall (a :: Hakaru) (b :: Hakaru). SCon '[ '( '[a], b)] (a ':-> b)
Lam_ SCon '[ '( '[a], a)] (a ':-> a)
-> SArgs abt '[ '( '[a], a)] -> Term abt (a ':-> a)
forall (args :: [([Hakaru], Hakaru)]) (a :: Hakaru)
       (abt :: [Hakaru] -> Hakaru -> *).
SCon args a -> SArgs abt args -> Term abt a
:$ Variable a -> abt '[] a -> abt '[a] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k) (xs :: [k]) (b :: k).
ABT syn abt =>
Variable a -> abt xs b -> abt (a : xs) b
bind Variable a
v abt '[] a
finalized abt '[a] a -> SArgs abt '[] -> SArgs abt '[ '( '[a], a)]
forall (abt :: [Hakaru] -> Hakaru -> *) (vars :: [Hakaru])
       (a :: Hakaru) (args :: [([Hakaru], Hakaru)]).
abt vars a -> SArgs abt args -> SArgs abt ('(vars, a) : args)
:* SArgs abt '[]
forall (abt :: [Hakaru] -> Hakaru -> *). SArgs abt '[]
End)

hoistTerm Term abt a
term = do
  abt '[] a
result <- Term abt a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
syn abt a -> abt '[] a
syn (Term abt a -> abt '[] a)
-> HoistM abt (Term abt a) -> HoistM abt (abt '[] a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (forall (h :: [Hakaru]) (i :: Hakaru).
 abt h i -> HoistM abt (abt h i))
-> Term abt a -> HoistM abt (Term abt a)
forall k1 k2 k3 (t :: (k1 -> k2 -> *) -> k3 -> *) (f :: * -> *)
       (a :: k1 -> k2 -> *) (b :: k1 -> k2 -> *) (j :: k3).
(Traversable21 t, Applicative f) =>
(forall (h :: k1) (i :: k2). a h i -> f (b h i))
-> t a j -> f (t b j)
traverse21 forall (h :: [Hakaru]) (i :: Hakaru).
abt h i -> HoistM abt (abt h i)
forall (abt :: [Hakaru] -> Hakaru -> *) (xs :: [Hakaru])
       (a :: Hakaru).
ABT Term abt =>
abt xs a -> HoistM abt (abt xs a)
hoist' Term abt a
term
  if abt '[] a -> Bool
forall (abt :: [Hakaru] -> Hakaru -> *) (xs :: [Hakaru])
       (a :: Hakaru).
ABT Term abt =>
abt xs a -> Bool
isValue abt '[] a
result
    then abt '[] a -> HoistM abt (abt '[] a)
forall (m :: * -> *) a. Monad m => a -> m a
return abt '[] a
result
    else do Variable a
fresh <- abt '[] a -> HoistM abt (Variable a)
forall (m :: * -> *) (abt :: [Hakaru] -> Hakaru -> *)
       (a :: Hakaru).
(Functor m, Gensym m, ABT Term abt) =>
abt '[] a -> m (Variable a)
varForExpr abt '[] a
result
            Variable a -> abt '[] a -> HoistM abt ()
forall (abt :: [Hakaru] -> Hakaru -> *) (a :: Hakaru).
ABT Term abt =>
Variable a -> abt '[] a -> HoistM abt ()
recordEntry Variable a
fresh abt '[] a
result
            abt '[] a -> HoistM abt (abt '[] a)
forall (m :: * -> *) a. Monad m => a -> m a
return (Variable a -> abt '[] a
forall k (syn :: ([k] -> k -> *) -> k -> *) (abt :: [k] -> k -> *)
       (a :: k).
ABT syn abt =>
Variable a -> abt '[] a
var Variable a
fresh)