-- | Module      : Data.BaseSystem
--   Description : Defines + implements the BaseSystem type-class
--   Copyright   : Zoey McBride (c) 2026
--   License     : BSD-3-Clause
--   Maintainer  : zoeymcbride@mailbox.org
--   Stability   : experimental
--
-- This file provides methods for encoding/decoding binary data to/from strings
-- of digits in some basesystem. Digits can represent a sum of value placements
-- (RadixSystem, eg: binary, base10) *OR* a concatination of fixed-width bit
-- groups (BitwiseSystem, eg: base64, base32).
module Data.BaseSystem
  ( Encoder,
    Decoder,
    BaseSystem (encoder, decoder),
  )
where

import Data.BaseSystem.Alphabet (Alphabet, alphaRadix)
import Data.BaseSystem.Alphabet qualified as Alpha
import Data.BaseSystem.Internal
import Data.Bits ((.&.), (.<<.), (.>>.), (.|.))
import Data.ByteString (ByteString)
import Data.ByteString qualified as Bytes
import Data.Maybe (fromJust, fromMaybe)
import Data.Text qualified as Text

-- | Type-class (interface) for implementing encode/decode functionality for
-- some data structure w/ Alphabet `a`.
class BaseSystem a where
  encoder :: a -> Encoder
  decoder :: a -> Decoder

-- | Function signature for encoding some ByteString to String of digits.
type Encoder = ByteString -> String

-- | Function signature for decoding some String of digits to ByteString, given
-- all digits in String are valid Symbols.
type Decoder = String -> Maybe ByteString

-- | Implements BaseSystem for number systems built by modular arithmetic w/ the
-- radix.
instance BaseSystem RadixSystem where
  encoder :: RadixSystem -> ByteString -> String
  encoder :: RadixSystem -> Encoder
encoder (RadixSystem String
_ Alphabet
abc) =
    let radix :: Value
radix = Int -> Value
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Value) -> Int -> Value
forall a b. (a -> b) -> a -> b
$ Alphabet -> Int
alphaRadix Alphabet
abc
     in (Text -> String) -> [Text] -> String
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap Text -> String
Text.unpack
          ([Text] -> String) -> (ByteString -> [Text]) -> Encoder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Value, Value)] -> [Text]
forall {a}. [(a, Value)] -> [Text]
divModSymbols
          ([(Value, Value)] -> [Text])
-> (ByteString -> [(Value, Value)]) -> ByteString -> [Text]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Value, Value) -> Bool) -> [(Value, Value)] -> [(Value, Value)]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile (Value, Value) -> Bool
forall {a} {a}. (Ord a, Ord a, Num a, Num a) => (a, a) -> Bool
divModContinue
          ([(Value, Value)] -> [(Value, Value)])
-> (ByteString -> [(Value, Value)])
-> ByteString
-> [(Value, Value)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Value, Value) -> (Value, Value))
-> (Value -> (Value, Value)) -> Value -> [(Value, Value)]
forall b a. (b -> b) -> (a -> b) -> a -> [b]
iterateInit (\(Value
num, Value
_) -> Value
num Value -> Value -> (Value, Value)
forall a. Integral a => a -> a -> (a, a)
`divMod` Value
radix) Value -> (Value, Value)
forall {b} {a}. Num b => a -> (a, b)
mkDivMod
          (Value -> [(Value, Value)])
-> (ByteString -> Value) -> ByteString -> [(Value, Value)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Value
bytesToInteger
    where
      -- Initial value to iterate on divMod.
      mkDivMod :: a -> (a, b)
mkDivMod a
numerator = (a
numerator, b
0)
      -- Predicates divMod iteration results.
      divModContinue :: (a, a) -> Bool
divModContinue (a
nextinput, a
curresult) = a
nextinput a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
0 Bool -> Bool -> Bool
|| a
curresult a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
0
      -- Resolves the symbols from the result of iterating divMod.
      divModSymbols :: [(a, Value)] -> [Text]
divModSymbols =
        [Text] -> [Text]
forall a. [a] -> [a]
reverse
          ([Text] -> [Text])
-> ([(a, Value)] -> [Text]) -> [(a, Value)] -> [Text]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe [Text] -> [Text]
forall a. HasCallStack => Maybe a -> a
fromJust
          (Maybe [Text] -> [Text])
-> ([(a, Value)] -> Maybe [Text]) -> [(a, Value)] -> [Text]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((a, Value) -> Maybe Text) -> [(a, Value)] -> Maybe [Text]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (\(a
_, Value
digit) -> Alphabet -> Value -> Maybe Text
Alpha.resolveSymbol Alphabet
abc Value
digit)
          ([(a, Value)] -> Maybe [Text])
-> ([(a, Value)] -> [(a, Value)]) -> [(a, Value)] -> Maybe [Text]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(a, Value)] -> [(a, Value)] -> [(a, Value)]
forall a. [a] -> [a] -> [a]
replaceNull [(a
forall a. HasCallStack => a
undefined, Value
0)]

  decoder :: RadixSystem -> String -> Maybe ByteString
  decoder :: RadixSystem -> Decoder
