{-# LANGUAGE GADTs #-}

module Test.Credit.Heap.LazyPairing where

import Prettyprinter (Pretty)

import Control.Monad.Credit
import Test.Credit
import Test.Credit.Heap.Base

-- Okasaki does not present an amortized analysis and instead merely conjectures
-- that they have O(log n) amortized cost for insert and splitMin (Section 6.5).
-- An amortized analysis in a sequential setting is given by Nipkow and Brinkop
-- in 'Amortized Complexity Verified' (2019), Section 8.
-- Below we generalize it to the persistent setting.

data LazyPairing a s
  = Empty
  | Heap Size a (LazyPairing a s) (PThunk s (LazyPairing a s))
  -- ^ Changed from Okasaki is that we annotate the size of the thunk.
  --   Invariant: For 'Heap sm x a m', we have either:
  --   - 'a' is Empty and 'm' has (log2 sm) credits
  --   - 'a' is not Empty and 'm' has (2 * log2 sm) credits
  --   - Right before forcing, 'm' has (3 * log2 sm) credits

size :: LazyPairing a s -> Size
size :: forall a (s :: * -> *). LazyPairing a s -> Size
size LazyPairing a s
Empty = Size
0
size (Heap Size
sm a
_ LazyPairing a s
a PThunk s (LazyPairing a s)
_) = Size
1 Size -> Size -> Size
forall a. Num a => a -> a -> a
+ Size
sm Size -> Size -> Size
forall a. Num a => a -> a -> a
+ LazyPairing a s -> Size
forall a (s :: * -> *). LazyPairing a s -> Size
size LazyPairing a s
a

data PLazyCon m a where
  Em :: PLazyCon m (LazyPairing a m)
  Link :: Ord a => Size -> LazyPairing a m -> LazyPairing a m -> Thunk m (PLazyCon m) (LazyPairing a m) -> PLazyCon m (LazyPairing a m)
  -- ^ Merging 'h = Link(a, b, m)' costs one tick and performs two links, and assigns some credits to 'm'.
  --   Because 'link a b' costs 'log2 (sa + sb)' credits, we have total costs of:
  --     2 + 2*log2 (sa + sb + sm) + 2*log2 (sa + sb) + 2*log2 sm
  --     <= 6 * log2 sh (since sa + sb + sm <= sh)

instance MonadCredit m => HasStep (PLazyCon m) m where
  step :: forall a. PLazyCon m a -> m a
step PLazyCon m a
Em = a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
LazyPairing a m
forall a (s :: * -> *). LazyPairing a s
Empty
  step (Link Size
sm LazyPairing a m
a LazyPairing a m
b Thunk m (PLazyCon m) (LazyPairing a m)
m) = m ()
forall (m :: * -> *). MonadCount m => m ()
tick m () -> m a -> m a
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> do -- 1
    Thunk m (PLazyCon m) (LazyPairing a m) -> Credit -> m ()
forall (m :: * -> *) (t :: * -> *) a.
MonadCredit m =>
Thunk m t a -> Credit -> m ()
forall (t :: * -> *) a. Thunk m t a -> Credit -> m ()
creditWith Thunk m (PLazyCon m) (LazyPairing a m)
m (Size -> Credit
log2 Size
sm) -- log2 sm
    LazyPairing a m
m <- Thunk m (PLazyCon m) (LazyPairing a m) -> m (LazyPairing a m)
forall (m :: * -> *) (t :: * -> *) a.
(MonadLazy m, HasStep t m) =>
Thunk m t a -> m a
forall (t :: * -> *) a. HasStep t m => Thunk m t a -> m a
force Thunk m (PLazyCon m) (LazyPairing a m)
m -- free
    LazyPairing a m
ab <- LazyPairing a m -> LazyPairing a m -> m (LazyPairing a m)
forall (m :: * -> *) a.
(MonadCredit m, Ord a) =>
LazyPairing a m -> LazyPairing a m -> m (LazyPairing a m)
link LazyPairing a m
a LazyPairing a m
b -- log2 (sa + sb)
    LazyPairing a m -> LazyPairing a m -> m (LazyPairing a m)
forall (m :: * -> *) a.
(MonadCredit m, Ord a) =>
LazyPairing a m -> LazyPairing a m -> m (LazyPairing a m)
link LazyPairing a m
ab LazyPairing a m
m -- log2 (sa + sb + sm)

type PThunk s = Thunk s (PLazyCon s)

data NEHeap s a = NEHeap Size a (LazyPairing a s) (PThunk s (LazyPairing a s))

-- | 'mergePairs' costs up to 'log2 (sz + sa)' credits
mergePairs :: MonadCredit m => Ord a => NEHeap m a -> LazyPairing a m -> m (LazyPairing a m)
mergePairs :: forall (m :: * -> *) a.
(MonadCredit m, Ord a) =>
NEHeap m a -> LazyPairing a m -> m (LazyPairing a m)
mergePairs (NEHeap Size
sm a
x LazyPairing a m
Empty PThunk m (LazyPairing a m)
m) LazyPairing a m
a = do
  PThunk m (LazyPairing a m) -> Credit -> m ()
forall (m :: * -> *) (t :: * -> *) a.
MonadCredit m =>
Thunk m t a -> Credit -> m ()
forall (t :: * -> *) a. Thunk m t a -> Credit -> m ()
creditWith PThunk m (LazyPairing a m)
m (Size -> Credit
log2 Size
sm)
  LazyPairing a m -> m (LazyPairing a m)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LazyPairing a m -> m (LazyPairing a m))
