{-# LANGUAGE InstanceSigs #-}
module Codec.Arithmetic.Variety.BitVec
  ( BitVec

  -- * Construction
  , bitVec

  -- * Conversion
  , fromBits
  , toBits
  , fromBytes
  , toBytes
  , fromInteger
  , toInteger
  , fromString
  , toString

  -- * Methods
  , empty
  , null
  , length
  , singleton
  , append
  , take
  , drop
  , splitAt
  , replicate
  , countLeadingZeros
  , (!!)
  , (!?)

  -- * Extra
  , bitLen
  ) where

import Prelude hiding
  (null, length, take, drop, splitAt, replicate, (!!), fromInteger, toInteger)
import GHC.Num (integerLog2)
import Control.Exception (assert)

import Data.Bits ((.&.),(.|.))
import qualified Data.Bits as Bits
import Data.Word (Word8)
import qualified Data.List as L
import Data.ByteString.Lazy (ByteString)
import qualified Data.ByteString.Lazy as BS

err :: String -> a
err :: forall a. String -> a
err = String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> (String -> String) -> String -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String
"Variety.BitVec: " String -> String -> String
forall a. [a] -> [a] -> [a]
++)

-- | A vector of bits
data BitVec = BitVec !Int !Integer
  deriving (Int -> BitVec -> String -> String
[BitVec] -> String -> String
BitVec -> String
(Int -> BitVec -> String -> String)
-> (BitVec -> String)
-> ([BitVec] -> String -> String)
-> Show BitVec
forall a.
(Int -> a -> String -> String)
-> (a -> String) -> ([a] -> String -> String) -> Show a
$cshowsPrec :: Int -> BitVec -> String -> String
showsPrec :: Int -> BitVec -> String -> String
$cshow :: BitVec -> String
show :: BitVec -> String
$cshowList :: [BitVec] -> String -> String
showList :: [BitVec] -> String -> String
Show,ReadPrec [BitVec]
ReadPrec BitVec
Int -> ReadS BitVec
ReadS [BitVec]
(Int -> ReadS BitVec)
-> ReadS [BitVec]
-> ReadPrec BitVec
-> ReadPrec [BitVec]
-> Read BitVec
forall a.
(Int -> ReadS a)
-> ReadS [a] -> ReadPrec a -> ReadPrec [a] -> Read a
$creadsPrec :: Int -> ReadS BitVec
readsPrec :: Int -> ReadS BitVec
$creadList :: ReadS [BitVec]
readList :: ReadS [BitVec]
$creadPrec :: ReadPrec BitVec
readPrec :: ReadPrec BitVec
$creadListPrec :: ReadPrec [BitVec]
readListPrec :: ReadPrec [BitVec]
Read,Eq BitVec
Eq BitVec =>
(BitVec -> BitVec -> Ordering)
-> (BitVec -> BitVec -> Bool)
-> (BitVec -> BitVec -> Bool)
-> (BitVec -> BitVec -> Bool)
-> (BitVec -> BitVec -> Bool)
-> (BitVec -> BitVec -> BitVec)
-> (BitVec -> BitVec -> BitVec)
-> Ord BitVec
BitVec -> BitVec -> Bool
BitVec -> BitVec -> Ordering
BitVec -> BitVec -> BitVec
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: BitVec -> BitVec -> Ordering
compare :: BitVec -> BitVec -> Ordering
$c< :: BitVec -> BitVec -> Bool
< :: BitVec -> BitVec -> Bool
$c<= :: BitVec -> BitVec -> Bool
<= :: BitVec -> BitVec -> Bool
$c> :: BitVec -> BitVec -> Bool
> :: BitVec -> BitVec -> Bool
$c>= :: BitVec -> BitVec -> Bool
>= :: BitVec -> BitVec -> Bool
$cmax :: BitVec -> BitVec -> BitVec
max :: BitVec -> BitVec -> BitVec
$cmin :: BitVec -> BitVec -> BitVec
min :: BitVec -> BitVec -> BitVec
Ord,BitVec -> BitVec -> Bool
(BitVec -> BitVec -> Bool)
-> (BitVec -> BitVec -> Bool) -> Eq BitVec
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: BitVec -> BitVec -> Bool
== :: BitVec -> BitVec -> Bool
$c/= :: BitVec -> BitVec -> Bool
/= :: BitVec -> BitVec -> Bool
Eq)

