{-# LANGUAGE CPP #-}
#ifdef TRUSTWORTHY
{-# LANGUAGE Trustworthy #-}
#endif
#include "lens-common.h"
module Control.Lens.Internal.PrismTH
  ( makePrisms
  , makeClassyPrisms
  , makeDecPrisms
  ) where
import Control.Applicative
import Control.Lens.Fold
import Control.Lens.Getter
import Control.Lens.Internal.TH
import Control.Lens.Lens
import Control.Lens.Setter
import Control.Monad
import Data.Char (isUpper)
import Data.List
import Data.Set.Lens
import Data.Traversable
import Language.Haskell.TH
import qualified Language.Haskell.TH.Datatype as D
import Language.Haskell.TH.Lens
import qualified Data.Map as Map
import qualified Data.Set as Set
import Prelude
makePrisms :: Name  -> DecsQ
makePrisms = makePrisms' True
makeClassyPrisms :: Name  -> DecsQ
makeClassyPrisms = makePrisms' False
makePrisms' :: Bool -> Name -> DecsQ
makePrisms' normal typeName =
  do info <- D.reifyDatatype typeName
     let cls | normal    = Nothing
             | otherwise = Just (D.datatypeName info)
         cons = D.datatypeCons info
     makeConsPrisms (D.datatypeType info) (map normalizeCon cons) cls
makeDecPrisms :: Bool  -> Dec -> DecsQ
makeDecPrisms normal dec =
  do info <- D.normalizeDec dec
     let cls | normal    = Nothing
             | otherwise = Just (D.datatypeName info)
         cons = D.datatypeCons info
     makeConsPrisms (D.datatypeType info) (map normalizeCon cons) cls
makeConsPrisms :: Type -> [NCon] -> Maybe Name -> DecsQ
makeConsPrisms t [con@(NCon _ [] [] _)] Nothing = makeConIso t con
makeConsPrisms t cons Nothing =
  fmap concat $ for cons $ \con ->
    do let conName = view nconName con
       stab <- computeOpticType t cons con
       let n = prismName conName
       sequenceA
         ( [ sigD n (close (stabToType stab))
           , valD (varP n) (normalB (makeConOpticExp stab cons con)) []
           ]
           ++ inlinePragma n
         )
makeConsPrisms t cons (Just typeName) =
  sequenceA
    [ makeClassyPrismClass t className methodName cons
    , makeClassyPrismInstance t className methodName cons
    ]
  where
  typeNameBase = nameBase typeName
  className = mkName ("As" ++ typeNameBase)
  sameNameAsCon = any (\con -> nameBase (view nconName con) == typeNameBase) cons
  methodName = prismName' sameNameAsCon typeName
data OpticType = PrismType | ReviewType
data Stab  = Stab Cxt OpticType Type Type Type Type
simplifyStab :: Stab -> Stab
simplifyStab (Stab cx ty _ t _ b) = Stab cx ty t t b b
  
  
stabSimple :: Stab -> Bool
stabSimple (Stab _ _ s t a b) = s == t && a == b
stabToType :: Stab -> Type
stabToType stab@(Stab cx ty s t a b) = ForallT vs cx $
  case ty of
    PrismType  | stabSimple stab -> prism'TypeName  `conAppsT` [t,b]
               | otherwise       -> prismTypeName   `conAppsT` [s,t,a,b]
    ReviewType                   -> reviewTypeName  `conAppsT` [t,b]
  where
  vs = map PlainTV
     $ nub 
     $ toListOf typeVars cx
stabType :: Stab -> OpticType
stabType (Stab _ o _ _ _ _) = o
computeOpticType :: Type -> [NCon] -> NCon -> Q Stab
computeOpticType t cons con =
  do let cons' = delete con cons
     if null (_nconVars con)
         then computePrismType t (view nconCxt con) cons' con
         else computeReviewType t (view nconCxt con) (view nconTypes con)
computeReviewType :: Type -> Cxt -> [Type] -> Q Stab
computeReviewType s' cx tys =
  do let t = s'
     s <- fmap VarT (newName "s")
     a <- fmap VarT (newName "a")
     b <- toTupleT (map return tys)
     return (Stab cx ReviewType s t a b)
computePrismType :: Type -> Cxt -> [NCon] -> NCon -> Q Stab
computePrismType t cx cons con =
  do let ts      = view nconTypes con
         unbound = setOf typeVars t Set.\\ setOf typeVars cons
     sub <- sequenceA (fromSet (newName . nameBase) unbound)
     b   <- toTupleT (map return ts)
     a   <- toTupleT (map return (substTypeVars sub ts))
     let s = substTypeVars sub t
     return (Stab cx PrismType s t a b)
computeIsoType :: Type -> [Type] -> TypeQ
computeIsoType t' fields =
  do sub <- sequenceA (fromSet (newName . nameBase) (setOf typeVars t'))
     let t = return                    t'
         s = return (substTypeVars sub t')
         b = toTupleT (map return                    fields)
         a = toTupleT (map return (substTypeVars sub fields))
#ifndef HLINT
         ty | Map.null sub = appsT (conT iso'TypeName) [t,b]
            | otherwise    = appsT (conT isoTypeName) [s,t,a,b]
#endif
     close =<< ty
makeConOpticExp :: Stab -> [NCon] -> NCon -> ExpQ
makeConOpticExp stab cons con =
  case stabType stab of
    PrismType  -> makeConPrismExp stab cons con
    ReviewType -> makeConReviewExp con
makeConIso :: Type -> NCon -> DecsQ
makeConIso s con =
  do let ty      = computeIsoType s (view nconTypes con)
         defName = prismName (view nconName con)
     sequenceA
       ( [ sigD       defName  ty
         , valD (varP defName) (normalB (makeConIsoExp con)) []
         ] ++
         inlinePragma defName
       )
makeConPrismExp ::
  Stab ->
  [NCon]  ->
  NCon    ->
  ExpQ
makeConPrismExp stab cons con = appsE [varE prismValName, reviewer, remitter]
  where
  ts = view nconTypes con
  fields  = length ts
  conName = view nconName con
  reviewer                   = makeReviewer       conName fields
  remitter | stabSimple stab = makeSimpleRemitter conName (length cons) fields
           | otherwise       = makeFullRemitter cons conName
makeConIsoExp :: NCon -> ExpQ
makeConIsoExp con = appsE [varE isoValName, remitter, reviewer]
  where
  conName = view nconName con
  fields  = length (view nconTypes con)
  reviewer = makeReviewer    conName fields
  remitter = makeIsoRemitter conName fields
makeConReviewExp :: NCon -> ExpQ
makeConReviewExp con = appE (varE untoValName) reviewer
  where
  conName = view nconName con
  fields  = length (view nconTypes con)
  reviewer = makeReviewer conName fields
makeReviewer :: Name -> Int -> ExpQ
makeReviewer conName fields =
  do xs <- newNames "x" fields
     lam1E (toTupleP (map varP xs))
           (conE conName `appsE1` map varE xs)
makeSimpleRemitter ::
  Name  ->
  Int   ->
  Int   ->
  ExpQ
makeSimpleRemitter conName numCons fields =
  do x  <- newName "x"
     xs <- newNames "y" fields
     let matches =
           [ match (conP conName (map varP xs))
                   (normalB (appE (conE rightDataName) (toTupleE (map varE xs))))
                   []
           ] ++
           [ match wildP (normalB (appE (conE leftDataName) (varE x))) []
           | numCons > 1 
                         
           ]
     lam1E (varP x) (caseE (varE x) matches)
makeFullRemitter :: [NCon] -> Name -> ExpQ
makeFullRemitter cons target =
  do x <- newName "x"
     lam1E (varP x) (caseE (varE x) (map mkMatch cons))
  where
  mkMatch (NCon conName _ _ n) =
    do xs <- newNames "y" (length n)
       match (conP conName (map varP xs))
             (normalB
               (if conName == target
                  then appE (conE rightDataName) (toTupleE (map varE xs))
                  else appE (conE leftDataName) (conE conName `appsE1` map varE xs)))
             []
makeIsoRemitter :: Name -> Int -> ExpQ
makeIsoRemitter conName fields =
  do xs <- newNames "x" fields
     lam1E (conP conName (map varP xs))
           (toTupleE (map varE xs))
makeClassyPrismClass ::
  Type    ->
  Name    ->
  Name    ->
  [NCon]  ->
  DecQ
makeClassyPrismClass t className methodName cons =
  do r <- newName "r"
#ifndef HLINT
     let methodType = appsT (conT prism'TypeName) [varT r,return t]
#endif
     methodss <- traverse (mkMethod (VarT r)) cons'
     classD (cxt[]) className (map PlainTV (r : vs)) (fds r)
       ( sigD methodName methodType
       : map return (concat methodss)
       )
  where
  mkMethod r con =
    do Stab cx o _ _ _ b <- computeOpticType t cons con
       let stab' = Stab cx o r r b b
           defName = view nconName con
           body    = appsE [varE composeValName, varE methodName, varE defName]
       sequenceA
         [ sigD defName        (return (stabToType stab'))
         , valD (varP defName) (normalB body) []
         ]
  cons'         = map (over nconName prismName) cons
  vs            = Set.toList (setOf typeVars t)
  fds r
    | null vs   = []
    | otherwise = [FunDep [r] vs]
makeClassyPrismInstance ::
  Type ->
  Name      ->
  Name      ->
  [NCon]  ->
  DecQ
makeClassyPrismInstance s className methodName cons =
  do let vs = Set.toList (setOf typeVars s)
         cls = className `conAppsT` (s : map VarT vs)
     instanceD (cxt[]) (return cls)
       (   valD (varP methodName)
                (normalB (varE idValName)) []
       : [ do stab <- computeOpticType s cons con
              let stab' = simplifyStab stab
              valD (varP (prismName conName))
                (normalB (makeConOpticExp stab' cons con)) []
           | con <- cons
           , let conName = view nconName con
           ]
       )
data NCon = NCon
  { _nconName :: Name
  , _nconVars :: [Name]
  , _nconCxt  :: Cxt
  , _nconTypes :: [Type]
  }
  deriving (Eq)
instance HasTypeVars NCon where
  typeVarsEx s f (NCon x vars y z) = NCon x vars <$> typeVarsEx s' f y <*> typeVarsEx s' f z
    where s' = foldl' (flip Set.insert) s vars
nconName :: Lens' NCon Name
nconName f x = fmap (\y -> x {_nconName = y}) (f (_nconName x))
nconCxt :: Lens' NCon Cxt
nconCxt f x = fmap (\y -> x {_nconCxt = y}) (f (_nconCxt x))
nconTypes :: Lens' NCon [Type]
nconTypes f x = fmap (\y -> x {_nconTypes = y}) (f (_nconTypes x))
normalizeCon :: D.ConstructorInfo -> NCon
normalizeCon info = NCon (D.constructorName info)
                         (D.tvName <$> D.constructorVars info)
                         (D.constructorContext info)
                         (D.constructorFields info)
prismName :: Name -> Name
prismName = prismName' False
prismName' ::
  Bool  ->
  Name  ->
  Name 
prismName' sameNameAsCon n =
  case nameBase n of
    [] -> error "prismName: empty name base?"
    nb@(x:_) | isUpper x -> mkName (prefix '_' nb)
             | otherwise -> mkName (prefix '.' nb) 
  where
    prefix :: Char -> String -> String
    prefix char str | sameNameAsCon = char:char:str
                    | otherwise     =      char:str
close :: Type -> TypeQ
close t = forallT (map PlainTV (Set.toList vs)) (cxt[]) (return t)
  where
  vs = setOf typeVars t