module LLVM.Internal.OrcJIT.CompileLayer
  ( module LLVM.Internal.OrcJIT.CompileLayer
  , FFI.ModuleHandle
  ) where
import LLVM.Prelude
import Control.Exception
import Control.Monad.AnyCont
import Control.Monad.IO.Class
import Data.IORef
import Foreign.Ptr
import LLVM.Internal.Coding
import qualified LLVM.Internal.FFI.DataLayout as FFI
import qualified LLVM.Internal.FFI.OrcJIT as FFI
import qualified LLVM.Internal.FFI.OrcJIT.CompileLayer as FFI
import LLVM.Internal.Module hiding (getDataLayout)
import LLVM.Internal.OrcJIT
class CompileLayer l where
  getCompileLayer :: l -> Ptr FFI.CompileLayer
  getDataLayout :: l -> Ptr FFI.DataLayout
  getCleanups :: l -> IORef [IO ()]
mangleSymbol :: CompileLayer l => l -> ShortByteString -> IO MangledSymbol
mangleSymbol compileLayer symbol = flip runAnyContT return $ do
  mangledSymbol <- alloca
  symbol' <- encodeM symbol
  anyContToM $ bracket
    (FFI.getMangledSymbol mangledSymbol symbol' (getDataLayout compileLayer))
    (\_ -> FFI.disposeMangledSymbol =<< peek mangledSymbol)
  decodeM =<< peek mangledSymbol
findSymbol :: CompileLayer l => l -> MangledSymbol -> Bool -> IO JITSymbol
findSymbol compileLayer symbol exportedSymbolsOnly = flip runAnyContT return $ do
  symbol' <- encodeM symbol
  exportedSymbolsOnly' <- encodeM exportedSymbolsOnly
  symbol <- anyContToM $ bracket
    (FFI.findSymbol (getCompileLayer compileLayer) symbol' exportedSymbolsOnly') FFI.disposeSymbol
  decodeM symbol
findSymbolIn :: CompileLayer l => l -> FFI.ModuleHandle -> MangledSymbol -> Bool -> IO JITSymbol
findSymbolIn compileLayer handle symbol exportedSymbolsOnly = flip runAnyContT return $ do
  symbol' <- encodeM symbol
  exportedSymbolsOnly' <- encodeM exportedSymbolsOnly
  symbol <- anyContToM $ bracket
    (FFI.findSymbolIn (getCompileLayer compileLayer) handle symbol' exportedSymbolsOnly') FFI.disposeSymbol
  decodeM symbol
addModule :: CompileLayer l => l -> Module -> SymbolResolver -> IO FFI.ModuleHandle
addModule compileLayer mod resolver = flip runAnyContT return $ do
  resolverAct <- encodeM resolver
  resolver' <- liftIO $ resolverAct (getCleanups compileLayer)
  mod' <- liftIO $ readModule mod
  liftIO $ deleteModule mod
  errMsg <- alloca
  liftIO $
    FFI.addModule
      (getCompileLayer compileLayer)
      (getDataLayout compileLayer)
      mod'
      resolver'
      errMsg
removeModule :: CompileLayer l => l -> FFI.ModuleHandle -> IO ()
removeModule compileLayer handle =
  FFI.removeModule (getCompileLayer compileLayer) handle
withModule :: CompileLayer l => l -> Module -> SymbolResolver -> (FFI.ModuleHandle -> IO a) -> IO a
withModule compileLayer mod resolver =
  bracket
    (addModule compileLayer mod resolver)
    (removeModule compileLayer)
disposeCompileLayer :: CompileLayer l => l -> IO ()
disposeCompileLayer l = do
  FFI.disposeCompileLayer (getCompileLayer l)
  sequence_ =<< readIORef (getCleanups l)