{-# LANGUAGE CPP #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE ExplicitNamespaces #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE OverloadedRecordDot #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE RoleAnnotations #-}
{-# LANGUAGE ViewPatterns #-}
{-# OPTIONS_GHC -Wno-duplicate-exports #-}

module Data.DeBruijn.Index.Fast (
  -- * DeBruijn Indexes
  Ix (FZ, FS),
  eqIx,
  fromIx,
  fromIxRaw,
  isPos,
  thin,
  thick,
  inject,
  raise,

  -- * Existential Wrapper
  SomeIx (..),
  withSomeIx,
  toSomeIx,
  toSomeIxRaw,
  fromSomeIx,
  fromSomeIxRaw,

  -- * Fast
  IxRep,
  intToIxRep,
  ixRepToInt,
  snatRepToIxRep,
  ixRepToSNatRep,
  Ix (UnsafeIx, ixRep),
) where

import Control.DeepSeq (NFData (..))
import Data.Bifunctor (Bifunctor (..))
import Data.Kind (Type)
import Data.Type.Equality (type (:~:) (Refl))
import Data.Type.Nat (Nat (..), Pos, Pred, type (+))
import Data.Type.Nat.Singleton.Fast (SNat (..), SNatRep, decSNat)
import Text.Printf (printf)
import Unsafe.Coerce (unsafeCoerce)

#if defined(IX_AS_WORD8) || defined(SNAT_AS_WORD8)
import Control.Exception (ArithException (Overflow, Underflow), throw)
import Data.Word (Word8)
#endif

{- $setup
>>> import Data.DeBruijn.Index.Fast.Arbitrary
-}

--------------------------------------------------------------------------------
-- DeBruijn Index Representation
--------------------------------------------------------------------------------

#if defined(IX_AS_WORD8)
type IxRep = Word8
#elif defined(IX_AS_INT)
type IxRep = Int
#elif !defined(__HLINT__)
#error "cpp: define one of [IX_AS_WORD8, IX_AS_INT]"
#endif

mkFZRep :: IxRep
mkFZRep :: Int
mkFZRep = Int
0
{-# INLINE mkFZRep #-}

mkFSRep :: IxRep -> IxRep
mkFSRep :: Int -> Int
mkFSRep = (Int
1 +)
{-# INLINE mkFSRep #-}

unFSRep :: IxRep -> IxRep
unFSRep :: Int -> Int
unFSRep = Int -> Int -> Int
forall a. Num a => a -> a -> a
subtract Int
1
{-# INLINE unFSRep #-}

elIxRep :: a -> (IxRep -> a) -> IxRep -> a
elIxRep :: forall a. a -> (Int -> a) -> Int -> a
elIxRep a
ifZ Int -> a
ifS Int
i =
  if Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
mkFZRep then a
ifZ else Int -> a
ifS (Int -> Int
unFSRep Int
i)
{-# INLINE elIxRep #-}

thinRep :: IxRep -> IxRep -> IxRep
thinRep :: Int -> Int -> Int
thinRep Int
i Int
j
  | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
j = Int -> Int
mkFSRep Int
j
  | Bool
otherwise = Int
j

thickRep :: IxRep -> IxRep -> Maybe IxRep
thickRep :: Int -> Int -> Maybe Int
thickRep Int
i Int
j = case Int
i Int -> Int -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` Int
j of
  Ordering
LT -> Int -> Maybe Int
forall a. a -> Maybe a
Just (Int -> Int
unFSRep Int
j)
  Ordering
EQ -> Maybe Int
forall a. Maybe a
Nothing
  Ordering
GT -> Int -> Maybe Int
forall a. a -> Maybe a
Just Int
j

--------------------------------------------------------------------------------
-- DeBruijn Indexes
--------------------------------------------------------------------------------

-- | @'Ix' n@ is the type of DeBruijn indices less than @n@.
type Ix :: Nat -> Type
newtype Ix n = UnsafeIx {forall (n :: Nat). Ix n -> Int
ixRep :: IxRep}

type role Ix nominal

eqIx :: Ix n -> Ix m -> Bool
eqIx :: forall (n :: Nat) (m :: Nat). Ix n -> Ix m -> Bool
eqIx Ix n
i Ix m
j = Ix n -> Int
forall (n :: Nat). Ix n -> Int
fromIxRaw Ix n
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Ix m -> Int
forall (n :: Nat). Ix n -> Int
fromIxRaw Ix m
j

instance Eq (Ix n) where
  (==) :: Ix n -> Ix n -> Bool
  == :: Ix n -> Ix n -> Bool
(==) = Ix n -> Ix n -> Bool
forall (n :: Nat) (m :: Nat). Ix n -> Ix m -> Bool
eqIx

instance Show (Ix n) where
  showsPrec :: Int -> Ix n -> ShowS
  showsPrec :: Int -> Ix n -> ShowS
showsPrec Int
p =
    Bool -> ShowS -> ShowS
showParen (Int
p Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
10) (ShowS -> ShowS) -> (Ix n -> ShowS) -> Ix n -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
      Ix n
FZ -> [Char] -> ShowS
showString [Char]
"FZ"
      FS Ix (Pred n)
n -> [Char] -> ShowS
showString [Char]
"FS " ShowS -> ShowS -> ShowS
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Ix (Pred n) -> ShowS
forall a. Show a => Int -> a -> ShowS
showsPrec Int
11 Ix (Pred n)
n

instance NFData (Ix n) where
  rnf :: Ix n -> ()
  rnf :: Ix n -> ()
rnf (UnsafeIx Int
u) = Int -> ()
forall a. NFData a => a -> ()
rnf Int
u

mkFZ :: Ix (S n)
mkFZ :: forall (n :: Nat). Ix (S n)
mkFZ = Int -> Ix (S n)
forall (n :: Nat). Int -> Ix n
UnsafeIx Int
mkFZRep
{-# INLINE mkFZ #-}

mkFS :: Ix n -> Ix (S n)
mkFS :: forall (n :: Nat). Ix n -> Ix (S n)
mkFS = Int -> Ix (S n)
forall (n :: Nat). Int -> Ix n
UnsafeIx (Int -> Ix (S n)) -> (Ix n -> Int) -> Ix n -> Ix (S n)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int
mkFSRep (Int -> Int) -> (Ix n -> Int) -> Ix n -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (.ixRep)
{-# INLINE mkFS #-}

elIx :: a -> (Ix (Pred n) -> a) -> Ix n -> a
elIx :: forall a (n :: Nat). a -> (Ix (Pred n) -> a) -> Ix n -> a
elIx a
ifFZ Ix (Pred n) -> a
ifFS = a -> (Int -> a) -> Int -> a
forall a. a -> (Int -> a) -> Int -> a
elIxRep a
ifFZ (Ix (Pred n) -> a
ifFS (Ix (Pred n) -> a) -> (Int -> Ix (Pred n)) -> Int -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Ix (Pred n)
forall (n :: Nat). Int -> Ix n
UnsafeIx) (Int -> a) -> (Ix n -> Int) -> Ix n -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (.ixRep)
{-# INLINE elIx #-}

-- | @'fromSNat' n@ returns the numeric representation of 'SNat n'.
fromIx :: (Integral i) => Ix n -> i
fromIx :: forall i (n :: Nat). Integral i => Ix n -> i
fromIx = Integer -> i
forall a. Num a => Integer -> a
fromInteger (Integer -> i) -> (Ix n -> Integer) -> Ix n -> i
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Integer
forall a. Integral a => a -> Integer
toInteger (Int -> Integer) -> (Ix n -> Int) -> Ix n -> Integer
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (.ixRep)
{-# INLINE fromIx #-}

-- | @'fromIxRaw' n@ returns the raw numeric representation of 'SNat n'.
fromIxRaw :: Ix n -> IxRep
fromIxRaw :: forall (n :: Nat). Ix n -> Int
fromIxRaw = (.ixRep)
{-# INLINE fromIxRaw #-}

-- | @'IxF'@ is the base functor of @'Ix'@.
data IxF (ix :: Nat -> Type) (n :: Nat) :: Type where
  FZF :: IxF ix (S m)
  FSF :: !(ix m) -> IxF ix (S m)

projectIx :: Ix n -> IxF Ix n
projectIx :: forall (n :: Nat). Ix n -> IxF Ix n
projectIx = IxF Ix n -> (Ix (Pred n) -> IxF Ix n) -> Ix n -> IxF Ix n
forall a (n :: Nat). a -> (Ix (Pred n) -> a) -> Ix n -> a
elIx (IxF Any (S Any) -> IxF Ix n
forall a b. a -> b
unsafeCoerce IxF Any (S Any)
forall (ix :: Nat -> *) (m :: Nat). IxF ix (S m)
FZF) (IxF Ix (S (Pred n)) -> IxF Ix n
forall a b. a -> b
unsafeCoerce (IxF Ix (S (Pred n)) -> IxF Ix n)
-> (Ix (Pred n) -> IxF Ix (S (Pred n))) -> Ix (Pred n) -> IxF Ix n
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ix (Pred n) -> IxF Ix (S (Pred n))
forall (ix :: Nat -> *) (m :: Nat). ix m -> IxF ix (S m)
FSF)
{-# INLINE projectIx #-}

embedIx :: IxF Ix n -> Ix n
embedIx :: forall (n :: Nat). IxF Ix n -> Ix n
embedIx = \case
  IxF Ix n
FZF -> Ix n
Ix (S m)
forall (n :: Nat). Ix (S n)
mkFZ
  FSF Ix m
i -> Ix m -> Ix (S m)
forall (n :: Nat). Ix n -> Ix (S n)
mkFS Ix m
i
{-# INLINE embedIx #-}

-- NOTE:
--   Type signatures for pattern synonyms are weird, see:
--   https://ghc.gitlab.haskell.org/ghc/doc/users_guide/exts/pattern_synonyms.html#typing-of-pattern-synonyms

pattern FZ :: () => (Pos n) => Ix n
pattern $mFZ :: forall {r} {n :: Nat}. Ix n -> (Pos n => r) -> ((# #) -> r) -> r
$bFZ :: forall (n :: Nat). Pos n => Ix n
FZ <- (projectIx -> FZF) where FZ = IxF Ix n -> Ix n
forall (n :: Nat). IxF Ix n -> Ix n
embedIx IxF Ix n
IxF Ix (S (Pred n))
forall (ix :: Nat -> *) (m :: Nat). IxF ix (S m)
FZF
{-# INLINE FZ #-}

pattern FS :: () => (Pos n) => Ix (Pred n) -> Ix n
pattern $mFS :: forall {r} {n :: Nat}.
Ix n -> (Pos n => Ix (Pred n) -> r) -> ((# #) -> r) -> r
$bFS :: forall (n :: Nat). Pos n => Ix (Pred n) -> Ix n
FS i <- (projectIx -> FSF i) where FS Ix (Pred n)
i = IxF Ix n -> Ix n
forall (n :: Nat). IxF Ix n -> Ix n
embedIx (Ix (Pred n) -> IxF Ix (S (Pred n))
forall (ix :: Nat -> *) (m :: Nat). ix m -> IxF ix (S m)
FSF Ix (Pred n)
i)
{-# INLINE FS #-}

{-# COMPLETE FZ, FS #-}

-- | If any value of type @'Ix' n@ exists, @n@ must have a predecessor.
isPos :: Ix n -> ((Pos n) => a) -> a
isPos :: forall (n :: Nat) a. Ix n -> (Pos n => a) -> a
isPos Ix n
FZ Pos n => a
r = a
Pos n => a
r
isPos (FS Ix (Pred n)
_) Pos n => a
r = a
Pos n => a
r

-- | Thinning.
thin :: Ix (S n) -> Ix n -> Ix (S n)
thin :: forall (n :: Nat). Ix (S n) -> Ix n -> Ix (S n)
thin Ix (S n)
i Ix n
j = Int -> Ix (S n)
forall (n :: Nat). Int -> Ix n
UnsafeIx (Int -> Int -> Int
thinRep Ix (S n)
i.ixRep Ix n
j.ixRep)

-- | Thickening.
thick :: Ix (S n) -> Ix (S n) -> Maybe (Ix n)
thick :: forall (n :: Nat). Ix (S n) -> Ix (S n) -> Maybe (Ix n)
thick Ix (S n)
i Ix (S n)
j = Int -> Ix n
forall (n :: Nat). Int -> Ix n
UnsafeIx (Int -> Ix n) -> Maybe Int -> Maybe (Ix n)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> Int -> Maybe Int
thickRep Ix (S n)
i.ixRep Ix (S n)
j.ixRep

-- | Inject.
inject :: Ix n -> SNat m -> Ix (n + m)
inject :: forall (n :: Nat) (m :: Nat). Ix n -> SNat m -> Ix (n + m)
inject Ix n
i SNat m
_m = Int -> Ix (n + m)
forall (n :: Nat). Int -> Ix n
UnsafeIx Ix n
i.ixRep

-- | Raise.
raise :: SNat n -> Ix m -> Ix (n + m)
raise :: forall (n :: Nat) (m :: Nat). SNat n -> Ix m -> Ix (n + m)
raise SNat n
n Ix m
j = Int -> Ix (n + m)
forall (n :: Nat). Int -> Ix n
UnsafeIx (Int -> Int
snatRepToIxRep SNat n
n.snatRep Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Ix m
j.ixRep)

-- | Convert an 'IxRep' to an 'Int'.
intToIxRep :: Int -> IxRep
#ifdef IX_AS_WORD8
-- TODO: Make this safe.
intToIxRep int
  | int < 0 = throw Underflow
  | int > fromIntegral (maxBound @Word8) = throw Overflow
  | otherwise = fromIntegral @Int @Word8 int
{-# INLINE intToIxRep #-}
#else
intToIxRep :: Int -> Int
intToIxRep = forall a. a -> a
id @Int
{-# INLINE intToIxRep #-}
#endif

-- | Convert an 'IxRep' to an 'Int'.
ixRepToInt :: IxRep -> Int
#ifdef IX_AS_WORD8
ixRepToInt = fromIntegral @Word8 @Int
{-# INLINE ixRepToInt #-}
#else
ixRepToInt :: Int -> Int
ixRepToInt = forall a. a -> a
id @Int
{-# INLINE ixRepToInt #-}
#endif

-- | Convert an 'SNatRep' to an 'IxRep'.
snatRepToIxRep :: SNatRep -> IxRep
#ifdef SNAT_AS_WORD8
#ifdef IX_AS_WORD8
snatRepToIxRep = id @Word8
{-# INLINE snatRepToIxRep #-}
#else
snatRepToIxRep = fromIntegral @Word8 @Int
{-# INLINE snatRepToIxRep #-}
#endif
#else
#ifdef IX_AS_WORD8
-- Int -> Word8
snatRepToIxRep snatRep
  | snatRep < 0 = throw Underflow
  | snatRep > fromIntegral (maxBound @Word8) = throw Overflow
  | otherwise = fromIntegral snatRep
#else
snatRepToIxRep :: Int -> Int
snatRepToIxRep = forall a. a -> a
id @Int
{-# INLINE snatRepToIxRep #-}
#endif
#endif

-- | Convert an 'IxRep' to an 'SNatRep'.
ixRepToSNatRep :: IxRep -> SNatRep
#ifdef SNAT_AS_WORD8
#ifdef IX_AS_WORD8
ixRepToSNatRep = id @Word8
{-# INLINE ixRepToSNatRep #-}
#else
ixRepToSNatRep ixRep
  | ixRep < 0 = throw Underflow
  | ixRep > fromIntegral (maxBound @Word8) = throw Overflow
  | otherwise = fromIntegral ixRep
{-# INLINE ixRepToSNatRep #-}
#endif
#else
#ifdef IX_AS_WORD8
ixRepToSNatRep = fromIntegral @Int @Word8
#else
ixRepToSNatRep :: Int -> Int
ixRepToSNatRep = forall a. a -> a
id @Int
{-# INLINE ixRepToSNatRep #-}
#endif
#endif

--------------------------------------------------------------------------------
-- Existential Wrapper
--------------------------------------------------------------------------------

-- | An existential wrapper around indexes.
type SomeIx :: Type
data SomeIx = forall (n :: Nat). SomeIx
  { ()
bound :: {-# UNPACK #-} !(SNat n)
  , ()
index :: {-# UNPACK #-} !(Ix n)
  }

instance NFData SomeIx where
  rnf :: SomeIx -> ()
  rnf :: SomeIx -> ()
rnf (SomeIx SNat n
n Ix n
i) = SNat n -> ()
forall a. NFData a => a -> ()
rnf SNat n
n () -> () -> ()
forall a b. a -> b -> b
`seq` Ix n -> ()
forall a. NFData a => a -> ()
rnf Ix n
i

instance Eq SomeIx where
  (==) :: SomeIx -> SomeIx -> Bool
  SomeIx SNat n
n Ix n
i == :: SomeIx -> SomeIx -> Bool
== SomeIx SNat n
m Ix n
j
    | Just n :~: n
Refl <- SNat n -> SNat n -> Maybe (n :~: n)
forall (n :: Nat) (m :: Nat). SNat n -> SNat m -> Maybe (n :~: m)
decSNat SNat n
n SNat n
m = Ix n -> Ix n -> Bool
forall (n :: Nat) (m :: Nat). Ix n -> Ix m -> Bool
eqIx Ix n
i Ix n
j
    | Bool
otherwise = Bool
False

deriving instance Show SomeIx

withSomeIx :: (forall n. SNat n -> Ix n -> a) -> SomeIx -> a
withSomeIx :: forall a. (forall (n :: Nat). SNat n -> Ix n -> a) -> SomeIx -> a
withSomeIx forall (n :: Nat). SNat n -> Ix n -> a
action (SomeIx SNat n
n Ix n
i) = SNat n -> Ix n -> a
forall (n :: Nat). SNat n -> Ix n -> a
action SNat n
n Ix n
i

{-| @'toSomeIx' n@ constructs the index @n@ at type @'Ix' n@ from the number @n@.

prop> toSomeIx (fromSomeIx i) == i
-}
toSomeIx :: (Integral n, Integral i) => (n, i) -> SomeIx
toSomeIx :: forall n i. (Integral n, Integral i) => (n, i) -> SomeIx
toSomeIx = (Int, Int) -> SomeIx
toSomeIxRaw ((Int, Int) -> SomeIx)
-> ((n, i) -> (Int, Int)) -> (n, i) -> SomeIx
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (n -> Int) -> (i -> Int) -> (n, i) -> (Int, Int)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap n -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral i -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral

{-| @'toSomeIxRaw' n@ constructs the index @n@ at type @'Ix' n@ from the 'Int' @n@.

prop> toSomeIxRaw (fromSomeIxRaw i) == i
-}
toSomeIxRaw :: (SNatRep, IxRep) -> SomeIx
toSomeIxRaw :: (Int, Int) -> SomeIx
toSomeIxRaw (Int
n, Int
i)
  | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = [Char] -> SomeIx
forall a. HasCallStack => [Char] -> a
error ([Char] -> SomeIx) -> [Char] -> SomeIx
forall a b. (a -> b) -> a -> b
$ [Char] -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf [Char]
"index cannot contain negative value, found index %d" Int
i
  | Int -> Int
snatRepToIxRep Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
i = [Char] -> SomeIx
forall a. HasCallStack => [Char] -> a
error ([Char] -> SomeIx) -> [Char] -> SomeIx
forall a b. (a -> b) -> a -> b
$ [Char] -> Int -> Int -> [Char]
forall r. PrintfType r => [Char] -> r
printf [Char]
"bound must be larger than index, found bound %d and index %d" Int
n Int
i
  | Bool
otherwise = SNat Any -> Ix Any -> SomeIx
forall (n :: Nat). SNat n -> Ix n -> SomeIx
SomeIx (Int -> SNat Any
forall (n :: Nat). Int -> SNat n
UnsafeSNat Int
n) (Int -> Ix Any
forall (n :: Nat). Int -> Ix n
UnsafeIx Int
i)

-- | @'fromSomeSNat' n@ returns the numeric representation of the wrapped index.
fromSomeIx :: (Integral n, Integral i) => SomeIx -> (n, i)
fromSomeIx :: forall n i. (Integral n, Integral i) => SomeIx -> (n, i)
fromSomeIx = (Int -> n) -> (Int -> i) -> (Int, Int) -> (n, i)
forall a b c d. (a -> b) -> (c -> d) -> (a, c) -> (b, d)
forall (p :: * -> * -> *) a b c d.
Bifunctor p =>
(a -> b) -> (c -> d) -> p a c -> p b d
bimap Int -> n
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int -> i
forall a b. (Integral a, Num b) => a -> b
fromIntegral ((Int, Int) -> (n, i))
-> (SomeIx -> (Int, Int)) -> SomeIx -> (n, i)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SomeIx -> (Int, Int)
fromSomeIxRaw

-- | @'fromSomeSNat' n@ returns the 'Int' representation of the wrapped index.
fromSomeIxRaw :: SomeIx -> (SNatRep, IxRep)
fromSomeIxRaw :: SomeIx -> (Int, Int)
fromSomeIxRaw (SomeIx (UnsafeSNat Int
bound) (UnsafeIx Int
index)) = (Int
bound, Int
index)