{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE FunctionalDependencies #-}
{-# LANGUAGE InstanceSigs #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE Trustworthy #-}
{-# LANGUAGE ViewPatterns #-}
module Data.Parameterized.Context
 (
#ifdef UNSAFE_OPS
    module Data.Parameterized.Context.Unsafe
#else
    module Data.Parameterized.Context.Safe
#endif
  , singleton
  , toVector
  , pattern (:>)
  , pattern Empty
  , decompose
  , Data.Parameterized.Context.null
  , Data.Parameterized.Context.init
  , Data.Parameterized.Context.last
  , Data.Parameterized.Context.view
  , Data.Parameterized.Context.take
  , forIndexM
  , generateSome
  , generateSomeM
  , fromList
  , traverseAndCollect
  , traverseWithIndex_
  , dropPrefix
    
  , CtxEmbedding(..)
  , ExtendContext(..)
  , ExtendContext'(..)
  , ApplyEmbedding(..)
  , ApplyEmbedding'(..)
  , identityEmbedding
  , extendEmbeddingRightDiff
  , extendEmbeddingRight
  , extendEmbeddingBoth
  , appendEmbedding
  , ctxeSize
  , ctxeAssignment
    
  , Idx
  , field
  , natIndex
  , natIndexProxy
    
  , CurryAssignment
  , CurryAssignmentClass(..)
    
  , size1, size2, size3, size4, size5, size6
  , i1of2, i2of2
  , i1of3, i2of3, i3of3
  , i1of4, i2of4, i3of4, i4of4
  , i1of5, i2of5, i3of5, i4of5, i5of5
  , i1of6, i2of6, i3of6, i4of6, i5of6, i6of6
  ) where
import           Control.Applicative (liftA2)
import           Control.Lens hiding (Index, (:>), Empty)
import           Data.Functor (void)
import qualified Data.Vector as V
import qualified Data.Vector.Mutable as MV
import           GHC.TypeLits (Nat, type (-))
import           Data.Monoid ((<>))
import           Data.Parameterized.Classes
import           Data.Parameterized.Some
import           Data.Parameterized.TraversableFC
#ifdef UNSAFE_OPS
import           Data.Parameterized.Context.Unsafe
#else
import           Data.Parameterized.Context.Safe
#endif
singleton :: f tp -> Assignment f (EmptyCtx ::> tp)
singleton = (empty :>)
forIndexM :: forall ctx m
           . Applicative m
          => Size ctx
          -> (forall tp . Index ctx tp -> m ())
          -> m ()
forIndexM sz f = forIndexRange 0 sz (\i r -> f i *> r) (pure ())
generateSome :: forall f
              . Int
             -> (Int -> Some f)
             -> Some (Assignment f)
generateSome n f = go n
  where go :: Int -> Some (Assignment f)
        go 0 = Some empty
        go i = (\(Some a) (Some e) -> Some (a `extend` e)) (go (i-1)) (f (i-1))
generateSomeM :: forall m f
              .  Applicative m
              => Int
              -> (Int -> m (Some f))
              -> m (Some (Assignment f))
generateSomeM n f = go n
  where go :: Int -> m (Some (Assignment f))
        go 0 = pure (Some empty)
        go i = (\(Some a) (Some e) -> Some (a `extend` e)) <$> go (i-1) <*> f (i-1)
toVector :: Assignment f tps -> (forall tp . f tp -> e) -> V.Vector e
toVector a f = V.create $ do
  vm <- MV.new (sizeInt (size a))
  forIndexM (size a) $ \i -> do
    MV.write vm (indexVal i) (f (a ! i))
  return vm
{-# INLINABLE toVector #-}
dropPrefix :: forall f xs prefix a.
  TestEquality f =>
  Assignment f xs      ->
  Assignment f prefix  ->
  a  ->
  (forall addl. (xs ~ (prefix <+> addl)) => Assignment f addl -> a)
     ->
  a
dropPrefix xs0 prefix err = go xs0 (sizeInt (size xs0))
  where
  sz_prefix = sizeInt (size prefix)
  go :: forall ys.
   Assignment f ys ->
   Int ->
   (forall addl. (ys ~ (prefix <+> addl)) => Assignment f addl -> a) ->
   a
  go (xs' :> z) sz_x success | sz_x > sz_prefix =
    go xs' (sz_x-1) (\zs -> success (zs :> z))
  go xs _ success =
    case testEquality xs prefix of
      Just Refl -> success Empty
      Nothing   -> err
pattern Empty :: () => ctx ~ EmptyCtx => Assignment f ctx
pattern Empty <- (viewAssign -> AssignEmpty)
  where Empty = empty
infixl :>
pattern (:>) :: () => ctx' ~ (ctx ::> tp) => Assignment f ctx -> f tp -> Assignment f ctx'
pattern (:>) a v <- (viewAssign -> AssignExtend a v)
  where a :> v = extend a v
#if MIN_VERSION_base(4,10,0)
{-# COMPLETE (:>), Empty :: Assignment  #-}
#endif
null :: Assignment f ctx -> Bool
null a =
  case viewAssign a of
    AssignEmpty -> True
    AssignExtend{} -> False
decompose :: Assignment f (ctx ::> tp) -> (Assignment f ctx, f tp)
decompose x = (Data.Parameterized.Context.init x, Data.Parameterized.Context.last x)
init :: Assignment f (ctx '::> tp) -> Assignment f ctx
init x =
  case viewAssign x of
    AssignExtend t _ -> t
last :: Assignment f (ctx '::> tp) -> f tp
last x =
  case viewAssign x of
    AssignExtend _ e -> e
{-# DEPRECATED view "Use viewAssign or the Empty and :> patterns instead." #-}
view :: forall f ctx . Assignment f ctx -> AssignView f ctx
view = viewAssign
take :: forall f ctx ctx'. Size ctx -> Size ctx' -> Assignment f (ctx <+> ctx') -> Assignment f ctx
take sz sz' asgn =
  let diff = appendDiff sz' in
  generate sz (\i -> asgn ! extendIndex' diff i)
data CtxEmbedding (ctx :: Ctx k) (ctx' :: Ctx k)
  = CtxEmbedding { _ctxeSize       :: Size ctx'
                 , _ctxeAssignment :: Assignment (Index ctx') ctx
                 }
ctxeSize :: Simple Lens (CtxEmbedding ctx ctx') (Size ctx')
ctxeSize = lens _ctxeSize (\s v -> s { _ctxeSize = v })
ctxeAssignment :: Lens (CtxEmbedding ctx1 ctx') (CtxEmbedding ctx2 ctx')
                       (Assignment (Index ctx') ctx1) (Assignment (Index ctx') ctx2)
ctxeAssignment = lens _ctxeAssignment (\s v -> s { _ctxeAssignment = v })
class ApplyEmbedding (f :: Ctx k -> *) where
  applyEmbedding :: CtxEmbedding ctx ctx' -> f ctx -> f ctx'
class ApplyEmbedding' (f :: Ctx k -> k' -> *) where
  applyEmbedding' :: CtxEmbedding ctx ctx' -> f ctx v -> f ctx' v
class ExtendContext (f :: Ctx k -> *) where
  extendContext :: Diff ctx ctx' -> f ctx -> f ctx'
class ExtendContext' (f :: Ctx k -> k' -> *) where
  extendContext' :: Diff ctx ctx' -> f ctx v -> f ctx' v
instance ApplyEmbedding' Index where
  applyEmbedding' ctxe idx = (ctxe ^. ctxeAssignment) ! idx
instance ExtendContext' Index where
  extendContext' = extendIndex'
identityEmbedding :: Size ctx -> CtxEmbedding ctx ctx
identityEmbedding sz = CtxEmbedding sz (generate sz id)
extendEmbeddingRightDiff :: forall ctx ctx' ctx''.
                            Diff ctx' ctx''
                            -> CtxEmbedding ctx ctx'
                            -> CtxEmbedding ctx ctx''
extendEmbeddingRightDiff diff (CtxEmbedding sz' assgn) = CtxEmbedding (extSize sz' diff) updated
  where
    updated :: Assignment (Index ctx'') ctx
    updated = fmapFC (extendIndex' diff) assgn
extendEmbeddingRight :: CtxEmbedding ctx ctx' -> CtxEmbedding ctx (ctx' ::> tp)
extendEmbeddingRight = extendEmbeddingRightDiff knownDiff
appendEmbedding :: Size ctx -> Size ctx' -> CtxEmbedding ctx (ctx <+> ctx')
appendEmbedding sz sz' = CtxEmbedding (addSize sz sz') (generate sz (extendIndex' diff))
  where
  diff = appendDiff sz'
extendEmbeddingBoth :: forall ctx ctx' tp. CtxEmbedding ctx ctx' -> CtxEmbedding (ctx ::> tp) (ctx' ::> tp)
extendEmbeddingBoth ctxe = updated & ctxeAssignment %~ flip extend (nextIndex (ctxe ^. ctxeSize))
  where
    updated :: CtxEmbedding ctx (ctx' ::> tp)
    updated = extendEmbeddingRight ctxe
field :: forall n ctx f r. Idx n ctx r => Lens' (Assignment f ctx) (f r)
field = ixF' (natIndex @n)
type Idx n ctx r = (ValidIx n ctx, Idx' (FromLeft ctx n) ctx r)
natIndex :: forall n ctx r. Idx n ctx r => Index ctx r
natIndex = natIndex' @_ @(FromLeft ctx n)
natIndexProxy :: forall n ctx r proxy. Idx n ctx r => proxy n -> Index ctx r
natIndexProxy _ = natIndex @n
class KnownContext ctx => Idx' (n :: Nat) (ctx :: Ctx k) (r :: k) | n ctx -> r where
  natIndex' :: Index ctx r
instance KnownContext xs => Idx' 0 (xs '::> x) x where
  natIndex' = lastIndex knownSize
instance {-# Overlaps #-} (KnownContext xs, Idx' (n-1) xs r) =>
  Idx' n (xs '::> x) r where
  natIndex' = skipIndex (natIndex' @_ @(n-1))
type family CurryAssignment (ctx :: Ctx k) (f :: k -> *) (x :: *) :: * where
   CurryAssignment EmptyCtx    f x = x
   CurryAssignment (ctx ::> a) f x = CurryAssignment ctx f (f a -> x)
class CurryAssignmentClass (ctx :: Ctx k) where
  
  
  curryAssignment   :: (Assignment f ctx -> x) -> CurryAssignment ctx f x
  
  uncurryAssignment :: CurryAssignment ctx f x -> (Assignment f ctx -> x)
instance CurryAssignmentClass EmptyCtx where
  curryAssignment k = k empty
  uncurryAssignment k _ = k
instance CurryAssignmentClass ctx => CurryAssignmentClass (ctx ::> a) where
  curryAssignment k = curryAssignment (\asgn a -> k (asgn :> a))
  uncurryAssignment k asgn =
    case viewAssign asgn of
      AssignExtend asgn' x -> uncurryAssignment k asgn' x
fromList :: [Some f] -> Some (Assignment f)
fromList = go empty
  where go :: Assignment f ctx -> [Some f] -> Some (Assignment f)
        go prev [] = Some prev
        go prev (Some g:next) = (go $! prev `extend` g) next
newtype Collector m w a = Collector { runCollector :: m w }
instance Functor (Collector m w) where
  fmap _ (Collector x) = Collector x
instance (Applicative m, Monoid w) => Applicative (Collector m w) where
  pure _ = Collector (pure mempty)
  Collector x <*> Collector y = Collector (liftA2 (<>) x y)
traverseAndCollect ::
  (Monoid w, Applicative m) =>
  (forall tp. Index ctx tp -> f tp -> m w) ->
  Assignment f ctx ->
  m w
traverseAndCollect f =
  runCollector . traverseWithIndex (\i x -> Collector (f i x))
traverseWithIndex_ :: Applicative m
                   => (forall tp . Index ctx tp -> f tp -> m ())
                   -> Assignment f ctx
                   -> m ()
traverseWithIndex_ f = void . traverseAndCollect f
size1 :: Size (EmptyCtx ::> a)
size1 = incSize zeroSize
size2 :: Size (EmptyCtx ::> a ::> b)
size2 = incSize size1
size3 :: Size (EmptyCtx ::> a ::> b ::> c)
size3 = incSize size2
size4 :: Size (EmptyCtx ::> a ::> b ::> c ::> d)
size4 = incSize size3
size5 :: Size (EmptyCtx ::> a ::> b ::> c ::> d ::> e)
size5 = incSize size4
size6 :: Size (EmptyCtx ::> a ::> b ::> c ::> d ::> e ::> f)
size6 = incSize size5
i1of2 :: Index (EmptyCtx ::> a ::> b) a
i1of2 = skipIndex baseIndex
i2of2 :: Index (EmptyCtx ::> a ::> b) b
i2of2 = nextIndex size1
i1of3 :: Index (EmptyCtx ::> a ::> b ::> c) a
i1of3 = skipIndex i1of2
i2of3 :: Index (EmptyCtx ::> a ::> b ::> c) b
i2of3 = skipIndex i2of2
i3of3 :: Index (EmptyCtx ::> a ::> b ::> c) c
i3of3 = nextIndex size2
i1of4 :: Index (EmptyCtx ::> a ::> b ::> c ::> d) a
i1of4 = skipIndex i1of3
i2of4 :: Index (EmptyCtx ::> a ::> b ::> c ::> d) b
i2of4 = skipIndex i2of3
i3of4 :: Index (EmptyCtx ::> a ::> b ::> c ::> d) c
i3of4 = skipIndex i3of3
i4of4 :: Index (EmptyCtx ::> a ::> b ::> c ::> d) d
i4of4 = nextIndex size3
i1of5 :: Index (EmptyCtx ::> a ::> b ::> c ::> d ::> e) a
i1of5 = skipIndex i1of4
i2of5 :: Index (EmptyCtx ::> a ::> b ::> c ::> d ::> e) b
i2of5 = skipIndex i2of4
i3of5 :: Index (EmptyCtx ::> a ::> b ::> c ::> d ::> e) c
i3of5 = skipIndex i3of4
i4of5 :: Index (EmptyCtx ::> a ::> b ::> c ::> d ::> e) d
i4of5 = skipIndex i4of4
i5of5 :: Index (EmptyCtx ::> a ::> b ::> c ::> d ::> e) e
i5of5 = nextIndex size4
i1of6 :: Index (EmptyCtx ::> a ::> b ::> c ::> d ::> e ::> f) a
i1of6 = skipIndex i1of5
i2of6 :: Index (EmptyCtx ::> a ::> b ::> c ::> d ::> e ::> f) b
i2of6 = skipIndex i2of5
i3of6 :: Index (EmptyCtx ::> a ::> b ::> c ::> d ::> e ::> f) c
i3of6 = skipIndex i3of5
i4of6 :: Index (EmptyCtx ::> a ::> b ::> c ::> d ::> e ::> f) d
i4of6 = skipIndex i4of5
i5of6 :: Index (EmptyCtx ::> a ::> b ::> c ::> d ::> e ::> f) e
i5of6 = skipIndex i5of5
i6of6 :: Index (EmptyCtx ::> a ::> b ::> c ::> d ::> e ::> f) f
i6of6 = nextIndex size5