-> LazyPairing a m -> m (LazyPairing a m)
forall a b. (a -> b) -> a -> b
$ Size
-> a
-> LazyPairing a m
-> PThunk m (LazyPairing a m)
-> LazyPairing a m
forall a (s :: * -> *).
Size
-> a
-> LazyPairing a s
-> PThunk s (LazyPairing a s)
-> LazyPairing a s
Heap Size
sm a
x LazyPairing a m
a PThunk m (LazyPairing a m)
m
mergePairs (NEHeap Size
sm a
x LazyPairing a m
b PThunk m (LazyPairing a m)
m) LazyPairing a m
a = do
  PThunk m (LazyPairing a m)
t <- PLazyCon m (LazyPairing a m) -> m (PThunk m (LazyPairing a m))
forall (m :: * -> *) (t :: * -> *) a.
MonadLazy m =>
t a -> m (Thunk m t a)
forall (t :: * -> *) a. t a -> m (Thunk m t a)
delay (PLazyCon m (LazyPairing a m) -> m (PThunk m (LazyPairing a m)))
-> PLazyCon m (LazyPairing a m) -> m (PThunk m (LazyPairing a m))
forall a b. (a -> b) -> a -> b
$ Size
-> LazyPairing a m
-> LazyPairing a m
-> PThunk m (LazyPairing a m)
-> PLazyCon m (LazyPairing a m)
forall a (m :: * -> *).
Ord a =>
Size
-> LazyPairing a m
-> LazyPairing a m
-> Thunk m (PLazyCon m) (LazyPairing a m)
-> PLazyCon m (LazyPairing a m)
Link Size
sm LazyPairing a m
a LazyPairing a m
b PThunk m (LazyPairing a m)
m
  let sz :: Size
sz = LazyPairing a m -> Size
forall a (s :: * -> *). LazyPairing a s -> Size
size LazyPairing a m
a Size -> Size -> Size
forall a. Num a => a -> a -> a
+ LazyPairing a m -> Size
forall a (s :: * -> *). LazyPairing a s -> Size
size LazyPairing a m
b Size -> Size -> Size
forall a. Num a => a -> a -> a
+ Size
sm
  PThunk m (LazyPairing a m) -> Credit -> m ()
forall (m :: * -> *) (t :: * -> *) a.
MonadCredit m =>
Thunk m t a -> Credit -> m ()
forall (t :: * -> *) a. Thunk m t a -> Credit -> m ()
creditWith PThunk m (LazyPairing a m)
t (Size -> Credit
log2 Size
sz)
  LazyPairing a m -> m (LazyPairing a m)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (LazyPairing a m -> m (LazyPairing a m))