instance Semigroup BitVec where
  (<>) :: BitVec -> BitVec -> BitVec
  <> :: BitVec -> BitVec -> BitVec
(<>) = BitVec -> BitVec -> BitVec
append

instance Monoid BitVec where
  mempty :: BitVec
  mempty :: BitVec
mempty = BitVec
empty

-- | Construct a BitVec from a length and Integer.
bitVec :: Int -> Integer -> BitVec
bitVec :: Int -> Integer -> BitVec
bitVec = Int -> Integer -> BitVec
BitVec

-----------------
-- CONVERSIONS --
-----------------

-- | Construct from a list of bits. `True` is @1@ and `False` is @0@.
fromBits :: [Bool] -> BitVec
fromBits :: [Bool] -> BitVec
fromBits [Bool]
bs = Int -> Integer -> BitVec
BitVec Int
len (Integer -> BitVec) -> Integer -> BitVec
forall a b. (a -> b) -> a -> b
$ (Integer -> Int -> Integer) -> Integer -> [Int] -> Integer
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
Bits.setBit Integer
0 [Int]
ones
  where
    len :: Int
len = [Bool] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
L.length [Bool]
bs
    idxs :: [Int]
idxs = [Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1,Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2..Int
0]
    ones :: [Int]
ones = ((Bool, Int) -> Int) -> [(Bool, Int)] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Bool, Int) -> Int
forall a b. (a, b) -> b
snd ([(Bool, Int)] -> [Int]) -> [(Bool, Int)] -> [Int]
forall a b. (a -> b) -> a -> b
$ ((Bool, Int) -> Bool) -> [(Bool, Int)] -> [(Bool, Int)]
forall a. (a -> Bool) -> [a] -> [a]
L.filter (Bool, Int) -> Bool
forall a b. (a, b) -> a
fst ([(Bool, Int)] -> [(Bool, Int)]) -> [(Bool, Int)] -> [(Bool, Int)]
forall a b. (a -> b) -> a -> b
$ [Bool] -> [Int] -> [(Bool, Int)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Bool]
bs [Int]
idxs

-- | Return as a list of bits. `True` is @1@ and `False` is @0@.
toBits :: BitVec -> [Bool]
toBits :: BitVec -> [Bool]
toBits (BitVec Int
len Integer
int) = Integer -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
Bits.testBit Integer
int (Int -> Bool) -> [Int] -> [Bool]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1,Int
lenInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
2..Int
0]

-- | Construct from a lazy @ByteString@
fromBytes :: ByteString -> BitVec
fromBytes :: ByteString -> BitVec
fromBytes = [Bool] -> BitVec
fromBits ([Bool] -> BitVec)
-> (ByteString -> [Bool]) -> ByteString -> BitVec
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word8 -> [Bool]) -> [Word8] -> [Bool]
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Word8 -> [Bool]
unpack8bits ([Word8] -> [Bool])
-> (ByteString -> [Word8]) -> ByteString -> [Bool]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Word8]
BS.unpack

-- | Pack the bits into a lazy @ByteString@. Pads the left with @0@s if
-- the length is not a multiple of 8.
toBytes :: BitVec -> ByteString
toBytes :: BitVec -> ByteString
toBytes v :: BitVec
v@(BitVec Int
len Integer
_) = [Word8] -> ByteString
BS.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ ([Bool] -> Word8) -> [[Bool]] -> [Word8]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Bool] -> Word8
pack8bits ([[Bool]] -> [Word8]) -> [[Bool]] -> [Word8]
forall a b. (a -> b) -> a -> b
$
                           Int -> [Bool] -> [[Bool]]
forall {a}. Int -> [a] -> [[a]]
chunksOf Int
8 ([Bool] -> [[Bool]]) -> [Bool] -> [[Bool]]
forall a b. (a -> b) -> a -> b
$ [Bool]
pad [Bool] -> [Bool] -> [Bool]
forall a. [a] -> [a] -> [a]
++ BitVec -> [Bool]
toBits BitVec
v
  where
    padLen :: Int
padLen = (-Int
len) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
8
    pad :: [Bool]
pad = Bool -> [Bool] -> [Bool]
forall a. HasCallStack => Bool -> a -> a
assert ((Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
padLen) Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
8 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0) ([Bool] -> [Bool]) -> [Bool] -> [Bool]
forall a b. (a -> b) -> a -> b
$
          Int -> Bool -> [Bool]