decoder (RadixSystem String
_ Alphabet
abc) String
str =
    let radix :: Value
radix = Int -> Value
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Value) -> Int -> Value
forall a b. (a -> b) -> a -> b
$ Alphabet -> Int
alphaRadix Alphabet
abc
     in FinalizeBits -> Text -> DecoderBuilder -> Maybe ByteString
binaryDecoder FinalizeBits
forall a. Maybe a
Nothing (String -> Text
Text.pack String
str) (DecoderBuilder -> Maybe ByteString)
-> DecoderBuilder -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$
          \Value
curvalue Text
symbol -> do
            Value
value <- Alphabet -> Text -> Maybe Value
Alpha.resolveValue Alphabet
abc Text
symbol
            Value -> Maybe Value
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return (Value -> Maybe Value) -> Value -> Maybe Value
forall a b. (a -> b) -> a -> b
$ Value
curvalue Value -> Value -> Value
forall a. Num a => a -> a -> a
* Value
radix Value -> Value -> Value
forall a. Num a => a -> a -> a
+ Value
value

-- | Resolves symbols from Alphabet for a BitwiseSystem's encoder. Partitions
-- a ByteString into N sized bitgroups where N is the bitwidth of the Alphabet's
-- radix. *IMPORTANT*: this function requires the groupsize to be a multiple of
-- two because it generates a mask from subtracting it by 1.
groupSymbols :: Alphabet -> Int -> Int -> ByteString -> [Alpha.Symbol]
groupSymbols :: Alphabet -> Int -> Int -> ByteString -> [Text]
groupSymbols Alphabet
abc Int
symbits Int
groupsize ByteString
groupbytes =
  let groupint :: Int
groupint =
        case ByteString -> Value
bytesToInteger ByteString
groupbytes of
          Value
groupnum
            | Int -> ByteString -> Value -> Bool
fitsGroup Int
groupsize ByteString
groupbytes Value
groupnum -> Value -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Value
groupnum
            | Bool
otherwise -> String -> Int
forall a. HasCallStack => String -> a
error String
"invalid group size"
   in -- Crash if the implementation isn't complete
      Maybe [Text] -> [Text]
forall a. HasCallStack => Maybe a -> a
fromJust
        -- Extract the value from the shift and resolve its symbol.
        (Maybe [Text] -> [Text])
-> ([Int] -> Maybe [Text]) -> [Int] -> [Text]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Maybe Text) -> [Int] -> Maybe [Text]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Alphabet -> Value -> Maybe Text
Alpha.resolveSymbol Alphabet
abc (Value -> Maybe Text) -> (Int -> Value) -> Int -> Maybe Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Value
forall {b}. Num b => Int -> b
valueExtract (Int -> Value) -> (Int -> Int) -> Int -> Value
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int -> Int
forall {a}. Bits a => a -> Int -> a
nextInt Int
groupint)
        -- Take all non-zero shifts.
        ([Int] -> Maybe [Text])
-> ([Int] -> [Int]) -> [Int] -> Maybe [Text]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> Bool) -> [Int] -> [Int]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile (Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0)
        -- Generate a list of shift values from the # bits in groupbytes.
        ([Int] -> [Text]) -> [Int] -> [Text]
forall a b. (a -> b) -> a -> b
$ (Int -> Int) -> (ByteString -> Int) -> ByteString -> [Int]
forall b a. (b -> b) -> (a -> b) -> a -> [b]
iterateInit Int -> Int
shiftValue ByteString -> Int
forall {b}. Num b => ByteString -> b
mkBitLength ByteString
groupbytes
  where
    -- Gets the next int to extract group.
    nextInt :: a -> Int -> a
nextInt a
groupint Int
shift = a
groupint a -> Int -> a
forall {a}. Bits a => a -> Int -> a
.>>. Int
shift
    -- Extracts the first group from LSB from an Int.
    valueExtract :: Int -> b
valueExtract Int
int = Int -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> b) -> Int -> b
forall a b. (a -> b) -> a -> b
$ Int
int Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. (Alphabet -> Int
alphaRadix Alphabet
abc Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    -- Finds the length in bits of a ByteString.
    mkBitLength :: ByteString -> b
mkBitLength ByteString
bstr = Int -> b
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> b) -> Int -> b
forall a b. (a -> b) -> a -> b
$ Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
* ByteString -> Int
Bytes.length ByteString
bstr
    -- Gives the current shift value in iteration.
    shiftValue :: Int -> Int
