module LLVM.Core.CodeGen(
    
    newModule, newNamedModule, defineModule, createModule,
    getModuleValues, ModuleValue, castModuleValue, setTarget,
    
    Linkage(..),
    Visibility(..),
    
    Function, newFunction, newNamedFunction, defineFunction, createFunction, createNamedFunction, setFuncCallConv,
    addAttributes,
    FFI.Attribute(..),
    externFunction, staticFunction, staticNamedFunction,
    FunctionArgs, FunctionCodeGen, FunctionResult,
    TFunction,
    
    Global, newGlobal, newNamedGlobal, defineGlobal, createGlobal, createNamedGlobal, TGlobal,
    externGlobal, staticGlobal,
    
    Value(..), ConstValue(..),
    IsConst(..), valueOf, value,
    zero, allOnes, undef,
    createString, createStringNul,
    withString, withStringNul,
    constVector, constArray, constStruct, constPackedStruct,
    constCyclicVector, constCyclicArray,
    
    BasicBlock(..), newBasicBlock, newNamedBasicBlock, defineBasicBlock, createBasicBlock, getCurrentBasicBlock,
    fromLabel, toLabel,
    
    withCurrentBuilder
    ) where
import qualified LLVM.Core.UnaryVector as UnaryVector
import qualified LLVM.Core.Util as U
import qualified LLVM.Util.Proxy as LP
import LLVM.Core.CodeGenMonad
import LLVM.Core.Type
import LLVM.Core.Data
import qualified LLVM.FFI.Core as FFI
import LLVM.FFI.Core(Linkage(..), Visibility(..))
import qualified Type.Data.Num.Decimal.Proof as DecProof
import qualified Type.Data.Num.Decimal.Number as Dec
import Type.Base.Proxy (Proxy)
import qualified Foreign.Storable as St
import Foreign.C.String (withCString)
import Foreign.StablePtr (StablePtr, castStablePtrToPtr)
import Foreign.Ptr (Ptr, minusPtr, nullPtr, FunPtr, castFunPtrToPtr)
import System.IO.Unsafe (unsafePerformIO)
import Control.Monad.IO.Class (liftIO)
import Control.Monad (liftM, when)
import qualified Data.NonEmpty as NonEmpty
import qualified Data.Foldable as Fold
import Data.Typeable (Typeable)
import Data.Int (Int8, Int16, Int32, Int64)
import Data.Word (Word8, Word16, Word32, Word64)
import Data.Maybe.HT (toMaybe)
import Data.Maybe (fromMaybe)
import Text.Printf (printf)
newModule :: IO U.Module
newModule = newNamedModule "_module"  
newNamedModule :: String              
               -> IO U.Module
newNamedModule = U.createModule
defineModule :: U.Module              
             -> CodeGenModule a       
             -> IO a
defineModule = runCodeGenModule
createModule :: CodeGenModule a       
             -> IO a
createModule cgm = newModule >>= \ m -> defineModule m cgm
setTarget :: String -> CodeGenModule ()
setTarget triple = do
    modul <- getModule
    liftIO $ U.withModule modul $ \m -> withCString triple $ FFI.setTarget m
newtype ModuleValue = ModuleValue FFI.ValueRef
    deriving (Show, Typeable)
getModuleValues :: U.Module -> IO [(String, ModuleValue)]
getModuleValues =
    liftM (map (\ (s,p) -> (s, ModuleValue p))) . U.getModuleValues
castModuleValue :: forall a . (IsType a) => ModuleValue -> Maybe (Value a)
castModuleValue (ModuleValue f) =
    toMaybe (U.valueHasType f (unsafeTypeRef (LP.Proxy :: LP.Proxy a))) (Value f)
newtype Value a = Value { unValue :: FFI.ValueRef }
    deriving (Show, Typeable)
newtype ConstValue a = ConstValue { unConstValue :: FFI.ValueRef }
    deriving (Show, Typeable)
class IsConst a where
    constOf :: a -> ConstValue a
instance IsConst Bool   where constOf = constEnum (typeRef (LP.Proxy :: LP.Proxy Bool))
instance IsConst Word8  where constOf = constI
instance IsConst Word16 where constOf = constI
instance IsConst Word32 where constOf = constI
instance IsConst Word64 where constOf = constI
instance IsConst Int8   where constOf = constI
instance IsConst Int16  where constOf = constI
instance IsConst Int32  where constOf = constI
instance IsConst Int64  where constOf = constI
instance IsConst Float  where constOf = constF
instance IsConst Double where constOf = constF
instance (Dec.Positive n) => IsConst (WordN n) where
    constOf (WordN i) = constInteger i