forall a. Int -> a -> [a]
L.replicate Int
padLen Bool
False

    chunksOf :: Int -> [a] -> [[a]]
chunksOf Int
_ [] = []
    chunksOf Int
i [a]
xs = [a]
a [a] -> [[a]] -> [[a]]
forall a. a -> [a] -> [a]
: Int -> [a] -> [[a]]
chunksOf Int
i [a]
b
      where ([a]
a,[a]
b) = Int -> [a] -> ([a], [a])
forall a. Int -> [a] -> ([a], [a])
L.splitAt Int
i [a]
xs

-- | Read bits from the binary representation of an @Integer@. This
-- excludes the possibility of any leading zeros. Use `bitVec` for more
-- flexible construction.
fromInteger :: Integer -> BitVec
fromInteger :: Integer -> BitVec
fromInteger Integer
int = Int -> Integer -> BitVec
BitVec Int
sz Integer
int
  where sz :: Int
sz = Integer -> Int
bitLen Integer
int

-- | Return the @Integer@ representation of the @BitVec@.
toInteger :: BitVec -> Integer
toInteger :: BitVec -> Integer
toInteger (BitVec Int
_ Integer
int) = Integer
int

-- | Read the code from a list of @0@ and @1@ chars.
fromString :: String -> BitVec
fromString :: String -> BitVec
fromString = [Bool] -> BitVec
fromBits ([Bool] -> BitVec) -> (String -> [Bool]) -> String -> BitVec
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Char -> Bool) -> String -> [Bool]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Char -> Bool
f
  where f :: Char -> Bool
f Char
'0' = Bool
False
        f Char
'1' = Bool
True
        f Char
c = String -> Bool
forall a. String -> a
err (String -> Bool) -> String -> Bool
forall a b. (a -> b) -> a -> b
$ String
"Non-binary char encountered: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Char
c]

-- | Return the bits as a list of @0@ and @1@ chars.
toString :: BitVec -> String
toString :: BitVec -> String
toString = (Bool -> Char) -> [Bool] -> String
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Bool -> Char
f ([Bool] -> String) -> (BitVec -> [Bool]) -> BitVec -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BitVec -> [Bool]
toBits
  where f :: Bool -> Char
f Bool
b = if Bool
b then Char
'1' else Char
'0'

-------------
-- METHODS --
-------------

-- | The empty bit vector.
empty :: BitVec
empty :: BitVec
empty = Int -> Integer -> BitVec
BitVec Int
0 Integer
0

-- | Returns `True` iff the bit vector is empty.
null :: BitVec -> Bool
null :: BitVec -> Bool
null (BitVec Int
0 Integer
0) = Bool
True
null BitVec
_ = Bool
False

-- | Returns the number of bits in the vector.
length :: BitVec -> Int
length :: BitVec -> Int
length (BitVec Int
len Integer
_) = Int
len

-- | Concatenate two bit vectors.
append :: BitVec -> BitVec -> BitVec
append :: BitVec -> BitVec -> BitVec
append (BitVec Int
len0 Integer
int0) (BitVec Int
len1 Integer
int1) =
  Int -> Integer -> BitVec
BitVec (Int
len0 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
len1) (Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
Bits.shiftL Integer
int0 Int
len1 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
int1)

-- | A vector of length 1 with the given bit.
singleton :: Bool -> BitVec
singleton :: Bool -> BitVec
singleton Bool
False = Int -> Integer -> BitVec
BitVec Int
1 Integer
0
singleton Bool
True = Int -> Integer -> BitVec
BitVec Int
1 Integer
1

-- | @`take` n bv@ returns the bit vector consisting of the first @n@
-- bits of @bv@.
take :: Int -> BitVec -> BitVec
take :: Int -> BitVec -> BitVec
take Int
n bv :: BitVec
bv@(BitVec Int
len Integer
int)
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0    = BitVec
empty
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
len  = BitVec
bv
  | Bool
otherwise = Int -> Integer -> BitVec
BitVec Int
n (Integer -> BitVec) -> Integer -> BitVec
forall a b. (a -> b) -> a -> b
$ Integer
int Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`Bits.shiftR` (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n)

-- | @`drop` n bv@ returns @bv@ with the first @n@ bits removed.
drop :: Int -> BitVec -> BitVec
drop :: Int -> BitVec -> BitVec
drop Int
n bv :: BitVec
bv@(BitVec Int
len Integer
int)
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0    = BitVec
bv
  | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
