-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.TP.Kadane
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- Proving the correctness of Kadane's algorithm for computing the maximum
-- sum of any contiguous list (maximum segment sum problem).
--
-- Kadane's algorithm is a classic dynamic programming algorithm that solves
-- the maximum segment sum problem in O(n) time. Given a list of integers,
-- it finds the maximum sum of any contiguous list, where the empty
-- list has sum 0.
-----------------------------------------------------------------------------

{-# LANGUAGE CPP                 #-}
{-# LANGUAGE DataKinds           #-}
{-# LANGUAGE OverloadedLists     #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications    #-}

{-# OPTIONS_GHC -Wall -Werror #-}

module Documentation.SBV.Examples.TP.Kadane where

import Prelude hiding (length, maximum, null, head, tail, (++))

import Data.SBV
import Data.SBV.List
import Data.SBV.TP

#ifdef DOCTEST
-- $setup
-- >>> import Data.SBV
-- >>> import Data.SBV.TP
-- >>> :set -XOverloadedLists
#endif

-- * Problem specification

-- | The maximum segment sum problem: Find the maximum sum of any contiguous
-- subarray. We include the empty subarray (with sum 0) as a valid segment.
-- This is the obvious definition: Empty list maps to 0. Otherwise, we take the
-- value of the segment starting at the current position, and take the maximum
-- of that value with the recursive result of the tail. This is obviously
-- correct, but has the runtime of O(n^2).
--
-- We have:
--
-- >>> mss [1, -2, 3, 4, -1, 2]  -- the segment: [3, 4, -1, 2]
-- 8 :: SInteger
-- >>> mss [-2, -3, -1]          -- empty segment
-- 0 :: SInteger
-- >>> mss [1, 2, 3]             -- the whole list
-- 6 :: SInteger
mss :: SList Integer -> SInteger
mss :: SList Integer -> SInteger
mss = String -> (SList Integer -> SInteger) -> SList Integer -> SInteger
forall a.
(SMTDefinable a, Typeable a, Lambda Symbolic a) =>
String -> a -> a
smtFunction String
"mss" ((SList Integer -> SInteger) -> SList Integer -> SInteger)
-> (SList Integer -> SInteger) -> SList Integer -> SInteger
forall a b. (a -> b) -> a -> b
$ \SList Integer
xs -> SBool -> SInteger -> SInteger -> SInteger
forall a. Mergeable a => SBool -> a -> a -> a
ite (SList Integer -> SBool
forall a. SymVal a => SList a -> SBool
null SList Integer
xs) SInteger
0 (SList Integer -> SInteger
mssBegin SList Integer
xs SInteger -> SInteger -> SInteger
forall a. OrdSymbolic a => a -> a -> a
`smax` SList Integer -> SInteger
mss (SList Integer -> SList Integer
forall a. SymVal a => SList a -> SList a
tail SList Integer
xs))

-- | Maximum sum of segments starting at the beginning of the given list.
-- This is 0 if the empty segment is best, or positive if a non-empty prefix exists.
--
-- We have:
--
-- >>> mssBegin [1, -2, 3, 4, -1, 2]  -- the segment: [1, -2, 3, 4, -1, 2]
-- 7 :: SInteger
-- >>> mssBegin [-2, -3, -1]          -- empty segment
-- 0 :: SInteger
-- >>> mssBegin [1, 2, 3]             -- the whole list
-- 6 :: SInteger
mssBegin :: SList Integer -> SInteger
mssBegin :: SList Integer -> SInteger
mssBegin = String -> (SList Integer -> SInteger) -> SList Integer -> SInteger
forall a.
(SMTDefinable a, Typeable a, Lambda Symbolic a) =>
String -> a -> a
smtFunction String
"mssBegin" ((SList Integer -> SInteger) -> SList Integer -> SInteger)
-> (SList Integer -> SInteger) -> SList Integer -> SInteger
forall a b. (a -> b) -> a -> b
$ \SList Integer
xs -> SBool -> SInteger -> SInteger -> SInteger
forall a. Mergeable a => SBool -> a -> a -> a
ite (SList Integer -> SBool
forall a. SymVal a => SList a -> SBool
null SList Integer
xs) SInteger
0
                                             (SInteger -> SInteger) -> SInteger -> SInteger
forall a b. (a -> b) -> a -> b
$ let (SInteger
h, SList Integer
t) = SList Integer -> (SInteger, SList Integer)
forall a. SymVal a => SList a -> (SBV a, SList a)
uncons SList Integer
xs
                                               in SInteger
0 SInteger -> SInteger -> SInteger
forall a. OrdSymbolic a => a -> a -> a
`smax` (SInteger
h SInteger -> SInteger -> SInteger
forall a. OrdSymbolic a => a -> a -> a
`smax` (SInteger
h SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SList Integer -> SInteger
mssBegin SList Integer
t))

-- * Kadane's algorithm implementation

-- | Kadane algorithm: We call the helper with the values of maximum value ending
-- at the beginning and the list, and recurse.
--
-- >>> kadane [1, -2, 3, 4, -1, 2]  -- the segment: [3, 4, -1, 2]
-- 8 :: SInteger
-- >>> kadane [-2, -3, -1]          -- empty segment
-- 0 :: SInteger
-- >>> kadane [1, 2, 3]             -- the whole list
-- 6 :: SInteger
kadane :: SList Integer -> SInteger
kadane :: SList Integer -> SInteger
kadane SList Integer
xs = SList Integer -> SInteger -> SInteger -> SInteger
kadaneHelper SList Integer
xs SInteger
0 SInteger
0

-- | Helper for Kadane's algorithm. Along with the list, we keep track of the maximum-value
-- ending at the beginning of the list argument, and the maximum value sofar.
kadaneHelper :: SList Integer -> SInteger -> SInteger -> SInteger
kadaneHelper :: SList Integer -> SInteger -> SInteger -> SInteger
kadaneHelper = String
-> (SList Integer -> SInteger -> SInteger -> SInteger)
-> SList Integer
-> SInteger
-> SInteger
-> SInteger
forall a.
(SMTDefinable a, Typeable a, Lambda Symbolic a) =>
String -> a -> a
smtFunction String
"kadaneHelper" ((SList Integer -> SInteger -> SInteger -> SInteger)
 -> SList Integer -> SInteger -> SInteger -> SInteger)
-> (SList Integer -> SInteger -> SInteger -> SInteger)
-> SList Integer
-> SInteger
-> SInteger
-> SInteger
forall a b. (a -> b) -> a -> b
$ \SList Integer
xs SInteger
maxEndingHere SInteger
maxSoFar ->
                    SBool -> SInteger -> SInteger -> SInteger
forall a. Mergeable a => SBool -> a -> a -> a
ite (SList Integer -> SBool
forall a. SymVal a => SList a -> SBool
null SList Integer
xs)
                        SInteger
maxSoFar   -- end of the list, take the max-value calculated
                      (SInteger -> SInteger) -> SInteger -> SInteger
forall a b. (a -> b) -> a -> b
$ let (SInteger
h, SList Integer
t)           = SList Integer -> (SInteger, SList Integer)
forall a. SymVal a => SList a -> (SBV a, SList a)
uncons SList Integer
xs
                            newMaxEndingHere :: SInteger
newMaxEndingHere = SInteger
0 SInteger -> SInteger -> SInteger
forall a. OrdSymbolic a => a -> a -> a
`smax` (SInteger
h SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SInteger
maxEndingHere)     -- We can add head to the so far, or restart
                            newMaxSofar :: SInteger
newMaxSofar      = SInteger
maxSoFar SInteger -> SInteger -> SInteger
forall a. OrdSymbolic a => a -> a -> a
`smax` SInteger
newMaxEndingHere -- Maximum of result so far, and the new
                        in SList Integer -> SInteger -> SInteger -> SInteger
kadaneHelper SList Integer
t SInteger
newMaxEndingHere SInteger
newMaxSofar

-- * Correctness proof

-- | The key insight is that we need a generalized invariant that characterizes
-- @kadaneHelper@ for arbitrary accumulator values, not just the initial @(0, 0)@.
--
-- The invariant states: for @kadaneHelper xs meh msf@ where:
--
--   * @meh@ (max-ending-here) is the maximum sum of a segment ending at the boundary
--   * @msf@ (max-so-far) is the best segment sum seen in the already-processed prefix
--   * Preconditions: @meh >= 0@ and @msf >= meh@
--
-- @
--   kadaneHelper xs meh msf == msf `smax` mss xs `smax` (meh + mssBegin xs)
-- @
--
-- This captures that the result is the maximum of:
--
--   * @msf@ - the best segment entirely in the already-processed prefix
--   * @mss xs@ - the best segment entirely in the remaining suffix
--   * @meh + mssBegin xs@ - the best segment crossing the boundary
--
-- >>> runTPWith cvc5 correctness
-- Inductive lemma: kadaneHelperInvariant
--   Step: Base                            Q.E.D.
--   Step: 1                               Q.E.D.
--   Step: 2                               Q.E.D.
--   Result:                               Q.E.D.
-- Lemma: correctness
--   Step: 1                               Q.E.D.
--   Step: 2                               Q.E.D.
--   Step: 3                               Q.E.D.
--   Step: 4                               Q.E.D.
--   Result:                               Q.E.D.
-- [Proven] correctness :: Ɐxs ∷ [Integer] → Bool
correctness :: TP (Proof (Forall "xs" [Integer] -> SBool))
correctness :: TP (Proof (Forall "xs" [Integer] -> SBool))
correctness = do

  -- First, prove the generalized invariant. This is the heart of the proof: it relates kadaneHelper with arbitrary
  -- accumulators to the specification functions mss and mssBegin.
  Proof
  (Forall "xs" [Integer]
   -> Forall "meh" Integer -> Forall "msf" Integer -> SBool)
invariant <- String
-> (Forall "xs" [Integer]
    -> Forall "meh" Integer -> Forall "msf" Integer -> SBool)
-> (Proof
      (IHType
         (Forall "xs" [Integer]
          -> Forall "meh" Integer -> Forall "msf" Integer -> SBool))
    -> IHArg
         (Forall "xs" [Integer]
          -> Forall "meh" Integer -> Forall "msf" Integer -> SBool)
    -> IStepArgs
         (Forall "xs" [Integer]
          -> Forall "meh" Integer -> Forall "msf" Integer -> SBool)
         Integer)
-> TP
     (Proof
        (Forall "xs" [Integer]
         -> Forall "meh" Integer -> Forall "msf" Integer -> SBool))
forall t.
(Proposition
   (Forall "xs" [Integer]
    -> Forall "meh" Integer -> Forall "msf" Integer -> SBool),
 SymVal t, EqSymbolic (SBV t)) =>
String
-> (Forall "xs" [Integer]
    -> Forall "meh" Integer -> Forall "msf" Integer -> SBool)
-> (Proof
      (IHType
         (Forall "xs" [Integer]
          -> Forall "meh" Integer -> Forall "msf" Integer -> SBool))
    -> IHArg
         (Forall "xs" [Integer]
          -> Forall "meh" Integer -> Forall "msf" Integer -> SBool)
    -> IStepArgs
         (Forall "xs" [Integer]
          -> Forall "meh" Integer -> Forall "msf" Integer -> SBool)
         t)
-> TP
     (Proof
        (Forall "xs" [Integer]
         -> Forall "meh" Integer -> Forall "msf" Integer -> SBool))
forall a t.
(Inductive a, Proposition a, SymVal t, EqSymbolic (SBV t)) =>
String
-> a
-> (Proof (IHType a) -> IHArg a -> IStepArgs a t)
-> TP (Proof a)
induct String
"kadaneHelperInvariant"
      (\(Forall SList Integer
xs) (Forall SInteger
meh) (Forall SInteger
msf) ->
         (SInteger
meh SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.>= SInteger
0 SBool -> SBool -> SBool
.&& SInteger
msf SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.>= SInteger
meh) SBool -> SBool -> SBool
.=> SList Integer -> SInteger -> SInteger -> SInteger
kadaneHelper SList Integer
xs SInteger
meh SInteger
msf SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== (SInteger
msf SInteger -> SInteger -> SInteger
forall a. OrdSymbolic a => a -> a -> a
`smax` SList Integer -> SInteger
mss SList Integer
xs SInteger -> SInteger -> SInteger
forall a. OrdSymbolic a => a -> a -> a
`smax` (SInteger
meh SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SList Integer -> SInteger
mssBegin SList Integer
xs))) ((Proof
    (IHType
       (Forall "xs" [Integer]
        -> Forall "meh" Integer -> Forall "msf" Integer -> SBool))
  -> IHArg
       (Forall "xs" [Integer]
        -> Forall "meh" Integer -> Forall "msf" Integer -> SBool)
  -> IStepArgs
       (Forall "xs" [Integer]
        -> Forall "meh" Integer -> Forall "msf" Integer -> SBool)
       Integer)
 -> TP
      (Proof
         (Forall "xs" [Integer]
          -> Forall "meh" Integer -> Forall "msf" Integer -> SBool)))
-> (Proof
      (IHType
         (Forall "xs" [Integer]
          -> Forall "meh" Integer -> Forall "msf" Integer -> SBool))
    -> IHArg
         (Forall "xs" [Integer]
          -> Forall "meh" Integer -> Forall "msf" Integer -> SBool)
    -> IStepArgs
         (Forall "xs" [Integer]
          -> Forall "meh" Integer -> Forall "msf" Integer -> SBool)
         Integer)
-> TP
     (Proof
        (Forall "xs" [Integer]
         -> Forall "meh" Integer -> Forall "msf" Integer -> SBool))
forall a b. (a -> b) -> a -> b
$
      \Proof
  (IHType
     (Forall "xs" [Integer]
      -> Forall "meh" Integer -> Forall "msf" Integer -> SBool))
ih (SInteger
a, SList Integer
as) SInteger
meh SInteger
msf ->
         [SInteger
meh SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.>= SInteger
0, SInteger
msf SInteger -> SInteger -> SBool
forall a. OrdSymbolic a => a -> a -> SBool
.>= SInteger
meh] [SBool] -> TPProofRaw SInteger -> (SBool, TPProofRaw SInteger)
forall a. [SBool] -> TPProofRaw a -> (SBool, TPProofRaw a)
|- let newMeh :: SInteger
newMeh = SInteger
0 SInteger -> SInteger -> SInteger
forall a. OrdSymbolic a => a -> a -> a
`smax` (SInteger
a SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SInteger
meh)
                                         newMsf :: SInteger
newMsf = SInteger
msf SInteger -> SInteger -> SInteger
forall a. OrdSymbolic a => a -> a -> a
`smax` SInteger
newMeh
                                     in SList Integer -> SInteger -> SInteger -> SInteger
kadaneHelper (SInteger
a SInteger -> SList Integer -> SList Integer
forall a. SymVal a => SBV a -> SList a -> SList a
.: SList Integer
as) SInteger
meh SInteger
msf
                                     SInteger -> ChainsTo SInteger -> ChainsTo SInteger
forall a. ChainStep a (ChainsTo a) => a -> ChainsTo a -> ChainsTo a
=: SList Integer -> SInteger -> SInteger -> SInteger
kadaneHelper SList Integer
as SInteger
newMeh SInteger
newMsf
                                     SInteger -> Proof Bool -> Hinted SInteger
forall a b. HintsTo a b => a -> b -> Hinted a
?? Proof
  (IHType
     (Forall "xs" [Integer]
      -> Forall "meh" Integer -> Forall "msf" Integer -> SBool))
Proof (Forall "meh" Integer -> Forall "msf" Integer -> SBool)
ih Proof (Forall "meh" Integer -> Forall "msf" Integer -> SBool)
-> IArgs (Forall "meh" Integer -> Forall "msf" Integer -> SBool)
-> Proof Bool
forall a. Instantiatable a => Proof a -> IArgs a -> Proof Bool
`at` (forall (nm :: Symbol) a. SBV a -> Inst nm a
Inst @"meh" SInteger
newMeh, forall (nm :: Symbol) a. SBV a -> Inst nm a
Inst @"msf" SInteger
newMsf)
                                     TPProofRaw SInteger
-> ChainsTo (TPProofRaw SInteger) -> ChainsTo (TPProofRaw SInteger)
forall a. ChainStep a (ChainsTo a) => a -> ChainsTo a -> ChainsTo a
=: SInteger
newMsf SInteger -> SInteger -> SInteger
forall a. OrdSymbolic a => a -> a -> a
`smax` SList Integer -> SInteger
mss SList Integer
as SInteger -> SInteger -> SInteger
forall a. OrdSymbolic a => a -> a -> a
`smax` (SInteger
newMeh SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SList Integer -> SInteger
mssBegin SList Integer
as)
                                     SInteger -> ChainsTo SInteger -> ChainsTo SInteger
forall a. ChainStep a (ChainsTo a) => a -> ChainsTo a -> ChainsTo a
=: ChainsTo SInteger
TPProofRaw SInteger
forall a. TPProofRaw a
qed

  -- Now the main theorem follows easily: kadane xs = kadaneHelper xs 0 0
  -- and with meh=0, msf=0, the invariant gives us:
  --   kadaneHelper xs 0 0 = 0 `smax` mss xs `smax` (0 + mssBegin xs)
  --                       = mss xs `smax` mssBegin xs
  --                       = mss xs  (since mss xs >= mssBegin xs by definition)
  String
-> (Forall "xs" [Integer] -> SBool)
-> StepArgs (Forall "xs" [Integer] -> SBool) Integer
-> TP (Proof (Forall "xs" [Integer] -> SBool))
forall t.
(Proposition (Forall "xs" [Integer] -> SBool), SymVal t,
 EqSymbolic (SBV t)) =>
String
-> (Forall "xs" [Integer] -> SBool)
-> StepArgs (Forall "xs" [Integer] -> SBool) t
-> TP (Proof (Forall "xs" [Integer] -> SBool))
forall a t.
(Calc a, Proposition a, SymVal t, EqSymbolic (SBV t)) =>
String -> a -> StepArgs a t -> TP (Proof a)
calc String
"correctness"
       (\(Forall SList Integer
xs) -> SList Integer -> SInteger
mss SList Integer
xs SInteger -> SInteger -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SList Integer -> SInteger
kadane SList Integer
xs) (StepArgs (Forall "xs" [Integer] -> SBool) Integer
 -> TP (Proof (Forall "xs" [Integer] -> SBool)))
-> StepArgs (Forall "xs" [Integer] -> SBool) Integer
-> TP (Proof (Forall "xs" [Integer] -> SBool))
forall a b. (a -> b) -> a -> b
$
       \SList Integer
xs -> [] [SBool] -> TPProofRaw SInteger -> (SBool, TPProofRaw SInteger)
forall a. [SBool] -> TPProofRaw a -> (SBool, TPProofRaw a)
|- SList Integer -> SInteger
kadane SList Integer
xs
                 SInteger -> ChainsTo SInteger -> ChainsTo SInteger
forall a. ChainStep a (ChainsTo a) => a -> ChainsTo a -> ChainsTo a
=: SList Integer -> SInteger -> SInteger -> SInteger
kadaneHelper SList Integer
xs SInteger
0 SInteger
0
                 SInteger -> Proof Bool -> Hinted SInteger
forall a b. HintsTo a b => a -> b -> Hinted a
?? Proof
  (Forall "xs" [Integer]
   -> Forall "meh" Integer -> Forall "msf" Integer -> SBool)
invariant Proof
  (Forall "xs" [Integer]
   -> Forall "meh" Integer -> Forall "msf" Integer -> SBool)
-> IArgs
     (Forall "xs" [Integer]
      -> Forall "meh" Integer -> Forall "msf" Integer -> SBool)
-> Proof Bool
forall a. Instantiatable a => Proof a -> IArgs a -> Proof Bool
`at` (forall (nm :: Symbol) a. SBV a -> Inst nm a
Inst @"xs" SList Integer
xs, forall (nm :: Symbol) a. SBV a -> Inst nm a
Inst @"meh" (SInteger
0 :: SInteger), forall (nm :: Symbol) a. SBV a -> Inst nm a
Inst @"msf" (SInteger
0 :: SInteger))
                 TPProofRaw SInteger
-> ChainsTo (TPProofRaw SInteger) -> ChainsTo (TPProofRaw SInteger)
forall a. ChainStep a (ChainsTo a) => a -> ChainsTo a -> ChainsTo a
=: SInteger
0 SInteger -> SInteger -> SInteger
forall a. OrdSymbolic a => a -> a -> a
`smax` SList Integer -> SInteger
mss SList Integer
xs SInteger -> SInteger -> SInteger
forall a. OrdSymbolic a => a -> a -> a
`smax` (SInteger
0 SInteger -> SInteger -> SInteger
forall a. Num a => a -> a -> a
+ SList Integer -> SInteger
mssBegin SList Integer
xs)
                 SInteger -> ChainsTo SInteger -> ChainsTo SInteger
forall a. ChainStep a (ChainsTo a) => a -> ChainsTo a -> ChainsTo a
=: SList Integer -> SInteger
mss SList Integer
xs SInteger -> SInteger -> SInteger
forall a. OrdSymbolic a => a -> a -> a
`smax` SList Integer -> SInteger
mssBegin SList Integer
xs
                 -- mss xs >= mssBegin xs by definition (mss considers all segments)
                 SInteger -> ChainsTo SInteger -> ChainsTo SInteger
forall a. ChainStep a (ChainsTo a) => a -> ChainsTo a -> ChainsTo a
=: SList Integer -> SInteger
mss SList Integer
xs
                 SInteger -> ChainsTo SInteger -> ChainsTo SInteger
forall a. ChainStep a (ChainsTo a) => a -> ChainsTo a -> ChainsTo a
=: ChainsTo SInteger
TPProofRaw SInteger
forall a. TPProofRaw a
qed