{-# LANGUAGE ScopedTypeVariables #-}

module Data.Packed.TH.Case (caseFName, genCase) where

import Data.Packed.Reader hiding (return)
import Data.Packed.TH.Flag
import Data.Packed.TH.Utils (Tag, getBranchesTyList, getNameAndBangTypesFromCon, resolveAppliedType, sanitizeConName)
import Language.Haskell.TH

caseFName :: Name -> Name
caseFName :: Name -> Name
caseFName Name
tyName = String -> Name
mkName (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ String
"case" String -> String -> String
forall a. [a] -> [a] -> [a]
++ Name -> String
sanitizeConName Name
tyName

-- | Generates a function to allow pattern matching a packed data type using the data constructors
--
--  __Example:__
--
-- For the 'Tree' data type, it generates the following function:
--
-- @
-- caseTree ::
--     ('Data.Packed.PackedReader' '[a] r b) ->
--     ('Data.Packed.PackedReader' '[Tree a, Tree a] r b) ->
--     'Data.Packed.PackedReader' '[Tree a] r b
-- caseTree leafCase nodeCase = 'Data.Packed.Reader.mkPackedReader' $ \packed l -> do
--    (tag :: 'Tag', packed1, l1) <- 'Data.Packed.Unpackable.runReader' 'Data.Packed.reader' packed l
--    case tag of
--        0 -> 'Data.Packed.Reader.runReader' leafCase packed1 l1
--        1 -> 'Data.Packed.Reader.runReader' nodeCase packed1 l1
--        _ -> fail "Bad Tag"
-- @
genCase ::
    [PackingFlag] ->
    -- | The name of the type to generate the function for
    Name ->
    Q [Dec]
genCase :: [PackingFlag] -> Name -> Q [Dec]
genCase [PackingFlag]
flags Name
tyName = do
    (TyConI (DataD _ _ _ _ cs _)) <- Name -> Q Info
reify Name
tyName
    packedName <- newName "packed"
    -- For each data constructor, we build names for the pattern for the case functions
    -- Example: leafCase, nodeCase, etc.
    let casePatterns = Con -> Name
buildCaseFunctionName (Con -> Name) -> [Con] -> [Name]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Con]
cs
    body <- buildBody casePatterns packedName
    signature <- genCaseSignature flags tyName
    return
        [ signature
        , FunD
            (caseFName tyName)
            [Clause (VarP <$> casePatterns) (NormalB body) []]
        ]
  where
    -- Build the body (the do, binding and case expressions)
    buildBody :: [Name] -> Name -> Q Exp
buildBody [Name]
casePatterns Name
packedName =
        let bytes1VarName :: Name
bytes1VarName = String -> Name
mkName String
"b"
            length1VarName :: Name
length1VarName = String -> Name
mkName String
"l"
            flagVarName :: Name
flagVarName = String -> Name
mkName String
"flag"
         in do
                caseExpression <- Name -> [Name] -> Name -> Name -> Q Exp
buildCaseExpression Name
flagVarName [Name]
casePatterns Name
bytes1VarName Name
length1VarName
                [|
                    mkPackedReader $ \($(varP packedName)) l' -> do
                        ($(varP flagVarName), $(varP bytes1VarName), $(varP length1VarName)) <- runPackedReader reader $(varE packedName) l'
                        $(return caseExpression)
                    |]
    -- for dataconstructor Leaf, will be 'leafCase'
    buildCaseFunctionName :: Con -> Name
buildCaseFunctionName = Name -> Name
conNameToCaseFunctionName (Name -> Name) -> (Con -> Name) -> Con -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Name, [BangType]) -> Name
forall a b. (a, b) -> a
fst ((Name, [BangType]) -> Name)
-> (Con -> (Name, [BangType])) -> Con -> Name
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Con -> (Name, [BangType])
getNameAndBangTypesFromCon
    conNameToCaseFunctionName :: Name -> Name
conNameToCaseFunctionName Name
conName = String -> Name
mkName (String -> Name) -> String -> Name
forall a b. (a -> b) -> a -> b
$ Char
'c' Char -> String -> String
forall a. a -> [a] -> [a]
: (Name -> String
sanitizeConName Name
conName) String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"Case"

    -- Build the case .. of ... expression using the list of available xxxCase, the flag and bytestring
    buildCaseExpression :: Name -> [Name] -> Name -> Name -> Q Exp
    buildCaseExpression :: Name -> [Name] -> Name -> Name -> Q Exp
buildCaseExpression Name
e [Name]
casePatterns Name
bytesVarName Name
lengthVarName =
        -- For each xxxCase, we build a branch for the case expression
        let matches :: [Q Match]
matches =
                ( \(Integer
conIndex, Name
caseFuncName) -> do
                    body <- [|runPackedReader $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
caseFuncName) $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
bytesVarName) $(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
lengthVarName)|]
                    return $ Match (LitP $ IntegerL conIndex) (NormalB body) []
                )
                    ((Integer, Name) -> Q Match) -> [(Integer, Name)] -> [Q Match]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Integer] -> [Name] -> [(Integer, Name)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Integer
0 ..] [Name]
casePatterns
            fallbackMatch :: Q Match
