-----------------------------------------------------------------------------
-- |
-- Module    : Data.SBV.Tools.KD.Utils
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- Various KnuckleDrugger machinery.
-----------------------------------------------------------------------------

{-# LANGUAGE DeriveAnyClass             #-}
{-# LANGUAGE DeriveGeneric              #-}
{-# LANGUAGE DerivingStrategies         #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE NamedFieldPuns             #-}

{-# OPTIONS_GHC -Wall -Werror #-}

module Data.SBV.Tools.KD.Utils (
         KD, runKD, runKDWith, Proof(..)
       , startKD, finishKD, getKDState, getKDConfig, KDState(..), KDStats(..)
       , RootOfTrust(..), calculateRootOfTrust, message, updStats
       ) where

import Control.Monad.Reader (ReaderT, runReaderT, MonadReader, ask, liftIO)
import Control.Monad.Trans  (MonadIO)

import Data.Time (NominalDiffTime)

import Data.List (intercalate, nub, sort)
import System.IO (hFlush, stdout)

import Data.SBV.Core.Data (SBool)
import Data.SBV.Core.Symbolic  (SMTConfig, KDOptions(..))
import Data.SBV.Provers.Prover (defaultSMTCfg, SMTConfig(..))

import Data.SBV.Utils.TDiff (showTDiff, timeIf)
import Control.DeepSeq (NFData(rnf))

import Data.IORef

import GHC.Generics
import Data.Dynamic

-- | Various statistics we collect
data KDStats = KDStats { KDStats -> Int
noOfCheckSats :: Int
                       , KDStats -> NominalDiffTime
solverElapsed :: NominalDiffTime
                       }

-- | Extra state we carry in a KD context
data KDState = KDState { KDState -> IORef KDStats
stats  :: IORef KDStats
                       , KDState -> SMTConfig
config :: SMTConfig
                       }

-- | Monad for running KnuckleDragger proofs in.
newtype KD a = KD (ReaderT KDState IO a)
            deriving newtype (Functor KD
Functor KD =>
(forall a. a -> KD a)
-> (forall a b. KD (a -> b) -> KD a -> KD b)
-> (forall a b c. (a -> b -> c) -> KD a -> KD b -> KD c)
-> (forall a b. KD a -> KD b -> KD b)
-> (forall a b. KD a -> KD b -> KD a)
-> Applicative KD
forall a. a -> KD a
forall a b. KD a -> KD b -> KD a
forall a b. KD a -> KD b -> KD b
forall a b. KD (a -> b) -> KD a -> KD b
forall a b c. (a -> b -> c) -> KD a -> KD b -> KD c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall a. a -> KD a
pure :: forall a. a -> KD a
$c<*> :: forall a b. KD (a -> b) -> KD a -> KD b
<*> :: forall a b. KD (a -> b) -> KD a -> KD b
$cliftA2 :: forall a b c. (a -> b -> c) -> KD a -> KD b -> KD c
liftA2 :: forall a b c. (a -> b -> c) -> KD a -> KD b -> KD c
$c*> :: forall a b. KD a -> KD b -> KD b
*> :: forall a b. KD a -> KD b -> KD b
$c<* :: forall a b. KD a -> KD b -> KD a
<* :: forall a b. KD a -> KD b -> KD a
Applicative, (forall a b. (a -> b) -> KD a -> KD b)
-> (forall a b. a -> KD b -> KD a) -> Functor KD
forall a b. a -> KD b -> KD a
forall a b. (a -> b) -> KD a -> KD b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> KD a -> KD b
fmap :: forall a b. (a -> b) -> KD a -> KD b
$c<$ :: forall a b. a -> KD b -> KD a
<$ :: forall a b. a -> KD b -> KD a
Functor, Applicative KD
Applicative KD =>
(forall a b. KD a -> (a -> KD b) -> KD b)
-> (forall a b. KD a -> KD b -> KD b)
-> (forall a. a -> KD a)
-> Monad KD
forall a. a -> KD a
forall a b. KD a -> KD b -> KD b
forall a b. KD a -> (a -> KD b) -> KD b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall a b. KD a -> (a -> KD b) -> KD b
>>= :: forall a b. KD a -> (a -> KD b) -> KD b
$c>> :: forall a b. KD a -> KD b -> KD b
>> :: forall a b. KD a -> KD b -> KD b
$creturn :: forall a. a -> KD a
return :: forall a. a -> KD a
Monad, Monad KD
Monad KD => (forall a. IO a -> KD a) -> MonadIO KD
forall a. IO a -> KD a
forall (m :: * -> *).
Monad m =>
(forall a. IO a -> m a) -> MonadIO m
$cliftIO :: forall a. IO a -> KD a
liftIO :: forall a. IO a -> KD a
MonadIO, MonadReader KDState, Monad KD
Monad KD => (forall a. String -> KD a) -> MonadFail KD
forall a. String -> KD a
forall (m :: * -> *).
Monad m =>
(forall a. String -> m a) -> MonadFail m
$cfail :: forall a. String -> KD a
fail :: forall a. String -> KD a
MonadFail)

-- | Run a KD proof, using the default configuration.
runKD :: KD a -> IO a
runKD :: forall a. KD a -> IO a
runKD = SMTConfig -> KD a -> IO a
forall a. SMTConfig -> KD a -> IO a
runKDWith SMTConfig
defaultSMTCfg

-- | Run a KD proof, using the given configuration.
runKDWith :: SMTConfig -> KD a -> IO a
runKDWith :: forall a. SMTConfig -> KD a -> IO a
runKDWith cfg :: SMTConfig
cfg@SMTConfig{kdOptions :: SMTConfig -> KDOptions
kdOptions = KDOptions{Bool
measureTime :: Bool
measureTime :: KDOptions -> Bool
measureTime}} (KD ReaderT KDState IO a
f) = do
   rStats <- KDStats -> IO (IORef KDStats)
forall a. a -> IO (IORef a)
newIORef (KDStats -> IO (IORef KDStats)) -> KDStats -> IO (IORef KDStats)
forall a b. (a -> b) -> a -> b
$ KDStats { noOfCheckSats :: Int
noOfCheckSats = Int
0, solverElapsed :: NominalDiffTime
solverElapsed = NominalDiffTime
0 }
   (mbT, r) <- timeIf measureTime $ runReaderT f KDState {config = cfg, stats = rStats}
   case mbT of
     Maybe NominalDiffTime
Nothing -> () -> IO ()
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
     Just NominalDiffTime
t  -> do KDStats noOfCheckSats solverTime <- IORef KDStats -> IO KDStats
forall a. IORef a -> IO a
readIORef IORef KDStats
rStats

                   let stats = [ (String
"SBV",       NominalDiffTime -> String
showTDiff (NominalDiffTime
t NominalDiffTime -> NominalDiffTime -> NominalDiffTime
forall a. Num a => a -> a -> a
- NominalDiffTime
solverTime))
                               , (String
"Solver",    NominalDiffTime -> String
showTDiff NominalDiffTime
solverTime)
                               , (String
"Total",     NominalDiffTime -> String
showTDiff NominalDiffTime
t)
                               , (String
"Decisions", Int -> String
forall a. Show a => a -> String
show Int
noOfCheckSats)
                               ]

                   message cfg $ '[' : intercalate ", " [k ++ ": " ++ v | (k, v) <- stats] ++ "]\n"
   pure r

-- | get the state
getKDState :: KD KDState
getKDState :: KD KDState
getKDState = KD KDState
forall r (m :: * -> *). MonadReader r m => m r
ask

-- | get the configuration
getKDConfig :: KD SMTConfig
getKDConfig :: KD SMTConfig
getKDConfig = KDState -> SMTConfig
config (KDState -> SMTConfig) -> KD KDState -> KD SMTConfig
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> KD KDState
getKDState

-- | Update stats
updStats :: MonadIO m => KDState -> (KDStats -> KDStats) -> m ()
updStats :: forall (m :: * -> *).
MonadIO m =>
KDState -> (KDStats -> KDStats) -> m ()
updStats KDState{IORef KDStats
stats :: KDState -> IORef KDStats
stats :: IORef KDStats
stats} KDStats -> KDStats
u = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ IORef KDStats -> (KDStats -> KDStats) -> IO ()
forall a. IORef a -> (a -> a) -> IO ()
modifyIORef' IORef KDStats
stats KDStats -> KDStats
u

-- | Display the message if not quiet. Note that we don't print a newline; so the message must have it if needed.
message :: MonadIO m => SMTConfig -> String -> m ()
message :: forall (m :: * -> *). MonadIO m => SMTConfig -> String -> m ()
message SMTConfig{kdOptions :: SMTConfig -> KDOptions
kdOptions = KDOptions{Bool
quiet :: Bool
quiet :: KDOptions -> Bool
quiet}} String
s
  | Bool
quiet = () -> m ()
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  | Bool
True  = IO () -> m ()
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> m ()) -> IO () -> m ()
forall a b. (a -> b) -> a -> b
$ String -> IO ()
putStr String
s

-- | Start a proof. We return the number of characters we printed, so the finisher can align the result.
startKD :: SMTConfig -> Bool -> String -> [String] -> IO Int
startKD :: SMTConfig -> Bool -> String -> [String] -> IO Int
startKD SMTConfig
cfg Bool
newLine String
what [String]
nms = do SMTConfig -> String -> IO ()
forall (m :: * -> *). MonadIO m => SMTConfig -> String -> m ()
message SMTConfig
cfg (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ String
line String -> String -> String
forall a. [a] -> [a] -> [a]
++ if Bool
newLine then String
"\n" else String
""
                                  Handle -> IO ()
hFlush Handle
stdout
                                  Int -> IO Int
forall a. a -> IO a
forall (m :: * -> *) a. Monad m => a -> m a
return (String -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length String
line)
  where tab :: Int
tab    = Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
* [String] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length (Int -> [String] -> [String]
forall a. Int -> [a] -> [a]
drop Int
1 [String]
nms)
        indent :: String
indent = Int -> Char -> String
forall a. Int -> a -> [a]
replicate Int
tab Char
' '
        tag :: String
tag    = String
what String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
": " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
"." ((String -> Bool) -> [String] -> [String]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (String -> Bool) -> String -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. String -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null) [String]
nms)
        line :: String
line   = String
indent String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
tag

-- | Finish a proof. First argument is what we got from the call of 'startKD' above.
finishKD :: SMTConfig -> String -> (Int, Maybe NominalDiffTime) -> [NominalDiffTime] -> IO ()
finishKD :: SMTConfig
-> String
-> (Int, Maybe NominalDiffTime)
-> [NominalDiffTime]
-> IO ()
finishKD cfg :: SMTConfig
cfg@SMTConfig{kdOptions :: SMTConfig -> KDOptions
kdOptions = KDOptions{Int
ribbonLength :: Int
ribbonLength :: KDOptions -> Int
ribbonLength}} String
what (Int
skip, Maybe NominalDiffTime
mbT) [NominalDiffTime]
extraTiming =
   SMTConfig -> String -> IO ()
forall (m :: * -> *). MonadIO m => SMTConfig -> String -> m ()
message SMTConfig
cfg (String -> IO ()) -> String -> IO ()
forall a b. (a -> b) -> a -> b
$ Int -> Char -> String
forall a. Int -> a -> [a]
replicate (Int
ribbonLength Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
skip) Char
' ' String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
what String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
timing String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
extras String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"\n"
 where timing :: String
timing = String
-> (NominalDiffTime -> String) -> Maybe NominalDiffTime -> String
forall b a. b -> (a -> b) -> Maybe a -> b
maybe String
"" ((Char
' ' Char -> String -> String
forall a. a -> [a] -> [a]
:) (String -> String)
-> (NominalDiffTime -> String) -> NominalDiffTime -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NominalDiffTime -> String
mkTiming) Maybe NominalDiffTime
mbT
       extras :: String
extras = (NominalDiffTime -> String) -> [NominalDiffTime] -> String
forall (t :: * -> *) a b. Foldable t => (a -> [b]) -> t a -> [b]
concatMap NominalDiffTime -> String
mkTiming [NominalDiffTime]
extraTiming

       mkTiming :: NominalDiffTime -> String
mkTiming NominalDiffTime
t = Char
'[' Char -> String -> String
forall a. a -> [a] -> [a]
: NominalDiffTime -> String
showTDiff NominalDiffTime
t String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"]"

-- | Keeping track of where the sorry originates from. Used in displaying dependencies.
data RootOfTrust = None        -- ^ Trusts nothing (aside from SBV, underlying solver etc.)
                 | Self        -- ^ Trusts itself, i.e., established by a call to sorry
                 | Prop String -- ^ Trusts a parent that itself trusts something else. Note the name here is the
                               --   name of the proposition itself, not the parent that's trusted.
                deriving (RootOfTrust -> ()
(RootOfTrust -> ()) -> NFData RootOfTrust
forall a. (a -> ()) -> NFData a
$crnf :: RootOfTrust -> ()
rnf :: RootOfTrust -> ()
NFData, (forall x. RootOfTrust -> Rep RootOfTrust x)
-> (forall x. Rep RootOfTrust x -> RootOfTrust)
-> Generic RootOfTrust
forall x. Rep RootOfTrust x -> RootOfTrust
forall x. RootOfTrust -> Rep RootOfTrust x
forall a.
(forall x. a -> Rep a x) -> (forall x. Rep a x -> a) -> Generic a
$cfrom :: forall x. RootOfTrust -> Rep RootOfTrust x
from :: forall x. RootOfTrust -> Rep RootOfTrust x
$cto :: forall x. Rep RootOfTrust x -> RootOfTrust
to :: forall x. Rep RootOfTrust x -> RootOfTrust
Generic)

-- | Proof for a property. This type is left abstract, i.e., the only way to create on is via a
-- call to lemma/theorem etc., ensuring soundness. (Note that the trusted-code base here
-- is still large: The underlying solver, SBV, and KnuckleDragger kernel itself. But this
-- mechanism ensures we can't create proven things out of thin air, following the standard LCF
-- methodology.)
data Proof = Proof { Proof -> RootOfTrust
rootOfTrust :: RootOfTrust  -- ^ Root of trust, described above.
                   , Proof -> Bool
isUserAxiom :: Bool         -- ^ Was this an axiom given by the user?
                   , Proof -> SBool
getProof    :: SBool        -- ^ Get the underlying boolean
                   , Proof -> Dynamic
getProp     :: Dynamic      -- ^ The actual proposition
                   , Proof -> String
proofName   :: String       -- ^ User given name
                   }

-- | NFData ignores the getProp field
instance NFData Proof where
  rnf :: Proof -> ()
rnf (Proof RootOfTrust
rootOfTrust Bool
isUserAxiom SBool
getProof Dynamic
_getProp String
proofName) =     RootOfTrust -> ()
forall a. NFData a => a -> ()
rnf RootOfTrust
rootOfTrust
                                                                  () -> () -> ()
forall a b. a -> b -> b
`seq` Bool -> ()
forall a. NFData a => a -> ()
rnf Bool
isUserAxiom
                                                                  () -> () -> ()
forall a b. a -> b -> b
`seq` SBool -> ()
forall a. NFData a => a -> ()
rnf SBool
getProof
                                                                  () -> () -> ()
forall a b. a -> b -> b
`seq` String -> ()
forall a. NFData a => a -> ()
rnf String
proofName

-- | Show instance for t'Proof'
instance Show Proof where
  show :: Proof -> String
show Proof{RootOfTrust
rootOfTrust :: Proof -> RootOfTrust
rootOfTrust :: RootOfTrust
rootOfTrust, Bool
isUserAxiom :: Proof -> Bool
isUserAxiom :: Bool
isUserAxiom, String
proofName :: Proof -> String
proofName :: String
proofName} = Char
'[' Char -> String -> String
forall a. a -> [a] -> [a]
: String
tag String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"] " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
proofName
     where tag :: String
tag | Bool
isUserAxiom = String
"Axiom"
               | Bool
True        = case RootOfTrust
rootOfTrust of
                                 RootOfTrust
None   -> String
"Proven"
                                 RootOfTrust
Self   -> String
"Sorry"
                                 Prop String
s -> String
"Modulo: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
s

-- | Calculate the root of trust for a proof. The string is the modulo text, if any.
calculateRootOfTrust :: String -> [Proof] -> (RootOfTrust, String)
calculateRootOfTrust :: String -> [Proof] -> (RootOfTrust, String)
calculateRootOfTrust String
nm [Proof]
by | Bool -> Bool
not Bool
hasSelf Bool -> Bool -> Bool
&& [String] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [String]
depNames = (RootOfTrust
None,    String
"")
                           | Bool
True                         = (String -> RootOfTrust
Prop String
nm, String
" [Modulo: " String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
why String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
"]")
   where why :: String
why | Bool
hasSelf = String
"sorry"
             | Bool
True    = String -> [String] -> String
forall a. [a] -> [[a]] -> [a]
intercalate String
", " [String]
depNames

         -- What's the root-of-trust for this node?
         -- If there are no "sorry" parents, and no parent nodes
         -- that are marked with a root of trust, then we don't have it either.
         -- Otherwise, mark it accordingly.
         parentRoots :: [RootOfTrust]
parentRoots = (Proof -> RootOfTrust) -> [Proof] -> [RootOfTrust]
forall a b. (a -> b) -> [a] -> [b]
map Proof -> RootOfTrust
rootOfTrust [Proof]
by
         hasSelf :: Bool
hasSelf     = Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [()] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [() | RootOfTrust
Self   <- [RootOfTrust]
parentRoots]
         depNames :: [String]
depNames    = [String] -> [String]
forall a. Eq a => [a] -> [a]
nub ([String] -> [String]) -> [String] -> [String]
forall a b. (a -> b) -> a -> b
$ [String] -> [String]
forall a. Ord a => [a] -> [a]
sort [String
p  | Prop String
p <- [RootOfTrust]
parentRoots]