{-# LANGUAGE BangPatterns
           , DeriveLift
           , RankNTypes
           , ScopedTypeVariables #-}

module Data.RadixNTree.Word8.Strict.Pointer
  ( Pointer (..)
  , pointer0
  , pointer1

  , follow0
  , follow1
  ) where

import           Data.ByteArray.NonEmpty
import           Data.RadixNTree.Word8.Key
import           Data.RadixNTree.Word8.Strict

import           Control.Monad.ST
import           Data.Bits
import           Data.Primitive.ByteArray
import           Data.Word
import           Language.Haskell.TH.Syntax



-- | Pure compressed tree reference.
--
--   @since 1.1
data Pointer = Pointer
                 {-# UNPACK #-} !Word -- ^ Node depth (0 is root).
                 !ByteArray           -- ^ Little-endian bitmask of size @depth@.
                                      --   'Bin' choices are represented as 0 and 1 for
                                      --   left and right respectively;
                                      --   'Tip's can hold any data.
               deriving (Int -> Pointer -> ShowS
[Pointer] -> ShowS
Pointer -> String
(Int -> Pointer -> ShowS)
-> (Pointer -> String) -> ([Pointer] -> ShowS) -> Show Pointer
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> Pointer -> ShowS
showsPrec :: Int -> Pointer -> ShowS
$cshow :: Pointer -> String
show :: Pointer -> String
$cshowList :: [Pointer] -> ShowS
showList :: [Pointer] -> ShowS
Show, (forall (m :: * -> *). Quote m => Pointer -> m Exp)
-> (forall (m :: * -> *). Quote m => Pointer -> Code m Pointer)
-> Lift Pointer
forall t.
(forall (m :: * -> *). Quote m => t -> m Exp)
-> (forall (m :: * -> *). Quote m => t -> Code m t) -> Lift t
forall (m :: * -> *). Quote m => Pointer -> m Exp
forall (m :: * -> *). Quote m => Pointer -> Code m Pointer
$clift :: forall (m :: * -> *). Quote m => Pointer -> m Exp
lift :: forall (m :: * -> *). Quote m => Pointer -> m Exp
$cliftTyped :: forall (m :: * -> *). Quote m => Pointer -> Code m Pointer
liftTyped :: forall (m :: * -> *). Quote m => Pointer -> Code m Pointer
Lift)



-- | Mark a bit at the given depth as @1@.
mark :: MutableByteArray s -> Word -> ST s ()
mark :: forall s. MutableByteArray s -> Word -> ST s ()
mark MutableByteArray s
marr Word
n = do
  let x :: Int
x = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ Word
n Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
3

      y :: Int
y = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ Word
n Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
0x07

  Word8
i <- MutableByteArray (PrimState (ST s)) -> Int -> ST s Word8
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> m a
readByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
marr Int
x

  let i' :: Word8
i' = Word8
i Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|. Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
unsafeShiftL Word8
1 Int
y

  MutableByteArray (PrimState (ST s)) -> Int -> Word8 -> ST s ()
forall a (m :: * -> *).
(Prim a, PrimMonad m) =>
MutableByteArray (PrimState m) -> Int -> a -> m ()
writeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
marr Int
x (Word8
i' :: Word8)

-- | Check if the bit at the given depth is @0@.
left :: ByteArray -> Word -> Bool
left :: ByteArray -> Word -> Bool
left ByteArray
arr Word
n =
  let x :: Int
x = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ Word
n Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
3

      y :: Int
y = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word -> Int) -> Word -> Int
forall a b. (a -> b) -> a -> b
$ Word
n Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
0x07

  in (Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
unsafeShiftR (ByteArray -> Int -> Word8
forall a. Prim a => ByteArray -> Int -> a
indexByteArray ByteArray
arr Int
x :: Word8) Int
y) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0x1 Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== Word8
0

-- | Create a bitmask that can hold @depth@ bits and populate it.
form :: (forall s. MutableByteArray s -> ST s ()) -> Word -> ByteArray
form :: (forall s. MutableByteArray s -> ST s ()) -> Word -> ByteArray
form forall s. MutableByteArray s -> ST s ()
go Word
n = do
  (forall s. ST s ByteArray) -> ByteArray
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s ByteArray) -> ByteArray)
-> (forall s. ST s ByteArray) -> ByteArray
forall a b. (a -> b) -> a -> b
$ do
    let m :: Int
m = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word
n Word -> Int -> Word
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
3) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
    MutableByteArray s
marr <- Int -> ST s (MutableByteArray (PrimState (ST s)))
forall (m :: * -> *).
PrimMonad m =>
Int -> m (MutableByteArray (PrimState m))
newByteArray Int
m
    MutableByteArray (PrimState (ST s))
-> Int -> Int -> Word8 -> ST s ()
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> Int -> Int -> Word8 -> m ()
fillByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
marr Int
0 Int
m Word8
0x00
    MutableByteArray s -> ST s ()
forall s. MutableByteArray s -> ST s ()
go MutableByteArray s
marr
    MutableByteArray (PrimState (ST s)) -> ST s ByteArray
forall (m :: * -> *).
PrimMonad m =>
MutableByteArray (PrimState m) -> m ByteArray
unsafeFreezeByteArray MutableByteArray s
MutableByteArray (PrimState (ST s))
marr



{-# INLINE pointer0 #-}
pointer0 :: Feed -> RadixTree a -> Maybe Pointer
pointer0 :: forall a. Feed -> RadixTree a -> Maybe Pointer
pointer0 (Feed forall a. (forall x. (x -> Step Word8 x) -> x -> a) -> a
feed) = \(RadixTree Maybe a
mx Radix1Tree a
t) ->
  (forall x. (x -> Step Word8 x) -> x -> Maybe Pointer)
-> Maybe Pointer
forall a. (forall x. (x -> Step Word8 x) -> x -> a) -> a
feed ((forall x. (x -> Step Word8 x) -> x -> Maybe Pointer)
 -> Maybe Pointer)
-> (forall x. (x -> Step Word8 x) -> x -> Maybe Pointer)
-> Maybe Pointer
forall a b. (a -> b) -> a -> b
$ \x -> Step Word8 x
step x
s ->
    case x -> Step Word8 x
step x
s of
      More Word8
w x
z -> (x -> Step Word8 x) -> Word8 -> x -> Radix1Tree a -> Maybe Pointer
forall x a.
(x -> Step Word8 x) -> Word8 -> x -> Radix1Tree a -> Maybe Pointer
pointer_ x -> Step Word8 x
step Word8
w x
z Radix1Tree a
t
      Step Word8 x
Done     ->
        case Maybe a
mx of
          Just a
_  -> Pointer -> Maybe Pointer
forall a. a -> Maybe a
Just (Pointer -> Maybe Pointer) -> Pointer -> Maybe Pointer
forall a b. (a -> b) -> a -> b
$ Word -> ByteArray -> Pointer
Pointer Word
0 ByteArray
emptyByteArray
          Maybe a
Nothing -> Maybe Pointer
forall a. Maybe a
Nothing

{-# INLINE pointer1 #-}
pointer1 :: Feed1 -> Radix1Tree a -> Maybe Pointer
pointer1 :: forall a. Feed1 -> Radix1Tree a -> Maybe Pointer
pointer1 (Feed1 Word8
w forall a. (forall x. (x -> Step Word8 x) -> x -> a) -> a
feed) = (forall x.
 (x -> Step Word8 x) -> x -> Radix1Tree a -> Maybe Pointer)
-> Radix1Tree a -> Maybe Pointer
forall a. (forall x. (x -> Step Word8 x) -> x -> a) -> a
feed ((forall x.
  (x -> Step Word8 x) -> x -> Radix1Tree a -> Maybe Pointer)
 -> Radix1Tree a -> Maybe Pointer)
-> (forall x.
    (x -> Step Word8 x) -> x -> Radix1Tree a -> Maybe Pointer)
-> Radix1Tree a
-> Maybe Pointer
forall a b. (a -> b) -> a -> b
$ \x -> Step Word8 x
step -> (x -> Step Word8 x) -> Word8 -> x -> Radix1Tree a -> Maybe Pointer
forall x a.
(x -> Step Word8 x) -> Word8 -> x -> Radix1Tree a -> Maybe Pointer
pointer_ x -> Step Word8 x
step Word8
w

{-# INLINE pointer_ #-}
pointer_
  :: (x -> Step Word8 x)
  -> Word8 -> x -> Radix1Tree a -> Maybe Pointer
pointer_ :: forall x a.
(x -> Step Word8 x) -> Word8 -> x -> Radix1Tree a -> Maybe Pointer
pointer_ (x -> Step Word8 x
step :: x -> Step Word8 x) = Word
-> (forall s. MutableByteArray s -> ST s ())
-> Word8
-> x
-> Radix1Tree a
-> Maybe Pointer
forall a.
Word
-> (forall s. MutableByteArray s -> ST s ())
-> Word8
-> x
-> Radix1Tree a
-> Maybe Pointer
go Word
0 (\MutableByteArray s
_ -> () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ())
  where
    go :: Word -> (forall s. MutableByteArray s -> ST s ())
       -> Word8 -> x -> Radix1Tree a -> Maybe Pointer
    go :: forall a.
Word
-> (forall s. MutableByteArray s -> ST s ())
-> Word8
-> x
-> Radix1Tree a
-> Maybe Pointer
go !Word
i forall s. MutableByteArray s -> ST s ()
acc !Word8
w !x
s Radix1Tree a
t =
      case Radix1Tree a
t of
        Bin Word8
p Radix1Tree a
l Radix1Tree a
r ->
          if Word8
w Word8 -> Word8 -> Bool
forall a. Ord a => a -> a -> Bool
< Word8
p
            then Word
-> (forall s. MutableByteArray s -> ST s ())
-> Word8
-> x
-> Radix1Tree a
-> Maybe Pointer
forall a.
Word
-> (forall s. MutableByteArray s -> ST s ())
-> Word8
-> x
-> Radix1Tree a
-> Maybe Pointer
go (Word
i Word -> Word -> Word
forall a. Num a => a -> a -> a
+ Word
1)                          MutableByteArray s -> ST s ()
forall s. MutableByteArray s -> ST s ()
acc       Word8
w x
s Radix1Tree a
l
            else Word
-> (forall s. MutableByteArray s -> ST s ())
-> Word8
-> x
-> Radix1Tree a
-> Maybe Pointer
forall a.
Word
-> (forall s. MutableByteArray s -> ST s ())
-> Word8
-> x
-> Radix1Tree a
-> Maybe Pointer
go (Word
i Word -> Word -> Word
forall a. Num a => a -> a -> a
+ Word
1) (\MutableByteArray s
marr -> MutableByteArray s -> Word -> ST s ()
forall s. MutableByteArray s -> Word -> ST s ()
mark MutableByteArray s
marr Word
i ST s () -> ST s () -> ST s ()
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MutableByteArray s -> ST s ()
forall s. MutableByteArray s -> ST s ()
acc MutableByteArray s
marr) Word8
w x
s Radix1Tree a
r

        Tip ByteArray
arr Maybe a
mx Radix1Tree a
dx -> Word8 -> x -> Int -> Maybe Pointer
goarr Word8
w x
s Int
0
          where
            goarr :: Word8 -> x -> Int -> Maybe Pointer
goarr Word8
v !x
z Int
n
              | Word8
v Word8 -> Word8 -> Bool
forall a. Eq a => a -> a -> Bool
== ByteArray -> Int -> Word8
forall a. Prim a => ByteArray -> Int -> a
indexByteArray ByteArray
arr Int
n =
                  let n' :: Int
n' = Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
                  in if Int
n' Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= ByteArray -> Int
sizeofByteArray ByteArray
arr
                       then case x -> Step Word8 x
step x
z of
                              More Word8
u x
z' -> Word
-> (forall s. MutableByteArray s -> ST s ())
-> Word8
-> x
-> Radix1Tree a
-> Maybe Pointer
forall a.
Word
-> (forall s. MutableByteArray s -> ST s ())
-> Word8
-> x
-> Radix1Tree a
-> Maybe Pointer
go (Word
i Word -> Word -> Word
forall a. Num a => a -> a -> a
+ Word
1) MutableByteArray s -> ST s ()
forall s. MutableByteArray s -> ST s ()
acc Word8
u x
z' Radix1Tree a
dx
                              Step Word8 x
Done      ->
                                case Maybe a
mx of
                                  Just a
_  -> Pointer -> Maybe Pointer
forall a. a -> Maybe a
Just (Pointer -> Maybe Pointer) -> Pointer -> Maybe Pointer
forall a b. (a -> b) -> a -> b
$ Word -> ByteArray -> Pointer
Pointer (Word
i Word -> Word -> Word
forall a. Num a => a -> a -> a
+ Word
1) ((forall s. MutableByteArray s -> ST s ()) -> Word -> ByteArray
form MutableByteArray s -> ST s ()
forall s. MutableByteArray s -> ST s ()
acc Word
i)
                                  Maybe a
Nothing -> Maybe Pointer
forall a. Maybe a
Nothing

                       else case x -> Step Word8 x
step x
z of
                              More Word8
u x
z' -> Word8 -> x -> Int -> Maybe Pointer
goarr Word8
u x
z' Int
n'
                              Step Word8 x
Done      -> Maybe Pointer
forall a. Maybe a
Nothing

              | Bool
otherwise = Maybe Pointer
forall a. Maybe a
Nothing

        Radix1Tree a
Nil -> Maybe Pointer
forall a. Maybe a
Nothing



follow0 :: a -> Pointer -> RadixTree a -> a
follow0 :: forall a. a -> Pointer -> RadixTree a -> a
follow0 a
d (Pointer Word
len ByteArray
arr) (RadixTree Maybe a
mx Radix1Tree a
dx)
  | Word
len Word -> Word -> Bool
forall a. Eq a => a -> a -> Bool
== Word
0  = case Maybe a
mx of
                  Just a
x  -> a
x
                  Maybe a
Nothing -> a
d

  | Bool
otherwise = a -> Word -> ByteArray -> Radix1Tree a -> a
forall a. a -> Word -> ByteArray -> Radix1Tree a -> a
follow_ a
d Word
len ByteArray
arr Radix1Tree a
dx

follow1 :: a -> Pointer -> Radix1Tree a -> a
follow1 :: forall a. a -> Pointer -> Radix1Tree a -> a
follow1 a
d (Pointer Word
len ByteArray
arr) = a -> Word -> ByteArray -> Radix1Tree a -> a
forall a. a -> Word -> ByteArray -> Radix1Tree a -> a
follow_ a
d Word
len ByteArray
arr

follow_ :: a -> Word -> ByteArray -> Radix1Tree a -> a
follow_ :: forall a. a -> Word -> ByteArray -> Radix1Tree a -> a
follow_ a
d Word
len ByteArray
arr = Word -> Radix1Tree a -> a
go Word
0
  where
    go :: Word -> Radix1Tree a -> a
go !Word
i Radix1Tree a
t =
      case Radix1Tree a
t of
        Bin Word8
_ Radix1Tree a
l Radix1Tree a
r ->
          Word -> Radix1Tree a -> a
go (Word
i Word -> Word -> Word
forall a. Num a => a -> a -> a
+ Word
1) (Radix1Tree a -> a) -> Radix1Tree a -> a
forall a b. (a -> b) -> a -> b
$ if ByteArray -> Word -> Bool
left ByteArray
arr Word
i
                         then Radix1Tree a
l
                         else Radix1Tree a
r

        Tip ByteArray
_ Maybe a
mx Radix1Tree a
dx ->
          let i' :: Word
i' = Word
i Word -> Word -> Word
forall a. Num a => a -> a -> a
+ Word
1
          in if Word
i' Word -> Word -> Bool
forall a. Ord a => a -> a -> Bool
> Word
len
               then a
d
               else if Word
i' Word -> Word -> Bool
forall a. Eq a => a -> a -> Bool
== Word
len
                      then case Maybe a
mx of
                             Just a
x  -> a
x
                             Maybe a
Nothing -> a
d

                      else Word -> Radix1Tree a -> a
go Word
i' Radix1Tree a
dx

        Radix1Tree a
Nil -> a
d