#if !MIN_VERSION_base(4,8,0)
#endif
module Foreign.Lua.FunctionCalling
  ( FromLuaStack (..)
  , LuaCallFunc (..)
  , ToHaskellFunction (..)
  , HaskellFunction
  , ToLuaStack (..)
  , PreCFunction
  , toHaskellFunction
  , callFunc
  , freeCFunction
  , newCFunction
  , pushHaskellFunction
  , registerHaskellFunction
  ) where
import Control.Monad (when)
import Data.ByteString.Char8 (unpack)
import Foreign.C (CInt (..))
import Foreign.Lua.Api
import Foreign.Lua.Types
import Foreign.Lua.Util (getglobal')
import Foreign.Ptr (castPtr, freeHaskellFunPtr)
import Foreign.StablePtr (deRefStablePtr, freeStablePtr, newStablePtr)
import qualified Foreign.Storable as F
type PreCFunction = LuaState -> IO NumResults
type HaskellFunction = Lua NumResults
class ToHaskellFunction a where
  
  toHsFun :: StackIndex -> a -> Lua NumResults
#if MIN_VERSION_base(4,8,0)
instance  ToHaskellFunction HaskellFunction where
#else
instance ToHaskellFunction HaskellFunction where
#endif
  toHsFun _ = id
instance ToLuaStack a => ToHaskellFunction (Lua a) where
  toHsFun _narg x = 1 <$ (x >>= push)
instance (FromLuaStack a, ToHaskellFunction b) =>
         ToHaskellFunction (a -> b) where
  toHsFun narg f = getArg >>= toHsFun (narg + 1) . f
     where
      getArg = peek narg `catchLuaError` \err ->
        throwLuaError ("could not read argument "
                     ++ show (fromStackIndex narg) ++ ": " ++ show err)
toHaskellFunction :: ToHaskellFunction a => a -> HaskellFunction
toHaskellFunction a = toHsFun 1 a `catchLuaError` \err -> do
  push ("Error while calling haskell function: " ++ show err)
  fromIntegral <$> lerror
newCFunction :: ToHaskellFunction a => a -> Lua CFunction
newCFunction = liftIO . mkWrapper . flip runLuaWith . toHaskellFunction
foreign import ccall "wrapper"
  mkWrapper :: PreCFunction -> IO CFunction
freeCFunction :: CFunction -> Lua ()
freeCFunction = liftIO . freeHaskellFunPtr
class LuaCallFunc a where
  callFunc' :: String -> Lua () -> NumArgs -> a
instance (FromLuaStack a) => LuaCallFunc (Lua a) where
  callFunc' fnName x nargs = do
    getglobal' fnName
    x
    z <- pcall nargs 1 Nothing
    if z /= OK
      then tostring (1) >>= throwLuaError . unpack
      else peek (1) <* pop 1
instance (ToLuaStack a, LuaCallFunc b) => LuaCallFunc (a -> b) where
  callFunc' fnName pushArgs nargs x =
    callFunc' fnName (pushArgs *> push x) (nargs + 1)
callFunc :: (LuaCallFunc a) => String -> a
callFunc f = callFunc' f (return ()) 0
pushHaskellFunction :: ToHaskellFunction a => a -> Lua ()
pushHaskellFunction = pushPreCFunction . flip runLuaWith . toHaskellFunction
pushPreCFunction :: PreCFunction -> Lua ()
pushPreCFunction f = do
  stableptr <- liftIO $ newStablePtr f
  p <- newuserdata (F.sizeOf stableptr)
  liftIO $ F.poke (castPtr p) stableptr
  v <- newmetatable "HaskellImportedFunction"
  when v $ do
    
    pushcfunction hsmethod__gc_addr
    setfield (2) "__gc"
    pushcfunction hsmethod__call_addr
    setfield (2) "__call"
  setmetatable (2)
  return ()
registerHaskellFunction :: ToHaskellFunction a => String -> a -> Lua ()
registerHaskellFunction n f = do
  pushHaskellFunction f
  setglobal n
foreign export ccall hsMethodGc :: PreCFunction
foreign import ccall "&hsMethodGc" hsmethod__gc_addr :: CFunction
foreign export ccall hsMethodCall :: PreCFunction
foreign import ccall "&hsMethodCall" hsmethod__call_addr :: CFunction
hsMethodGc :: LuaState -> IO NumResults
hsMethodGc l = do
  ptr <- runLuaWith l $ peek (1)
  stableptr <- F.peek (castPtr ptr)
  freeStablePtr stableptr
  return 0
hsMethodCall :: LuaState -> IO NumResults
hsMethodCall l = do
  ptr <- runLuaWith l $ peek 1 <* remove 1
  stableptr <- F.peek (castPtr ptr)
  f <- deRefStablePtr stableptr
  f l