shiftValue Int
bitstotal = Int
bitstotal Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
symbits

-- | Implements encoder/decoder for a BitwiseSystem.
instance BaseSystem BitwiseSystem where
  encoder :: BitwiseSystem -> ByteString -> String
  encoder :: BitwiseSystem -> Encoder
encoder (BitwiseSystem String
_ Alphabet
abc Int
symbits Int
groupbytes Int
_ Maybe PaddingMethod
padmethod) ByteString
input =
    let -- Total # of bytes from input.
        bytestotal :: Int
bytestotal = ByteString -> Int
Bytes.length ByteString
input
        -- Total # of bits from input.
        bitstotal :: Double
bitstotal = Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
bytestotal) :: Double
        -- Actual # of symbols for the # of bits in bytes.
        numsymbols :: Int
numsymbols = Double -> Int
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
ceiling (Double
bitstotal Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
symbits)
        -- Minimum # of symbols to put in the resulting string.
        putsymbols :: Int
putsymbols = Int -> Int -> Int
forall a. Ord a => a -> a -> a
max Int
2 Int
numsymbols
        -- Result from applying the padding method to the # of bytes.
        padapply :: Maybe String
padapply = PaddingMethod -> Int -> String
paddingResolve (PaddingMethod -> Int -> String)
-> Maybe PaddingMethod -> Maybe (Int -> String)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe PaddingMethod
padmethod Maybe (Int -> String) -> Maybe Int -> Maybe String
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> Maybe Int
forall a. a -> Maybe a
Just Int
bytestotal
     in (String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> Maybe String -> String
forall a. a -> Maybe a -> a
fromMaybe String
"" Maybe String
padapply)
          (String -> String) -> Encoder -> Encoder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> String -> String
forall a. Int -> [a] -> [a]
take Int
putsymbols
          (String -> String) -> Encoder -> Encoder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((ByteString, ByteString) -> String)
-> [(ByteString, ByteString)] -> String
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap (\(ByteString
group, ByteString
_) -> Encoder
groupString ByteString
group)
          ([(ByteString, ByteString)] -> String)
-> (ByteString -> [(ByteString, ByteString)]) -> Encoder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((ByteString, ByteString) -> Bool)
-> [(ByteString, ByteString)] -> [(ByteString, ByteString)]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile (\(ByteString
group, ByteString
_) -> ByteString -> Int
Bytes.length ByteString
group Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
0)
          ([(ByteString, ByteString)] -> [(ByteString, ByteString)])
