-----------------------------------------------------------------------------
-- |
-- Module    : Data.SBV.Utils.TDiff
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- Runs an IO computation printing the time it took to run it
-----------------------------------------------------------------------------

{-# OPTIONS_GHC -Wall -Werror #-}

module Data.SBV.Utils.TDiff
  ( Timing(..)
  , timeIf
  , timeIfRNF
  , showTDiff
  , getTimeStampIf
  , getElapsedTime
  )
  where

import Data.Time (getCurrentTime, diffUTCTime, NominalDiffTime, UTCTime)
import Data.IORef (IORef)

import Data.List (intercalate)

import Data.Ratio
import GHC.Real   (Ratio((:%)))

import Numeric (showFFloat)

import Control.Monad.Trans (liftIO, MonadIO)
import Control.DeepSeq (NFData(rnf))


-- | Specify how to save timing information, if at all.
data Timing = NoTiming | PrintTiming | SaveTiming (IORef NominalDiffTime)

-- | Show 'NominalDiffTime' in human readable form. 'NominalDiffTime' is
-- essentially picoseconds (10^-12 seconds). We show it so that
-- it's represented at the day:hour:minute:second.XXX granularity.
showTDiff :: NominalDiffTime -> String
showTDiff :: NominalDiffTime -> String
showTDiff NominalDiffTime
diff
   | Integer
denom Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer
1    -- Should never happen! But just in case.
   = NominalDiffTime -> String
forall a. Show a => a -> String
show NominalDiffTime
diff
   | Bool
True
   = String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
":" [String]
fields
   where total, denom :: Integer
         Integer
total :% Integer
denom = (Integer
picoFactor Integer -> Integer -> Ratio Integer
forall a. Integral a => a -> a -> Ratio a
% Integer
1) Ratio Integer -> Ratio Integer -> Ratio Integer
forall a. Num a => a -> a -> a
* NominalDiffTime -> Ratio Integer
forall a. Real a => a -> Ratio Integer
toRational NominalDiffTime
diff

         -- there are 10^12 pico-seconds in a second
         picoFactor :: Integer
         picoFactor :: Integer
picoFactor = (Integer
10 :: Integer) Integer -> Integer -> Integer
forall a b. (Num a, Integral b) => a -> b -> a
^ (Integer
12 :: Integer)

         (Integer
s2p, Integer
m2s, Integer
h2m, Integer
d2h) = case Int -> [Integer] -> [Integer]
forall a. Int -> [a] -> [a]
drop Int
1 ([Integer] -> [Integer]) -> [Integer] -> [Integer]
forall a b. (a -> b) -> a -> b
$ (Integer -> Integer -> Integer)
-> Integer -> [Integer] -> [Integer]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
(*) Integer
1 [Integer
picoFactor, Integer
60, Integer
60, Integer
24] of
                                  (Integer
s2pv : Integer
m2sv : Integer
h2mv : Integer
d2hv : [Integer]
_) -> (Integer
s2pv, Integer
m2sv, Integer
h2mv, Integer
d2hv)
                                  [Integer]
_                               -> (Integer
0, Integer
0, Integer
0, Integer
0)  -- won't ever happen

         (Integer
days,    Integer
days')    = Integer
total    Integer -> Integer -> (Integer, Integer)
forall a. Integral a => a -> a -> (a, a)
`divMod` Integer
d2h
         (Integer
hours,   Integer
hours')   = Integer
days'    Integer -> Integer -> (Integer, Integer)
forall a. Integral a => a -> a -> (a, a)
`divMod` Integer
h2m
         (Integer
minutes, Integer
seconds') = Integer
hours'   Integer -> Integer -> (Integer, Integer)
forall a. Integral a => a -> a -> (a, a)
`divMod` Integer
m2s
         (Integer
seconds, Integer
picos)    = Integer
seconds' Integer -> Integer -> (Integer, Integer)
forall a. Integral a => a -> a -> (a, a)
`divMod` Integer
s2p
         secondsPicos :: String
secondsPicos        =  Integer -> String
forall a. Show a => a -> String
show Integer
seconds
                             String -> String -> String
forall a. [a] -> [a] -> [a]
++ (Char -> Bool) -> String -> String
forall a. (a -> Bool) -> [a] -> [a]
dropWhile (Char -> Char -> Bool
forall a. Eq a => a -> a -> Bool
/= Char
'.') (Maybe Int -> Double -> String -> String
forall a. RealFloat a => Maybe Int -> a -> String -> String
showFFloat (Int -> Maybe Int
forall a. a -> Maybe a
Just Int
3) (Integer -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral Integer
picos Double -> Double -> Double
forall a. Num a => a -> a -> a
* (Double
10Double -> Double -> Double
forall a. Floating a => a -> a -> a
**(-Double
12) :: Double)) String
"s")

         aboveSeconds :: [String]
aboveSeconds = ((Char, Integer) -> String) -> [(Char, Integer)] -> [String]
forall a b. (a -> b) -> [a] -> [b]
map (\(Char
t, Integer
v) -> Integer -> String
forall a. Show a => a -> String
show Integer
v String -> String -> String
forall a. [a] -> [a] -> [a]
++ [Char
t]) ([(Char, Integer)] -> [String]) -> [(Char, Integer)] -> [String]
forall a b. (a -> b) -> a -> b
$ ((Char, Integer) -> Bool) -> [(Char, Integer)] -> [(Char, Integer)]
forall a. (a -> Bool) -> [a] -> [a]
dropWhile (\(Char, Integer)
p -> (Char, Integer) -> Integer
forall a b. (a, b) -> b
snd (Char, Integer)
p Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0) [(Char
'd', Integer
days), (Char
'h', Integer
hours), (Char
'm', Integer
minutes)]
         fields :: [String]
fields       = [String]
aboveSeconds [String] -> [String] -> [String]
forall a. [a] -> [a] -> [a]
++ [String
secondsPicos]

-- | Run an action and measure how long it took. We reduce the result to weak-head-normal-form,
-- so beware of the cases if the result is lazily computed; in which case we'll stop soon as the
-- result is in WHNF, and not necessarily fully calculated.
timeIf :: MonadIO m => Bool -> m a -> m (Maybe NominalDiffTime, a)
timeIf :: forall (m :: * -> *) a.
MonadIO m =>
Bool -> m a -> m (Maybe NominalDiffTime, a)
timeIf Bool
measureTime m a
act = do mbStart <- Bool -> m (Maybe UTCTime)
forall (m :: * -> *). MonadIO m => Bool -> m (Maybe UTCTime)
getTimeStampIf Bool
measureTime
                            r     <- act
                            r `seq` do mbElapsed <- getElapsedTime mbStart
                                       pure (mbElapsed, r)

-- | Same as 'timeIf', except we fully evaluate the result, via its the NFData instance.
timeIfRNF :: (NFData a, MonadIO m) => Bool -> m a -> m (Maybe NominalDiffTime, a)
timeIfRNF :: forall a (m :: * -> *).
(NFData a, MonadIO m) =>
Bool -> m a -> m (Maybe NominalDiffTime, a)
timeIfRNF Bool
measureTime m a
act = Bool -> m a -> m (Maybe NominalDiffTime, a)
forall (m :: * -> *) a.
MonadIO m =>
Bool -> m a -> m (Maybe NominalDiffTime, a)
timeIf Bool
measureTime (m a
act m a -> (a -> m a) -> m a
forall a b. m a -> (a -> m b) -> m b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \a
r -> a -> ()
forall a. NFData a => a -> ()
rnf a
r () -> m a -> m a
forall a b. a -> b -> b
`seq` a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a
r)

-- | Get a time-stamp if we're asked to do so
getTimeStampIf  :: MonadIO m => Bool -> m (Maybe UTCTime)
getTimeStampIf :: forall (m :: * -> *). MonadIO m => Bool -> m (Maybe UTCTime)
getTimeStampIf Bool
measureTime
  | Bool -> Bool
not Bool
measureTime = Maybe UTCTime -> m (Maybe UTCTime)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe UTCTime
forall a. Maybe a
Nothing
  | Bool
True            = IO (Maybe UTCTime) -> m (Maybe UTCTime)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe UTCTime) -> m (Maybe UTCTime))
-> IO (Maybe UTCTime) -> m (Maybe UTCTime)
forall a b. (a -> b) -> a -> b
$ UTCTime -> Maybe UTCTime
forall a. a -> Maybe a
Just (UTCTime -> Maybe UTCTime) -> IO UTCTime -> IO (Maybe UTCTime)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO UTCTime
getCurrentTime

-- | Get elapsed time from the given beginning time, if any.
getElapsedTime :: MonadIO m => Maybe UTCTime -> m (Maybe NominalDiffTime)
getElapsedTime :: forall (m :: * -> *).
MonadIO m =>
Maybe UTCTime -> m (Maybe NominalDiffTime)
getElapsedTime Maybe UTCTime
Nothing      = Maybe NominalDiffTime -> m (Maybe NominalDiffTime)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe NominalDiffTime
forall a. Maybe a
Nothing
getElapsedTime (Just UTCTime
start) = IO (Maybe NominalDiffTime) -> m (Maybe NominalDiffTime)
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO (Maybe NominalDiffTime) -> m (Maybe NominalDiffTime))
-> IO (Maybe NominalDiffTime) -> m (Maybe NominalDiffTime)
forall a b. (a -> b) -> a -> b
$ do e <- IO UTCTime
getCurrentTime
                                          pure $ Just (diffUTCTime e start)