{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
module Test.HUnit.Lang (
  Assertion,
  assertFailure,
  assertEqual,
  Result (..),
  performTestCase,
  HUnitFailure (..),
  FailureReason (..),
  formatFailureReason
) where
import           Control.DeepSeq
import           Control.Exception as E
import           Control.Monad
import           Data.List
import           Data.Typeable
import           Data.CallStack
type Assertion = IO ()
data HUnitFailure = HUnitFailure (Maybe SrcLoc) FailureReason
    deriving (Eq, Show, Typeable)
instance Exception HUnitFailure
data FailureReason = Reason String | ExpectedButGot (Maybe String) String String
    deriving (Eq, Show, Typeable)
location :: HasCallStack => Maybe SrcLoc
location = case reverse callStack of
  (_, loc) : _ -> Just loc
  [] -> Nothing
assertFailure ::
     HasCallStack =>
     String 
  -> IO a
assertFailure msg = msg `deepseq` E.throwIO (HUnitFailure location $ Reason msg)
assertEqual :: (HasCallStack, Eq a, Show a)
                              => String 
                              -> a      
                              -> a      
                              -> Assertion
assertEqual preface expected actual =
  unless (actual == expected) $ do
    (prefaceMsg `deepseq` expectedMsg `deepseq` actualMsg `deepseq` E.throwIO (HUnitFailure location $ ExpectedButGot prefaceMsg expectedMsg actualMsg))
  where
    prefaceMsg
      | null preface = Nothing
      | otherwise = Just preface
    expectedMsg = show expected
    actualMsg = show actual
formatFailureReason :: FailureReason -> String
formatFailureReason (Reason reason) = reason
formatFailureReason (ExpectedButGot preface expected actual) = intercalate "\n" . maybe id (:) preface $ ["expected: " ++ expected, " but got: " ++ actual]
data Result = Success | Failure (Maybe SrcLoc) String | Error (Maybe SrcLoc) String
  deriving (Eq, Show)
performTestCase :: Assertion 
                -> IO Result
performTestCase action =
  (action >> return Success)
     `E.catches`
      [E.Handler (\(HUnitFailure loc reason) -> return $ Failure loc (formatFailureReason reason)),
       
       
       
       
       
       E.Handler (\e -> throw (e :: E.AsyncException)),
       E.Handler (\e -> return $ Error Nothing $ show (e :: E.SomeException))]