{-# LANGUAGE TypeApplications #-}

module Streamly.External.LMDB.Tests (tests) where

import Control.Concurrent.Async (asyncBound, wait)
import Control.Exception (SomeException, onException, try)
import Control.Monad (forM_)
import Data.ByteString (ByteString, pack, unpack)
import Data.ByteString.Unsafe (unsafeUseAsCStringLen)
import Data.List (find, foldl', nubBy, sort)
import Data.Word (Word8)
import Foreign (castPtr, nullPtr, with)
import Streamly.Prelude (fromList, toList, unfold)
import Test.QuickCheck (NonEmptyList (..), Gen, choose, elements, frequency)
import Test.QuickCheck.Monadic (PropertyM, monadicIO, pick, run)
import Test.Tasty (TestTree)
import Test.Tasty.QuickCheck (arbitrary, testProperty)

import qualified Data.ByteString as B
import qualified Streamly.Prelude as S

import Streamly.External.LMDB (Mode, ReadWrite, OverwriteOptions (..), ReadDirection (..), ReadOptions (..),
    WriteOptions(..), clearDatabase, defaultReadOptions, defaultWriteOptions, readLMDB, unsafeReadLMDB, writeLMDB)
import Streamly.External.LMDB.Internal (Database (..))
import Streamly.External.LMDB.Internal.Foreign (MDB_val (..), combineOptions,
    mdb_nooverwrite, mdb_put, mdb_txn_begin, mdb_txn_commit)

tests :: IO (Database ReadWrite) -> [TestTree]
tests db =
    [ testReadLMDB db
    , testUnsafeReadLMDB db
    , testWriteLMDB db
    , testWriteLMDB_2 db
    , testWriteLMDB_3 db
    , testBetween ]

-- | Clear the database, write key-value pairs to it in a normal manner, read
-- them back using our library, and make sure the result is what we wrote.
testReadLMDB :: (Mode mode) => IO (Database mode) -> TestTree
testReadLMDB res = testProperty "readLMDB" . monadicIO $ do
    db <- run res
    keyValuePairs <- arbitraryKeyValuePairs''
    run $ clearDatabase db

    run $ writeChunk db False keyValuePairs
    let keyValuePairsInDb = sort . removeDuplicateKeys $ keyValuePairs

    (readOpts, expectedResults) <- pick $ readOptionsAndResults keyValuePairsInDb
    results <- run . toList $ unfold (readLMDB db readOpts) undefined

    return $ results == expectedResults

-- | Similar to 'testReadLMDB', except that it tests the unsafe function in a different manner.
testUnsafeReadLMDB :: (Mode mode) => IO (Database mode) -> TestTree
testUnsafeReadLMDB res = testProperty "unsafeReadLMDB" . monadicIO $ do
    db <- run res
    keyValuePairs <- arbitraryKeyValuePairs''
    run $ clearDatabase db

    run $ writeChunk db False keyValuePairs
    let keyValuePairsInDb = sort . removeDuplicateKeys $ keyValuePairs

    (readOpts, expectedResults) <- pick $ readOptionsAndResults keyValuePairsInDb
    let expectedLengths = map (\(k, v) -> (B.length k, B.length v)) expectedResults
    lengths <- run . toList $ unfold (unsafeReadLMDB db readOpts (return . snd) (return . snd)) undefined

    return $ lengths == expectedLengths

-- | Clear the database, write key-value pairs to it using our library with key overwriting allowed, read
-- them back using our library (already covered by 'testReadLMDB'), and make sure the result is what we wrote.
testWriteLMDB :: IO (Database ReadWrite) -> TestTree
testWriteLMDB res = testProperty "writeLMDB" . monadicIO $ do
    db <- run res
    keyValuePairs <- arbitraryKeyValuePairs
    run $ clearDatabase db

    chunkSz <- pick arbitrary

    let fol' = writeLMDB db $ defaultWriteOptions { writeTransactionSize = chunkSz
                                                  , overwriteOptions = OverwriteAllow }

    -- TODO: Run with new "bound" functionality in streamly.
    run $ asyncBound (S.fold fol' (fromList keyValuePairs)) >>= wait
    let keyValuePairsInDb = sort . removeDuplicateKeys $ keyValuePairs

    readPairsAll <- run . toList $ unfold (readLMDB db defaultReadOptions) undefined

    return $ keyValuePairsInDb == readPairsAll

-- | Clear the database, write key-value pairs to it using our library with key overwriting
-- disallowed, and make sure an exception occurs iff we had a duplicate key in our pairs.
-- Furthermore make sure that key-value pairs prior to a duplicate key are actually in the database.
testWriteLMDB_2 :: IO (Database ReadWrite) -> TestTree
testWriteLMDB_2 res = testProperty "writeLMDB_2" . monadicIO $ do
    db <- run res
    keyValuePairs <- arbitraryKeyValuePairs'
    run $ clearDatabase db

    chunkSz <- pick arbitrary

    -- TODO: Run with new "bound" functionality in streamly.
    let fol' = writeLMDB db $ defaultWriteOptions { writeTransactionSize = chunkSz
                                                  , overwriteOptions = OverwriteDisallow }
    e <- run $ try @SomeException $ (asyncBound (S.fold fol' (fromList keyValuePairs)) >>= wait)
    exceptionAsExpected <-
        case e of
            Left _ -> return $ hasDuplicateKeys keyValuePairs
            Right _ -> return . not $ hasDuplicateKeys keyValuePairs

    let keyValuePairsInDb = sort . prefixBeforeDuplicate  $ keyValuePairs
    readPairsAll <- run . toList $ unfold (readLMDB db defaultReadOptions) undefined
    let pairsAsExpected = keyValuePairsInDb == readPairsAll

    return $ exceptionAsExpected && pairsAsExpected

-- | Clear the database, write key-value pairs to it using our library with key overwriting
-- disallowed except when attempting to replace an existing key-value pair, and make sure an
-- exception occurs iff we had a duplicate key with different values in our pairs. Furthermore
-- make sure that key-value pairs prior to a such a duplicate key are actually in the database.
testWriteLMDB_3 :: IO (Database ReadWrite) -> TestTree
testWriteLMDB_3 res = testProperty "writeLMDB_3" . monadicIO $ do
    db <- run res
    keyValuePairs <- arbitraryKeyValuePairs'
    run $ clearDatabase db

    chunkSz <- pick arbitrary

    -- TODO: Run with new "bound" functionality in streamly.
    let fol' = writeLMDB db $ defaultWriteOptions { writeTransactionSize = chunkSz
                                                  , overwriteOptions = OverwriteAllowSame }
    e <- run $ try @SomeException $ (asyncBound (S.fold fol' (fromList keyValuePairs)) >>= wait)
    exceptionAsExpected <-
        case e of
            Left _ -> return $ hasDuplicateKeysWithDiffVals keyValuePairs
            Right _ -> return . not $ hasDuplicateKeysWithDiffVals keyValuePairs

    let keyValuePairsInDb = sort . removeDuplicateKeys . prefixBeforeDuplicateWithDiffVal $ keyValuePairs
    readPairsAll <- run . toList $ unfold (readLMDB db defaultReadOptions) undefined
    let pairsAsExpected = keyValuePairsInDb == readPairsAll

    return $ exceptionAsExpected && pairsAsExpected

arbitraryKeyValuePairs :: PropertyM IO [(ByteString, ByteString)]
arbitraryKeyValuePairs =
    map (\(ws1, ws2) -> (pack ws1, pack ws2))
    . filter (\(ws1, _) -> not (null ws1)) -- LMDB does not allow empty keys.
   <$> pick arbitrary

-- A variation that makes duplicate keys more likely.
arbitraryKeyValuePairs' :: PropertyM IO [(ByteString, ByteString)]
arbitraryKeyValuePairs' = do
    arb <- arbitraryKeyValuePairs
    b <- pick arbitrary
    if not (null arb) && b then do
        let (k, v) = head arb
        b' <- pick arbitrary
        v' <- if b' then return v else pack <$> pick arbitrary
        i <- pick $ choose (negate $ length arb, 2 * length arb)
        let (arb1, arb2) = splitAt i arb
        let arb' = arb1 ++ [(k, v')] ++ arb2
        return arb'
    else
        return arb

-- A variation that makes more likely keys with same the prefix and a difference of trailing zero bytes.
arbitraryKeyValuePairs'' :: PropertyM IO [(ByteString, ByteString)]
arbitraryKeyValuePairs'' = do
    arb <- arbitraryKeyValuePairs
    if null arb then
        return arb
    else pick $ frequency [(1, return arb), (3, do
        let (k, v) = head arb
        b' <- arbitrary
        v' <- if b' then return v else pack <$> arbitrary
        i <- choose (0, length arb - 1)
        let (arb1, arb2) = splitAt i arb
        let arb3 = map (\i' -> (k `B.append` B.replicate i' 0, v')) [1..(i+1)]
        let arb' = arb1 ++ arb3 ++ arb2
        return arb' )]

-- | Note that this function retains the last value for each key.
removeDuplicateKeys :: (Eq a) => [(a, b)] -> [(a, b)]
removeDuplicateKeys = foldl' (\acc (a, b) -> if any ((== a) . fst) acc then acc else (a, b) : acc) [] . reverse

hasDuplicateKeys :: (Eq a) => [(a, b)] -> Bool
hasDuplicateKeys l =
    let l2 = nubBy (\(a1, _) (a2, _) -> a1 == a2) l
     in length l /= length l2

hasDuplicateKeysWithDiffVals :: (Eq a, Eq b) => [(a, b)] -> Bool
hasDuplicateKeysWithDiffVals l =
    let l2 = nubBy (\(a1, b1) (a2, b2) -> a1 == a2 && b1 /= b2) l
     in length l /= length l2

prefixBeforeDuplicate :: (Eq a) => [(a, b)] -> [(a, b)]
prefixBeforeDuplicate xs =
    let fstDup = snd <$> find (\((a, _), i) -> a `elem` map fst (take i xs)) (zip xs [0..])
    in case fstDup of
        Nothing -> xs
        Just i -> take i xs

prefixBeforeDuplicateWithDiffVal :: (Eq a, Eq b) => [(a, b)] -> [(a, b)]
prefixBeforeDuplicateWithDiffVal xs =
    let fstDup = snd <$> find (\((a, b), i) -> any (\(a', b') -> a == a' && b /= b') (take i xs)) (zip xs [0..])
    in case fstDup of
        Nothing -> xs
        Just i -> take i xs

-- Assumes first < second.
between :: [Word8] -> [Word8] -> [Word8] -> Maybe [Word8]
between [] [] _ = error "first = second"
between _ [] _ = error "first > second"
between [] (w:ws) commonPrefixRev
    | w == 0 && null ws = Nothing
    | w == 0 = between [] ws (w:commonPrefixRev)
    | otherwise = Just $ reverse (0:commonPrefixRev)
between (w1:ws1) (w2:ws2) commonPrefixRev
    | w1 == w2 = between ws1 ws2 (w1:commonPrefixRev)
    | w1 > w2 = error "first > second"
    | otherwise = Just $ reverse commonPrefixRev ++ [w1] ++ ws1 ++ [0]

testBetween :: TestTree
testBetween = testProperty "testBetween" $ \ws1 ws2 ->
    (ws1 == ws2) ||
        let (smaller, bigger) = if ws1 < ws2 then (ws1, ws2) else (ws2, ws1)
        in case between smaller bigger [] of
            Nothing -> drop (length ws1) ws2 == replicate (length ws2 - length ws1) 0
            Just betw -> smaller < betw && betw < bigger

betweenBs :: ByteString -> ByteString -> Maybe ByteString
betweenBs bs1 bs2 = between (unpack bs1) (unpack bs2) [] >>= (return . pack)

type PairsInDatabase = [(ByteString, ByteString)]
type ExpectedReadResult = [(ByteString, ByteString)]

-- | Given database pairs, randomly generates read options and corresponding expected results.
readOptionsAndResults :: PairsInDatabase -> Gen (ReadOptions, ExpectedReadResult)
readOptionsAndResults pairsInDb = do
    forw <- arbitrary
    let dir = if forw then Forward else Backward
    let len = length pairsInDb
    readAll <- frequency [ (1, return True), (3, return False) ]
    let ropts = defaultReadOptions { readDirection = dir }
    if readAll then
        return (ropts { readStart = Nothing }, (if forw then id else reverse) pairsInDb)
    else if len == 0 then do
        bs <- arbitrary >>= \(NonEmpty ws) -> return $ pack ws
        return (ropts { readStart = Just bs }, [])
    else do
        idx <- if len < 3 then choose (0, len - 1) else
                frequency [ (1, choose (1, len - 2)), (3, elements [0, len - 1]) ]
        let keyAt i = fst $ pairsInDb !! i
        let nextKey
                | idx + 1 <= len - 1 = betweenBs (keyAt idx) (keyAt $ idx + 1)
                | otherwise = Just $ keyAt (len - 1) `B.append` B.singleton 0
        let prevKey
                | idx == 0 && keyAt idx /= B.singleton 0 = Just $ B.singleton 0 -- Keys are known to be non-empty.
                | idx == 0 = Nothing
                | otherwise = betweenBs (keyAt $ idx - 1) (keyAt idx)
        let forwEq = (ropts { readStart = Just $ keyAt idx }, drop idx pairsInDb)
        let backwEq = (ropts { readStart = Just $ keyAt idx }, reverse $ take (idx + 1) pairsInDb)
        ord <- arbitrary @Ordering -- Proximity to the key at idx (if possible).
        return $ case (ord, dir) of
            (EQ, Forward) -> forwEq
            (EQ, Backward) -> backwEq
            (GT, Forward) -> case nextKey of
                Nothing -> forwEq
                Just nextKey' -> (ropts { readStart = Just nextKey' }, drop (idx + 1) pairsInDb)
            (GT, Backward) -> case nextKey of
                Nothing -> backwEq
                Just nextKey' -> (ropts { readStart = Just nextKey' }, reverse $ take (idx + 1) pairsInDb)
            (LT, Forward) -> case prevKey of
                Nothing -> forwEq
                Just prevKey' -> (ropts { readStart = Just prevKey' }, drop idx pairsInDb)
            (LT, Backward) -> case prevKey of
                Nothing -> backwEq
                Just prevKey' -> (ropts { readStart = Just prevKey' }, reverse $ take idx pairsInDb)

-- Writes the given key-value pairs to the given database.
writeChunk :: (Foldable t, Mode mode) => Database mode -> Bool -> t (ByteString, ByteString) -> IO ()
writeChunk (Database penv dbi) noOverwrite' keyValuePairs =
    let flags = combineOptions $ [mdb_nooverwrite | noOverwrite']
    in asyncBound (do
        ptxn <- mdb_txn_begin penv nullPtr 0
        onException (forM_ keyValuePairs $ \(k, v) ->
            marshalOut k $ \k' -> marshalOut v $ \v' -> with k' $ \k'' -> with v' $ \v'' ->
                mdb_put ptxn dbi k'' v'' flags)
            (mdb_txn_commit ptxn) -- Make sure the key-value pairs we have so far are committed.
        mdb_txn_commit ptxn) >>= wait

{-# INLINE marshalOut #-}
marshalOut :: ByteString -> (MDB_val -> IO ()) -> IO ()
marshalOut bs f = unsafeUseAsCStringLen bs $ \(ptr, len) -> f $ MDB_val (fromIntegral len) (castPtr ptr)