instance (Dec.Positive n) => IsConst (IntN n) where
    constOf (IntN i) = constInteger i
constOfPtr :: (IsType ptr) =>
    ptr -> Ptr b -> ConstValue ptr
constOfPtr proto p =
    let ip = p `minusPtr` nullPtr
        inttoptrC :: ConstValue int -> ConstValue ptr
        inttoptrC (ConstValue v) =
           unsafeConstValue $
           FFI.constIntToPtr v $ unsafeTypeRef $ LP.fromValue proto
    in  if St.sizeOf p == 4 then
            inttoptrC $ constOf (fromIntegral ip :: Word32)
        else if St.sizeOf p == 8 then
            inttoptrC $ constOf (fromIntegral ip :: Word64)
        else
            error "constOf Ptr: pointer size not 4 or 8"
instance (IsType a) => IsConst (Ptr a) where
    constOf p = constOfPtr p p
instance (IsFunction a) => IsConst (FunPtr a) where
    constOf p = constOfPtr p (castFunPtrToPtr p)
instance IsConst (StablePtr a) where
    constOf p = constOfPtr p (castStablePtrToPtr p)
instance (IsPrimitive a, IsConst a, Dec.Positive n) => IsConst (Vector n a) where
    constOf (Vector x) = constVectorGen constOf x
instance (IsConst a, IsSized a, Dec.Natural n) => IsConst (Array n a) where
    constOf (Array xs) = constArray (map constOf xs)
instance (IsConstFields a) => IsConst (Struct a) where
    constOf (Struct a) =
        unsafeConstValue $ U.constStruct (constFieldsOf a) False
instance (IsConstFields a) => IsConst (PackedStruct a) where
    constOf (PackedStruct a) =
        unsafeConstValue $ U.constStruct (constFieldsOf a) True
class IsConstFields a where
    constFieldsOf :: a -> [FFI.ValueRef]
instance (IsConst a, IsConstFields as) => IsConstFields (a, as) where
    constFieldsOf (a, as) = unConstValue (constOf a) : constFieldsOf as
instance IsConstFields () where
    constFieldsOf _ = []
unsafeConstValue :: IO FFI.ValueRef -> ConstValue a
unsafeConstValue =
    ConstValue . unsafePerformIO
unsafeWithConstValue ::
    forall a.
    (IsType a) =>
    (FFI.TypeRef -> IO FFI.ValueRef) ->
    ConstValue a
unsafeWithConstValue f =
    unsafePerformIO $ fmap ConstValue $
        f =<< typeRef (LP.Proxy :: LP.Proxy a)
constEnum :: (Enum a) => IO FFI.TypeRef -> a -> ConstValue a
constEnum mt i =
    unsafeConstValue $ mt >>= \t ->
        FFI.constInt t (fromIntegral $ fromEnum i) FFI.false
constInteger :: (IsType (intN n)) => Integer -> ConstValue (intN n)
constInteger i =
    unsafeWithConstValue $ \typ ->
    withCString (show i) $ \cstr ->
    FFI.constIntOfString typ cstr 10
constI :: (IsInteger a, Integral a) => a -> ConstValue a
constI i =
    unsafeWithConstValue $ \typ ->
    FFI.constInt typ (fromIntegral i) (FFI.consBool $ isSigned $ LP.fromValue i)
constF :: (IsFloating a, Real a) => a -> ConstValue a
constF i =
    unsafeWithConstValue $ \typ -> FFI.constReal typ (realToFrac i)
valueOf :: (IsConst a) => a -> Value a
valueOf = value . constOf
value :: ConstValue a -> Value a
value (ConstValue a) = Value a
zero :: forall a . (IsType a) => ConstValue a
zero = unsafeWithConstValue FFI.constNull
allOnes :: forall a . (IsInteger a) => ConstValue a
allOnes = unsafeWithConstValue FFI.constAllOnes
undef :: forall a . (IsType a) => ConstValue a
undef = unsafeWithConstValue FFI.getUndef
type Function a = Value (FunPtr a)
newNamedFunction :: forall a . (IsFunction a)
                 => Linkage
                 -> String   
                 -> CodeGenModule (Function a)
