{-# LANGUAGE CPP              #-}
{-# LANGUAGE BangPatterns     #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes       #-}
#ifndef BITVEC_THREADSAFE
module Data.Bit.Mutable
#else
module Data.Bit.MutableTS
#endif
  ( castFromWordsM
  , castToWordsM
  , cloneToWordsM
  , zipInPlace
  , invertInPlace
  , selectBitsInPlace
  , excludeBitsInPlace
  , reverseInPlace
  ) where
import Control.Monad.Primitive
#ifndef BITVEC_THREADSAFE
import Data.Bit.Internal
#else
import Data.Bit.InternalTS
#endif
import Data.Bit.Utils
import Data.Bits
import qualified Data.Vector.Primitive as P
import qualified Data.Vector.Unboxed as U
import qualified Data.Vector.Unboxed.Mutable as MU
castFromWordsM :: MVector s Word -> MVector s Bit
castFromWordsM (MU.MV_Word (P.MVector off len ws)) =
  BitMVec (mulWordSize off) (mulWordSize len) ws
castToWordsM :: MVector s Bit -> Maybe (MVector s Word)
castToWordsM (BitMVec s n ws)
  | aligned s, aligned n = Just $ MU.MV_Word $ P.MVector (divWordSize s)
                                                         (divWordSize n)
                                                         ws
  | otherwise = Nothing
cloneToWordsM
  :: PrimMonad m => MVector (PrimState m) Bit -> m (MVector (PrimState m) Word)
cloneToWordsM v = do
  let lenBits  = MU.length v
      lenWords = nWords lenBits
  w@(BitMVec _ _ arr) <- MU.unsafeNew (mulWordSize lenWords)
  MU.unsafeCopy (MU.slice 0 lenBits w) v
  MU.set (MU.slice lenBits (mulWordSize lenWords - lenBits) w) (Bit False)
  pure $ MU.MV_Word $ P.MVector 0 lenWords arr
{-# INLINE cloneToWordsM #-}
zipInPlace
  :: PrimMonad m
  => (forall a . Bits a => a -> a -> a)
  -> Vector Bit
  -> MVector (PrimState m) Bit
  -> m ()
zipInPlace f xs ys = loop 0
 where
  !n = min (U.length xs) (MU.length ys)
  loop !i
    | i >= n = pure ()
    | otherwise = do
      let x = indexWord xs i
      y <- readWord ys i
      writeWord ys i (f x y)
      loop (i + wordSize)
{-# INLINE zipInPlace #-}
invertInPlace :: PrimMonad m => U.MVector (PrimState m) Bit -> m ()
invertInPlace xs = loop 0
 where
  !n = MU.length xs
  loop !i
    | i >= n = pure ()
    | otherwise = do
      x <- readWord xs i
      writeWord xs i (complement x)
      loop (i + wordSize)
{-# INLINE invertInPlace #-}
selectBitsInPlace
  :: PrimMonad m => U.Vector Bit -> U.MVector (PrimState m) Bit -> m Int
selectBitsInPlace is xs = loop 0 0
 where
  !n = min (U.length is) (MU.length xs)
  loop !i !ct
    | i >= n = return ct
    | otherwise = do
      x <- readWord xs i
      let !(nSet, x') = selectWord (masked (n - i) (indexWord is i)) x
      writeWord xs ct x'
      loop (i + wordSize) (ct + nSet)
excludeBitsInPlace
  :: PrimMonad m => U.Vector Bit -> U.MVector (PrimState m) Bit -> m Int
excludeBitsInPlace is xs = loop 0 0
 where
  !n = min (U.length is) (MU.length xs)
  loop !i !ct
    | i >= n = return ct
    | otherwise = do
      x <- readWord xs i
      let !(nSet, x') =
            selectWord (masked (n - i) (complement (indexWord is i))) x
      writeWord xs ct x'
      loop (i + wordSize) (ct + nSet)
reverseInPlace :: PrimMonad m => U.MVector (PrimState m) Bit -> m ()
reverseInPlace xs | len == 0  = pure ()
                  | otherwise = loop 0
 where
  len = MU.length xs
  loop !i
    | i' <= j' = do
      x <- readWord xs i
      y <- readWord xs j'
      writeWord xs i  (reverseWord y)
      writeWord xs j' (reverseWord x)
      loop i'
    | i' < j = do
      let w = (j - i) `shiftR` 1
          k = j - w
      x <- readWord xs i
      y <- readWord xs k
      writeWord xs i (meld w (reversePartialWord w y) x)
      writeWord xs k (meld w (reversePartialWord w x) y)
      loop i'
    | otherwise = do
      let w = j - i
      x <- readWord xs i
      writeWord xs i (meld w (reversePartialWord w x) x)
   where
    !j  = len - i
    !i' = i + wordSize
    !j' = j - wordSize