module Data.Packed.TH.Utils (
    Tag,
    getParentTypeFromConstructorType,
    resolveAppliedType,
    getNameAndBangTypesFromCon,
    sanitizeConName,
    getBranchesTyList,
    getBranchTyList,
    typeIsFieldSize,
    getConFieldsIdxAndNeedsFS,
) where

import Control.Monad
import Data.Char
import Data.Functor
import Data.Packed.FieldSize (FieldSize)
import Data.Packed.TH.Flag
import Data.Word (Word8)
import Language.Haskell.TH

-- | Byte in a 'Data.Packed' value to identify which data constructor is serialised
type Tag = Word8

getParentTypeFromConstructorType :: Type -> Type
getParentTypeFromConstructorType :: Type -> Type
getParentTypeFromConstructorType (ForallT [TyVarBndr Specificity]
_ Cxt
_ Type
t) = Type -> Type
getParentTypeFromConstructorType Type
t
getParentTypeFromConstructorType t :: Type
t@(AppT Type
_ (VarT Name
_)) = Type
t
getParentTypeFromConstructorType (AppT Type
_ Type
t) = Type -> Type
getParentTypeFromConstructorType Type
t
getParentTypeFromConstructorType Type
x = Type
x

-- From a type, returns the fully applied type with type variables' names
-- For a type 'Tree', will return (Tree a, [a])
resolveAppliedType :: Name -> Q (Type, [Name])
resolveAppliedType :: Name -> Q (Type, [Name])
resolveAppliedType Name
tyName = do
    (TyConI (DataD _ _ boundTypeVar _ _ _)) <- Name -> Q Info
reify Name
tyName
    -- Extract already existing type names from types variables bound to source type
    let typeParameterNames =
            ( \case
                (KindedTV Name
n BndrVis
_ Type
_) -> Name
n
                TyVarBndr BndrVis
x -> [Char] -> Name
forall a. HasCallStack => [Char] -> a
error ([Char] -> Name) -> [Char] -> Name
forall a b. (a -> b) -> a -> b
$ [Char]
"unhandled type parameter" [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ TyVarBndr BndrVis -> [Char]
forall a. Show a => a -> [Char]
show TyVarBndr BndrVis
x
            )
                (TyVarBndr BndrVis -> Name) -> [TyVarBndr BndrVis] -> [Name]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [TyVarBndr BndrVis]
boundTypeVar
    -- Builds back 'Tree a' using type variable names (fold by applying each of them to the source type name)
    sourceType <- foldl (\Q Type
ty Name
par -> [t|$Q Type
ty $(Name -> Q Type
forall (m :: * -> *). Quote m => Name -> m Type
varT Name
par)|]) (conT tyName) typeParameterNames
    return (sourceType, typeParameterNames)

getNameAndBangTypesFromCon :: Con -> (Name, [BangType])
getNameAndBangTypesFromCon :: Con -> (Name, [BangType])
getNameAndBangTypesFromCon (NormalC Name
name [BangType]
bt) = (Name
name, [BangType]
bt)
getNameAndBangTypesFromCon (RecC Name
name [VarBangType]
nbt) = (Name
name, (\(Name
_, Bang
b, Type
t) -> (Bang
b, Type
t)) (VarBangType -> BangType) -> [VarBangType] -> [BangType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VarBangType]
nbt)
getNameAndBangTypesFromCon (InfixC BangType
bt1 Name
name BangType
bt2) = (Name
name, [BangType
bt1, BangType
bt2])
getNameAndBangTypesFromCon (ForallC [TyVarBndr Specificity]
_ Cxt
_ Con
con) = Con -> (Name, [BangType])
getNameAndBangTypesFromCon Con
con
getNameAndBangTypesFromCon (GadtC (Name
name : [Name]
_) [BangType]
bt Type
_) = (Name
name, [BangType]
bt)
getNameAndBangTypesFromCon (RecGadtC (Name
name : [Name]
_) [VarBangType]
nbt Type
_) = (Name
name, (\(Name
_, Bang
b, Type
t) -> (Bang
b, Type
t)) (VarBangType -> BangType) -> [VarBangType] -> [BangType]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [VarBangType]
nbt)
getNameAndBangTypesFromCon Con
x = [Char] -> (Name, [BangType])
forall a. HasCallStack => [Char] -> a
error ([Char] -> (Name, [BangType])) -> [Char] -> (Name, [BangType])
forall a b. (a -> b) -> a -> b
$ [Char]
"unhandled data constructor: " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Con -> [Char]
forall a. Show a => a -> [Char]
show Con
x

-- | Sanitize constructor name so that it can be used as a symbol name
sanitizeConName :: Name -> String
sanitizeConName :: Name -> [Char]
sanitizeConName Name
conName = [Char] -> [Char]
strName ([Char] -> [Char]) -> [Char] -> [Char]
forall a b. (a -> b) -> a -> b
$ Name -> [Char]
nameBase Name
conName
  where
    strName :: [Char] -> [Char]
strName [Char]
s = (\Char
c -> if Char -> Bool
isAlphaNum Char
c then [Char
c] else Int -> [Char]
forall a. Show a => a -> [Char]
show (Int -> [Char]) -> Int -> [Char]
forall a b. (a -> b) -> a -> b
$ Char -> Int
ord Char
c) (Char -> [Char]) -> [Char] -> [Char]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Char]
s