newNamedFunction linkage name = do
    modul <- getModule
    typ <- liftIO $ typeRef (LP.Proxy :: LP.Proxy a)
    liftIO $ liftM Value $ U.addFunction modul linkage name typ
newFunction :: forall a . (IsFunction a)
            => Linkage
            -> CodeGenModule (Function a)
newFunction linkage = genMSym "fun" >>= newNamedFunction linkage
defineFunctionParam ::
                  Function f        
               -> Parameterized r f 
               -> CodeGenModule ()
defineFunctionParam fn p = do
    bld <- liftIO $ U.createBuilder
    let body' = do
	    newBasicBlock >>= defineBasicBlock
	    defineParameterized fn p
    runCodeGenFunction bld (unValue fn) body'
defineFunction :: forall f . (FunctionArgs f)
               => Function f        
               -> FunctionCodeGen f 
               -> CodeGenModule ()
defineFunction fn body =
    defineFunctionParam fn $ paramFunc body
createFunction :: (FunctionArgs f)
               => Linkage
               -> FunctionCodeGen f  
               -> CodeGenModule (Function f)
createFunction linkage body = do
    f <- newFunction linkage
    defineFunction f body
    return f
createNamedFunction :: (FunctionArgs f)
               => Linkage
	       -> String
               -> FunctionCodeGen f  
               -> CodeGenModule (Function f)
createNamedFunction linkage name body = do
    f <- newNamedFunction linkage name
    defineFunction f body
    return f
setFuncCallConv :: Function a
                -> FFI.CallingConvention
                -> CodeGenModule ()
setFuncCallConv (Value f) cc = do
  liftIO $ FFI.setFunctionCallConv f (FFI.fromCallingConvention cc)
addAttributes :: Value a -> Int -> [FFI.Attribute] -> CodeGenFunction r ()
addAttributes (Value f) i as = do
    liftIO $ FFI.addInstrAttribute f (fromIntegral i) (sum $ map FFI.fromAttribute as)
class IsFunction f => FunctionArgs f where
    type FunctionCodeGen f :: *
    type FunctionResult  f :: *
    paramFunc :: FunctionCodeGen f -> Parameterized (FunctionResult f) f
instance (FunctionArgs b, IsFirstClass a) => FunctionArgs (a -> b) where
    type FunctionCodeGen (a -> b) = Value a -> FunctionCodeGen b
    type FunctionResult  (a -> b) = FunctionResult b
    paramFunc f = param $ \x -> paramFunc (f x)
instance IsFirstClass a => FunctionArgs (IO a) where
    type FunctionCodeGen (IO a) = CodeGenFunction a ()
    type FunctionResult (IO a) = a
    paramFunc = parameterized
newtype
   Parameterized r f =
      Parameterized (Int -> FFI.ValueRef -> CodeGenFunction r ())
parameterized :: CodeGenFunction r () -> Parameterized r (IO r)
parameterized code = Parameterized (const $ const code)
param :: (Value a -> Parameterized r b) -> Parameterized r (a -> b)
param pf =
   Parameterized $ \n f ->
      case pf $ Value $ U.getParam f n of
         Parameterized p -> p (n+1) f
defineParameterized :: Function f -> Parameterized r f -> CodeGenFunction r ()
defineParameterized f (Parameterized p) = p 0 $ unValue f
newtype BasicBlock = BasicBlock FFI.BasicBlockRef
    deriving (Show, Typeable)
createBasicBlock :: CodeGenFunction r BasicBlock
createBasicBlock = do
    b <- newBasicBlock
    defineBasicBlock b
    return b
newBasicBlock :: CodeGenFunction r BasicBlock
newBasicBlock = genFSym >>= newNamedBasicBlock
newNamedBasicBlock :: String -> CodeGenFunction r BasicBlock
newNamedBasicBlock name = do
    fn <- getFunction
    liftIO $ liftM BasicBlock $ U.appendBasicBlock fn name
defineBasicBlock :: BasicBlock -> CodeGenFunction r ()
defineBasicBlock (BasicBlock l) = do
    bld <- getBuilder
    liftIO $ U.positionAtEnd bld l
getCurrentBasicBlock :: CodeGenFunction r BasicBlock
getCurrentBasicBlock = do
    bld <- getBuilder
    liftIO $ liftM BasicBlock $ U.getInsertBlock bld
