-- | Since the arithmetic operations of composition might get
-- computationally expensive on very large codes, a similar interface is
-- provided here which produces and consumes bits in chunks whenever
-- spaces are about to grow beyond a certain size given in bytes, at the
-- cost of at most one bit per chunk.
--
-- While the Haskell language standard defines `Integer` as having no
-- upper bound, GHC most commonly uses the GNU Multiple Precision
-- Arithmetic Library (GMP) as a backend for it, which incurs a limit of
-- 16GiB (or a little over 17GB) on the size of `Integer` values.
module Codec.Arithmetic.Variety.Bounded
  ( encode
  , codeLen
  , decode
  ) where

import Data.Bits (Bits(bit))

import qualified Codec.Arithmetic.Variety as V
import Codec.Arithmetic.Variety.BitVec (BitVec)
import qualified Codec.Arithmetic.Variety.BitVec as BV

err :: String -> a
err :: forall a. String -> a
err = String -> a
forall a. HasCallStack => String -> a
error (String -> a) -> (String -> String) -> String -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (String
"Variety.Bounded: " String -> String -> String
forall a. [a] -> [a] -> [a]
++)

groupWithinPrec :: (a -> Integer) -> Int -> [a] -> [(Integer,[a])]
groupWithinPrec :: forall a. (a -> Integer) -> Int -> [a] -> [(Integer, [a])]
groupWithinPrec a -> Integer
getBase Int
prec
  | Int
prec Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
0 = String -> [a] -> [(Integer, [a])]
forall a. String -> a
err String
"negative precision"
  | Bool
otherwise = ([a] -> [a]) -> [(Integer, [a])] -> [(Integer, [a])]
forall (f :: * -> *) (g :: * -> *) a b.
(Functor f, Functor g) =>
(a -> b) -> f (g a) -> f (g b)
ffmap [a] -> [a]
forall a. [a] -> [a]
reverse ([(Integer, [a])] -> [(Integer, [a])])
-> ([a] -> [(Integer, [a])]) -> [a] -> [(Integer, [a])]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Integer -> [a] -> [a] -> [(Integer, [a])]
go Integer
1 []
  where
    maxBase :: Integer
maxBase = Int -> Integer
forall a. Bits a => Int -> a
bit (Int
precInt -> Int -> Int
forall a. Num a => a -> a -> a
*Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1 -- max val with `prec` bytes
    go :: Integer -> [a] -> [a] -> [(Integer, [a])]
go Integer
base [a]
group [] = ((Integer, [a]) -> Bool) -> [(Integer, [a])] -> [(Integer, [a])]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool)
-> ((Integer, [a]) -> Bool) -> (Integer, [a]) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null ([a] -> Bool) -> ((Integer, [a]) -> [a]) -> (Integer, [a]) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer, [a]) -> [a]
forall a b. (a, b) -> b
snd) [(Integer
base,[a]
group)]
    go Integer
1 [a]
group (a
a:[a]
as) = Integer -> [a] -> [a] -> [(Integer, [a])]
go (a -> Integer
getBase a
a) (a
aa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
group) [a]
as
    go Integer
base [a]
group (a
a:[a]
as)
      | Integer
base' Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
maxBase = (Integer
base,[a]
group) (Integer, [a]) -> [(Integer, [a])] -> [(Integer, [a])]
forall a. a -> [a] -> [a]
: Integer -> [a] -> [a] -> [(Integer, [a])]
go Integer
b [a
a] [a]
as
      | Bool
otherwise = Integer -> [a] -> [a] -> [(Integer, [a])]
go Integer
base' (a
aa -> [a] -> [a]
forall a. a -> [a] -> [a]
:[a]
group) [a]
as
      where
        b :: Integer
b = a -> Integer
getBase a
a
        base' :: Integer
base' = Integer
base Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
b
{-# INLINE groupWithinPrec #-}

-- | Given a max precision in bytes, encode a series of value-base pairs
-- into a single bit vector. Bases must be at least equal to @1@ and the
-- associated values must exist in the range @[0..base-1]@.
encode :: Int -> [(Integer,Integer)] -> BitVec
encode :: Int -> [(Integer, Integer)] -> BitVec
encode = [BitVec] -> BitVec
forall a. Monoid a => [a] -> a
mconcat
         ([BitVec] -> BitVec)
-> ([(Integer, [(Integer, Integer)])] -> [BitVec])
-> [(Integer, [(Integer, Integer)])]
-> BitVec
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Integer, [(Integer, Integer)]) -> BitVec)
-> [(Integer, [(Integer, Integer)])] -> [BitVec]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([(Integer, Integer)] -> BitVec
V.encode ([(Integer, Integer)] -> BitVec)
-> ((Integer, [(Integer, Integer)]) -> [(Integer, Integer)])
-> (Integer, [(Integer, Integer)])
-> BitVec
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer, [(Integer, Integer)]) -> [(Integer, Integer)]
forall a b. (a, b) -> b
snd)
         ([(Integer, [(Integer, Integer)])] -> BitVec)
-> (Int
    -> [(Integer, Integer)] -> [(Integer, [(Integer, Integer)])])