-- | for a given type, and the packing flags, gives back the list of types for each branch
--
-- @
--  getBranchesTyList ''Tree ['InsertFieldSize']
--
--  > [[FieldSize, Int], [FieldSize, Tree a, FieldSize, Tree a]]
-- @
getBranchesTyList :: Name -> [PackingFlag] -> Q [[Type]]
getBranchesTyList :: Name -> [PackingFlag] -> Q [Cxt]
getBranchesTyList Name
tyName [PackingFlag]
flags = do
    (TyConI (DataD _ _ _ _ cs _)) <- Name -> Q Info
reify Name
tyName
    forM cs (`getBranchTyList` flags)

getBranchTyList :: Con -> [PackingFlag] -> Q [Type]
getBranchTyList :: Con -> [PackingFlag] -> Q Cxt
getBranchTyList Con
con [PackingFlag]
flags = do
    fields <- [(Type, Int, Bool)] -> ((Type, Int, Bool) -> Q Cxt) -> Q [Cxt]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
t a -> (a -> m b) -> m (t b)
forM (Con -> [PackingFlag] -> [(Type, Int, Bool)]
getConFieldsIdxAndNeedsFS Con
con [PackingFlag]
flags) (((Type, Int, Bool) -> Q Cxt) -> Q [Cxt])
-> ((Type, Int, Bool) -> Q Cxt) -> Q [Cxt]
forall a b. (a -> b) -> a -> b
$ \(Type
fieldTy, Int
_, Bool
needsFS) ->
        if Bool
needsFS
            then [t|FieldSize|] Q Type -> (Type -> Cxt) -> Q Cxt
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> (Type -> Cxt -> Cxt
forall a. a -> [a] -> [a]
: [Type
fieldTy])
            else Cxt -> Q Cxt
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return [Type
fieldTy]
    return $ concat fields

getConFieldsIdxAndNeedsFS :: Con -> [PackingFlag] -> [(Type, Int, Bool)]
getConFieldsIdxAndNeedsFS :: Con -> [PackingFlag] -> [(Type, Int, Bool)]
getConFieldsIdxAndNeedsFS Con
con [PackingFlag]
flags =
    [(Int, Type)]
consValueTypesWithIndex [(Int, Type)]
-> ((Int, Type) -> (Type, Int, Bool)) -> [(Type, Int, Bool)]
forall (f :: * -> *) a b. Functor f => f a -> (a -> b) -> f b
<&> \(Int
valIdx, Type
valTy) ->
        if (PackingFlag
InsertFieldSize PackingFlag -> [PackingFlag] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [PackingFlag]
flags)
            Bool -> Bool -> Bool
&& (PackingFlag
SkipLastFieldSize PackingFlag -> [PackingFlag] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`notElem` [PackingFlag]
flags Bool -> Bool -> Bool
|| (PackingFlag
SkipLastFieldSize PackingFlag -> [PackingFlag] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [PackingFlag]
flags Bool -> Bool -> Bool
&& Int
valIdx Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
consValueCount Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
            then (Type
valTy, Int
valIdx, Bool
True)
            else (Type
valTy, Int
valIdx, Bool
False)
  where
    consValueTypes :: Cxt
consValueTypes = BangType -> Type
forall a b. (a, b) -> b
snd (BangType -> Type) -> [BangType] -> Cxt
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Name, [BangType]) -> [BangType]
forall a b. (a, b) -> b
snd (Con -> (Name, [BangType])
getNameAndBangTypesFromCon Con
con)
    consValueCount :: Int
consValueCount = Cxt -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Cxt
consValueTypes
    consValueTypesWithIndex :: [(Int, Type)]
consValueTypesWithIndex = [Int] -> Cxt -> [(Int, Type)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Int
0 .. Cxt -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length Cxt
consValueTypes] Cxt
consValueTypes

typeIsFieldSize :: Type -> Bool
typeIsFieldSize :: Type -> Bool
typeIsFieldSize = (Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
== Name -> Type
ConT ''FieldSize)