-> LazyPairing a m -> m (LazyPairing a m)
forall a b. (a -> b) -> a -> b
$ Size
-> a
-> LazyPairing a m
-> PThunk m (LazyPairing a m)
-> LazyPairing a m
forall a (s :: * -> *).
Size
-> a
-> LazyPairing a s
-> PThunk s (LazyPairing a s)
-> LazyPairing a s
Heap Size
sz a
x LazyPairing a m
forall a (s :: * -> *). LazyPairing a s
Empty PThunk m (LazyPairing a m)
t

-- | 'link' costs up to 'log2 (sz + sa) + 1' credits
link :: MonadCredit m => Ord a => LazyPairing a m -> LazyPairing a m -> m (LazyPairing a m)
link :: forall (m :: * -> *) a.
(MonadCredit m, Ord a) =>
LazyPairing a m -> LazyPairing a m -> m (LazyPairing a m)
link LazyPairing a m
a LazyPairing a m
Empty = LazyPairing a m -> m (LazyPairing a m)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure LazyPairing a m
a
link LazyPairing a m
Empty LazyPairing a m
b = LazyPairing a m -> m (LazyPairing a m)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure LazyPairing a m
b
link a :: LazyPairing a m
a@(Heap Size
sa a
x LazyPairing a m
a1 PThunk m (LazyPairing a m)
a2) b :: LazyPairing a m
b@(Heap Size
sb a
y LazyPairing a m
b1 PThunk m (LazyPairing a m)
b2)
  | a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
<= a
y    = NEHeap m a -> LazyPairing a m -> m (LazyPairing a m)
forall (m :: * -> *) a.
(MonadCredit m, Ord a) =>
NEHeap m a -> LazyPairing a m -> m (LazyPairing a m)
mergePairs (Size
-> a -> LazyPairing a m -> PThunk m (LazyPairing a m) -> NEHeap m a
forall (s :: * -> *) a.
Size
-> a -> LazyPairing a s -> PThunk s (LazyPairing a s) -> NEHeap s a
NEHeap Size
sa a
x LazyPairing a m
a1 PThunk m (LazyPairing a m)
a2) LazyPairing a m
b
  | Bool
otherwise = NEHeap m a -> LazyPairing a m -> m (LazyPairing a m)
forall (m :: * -> *) a.
(MonadCredit m, Ord a) =>
NEHeap m a -> LazyPairing a m -> m (LazyPairing a m)
mergePairs (Size
-> a -> LazyPairing a m -> PThunk m (LazyPairing a m) -> NEHeap m a
forall (s :: * -> *) a.
Size
-> a -> LazyPairing a s -> PThunk s (LazyPairing a s) -> NEHeap s a
NEHeap Size
sb a
y LazyPairing a m
b1 PThunk m (LazyPairing a m)
b2) LazyPairing a m
a 

instance Heap LazyPairing where
  empty :: forall (m :: * -> *) a. MonadCredit m => m (LazyPairing a m)
empty = LazyPairing a m -> m (LazyPairing a m)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure LazyPairing a m
forall a (s :: * -> *). LazyPairing a s
Empty
  insert :: forall (m :: * -> *) a.
(MonadCredit m, Ord a) =>
a -> LazyPairing a m -> m (LazyPairing a m)
insert a
x LazyPairing a m
a = do
    Thunk m (PLazyCon m) (LazyPairing a m)
t <- PLazyCon m (LazyPairing a m)
-> m (Thunk m (PLazyCon m) (LazyPairing a m))
forall (m :: * -> *) (t :: * -> *) a.
MonadLazy m =>
t a -> m (Thunk m t a)
forall (t :: * -> *) a. t a -> m (Thunk m t a)
delay (PLazyCon m (LazyPairing a m)
 -> m (Thunk m (PLazyCon m) (LazyPairing a m)))
-> PLazyCon m (LazyPairing a m)
-> m (Thunk m (PLazyCon m) (LazyPairing a m))
forall a b. (a -> b) -> a -> b
$ PLazyCon m (LazyPairing a m)
forall (m :: * -> *) a. PLazyCon m (LazyPairing a m)
Em
    LazyPairing a m -> LazyPairing a m -> m (LazyPairing a m)