fallbackMatch = do
                fallbackBody <- [|Prelude.fail "Bad Tag"|]
                return $ Match WildP (NormalB fallbackBody) []
         in Q Exp -> [Q Match] -> Q Exp
forall (m :: * -> *). Quote m => m Exp -> [m Match] -> m Exp
caseE [|$(Name -> Q Exp
forall (m :: * -> *). Quote m => Name -> m Exp
varE Name
e) :: Tag|] ([Q Match] -> Q Exp) -> [Q Match] -> Q Exp
forall a b. (a -> b) -> a -> b
$ [Q Match]
matches [Q Match] -> [Q Match] -> [Q Match]
forall a. [a] -> [a] -> [a]
++ [Q Match
fallbackMatch]

-- For a type 'Tree', generates the following signature
-- caseTree ::
--     ('Data.Packed.PackedReader' '[a] r b) ->
--     ('Data.Packed.PackedReader' '[Tree a, Tree a] r b) ->
--     'Data.Packed.PackedReader' '[Tree a] r b
genCaseSignature :: [PackingFlag] -> Name -> Q Dec
genCaseSignature :: [PackingFlag] -> Name -> Q Dec
genCaseSignature [PackingFlag]
flags Name
tyName = do
    (sourceType, _) <- Name -> Q (Kind, [Name])
resolveAppliedType Name
tyName
    bVar <- newName "b"
    rVar <- newName "r"
    branchesTypes <- getBranchesTyList tyName flags
    let
        bType = Name -> Q Kind
forall (m :: * -> *). Quote m => Name -> m Kind
varT Name
bVar
        rType = Name -> Q Kind
forall (m :: * -> *). Quote m => Name -> m Kind
varT Name
rVar
        lambdaTypes = (\Cxt
branchTypes -> Cxt -> Q Kind -> Q Kind -> Q Kind
buildLambdaType Cxt
branchTypes Q Kind
bType Q Kind
rType) (Cxt -> Q Kind) -> [Cxt] -> [Q Kind]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Cxt]
branchesTypes
        outType = [t|PackedReader '[$(Kind -> Q Kind
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Kind
sourceType)] $Q Kind
rType $Q Kind
bType|]
    signature <- foldr (\Q Kind
lambda Q Kind
out -> [t|$Q Kind
lambda -> $Q Kind
out|]) outType lambdaTypes
    return $ SigD (caseFName tyName) signature
  where
    -- From a constructor (say Leaf a), build type PackedReader '[a] r b
    buildLambdaType :: [Type] -> Q Type -> Q Type -> Q Type
    buildLambdaType :: Cxt -> Q Kind -> Q Kind -> Q Kind
buildLambdaType Cxt
branchType Q Kind
returnType Q Kind
restType = do
        let branchTypeList :: Q Kind
branchTypeList = (Kind -> Q Kind -> Q Kind) -> Q Kind -> Cxt -> Q Kind
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Kind
a Q Kind
rest -> [t|$(Kind -> Q Kind
forall a. a -> Q a
forall (m :: * -> *) a. Monad m => a -> m a
return Kind
a) ': $Q Kind
rest|]) [t|'[]|] Cxt
branchType
        [t|PackedReader $Q Kind
branchTypeList $Q Kind
restType $Q Kind
returnType|]