toLabel :: BasicBlock -> Value Label
toLabel (BasicBlock ptr) =
    Value (unsafePerformIO $ FFI.basicBlockAsValue ptr)
fromLabel :: Value Label -> BasicBlock
fromLabel (Value ptr) =
    BasicBlock (unsafePerformIO $ FFI.valueAsBasicBlock ptr)
externFunction :: forall a r . (IsFunction a) => String -> CodeGenFunction r (Function a)
externFunction name =
    externCore name $
        fmap (unValue :: Function a -> FFI.ValueRef) .
        newNamedFunction ExternalLinkage
externGlobal :: forall a r . (IsType a) => Bool -> String -> CodeGenFunction r (Global a)
externGlobal isConst name =
    externCore name $
        fmap (unValue :: Global a -> FFI.ValueRef) .
        newNamedGlobal isConst ExternalLinkage
externCore ::
    String -> (String -> CodeGenModule FFI.ValueRef) ->
    CodeGenFunction r (Value ptr)
externCore name act = do
    es <- getExterns
    case lookup name es of
        Just f -> return $ Value f
        Nothing -> do
            f <- liftCodeGenModule $ act name
            putExterns ((name, f) : es)
            return $ Value f
staticFunction :: forall f r. (IsFunction f) => FunPtr f -> CodeGenFunction r (Function f)
staticFunction = staticNamedFunction ""
staticNamedFunction :: forall f r. (IsFunction f) => String -> FunPtr f -> CodeGenFunction r (Function f)
staticNamedFunction name func = liftCodeGenModule $ do
    val <- newNamedFunction ExternalLinkage name
    addFunctionMapping (unValue (val :: Function f)) func
    return val
staticGlobal :: forall a r. (IsType a) => Bool -> Ptr a -> CodeGenFunction r (Global a)
staticGlobal isConst gbl = liftCodeGenModule $ do
    val <- newNamedGlobal isConst ExternalLinkage ""
    addGlobalMapping (unValue (val :: Global a)) gbl
    return val
withCurrentBuilder :: (FFI.BuilderRef -> IO a) -> CodeGenFunction r a
withCurrentBuilder body = do
    bld <- getBuilder
    liftIO $ U.withBuilder bld body
type Global a = Value (Ptr a)
newNamedGlobal :: forall a . (IsType a)
               => Bool         
               -> Linkage      
               -> String       
               -> TGlobal a
newNamedGlobal isConst linkage name = do
    modul <- getModule
    typ <- liftIO $ typeRef (LP.Proxy :: LP.Proxy a)
    liftIO $ liftM Value $ do
        g <- U.addGlobal modul linkage name typ
        when isConst $ FFI.setGlobalConstant g FFI.true
        return g
newGlobal :: forall a . (IsType a) => Bool -> Linkage -> TGlobal a
newGlobal isConst linkage = genMSym "glb" >>= newNamedGlobal isConst linkage
defineGlobal :: Global a -> ConstValue a -> CodeGenModule ()
defineGlobal (Value g) (ConstValue v) =
    liftIO $ FFI.setInitializer g v
createGlobal :: (IsType a) => Bool -> Linkage -> ConstValue a -> TGlobal a
createGlobal isConst linkage con = do
    g <- newGlobal isConst linkage
    defineGlobal g con
    return g
createNamedGlobal :: (IsType a) => Bool -> Linkage -> String -> ConstValue a -> TGlobal a
createNamedGlobal isConst linkage name con = do
    g <- newNamedGlobal isConst linkage name
    defineGlobal g con
    return g
type TFunction a = CodeGenModule (Function a)
type TGlobal a = CodeGenModule (Global a)
createString :: String -> TGlobal (Array n Word8)
createString s = string (length s) (U.constString s)
createStringNul :: String -> TGlobal (Array n Word8)
createStringNul s = string (length s + 1) (U.constStringNul s)
withString ::
   String ->
   (forall n. (Dec.Natural n) => Global (Array n Word8) -> CodeGenModule a) ->
   CodeGenModule a
withString s act =
   let n = length s
   in  fromMaybe (error "withString: length must always be non-negative") $
       Dec.reifyNatural (fromIntegral n) (\tn ->
          do arr <- string n (U.constString s)
             act (fixArraySize tn arr))