forall (m :: * -> *) a.
(MonadCredit m, Ord a) =>
LazyPairing a m -> LazyPairing a m -> m (LazyPairing a m)
forall (h :: * -> (* -> *) -> *) (m :: * -> *) a.
(Heap h, MonadCredit m, Ord a) =>
h a m -> h a m -> m (h a m)
merge (Size
-> a
-> LazyPairing a m
-> Thunk m (PLazyCon m) (LazyPairing a m)
-> LazyPairing a m
forall a (s :: * -> *).
Size
-> a
-> LazyPairing a s
-> PThunk s (LazyPairing a s)
-> LazyPairing a s
Heap Size
0 a
x LazyPairing a m
forall a (s :: * -> *). LazyPairing a s
Empty Thunk m (PLazyCon m) (LazyPairing a m)
t) LazyPairing a m
a
  -- | 'merge' costs '1 + log2 (sa + sb)' credits
  merge :: forall (m :: * -> *) a.
(MonadCredit m, Ord a) =>
LazyPairing a m -> LazyPairing a m -> m (LazyPairing a m)
merge LazyPairing a m
a LazyPairing a m
b = m ()
forall (m :: * -> *). MonadCount m => m ()
tick m () -> m (LazyPairing a m) -> m (LazyPairing a m)
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> LazyPairing a m -> LazyPairing a m -> m (LazyPairing a m)
forall (m :: * -> *) a.
(MonadCredit m, Ord a) =>
LazyPairing a m -> LazyPairing a m -> m (LazyPairing a m)
link LazyPairing a m
a LazyPairing a m
b
  splitMin :: forall (m :: * -> *) a.
(MonadCredit m, Ord a) =>
LazyPairing a m -> m (Maybe (a, LazyPairing a m))
splitMin LazyPairing a m
Empty = Maybe (a, LazyPairing a m) -> m (Maybe (a, LazyPairing a m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe (a, LazyPairing a m)
forall a. Maybe a
Nothing
  splitMin (Heap Size
sm a
x LazyPairing a m
a PThunk m (LazyPairing a m)
m) = do
    PThunk m (LazyPairing a m) -> Credit -> m ()
forall (m :: * -> *) (t :: * -> *) a.
MonadCredit m =>
Thunk m t a -> Credit -> m ()
forall (t :: * -> *) a. Thunk m t a -> Credit -> m ()
creditWith PThunk m (LazyPairing a m)
m (Credit
2 Credit -> Credit -> Credit
forall a. Num a => a -> a -> a
* Size -> Credit
log2 Size
sm) -- in case 'a' is Empty
    LazyPairing a m
m <- PThunk m (LazyPairing a m) -> m (LazyPairing a m)
forall (m :: * -> *) (t :: * -> *) a.
(MonadLazy m, HasStep t m) =>
Thunk m t a -> m a
forall (t :: * -> *) a. HasStep t m => Thunk m t a -> m a
force PThunk m (LazyPairing a m)
m
    LazyPairing a m
am <- LazyPairing a m -> LazyPairing a m -> m (LazyPairing a m)
forall (m :: * -> *) a.
(MonadCredit m, Ord a) =>
LazyPairing a m -> LazyPairing a m -> m (LazyPairing a m)
forall (h :: * -> (* -> *) -> *) (m :: * -> *) a.
(Heap h, MonadCredit m, Ord a) =>
h a m -> h a m -> m (h a m)
merge LazyPairing a m
a LazyPairing a m
m
    Maybe (a, LazyPairing a m) -> m (Maybe (a, LazyPairing a m))
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (a, LazyPairing a m) -> m (Maybe (a, LazyPairing a m)))
-> Maybe (a, LazyPairing a m) -> m (Maybe (a, LazyPairing a m))
forall a b. (a -> b) -> a -> b
$ (a, LazyPairing a m) -> Maybe (a, LazyPairing a m)
forall a. a -> Maybe a
Just (a
x, LazyPairing a m
am)

instance BoundedHeap LazyPairing where
  hcost :: forall a. Size -> HeapOp a -> Credit
hcost Size
n (Insert a
_) = Credit
1 Credit -> Credit -> Credit
forall a. Num a => a -> a -> a
+ Size -> Credit
log2 (Size
n Size -> Size -> Size
forall a. Num a => a -> a -> a
+ Size
1)
  hcost Size
n HeapOp a
Merge = Credit
1 Credit -> Credit -> Credit
forall a. Num a => a -> a -> a
+ Size -> Credit
log2 (Size
n Size -> Size -> Size
forall a. Num a => a -> a -> a
+ Size
1)
  hcost Size
n HeapOp a
SplitMin = Credit
1 Credit -> Credit -> Credit
forall a. Num a => a -> a -> a
+ Credit
3 Credit -> Credit -> Credit
forall a. Num a => a -> a -> a
* Size -> Credit
log2 (Size
n Size -> Size -> Size
forall a. Num a => a -> a -> a
+ Size
1)

instance (MonadMemory m, MemoryCell m a) => MemoryCell m (PLazyCon m a) where
  prettyCell :: PLazyCon m a -> m Memory
prettyCell PLazyCon m a
Em = Memory -> m Memory
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Memory -> m Memory) -> Memory -> m Memory
forall a b. (a -> b) -> a -> b
$ String -> [Memory] -> Memory
mkMCell String
"Empty" []
  prettyCell (Link Size
_ LazyPairing a m
a LazyPairing a m
b Thunk m (PLazyCon m) (LazyPairing a m)
m) = do
    Memory
