-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.ProofTools.Sum
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- Example inductive proof to show partial correctness of the traditional
-- for-loop sum algorithm:
--
-- @
--     s = 0
--     i = 0
--     while i <= n:
--        s += i
--        i++
-- @
--
-- We prove the loop invariant and establish partial correctness that
-- @s@ is the sum of all numbers up to and including @n@ upon termination.
-----------------------------------------------------------------------------

{-# LANGUAGE DeriveTraversable     #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NamedFieldPuns        #-}
{-# LANGUAGE TypeFamilies          #-}

{-# OPTIONS_GHC -Wall -Werror #-}

module Documentation.SBV.Examples.ProofTools.Sum where

import Data.SBV
import Data.SBV.Tools.Induction

-- * System state

-- | System state. We simply have two components, parameterized
-- over the type so we can put in both concrete and symbolic values.
data S a = S { forall a. S a -> a
s :: a, forall a. S a -> a
i :: a, forall a. S a -> a
n :: a } deriving (Int -> S a -> ShowS
[S a] -> ShowS
S a -> String
(Int -> S a -> ShowS)
-> (S a -> String) -> ([S a] -> ShowS) -> Show (S a)
forall a. Show a => Int -> S a -> ShowS
forall a. Show a => [S a] -> ShowS
forall a. Show a => S a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> S a -> ShowS
showsPrec :: Int -> S a -> ShowS
$cshow :: forall a. Show a => S a -> String
show :: S a -> String
$cshowList :: forall a. Show a => [S a] -> ShowS
showList :: [S a] -> ShowS
Show, Functor S
Foldable S
(Functor S, Foldable S) =>
(forall (f :: * -> *) a b.
 Applicative f =>
 (a -> f b) -> S a -> f (S b))
-> (forall (f :: * -> *) a. Applicative f => S (f a) -> f (S a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> S a -> m (S b))
-> (forall (m :: * -> *) a. Monad m => S (m a) -> m (S a))
-> Traversable S
forall (t :: * -> *).
(Functor t, Foldable t) =>
(forall (f :: * -> *) a b.
 Applicative f =>
 (a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (m :: * -> *) a. Monad m => S (m a) -> m (S a)
forall (f :: * -> *) a. Applicative f => S (f a) -> f (S a)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> S a -> m (S b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> S a -> f (S b)
$ctraverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> S a -> f (S b)
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> S a -> f (S b)
$csequenceA :: forall (f :: * -> *) a. Applicative f => S (f a) -> f (S a)
sequenceA :: forall (f :: * -> *) a. Applicative f => S (f a) -> f (S a)
$cmapM :: forall (m :: * -> *) a b. Monad m => (a -> m b) -> S a -> m (S b)
mapM :: forall (m :: * -> *) a b. Monad m => (a -> m b) -> S a -> m (S b)
$csequence :: forall (m :: * -> *) a. Monad m => S (m a) -> m (S a)
sequence :: forall (m :: * -> *) a. Monad m => S (m a) -> m (S a)
Traversable, (forall a b. (a -> b) -> S a -> S b)
-> (forall a b. a -> S b -> S a) -> Functor S
forall a b. a -> S b -> S a
forall a b. (a -> b) -> S a -> S b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> S a -> S b
fmap :: forall a b. (a -> b) -> S a -> S b
$c<$ :: forall a b. a -> S b -> S a
<$ :: forall a b. a -> S b -> S a
Functor, (forall m. Monoid m => S m -> m)
-> (forall m a. Monoid m => (a -> m) -> S a -> m)
-> (forall m a. Monoid m => (a -> m) -> S a -> m)
-> (forall a b. (a -> b -> b) -> b -> S a -> b)
-> (forall a b. (a -> b -> b) -> b -> S a -> b)
-> (forall b a. (b -> a -> b) -> b -> S a -> b)
-> (forall b a. (b -> a -> b) -> b -> S a -> b)
-> (forall a. (a -> a -> a) -> S a -> a)
-> (forall a. (a -> a -> a) -> S a -> a)
-> (forall a. S a -> [a])
-> (forall a. S a -> Bool)
-> (forall a. S a -> Int)
-> (forall a. Eq a => a -> S a -> Bool)
-> (forall a. Ord a => S a -> a)
-> (forall a. Ord a => S a -> a)
-> (forall a. Num a => S a -> a)
-> (forall a. Num a => S a -> a)
-> Foldable S
forall a. Eq a => a -> S a -> Bool
forall a. Num a => S a -> a
forall a. Ord a => S a -> a
forall m. Monoid m => S m -> m
forall a. S a -> Bool
forall a. S a -> Int
forall a. S a -> [a]
forall a. (a -> a -> a) -> S a -> a
forall m a. Monoid m => (a -> m) -> S a -> m
forall b a. (b -> a -> b) -> b -> S a -> b
forall a b. (a -> b -> b) -> b -> S a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
$cfold :: forall m. Monoid m => S m -> m
fold :: forall m. Monoid m => S m -> m
$cfoldMap :: forall m a. Monoid m => (a -> m) -> S a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> S a -> m
$cfoldMap' :: forall m a. Monoid m => (a -> m) -> S a -> m
foldMap' :: forall m a. Monoid m => (a -> m) -> S a -> m
$cfoldr :: forall a b. (a -> b -> b) -> b -> S a -> b
foldr :: forall a b. (a -> b -> b) -> b -> S a -> b
$cfoldr' :: forall a b. (a -> b -> b) -> b -> S a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> S a -> b
$cfoldl :: forall b a. (b -> a -> b) -> b -> S a -> b
foldl :: forall b a. (b -> a -> b) -> b -> S a -> b
$cfoldl' :: forall b a. (b -> a -> b) -> b -> S a -> b
foldl' :: forall b a. (b -> a -> b) -> b -> S a -> b
$cfoldr1 :: forall a. (a -> a -> a) -> S a -> a
foldr1 :: forall a. (a -> a -> a) -> S a -> a
$cfoldl1 :: forall a. (a -> a -> a) -> S a -> a
foldl1 :: forall a. (a -> a -> a) -> S a -> a
$ctoList :: forall a. S a -> [a]
toList :: forall a. S a -> [a]
$cnull :: forall a. S a -> Bool
null :: forall a. S a -> Bool
$clength :: forall a. S a -> Int
length :: forall a. S a -> Int
$celem :: forall a. Eq a => a -> S a -> Bool
elem :: forall a. Eq a => a -> S a -> Bool
$cmaximum :: forall a. Ord a => S a -> a
maximum :: forall a. Ord a => S a -> a
$cminimum :: forall a. Ord a => S a -> a
minimum :: forall a. Ord a => S a -> a
$csum :: forall a. Num a => S a -> a
sum :: forall a. Num a => S a -> a
$cproduct :: forall a. Num a => S a -> a
product :: forall a. Num a => S a -> a
Foldable)

-- | 'Queriable instance for our state
instance Queriable IO (S SInteger) where
  type QueryResult (S SInteger) = S Integer
  create :: QueryT IO (S SInteger)
create = SInteger -> SInteger -> SInteger -> S SInteger
forall a. a -> a -> a -> S a
S (SInteger -> SInteger -> SInteger -> S SInteger)
-> QueryT IO SInteger
-> QueryT IO (SInteger -> SInteger -> S SInteger)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> QueryT IO SInteger
forall a (m :: * -> *).
(MonadIO m, MonadQuery m, SymVal a) =>
m (SBV a)
freshVar_ QueryT IO (SInteger -> SInteger -> S SInteger)
-> QueryT IO SInteger -> QueryT IO (SInteger -> S SInteger)
forall a b. QueryT IO (a -> b) -> QueryT IO a -> QueryT IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> QueryT IO SInteger
forall a (m :: * -> *).
(MonadIO m, MonadQuery m, SymVal a) =>
m (SBV a)
freshVar_ QueryT IO (SInteger -> S SInteger)
-> QueryT IO SInteger -> QueryT IO (S SInteger)
forall a b. QueryT IO (a -> b) -> QueryT IO a -> QueryT IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> QueryT IO SInteger
forall a (m :: * -> *).
(MonadIO m, MonadQuery m, SymVal a) =>
m (SBV a)
freshVar_

-- | Encoding partial correctness of the sum algorithm. We have:
--
-- >>> sumCorrect
-- Q.E.D.
sumCorrect :: IO (InductionResult (S Integer))
sumCorrect :: IO (InductionResult (S Integer))
sumCorrect = Bool
-> Symbolic ()
-> (S SInteger -> SBool)
-> (S SInteger -> S SInteger -> SBool)
-> [(String, S SInteger -> SBool)]
-> (S SInteger -> SBool)
-> (S SInteger -> (SBool, SBool))
-> IO (InductionResult (S Integer))
forall res st.
(Show res, Queriable IO st, res ~ QueryResult st) =>
Bool
-> Symbolic ()
-> (st -> SBool)
-> (st -> st -> SBool)
-> [(String, st -> SBool)]
-> (st -> SBool)
-> (st -> (SBool, SBool))
-> IO (InductionResult res)
induct Bool
chatty Symbolic ()
setup S SInteger -> SBool
initial S SInteger -> S SInteger -> SBool
trans [(String, S SInteger -> SBool)]
strengthenings S SInteger -> SBool
inv S SInteger -> (SBool, SBool)
goal
  where -- Set this to True for SBV to print steps as it proceeds
        -- through the inductive proof
        chatty :: Bool
        chatty :: Bool
chatty = Bool
False

        -- This is where we would put solver options, typically via
        -- calls to 'Data.SBV.setOption'. We do not need any for this problem,
        -- so we simply do nothing.
        setup :: Symbolic ()
        setup :: Symbolic ()
setup = () -> Symbolic ()
forall a. a -> SymbolicT IO a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

        -- Initially, @s@ and @i@ are both @0@. We also require @n@ to be at least @0@.
        initial :: S SInteger -> SBool
        initial :: S SInteger -> SBool
initial S{SInteger
s :: forall a. S a -> a
s :: SInteger
s, SInteger
i :: forall a. S a -> a
i :: SInteger
i, SInteger
n :: forall a. S a -> a
n :: SInteger
n} = SInteger
s SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
0 SBool -> SBool -> SBool
.&& SInteger
i SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
0 SBool -> SBool -> SBool
.&& SInteger
n SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.>= SInteger
0

        -- We code the algorithm almost literally in SBV notation:
        trans :: S SInteger -> S SInteger -> SBool
        trans :: S SInteger -> S SInteger -> SBool
trans S{SInteger
s :: forall a. S a -> a
s :: SInteger
s, SInteger
i :: forall a. S a -> a
i :: SInteger
i, SInteger
n :: forall a. S a -> a
n :: SInteger
n} S{s :: forall a. S a -> a
s = SInteger
s', i :: forall a. S a -> a
i = SInteger
i', n :: forall a. S a -> a
n = SInteger
n'} = (SInteger
s', SInteger
i', SInteger
n') (SInteger, SInteger, SInteger)
-> (SInteger, SInteger, SInteger) -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SBool
-> (SInteger, SInteger, SInteger)
-> (SInteger, SInteger, SInteger)
-> (SInteger, SInteger, SInteger)
forall a. Mergeable a => SBool -> a -> a -> a
ite (SInteger
i SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.<= SInteger
n)
                                                                          (SInteger
sSInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+SInteger
i, SInteger
iSInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+SInteger
1, SInteger
n)
                                                                          (SInteger
s  , SInteger
i  , SInteger
n)

        -- No strengthenings needed for this problem!
        strengthenings :: [(String, S SInteger -> SBool)]
        strengthenings :: [(String, S SInteger -> SBool)]
strengthenings = []

        -- Loop invariant: @i@ remains at most @n+1@ and @s@ the sum of
        -- all the numbers up-to @i-1@.
        inv :: S SInteger -> SBool
        inv :: S SInteger -> SBool
inv S{SInteger
s :: forall a. S a -> a
s :: SInteger
s, SInteger
i :: forall a. S a -> a
i :: SInteger
i, SInteger
n :: forall a. S a -> a
n :: SInteger
n} =    SInteger
i SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.<= SInteger
nSInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+SInteger
1
                        SBool -> SBool -> SBool
.&& SInteger
s SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== (SInteger
i SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
* (SInteger
i SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
- SInteger
1)) SInteger -> SInteger -> SInteger
forall a. SDivisible a => a -> a -> a
`sDiv` SInteger
2

        -- Final goal. When the termination condition holds, the sum is
        -- equal to all the numbers up to and including @n@. Note that
        -- SBV does not prove the termination condition; it simply is
        -- the indication that the loop has ended as specified by the user.
        goal :: S SInteger -> (SBool, SBool)
        goal :: S SInteger -> (SBool, SBool)
goal S{SInteger
s :: forall a. S a -> a
s :: SInteger
s, SInteger
i :: forall a. S a -> a
i :: SInteger
i, SInteger
n :: forall a. S a -> a
n :: SInteger
n} = (SInteger
i SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SInteger
nSInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+SInteger
1, SInteger
s SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== (SInteger
n SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
* (SInteger
nSInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+SInteger
1)) SInteger -> SInteger -> SInteger
forall a. SDivisible a => a -> a -> a
`sDiv` SInteger
2)