len  = BitVec
empty
  | Bool
otherwise = Int -> Integer -> BitVec
BitVec (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n) (Integer -> BitVec) -> Integer -> BitVec
forall a b. (a -> b) -> a -> b
$
                Integer
int Integer -> Integer -> Integer
forall a. Bits a => a -> a -> a
.&. ((Integer
1 Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`Bits.shiftL` (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n)) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)

-- | @`splitAt` n bv@ is equivalent to @(`take` n bv, `drop` n bv)@
splitAt :: Int -> BitVec -> (BitVec, BitVec)
splitAt :: Int -> BitVec -> (BitVec, BitVec)
splitAt Int
n BitVec
bv = (Int -> BitVec -> BitVec
take Int
n BitVec
bv, Int -> BitVec -> BitVec
drop Int
n BitVec
bv)

-- | @`replicate` n b@ constructs a bit vector of length @n@ with @b@
-- the value of every bit.
replicate :: Int -> Bool -> BitVec
replicate :: Int -> Bool -> BitVec
replicate Int
n Bool
False = Int -> Integer -> BitVec
BitVec Int
n Integer
0
replicate Int
n Bool
True = Int -> Integer -> BitVec
BitVec Int
n (Int -> Integer
forall a. Bits a => Int -> a
Bits.bit Int
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)

-- | Count the number of @0@ bits preceeding the first @1@ bit.
countLeadingZeros :: BitVec -> Int
countLeadingZeros :: BitVec -> Int
countLeadingZeros (BitVec Int
len Integer
int) = Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
intLen
  where intLen :: Int
intLen = Integer -> Int
bitLen Integer
int

-- | Returns the value of a bit at a given index, with @0@ being the
-- index of the most significant (left-most) bit.
(!!) :: BitVec -> Int -> Bool
(BitVec Int
len Integer
int) !! :: BitVec -> Int -> Bool
!! Int
i = Integer -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
Bits.testBit Integer
int (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
infixl 9 !!

-- | Returns the value of a bit at a given index if within bounds, with
-- @0@ being the index of the most significant (left-most) bit.
(!?) :: BitVec -> Int -> Maybe Bool
(BitVec Int
len Integer
int) !? :: BitVec -> Int -> Maybe Bool
!? Int
i
  | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 Bool -> Bool -> Bool
|| Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
len = Maybe Bool
forall a. Maybe a
Nothing
  | Bool
otherwise = Bool -> Maybe Bool
forall a. a -> Maybe a
Just (Bool -> Maybe Bool) -> Bool -> Maybe Bool
forall a b. (a -> b) -> a -> b
$ Integer -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
Bits.testBit Integer
int (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
infixl 9 !?

-------------
-- HELPERS --
-------------

-- | Pack exactly 8 bits into a byte
pack8bits :: [Bool] -> Word8
pack8bits :: [Bool] -> Word8
pack8bits = (Word8 -> Bool -> Word8) -> Word8 -> [Bool] -> Word8
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' Word8 -> Bool -> Word8
forall {a} {a}. (Bits a, Num a, Enum a) => a -> a -> a
f Word8
0
  where f :: a -> a -> a
f a
acc a
b = (a
acc a -> Int -> a
forall a. Bits a => a -> Int -> a
`Bits.shiftL` Int
1) a -> a -> a
forall a. Bits a => a -> a -> a
.|. Int -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral (a -> Int
forall a. Enum a => a -> Int
fromEnum a
b)

-- | Return the 8 bits that make a byte
unpack8bits :: Word8 -> [Bool]
unpack8bits :: Word8 -> [Bool]
unpack8bits Word8
w = Word8 -> Int -> Bool
forall a. Bits a => a -> Int -> Bool
Bits.testBit Word8
w (Int -> Bool) -> [Int] -> [Bool]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int
7,Int
6,Int
5,Int
4,Int
3,Int
2,Int
1,Int
0]

-- | The number of bits in the binary expansion of a positive
-- integer. For consistency with inductive definitions, leading zeros
-- are not considered and so @`bitLen` 0 == 0@.
bitLen :: Integer -> Int
bitLen :: Integer -> Int
bitLen Integer
0 = Int
0
bitLen Integer
n = Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Integer -> Word
integerLog2 Integer
n) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1