a' <- LazyPairing a m -> m Memory
forall (m :: * -> *) a. MemoryCell m a => a -> m Memory
prettyCell LazyPairing a m
a
    Memory
b' <- LazyPairing a m -> m Memory
forall (m :: * -> *) a. MemoryCell m a => a -> m Memory
prettyCell LazyPairing a m
b
    Memory
m' <- Thunk m (PLazyCon m) (LazyPairing a m) -> m Memory
forall (m :: * -> *) a. MemoryCell m a => a -> m Memory
prettyCell Thunk m (PLazyCon m) (LazyPairing a m)
m
    Memory -> m Memory
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Memory -> m Memory) -> Memory -> m Memory
forall a b. (a -> b) -> a -> b
$ String -> [Memory] -> Memory
mkMCell String
"Link" [Memory
a', Memory
b', Memory
m']

instance (MonadMemory m, MemoryCell m a) => MemoryCell m (LazyPairing a m) where
  prettyCell :: LazyPairing a m -> m Memory
prettyCell LazyPairing a m
Empty = Memory -> m Memory
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Memory -> m Memory) -> Memory -> m Memory
forall a b. (a -> b) -> a -> b
$ String -> [Memory] -> Memory
mkMCell String
"Empty" []
  prettyCell (Heap Size
sz a
x LazyPairing a m
a PThunk m (LazyPairing a m)
m) = do
    Memory
sz' <- Size -> m Memory
forall (m :: * -> *) a. MemoryCell m a => a -> m Memory
prettyCell Size
sz
    Memory
x' <- a -> m Memory
forall (m :: * -> *) a. MemoryCell m a => a -> m Memory
prettyCell a
x
    Memory
a' <- LazyPairing a m -> m Memory
forall (m :: * -> *) a. MemoryCell m a => a -> m Memory
prettyCell LazyPairing a m
a
    Memory
m' <- PThunk m (LazyPairing a m) -> m Memory
forall (m :: * -> *) a. MemoryCell m a => a -> m Memory
prettyCell PThunk m (LazyPairing a m)
m
    Memory -> m Memory
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Memory -> m Memory) -> Memory -> m Memory
forall a b. (a -> b) -> a -> b
$ String -> [Memory] -> Memory
mkMCell String
"Heap" [Memory
sz', Memory
x', Memory
a', Memory
m']

instance Pretty a => MemoryStructure (LazyPairing (PrettyCell a)) where
  prettyStructure :: forall (m :: * -> *).
MonadMemory m =>
LazyPairing (PrettyCell a) m -> m Memory
prettyStructure = LazyPairing (PrettyCell a) m -> m Memory
forall (m :: * -> *) a. MemoryCell m a => a -> m Memory
prettyCell