{-# LANGUAGE DeriveFoldable #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.KeyedPool
    ( KeyedPool
    , createKeyedPool
    , takeKeyedPool
    , Managed
    , managedResource
    , managedReused
    , managedRelease
    , keepAlive
    , Reuse (..)
    , dummyManaged
    ) where
import Control.Concurrent (forkIOWithUnmask, threadDelay)
import Control.Concurrent.STM
import Control.Exception (mask_, catch, SomeException)
import Control.Monad (join, unless, void)
import Data.Map (Map)
import Data.Maybe (isJust)
import qualified Data.Map.Strict as Map
import Data.Time (UTCTime, getCurrentTime, addUTCTime)
import Data.IORef (IORef, newIORef, mkWeakIORef, readIORef)
import qualified Data.Foldable as F
import GHC.Conc (unsafeIOToSTM)
import System.IO.Unsafe (unsafePerformIO)
data KeyedPool key resource = KeyedPool
    { kpCreate :: !(key -> IO resource)
    , kpDestroy :: !(resource -> IO ())
    , kpMaxPerKey :: !Int
    , kpMaxTotal :: !Int
    , kpVar :: !(TVar (PoolMap key resource))
    , kpAlive :: !(IORef ())
    }
data PoolMap key resource
    = PoolClosed
    | PoolOpen
        
        {-# UNPACK #-} !Int
        !(Map key (PoolList resource))
    deriving F.Foldable
data PoolList a
    = One a {-# UNPACK #-} !UTCTime
    | Cons
        a
        
        {-# UNPACK #-} !Int
        {-# UNPACK #-} !UTCTime
        !(PoolList a)
    deriving F.Foldable
plistToList :: PoolList a -> [(UTCTime, a)]
plistToList (One a t) = [(t, a)]
plistToList (Cons a _ t plist) = (t, a) : plistToList plist
plistFromList :: [(UTCTime, a)] -> Maybe (PoolList a)
plistFromList [] = Nothing
plistFromList [(t, a)] = Just (One a t)
plistFromList xs =
    Just . snd . go $ xs
  where
    go [] = error "plistFromList.go []"
    go [(t, a)] = (2, One a t)
    go ((t, a):rest) =
        let (i, rest') = go rest
            i' = i + 1
         in i' `seq` (i', Cons a i t rest')
createKeyedPool
    :: Ord key
    => (key -> IO resource) 
    -> (resource -> IO ())
       
       
       
    -> Int 
    -> Int 
    -> (SomeException -> IO ()) 
    -> IO (KeyedPool key resource)
createKeyedPool create destroy maxPerKey maxTotal onReaperException = do
    var <- newTVarIO $ PoolOpen 0 Map.empty
    
    
    
    alive <- newIORef ()
    void $ mkWeakIORef alive $ destroyKeyedPool' destroy var
    
    
    
    
    
    
    
    
    _ <- forkIOWithUnmask $ \restore -> keepRunning $ restore $ reap destroy var
    return KeyedPool
        { kpCreate = create
        , kpDestroy = destroy
        , kpMaxPerKey = maxPerKey
        , kpMaxTotal = maxTotal
        , kpVar = var
        , kpAlive = alive
        }
  where
    keepRunning action =
        loop
      where
        loop = action `catch` \e -> onReaperException e >> loop
destroyKeyedPool' :: (resource -> IO ())
                  -> TVar (PoolMap key resource)
                  -> IO ()
destroyKeyedPool' destroy var = do
    m <- atomically $ swapTVar var PoolClosed
    F.mapM_ (ignoreExceptions . destroy) m
reap :: forall key resource.
        Ord key
     => (resource -> IO ())
     -> TVar (PoolMap key resource)
     -> IO ()
reap destroy var =
    loop
  where
    loop = do
        threadDelay (5 * 1000 * 1000)
        join $ atomically $ do
            m'' <- readTVar var
            case m'' of
                PoolClosed -> return (return ())
                PoolOpen idleCount m
                    | Map.null m -> retry
                    | otherwise -> do
                        (m', toDestroy) <- findStale idleCount m
                        writeTVar var m'
                        return $ do
                            mask_ (mapM_ (ignoreExceptions . destroy) toDestroy)
                            loop
    findStale :: Int
              -> Map key (PoolList resource)
              -> STM (PoolMap key resource, [resource])
    findStale idleCount m = do
        
        
        
        
        now <- unsafeIOToSTM getCurrentTime
        let isNotStale time = 30 `addUTCTime` time >= now
        let findStale' toKeep toDestroy [] =
                (Map.fromList (toKeep []), toDestroy [])
            findStale' toKeep toDestroy ((key, plist):rest) =
                findStale' toKeep' toDestroy' rest
              where
                
                
                
                (notStale, stale) = span (isNotStale . fst) $ plistToList plist
                toDestroy' = toDestroy . (map snd stale++)
                toKeep' =
                    case plistFromList notStale of
                        Nothing -> toKeep
                        Just x -> toKeep . ((key, x):)
        let (toKeep, toDestroy) = findStale' id id (Map.toList m)
        let idleCount' = idleCount - length toDestroy
        return (PoolOpen idleCount' toKeep, toDestroy)
takeKeyedPool :: Ord key => KeyedPool key resource -> key -> IO (Managed resource)
takeKeyedPool kp key = mask_ $ join $ atomically $ do
    (m, mresource) <- fmap go $ readTVar (kpVar kp)
    writeTVar (kpVar kp) $! m
    return $ do
        resource <- maybe (kpCreate kp key) return mresource
        alive <- newIORef ()
        isReleasedVar <- newTVarIO False
        let release action = mask_ $ do
                isReleased <- atomically $ swapTVar isReleasedVar True
                unless isReleased $
                    case action of
                        Reuse -> putResource kp key resource
                        DontReuse -> ignoreExceptions $ kpDestroy kp resource
        _ <- mkWeakIORef alive $ release DontReuse
        return Managed
            { _managedResource = resource
            , _managedReused = isJust mresource
            , _managedRelease = release
            , _managedAlive = alive
            }
  where
    go PoolClosed = (PoolClosed, Nothing)
    go pcOrig@(PoolOpen idleCount m) =
        case Map.lookup key m of
            Nothing -> (pcOrig, Nothing)
            Just (One a _) ->
                (PoolOpen (idleCount - 1) (Map.delete key m), Just a)
            Just (Cons a _ _ rest) ->
                (PoolOpen (idleCount - 1) (Map.insert key rest m), Just a)
putResource :: Ord key => KeyedPool key resource -> key -> resource -> IO ()
putResource kp key resource = do
    now <- getCurrentTime
    join $ atomically $ do
        (m, action) <- fmap (go now) (readTVar (kpVar kp))
        writeTVar (kpVar kp) $! m
        return action
  where
    go _ PoolClosed = (PoolClosed, kpDestroy kp resource)
    go now pc@(PoolOpen idleCount m)
        | idleCount >= kpMaxTotal kp = (pc, kpDestroy kp resource)
        | otherwise = case Map.lookup key m of
            Nothing ->
                let cnt' = idleCount + 1
                    m' = PoolOpen cnt' (Map.insert key (One resource now) m)
                 in (m', return ())
            Just l ->
                let (l', mx) = addToList now (kpMaxPerKey kp) resource l
                    cnt' = idleCount + maybe 1 (const 0) mx
                    m' = PoolOpen cnt' (Map.insert key l' m)
                 in (m', maybe (return ()) (kpDestroy kp) mx)
addToList :: UTCTime -> Int -> a -> PoolList a -> (PoolList a, Maybe a)
addToList _ i x l | i <= 1 = (l, Just x)
addToList now _ x l@One{} = (Cons x 2 now l, Nothing)
addToList now maxCount x l@(Cons _ currCount _ _)
    | maxCount > currCount = (Cons x (currCount + 1) now l, Nothing)
    | otherwise = (l, Just x)
data Managed resource = Managed
    { _managedResource :: !resource
    , _managedReused :: !Bool
    , _managedRelease :: !(Reuse -> IO ())
    , _managedAlive :: !(IORef ())
    }
managedResource :: Managed resource -> resource
managedResource = _managedResource
managedReused :: Managed resource -> Bool
managedReused = _managedReused
managedRelease :: Managed resource -> Reuse -> IO ()
managedRelease = _managedRelease
data Reuse = Reuse | DontReuse
dummyManaged :: resource -> Managed resource
dummyManaged resource = Managed
    { _managedResource = resource
    , _managedReused = False
    , _managedRelease = const (return ())
    , _managedAlive = unsafePerformIO (newIORef ())
    }
ignoreExceptions :: IO () -> IO ()
ignoreExceptions f = f `catch` \(_ :: SomeException) -> return ()
keepAlive :: Managed resource -> IO ()
keepAlive = readIORef . _managedAlive