-> Int
-> [(Integer, Integer)]
-> BitVec
forall b c a1 a2. (b -> c) -> (a1 -> a2 -> b) -> a1 -> a2 -> c
.: ((Integer, Integer) -> Integer)
-> Int -> [(Integer, Integer)] -> [(Integer, [(Integer, Integer)])]
forall a. (a -> Integer) -> Int -> [a] -> [(Integer, [a])]
groupWithinPrec (Integer, Integer) -> Integer
forall a b. (a, b) -> b
snd

-- | Return the length of the code of a sequence of values in the given
-- precision and list of bases in bits.
codeLen :: Int -> [Integer] -> Int
codeLen :: Int -> [Integer] -> Int
codeLen = Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral
          (Int -> Int)
-> ([(Integer, [Integer])] -> Int) -> [(Integer, [Integer])] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Int] -> Int
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum
          ([Int] -> Int)
-> ([(Integer, [Integer])] -> [Int])
-> [(Integer, [Integer])]
-> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Integer, [Integer]) -> Int) -> [(Integer, [Integer])] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Integer -> Int
V.codeLen1 (Integer -> Int)
-> ((Integer, [Integer]) -> Integer) -> (Integer, [Integer]) -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Integer, [Integer]) -> Integer
forall a b. (a, b) -> a
fst)
          ([(Integer, [Integer])] -> Int)
-> (Int -> [Integer] -> [(Integer, [Integer])])
-> Int
-> [Integer]
-> Int
forall b c a1 a2. (b -> c) -> (a1 -> a2 -> b) -> a1 -> a2 -> c
.: (Integer -> Integer) -> Int -> [Integer] -> [(Integer, [Integer])]
forall a. (a -> Integer) -> Int -> [a] -> [(Integer, [a])]
groupWithinPrec Integer -> Integer
forall a. a -> a
id

-- | Try to decode a sequence of values at the head of a bit vector
-- given the same precision and list of bases that was used to encode
-- it. If successful, returns the decoded values and the remainder of
-- the `BitVec`, with the sequence's code removed. Throws an error if
-- the given vector's size doesn't match the given bases.
decode :: Int -> [Integer] -> BitVec -> [Integer]
decode :: Int -> [Integer] -> BitVec -> [Integer]
decode = [(Integer, [Integer])] -> BitVec -> [Integer]
go ([(Integer, [Integer])] -> BitVec -> [Integer])
-> (Int -> [Integer] -> [(Integer, [Integer])])
-> Int
-> [Integer]
-> BitVec
-> [Integer]
forall b c a1 a2. (b -> c) -> (a1 -> a2 -> b) -> a1 -> a2 -> c
.: (Integer -> Integer) -> Int -> [Integer] -> [(Integer, [Integer])]
forall a. (a -> Integer) -> Int -> [a] -> [(Integer, [a])]
groupWithinPrec Integer -> Integer
forall a. a -> a
id
  where
    go :: [(Integer, [Integer])] -> BitVec -> [Integer]
go [] BitVec
bv | Bool -> Bool
not (BitVec -> Bool
BV.null BitVec
bv) = String -> [Integer]
forall a. String -> a
err String
"decode: too many bits"
             | Bool
otherwise = []
    go ((Integer
base,[Integer]
bases):[(Integer, [Integer])]
rest) BitVec
bv
      | BitVec -> Int
BV.length BitVec
hd Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
len = String -> [Integer]
forall a. String -> a
err String
"decode: not enough bits"
      | Bool
otherwise = [Integer] -> BitVec -> [Integer]
V.decode [Integer]
bases BitVec
hd [Integer] -> [Integer] -> [Integer]
forall a. [a] -> [a] -> [a]
++ [(Integer, [Integer])] -> BitVec -> [Integer]
go [(Integer, [Integer])]
rest BitVec
tl
      where
        len :: Int
len = Integer -> Int
BV.bitLen (Integer
base Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1)
        (BitVec
hd,BitVec
tl) = Int -> BitVec -> (BitVec, BitVec)
BV.splitAt Int
len BitVec
bv

(.:) :: (b -> c) -> (a1 -> a2 -> b) -> a1 -> a2 -> c
.: :: forall b c a1 a2. (b -> c) -> (a1 -> a2 -> b) -> a1 -> a2 -> c
(.:) = ((a2 -> b) -> a2 -> c) -> (a1 -> a2 -> b) -> a1 -> a2 -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.)(((a2 -> b) -> a2 -> c) -> (a1 -> a2 -> b) -> a1 -> a2 -> c)
-> ((b -> c) -> (a2 -> b) -> a2 -> c)
-> (b -> c)
-> (a1 -> a2 -> b)
-> a1
-> a2
-> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(b -> c) -> (a2 -> b) -> a2 -> c
forall b c a. (b -> c) -> (a -> b) -> a -> c
(.)
infixr 8 .:
{-# INLINE (.:) #-}

ffmap :: (Functor f, Functor g) => (a -> b) -> f (g a) -> f (g b)
ffmap :: forall (f :: * -> *) (g :: * -> *) a b.
(Functor f, Functor g) =>
(a -> b) -> f (g a) -> f (g b)
ffmap = (g a -> g b) -> f (g a) -> f (g b)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((g a -> g b) -> f (g a) -> f (g b))
-> ((a -> b) -> g a -> g b) -> (a -> b) -> f (g a) -> f (g b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> b) -> g a -> g b
forall a b. (a -> b) -> g a -> g b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap
{-# INLINE ffmap #-}