-> (ByteString -> [(ByteString, ByteString)])
-> ByteString
-> [(ByteString, ByteString)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((ByteString, ByteString) -> (ByteString, ByteString))
-> (ByteString -> (ByteString, ByteString))
-> ByteString
-> [(ByteString, ByteString)]
forall b a. (b -> b) -> (a -> b) -> a -> [b]
iterateInit (\(ByteString
_, ByteString
rest) -> Int -> ByteString -> (ByteString, ByteString)
Bytes.splitAt Int
groupbytes ByteString
rest) ByteString -> (ByteString, ByteString)
forall {b}. b -> (ByteString, b)
mkSplit
          Encoder -> Encoder
forall a b. (a -> b) -> a -> b
$ Int -> ByteString
minimalBytes Int
bytestotal
    where
      -- Inits the iteration for Bytes.splitAt.
      mkSplit :: b -> (ByteString, b)
mkSplit b
initial = (ByteString
Bytes.empty, b
initial)
      -- Creates a string of symbols from a ByteString splitAt.
      groupString :: Encoder
groupString =
        Text -> String
Text.unpack (Text -> String) -> (ByteString -> Text) -> Encoder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Text] -> Text
Text.concat ([Text] -> Text) -> (ByteString -> [Text]) -> ByteString -> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Alphabet -> Int -> Int -> ByteString -> [Text]
groupSymbols Alphabet
abc Int
symbits Int
groupbytes
      -- Gives the minimal amount of bytes to produce a correct encoding.
      {-# INLINE minimalBytes #-}
      minimalBytes :: Int -> ByteString
minimalBytes Int
bytestotal
        | Int
bytestotal Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = [Word8] -> ByteString
Bytes.pack [Word8
0, Word8
0]
        | Int
bytestotal Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = ByteString -> Word8 -> ByteString
Bytes.snoc ByteString
input Word8
0
        | Int
bytesmodulus Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
0 = ByteString -> ByteString -> ByteString
Bytes.append ByteString
input (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Int -> Word8 -> ByteString
Bytes.replicate Int
zeros Word8
0
        | Bool
otherwise = ByteString
input
        where
          bytesmodulus :: Int
bytesmodulus = Int
bytestotal Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
groupbytes
          zeros :: Int
zeros = Int
groupbytes Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
bytesmodulus

  decoder :: BitwiseSystem -> String -> Maybe ByteString
  decoder :: BitwiseSystem -> Decoder
decoder (BitwiseSystem String
_ Alphabet
abc Int
symbits Int
_ Int
groupsyms Maybe PaddingMethod
pm) String
str =
    let -- Gives just padding char if not nothing.
        padchar :: Maybe Char
padchar = PaddingMethod -> Char
paddingChar (PaddingMethod -> Char) -> Maybe PaddingMethod -> Maybe Char
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe PaddingMethod
pm
        -- Creates symbol from maybe padding char.
        padsymbol :: Maybe Text
padsymbol = Char -> Text
Text.singleton (Char -> Text) -> Maybe Char -> Maybe Text
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe Char
padchar
        -- Removes the trailing padding chars from the Text of str.
        nopad :: Text
nopad = (Char -> Bool) -> Text -> Text
Text.dropWhileEnd (\Char
sym -> Char -> Maybe Char
forall a. a -> Maybe a
Just Char
sym Maybe Char -> Maybe Char -> Bool
forall a. Eq a => a -> a -> Bool
== Maybe Char
padchar) (Text -> Text) -> Text -> Text
forall a b. (a -> b) -> a -> b
$ String -> Text
Text.pack String
str
        -- Gives the total available characters.
        totalsyms :: Int
totalsyms = Text -> Int
Text.length Text
nopad
        -- Finds the needs # of symbols for the final padding chars.
        needsyms :: Int
needsyms = Int
groupsyms Int -> Int -> Int
forall a. Num a => a -> a -> a
- (Int
totalsyms Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
groupsyms)
        -- Gives the number of bits to correct for missing final padding chars.
        needbits :: Int
needbits = Int
needsyms Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
symbits
        -- Offset to Integer to byte-align the bits in the final ByteString.
        bytealign :: Int
bytealign = Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Double -> Int
forall b. Integral b => Double -> b
forall a b. (RealFrac a, Integral b) => a -> b
ceiling (Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
needbits Double -> Double -> Double
forall a. Fractional a => a -> a -> a
/ Double
8 :: Double)
     in FinalizeBits -> Text -> DecoderBuilder -> Maybe ByteString
binaryDecoder (Int -> Int -> FinalizeBits
forall {a}. Bits a => Int -> Int -> Maybe (a -> a)
finalizeBits Int
needsyms Int
bytealign) Text
nopad (DecoderBuilder -> Maybe ByteString)
-> DecoderBuilder -> Maybe ByteString
forall a b. (a -> b) -> a -> b
$
          \Value
curvalue Text
symbol -> do
            Value
value <- Alphabet -> Text -> Maybe Value
Alpha.resolveValue Alphabet
abc Text
symbol
            Value -> Maybe Value
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return (Value -> Maybe Value) -> Value -> Maybe Value
forall a b. (a -> b) -> a -> b
$
              if Text -> Maybe Text
forall a. a -> Maybe a
Just Text
symbol Maybe Text -> Maybe Text -> Bool
forall a. Eq a => a -> a -> Bool
/= Maybe Text
padsymbol
                then (Value
curvalue Value -> Int -> Value
forall {a}. Bits a => a -> Int -> a
.<<. Int
symbits) Value -> Value -> Value
forall a. Bits a => a -> a -> a
.|. Value
value
                else Value
curvalue Value -> Int -> Value
forall {a}. Bits a => a -> Int -> a
.<<. Int
symbits
    where
      -- Finds the value for the expected amount of padding regardless if the
      -- BaseSystem has a PaddingMethod or not.
      {-# INLINE finalizeBits #-}
      finalizeBits :: Int -> Int -> Maybe (a -> a)
finalizeBits Int
needsyms Int
align
        -- If the # of Symbols in last group is the groupsyms size, pass
        | Int
needsyms Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
groupsyms = Maybe (a -> a)
forall a. Maybe a
Nothing
        -- Find the needed # of bits and align the output to first byte.
        | Bool
otherwise = (a -> a) -> Maybe (a -> a)
forall a. a -> Maybe a
Just ((a -> a) -> Maybe (a -> a)) -> (a -> a) -> Maybe (a -> a)
forall a b. (a -> b) -> a -> b
$ \a
finalvalue ->
            a
finalvalue a -> Int -> a
forall {a}. Bits a => a -> Int -> a
.<<. (Int
symbits Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
needsyms) a -> Int -> a
forall {a}. Bits a => a -> Int -> a
.>>. Int
align