withStringNul ::
   String ->
   (forall n. (Dec.Natural n) => Global (Array n Word8) -> CodeGenModule a) ->
   CodeGenModule a
withStringNul s act =
   let n = length s + 1
   in  fromMaybe (error "withStringNul: length must always be non-negative") $
       Dec.reifyNatural (fromIntegral n) (\tn ->
          do arr <- string n (U.constStringNul s)
             act (fixArraySize tn arr))
fixArraySize :: Proxy n -> Global (Array n a) -> Global (Array n a)
fixArraySize _ = id
string :: Int -> FFI.ValueRef -> TGlobal (Array n Word8)
string n s = do
    modul <- getModule
    name <- genMSym "str"
    elemTyp <- liftIO $ typeRef (LP.Proxy :: LP.Proxy Word8)
    typ <- liftIO $ FFI.arrayType elemTyp (fromIntegral n)
    liftIO $ liftM Value $ do g <- U.addGlobal modul InternalLinkage name typ
    	     	   	      FFI.setGlobalConstant g FFI.true
			      FFI.setInitializer g s
			      return g
constVector ::
    forall a n u.
    (Dec.Positive n, Dec.ToUnary n ~ u,
     UnaryVector.Length (FixedList u) ~ u) =>
    UnaryVector.FixedList u (ConstValue a) ->
    ConstValue (Vector n a)
constVector =
    constVectorGen id
constVectorGen ::
    forall a b n u.
    (Dec.Positive n, Dec.ToUnary n ~ u) =>
    (b -> ConstValue a) ->
    UnaryVector.FixedList u b ->
    ConstValue (Vector n a)
constVectorGen f xs =
    unsafeConstValue $
    U.constVector
        (case DecProof.unaryNat :: DecProof.UnaryNat n of
             DecProof.UnaryNat ->
                 map (unConstValue . f) $
                 Fold.toList
                     (UnaryVector.fromFixedList xs :: UnaryVector.T u b))
constCyclicVector ::
    forall a n.
    (Dec.Positive n) =>
    NonEmpty.T [] (ConstValue a) ->
    ConstValue (Vector n a)
constCyclicVector xs =
    unsafeConstValue $
    U.constVector
        (take (Dec.integralFromSingleton (Dec.singleton :: Dec.Singleton n)) $
         map unConstValue $ NonEmpty.flatten $ NonEmpty.cycle xs)
constArray ::
    forall a n . (IsSized a, Dec.Natural n) =>
    [ConstValue a] -> ConstValue (Array n a)
constArray xs = unsafeConstValue $ do
    let m = length xs
        n = Dec.integralFromSingleton (Dec.singleton :: Dec.Singleton n)
    when (m /= n) $
        error $
            printf "LLVM.constArray: number of array elements (%d) mismatches typed array length (%d)"
                m n
    typ <- typeRef (LP.Proxy :: LP.Proxy a)
    U.constArray typ $ map unConstValue xs
constCyclicArray ::
    forall a n.
    (IsSized a, Dec.Natural n) =>
    NonEmpty.T [] (ConstValue a) ->
    ConstValue (Vector n a)
constCyclicArray xs = unsafeConstValue $ do
    typ <- typeRef (LP.Proxy :: LP.Proxy a)
    U.constArray typ
        (take (Dec.integralFromSingleton (Dec.singleton :: Dec.Singleton n)) $
         map unConstValue $ NonEmpty.flatten $ NonEmpty.cycle xs)
constStruct ::
    (IsConstStruct c) => c -> ConstValue (Struct (ConstStructOf c))
constStruct struct =
    unsafeConstValue $ U.constStruct (constValueFieldsOf struct) False
constPackedStruct ::
    (IsConstStruct c) => c -> ConstValue (PackedStruct (ConstStructOf c))
constPackedStruct struct =
    unsafeConstValue $ U.constStruct (constValueFieldsOf struct) True
class IsConstStruct c where
    type ConstStructOf c :: *
    constValueFieldsOf :: c -> [FFI.ValueRef]
instance (IsConst a, IsConstStruct cs) => IsConstStruct (ConstValue a, cs) where
    type ConstStructOf (ConstValue a, cs) = (a, ConstStructOf cs)
    constValueFieldsOf (a, as) = unConstValue a : constValueFieldsOf as
instance IsConstStruct () where
    type ConstStructOf () = ()
    constValueFieldsOf _ = []