-----------------------------------------------------------------------------
-- |
-- Module    : Documentation.SBV.Examples.TP.UpDown
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- Proves @reverse (down n) = up n@.
--
-- This problem is motivated by an ACL2 midterm exam question, from Fall 2011.
-- See: <https://www.cs.utexas.edu/~moore/classes/cs389r/midterm-answers.lisp>.
-----------------------------------------------------------------------------

{-# LANGUAGE CPP              #-}
{-# LANGUAGE DataKinds        #-}
{-# LANGUAGE OverloadedLists  #-}
{-# LANGUAGE QuasiQuotes      #-}
{-# LANGUAGE TypeAbstractions #-}
{-# LANGUAGE TypeApplications #-}

{-# OPTIONS_GHC -Wall -Werror #-}

module Documentation.SBV.Examples.TP.UpDown where

import Prelude hiding (reverse, (++))

import Data.SBV
import Data.SBV.TP
import Data.SBV.List

import Documentation.SBV.Examples.TP.Lists
import Documentation.SBV.Examples.TP.Peano

#ifdef DOCTEST
-- $setup
-- >>> import Data.SBV.TP
#endif

-- | Construct a list of size @n@, containing numbers @1@ to @n@.
--
-- >>> up 0
-- [] :: [SInteger]
-- >>> up 5
-- [1,2,3,4,5] :: [SInteger]
up :: SNat -> SList Integer
up :: SNat -> SList Integer
up SNat
n = SNat -> SList Integer -> SList Integer
upAcc SNat
n []

-- | Keep consing the first argument on to the accumulator, until we hit zero. After that, return the second argument.
-- Normally, we'd define this as a local function, but the definition needs to be visible for the proofs.
upAcc :: SNat -> SList Integer -> SList Integer
upAcc :: SNat -> SList Integer -> SList Integer
upAcc = String
-> (SNat -> SList Integer -> SList Integer)
-> SNat
-> SList Integer
-> SList Integer
forall a.
(SMTDefinable a, Typeable a, Lambda Symbolic a) =>
String -> a -> a
smtFunction String
"up" ((SNat -> SList Integer -> SList Integer)
 -> SNat -> SList Integer -> SList Integer)
-> (SNat -> SList Integer -> SList Integer)
-> SNat
-> SList Integer
-> SList Integer
forall a b. (a -> b) -> a -> b
$ \SNat
n SList Integer
lst -> [sCase|Nat n of
                                       Zero   -> lst
                                       Succ p -> upAcc p (n2i n .: lst)
                                     |]

-- | Construct a list of size @n@, containing numbers @n-1@ down to @0@.
--
-- >>> down 0
-- [] :: [SInteger]
-- >>> down 5
-- [5,4,3,2,1] :: [SInteger]
down :: SNat -> SList Integer
down :: SNat -> SList Integer
down = String -> (SNat -> SList Integer) -> SNat -> SList Integer
forall a.
(SMTDefinable a, Typeable a, Lambda Symbolic a) =>
String -> a -> a
smtFunction String
"down" ((SNat -> SList Integer) -> SNat -> SList Integer)
-> (SNat -> SList Integer) -> SNat -> SList Integer
forall a b. (a -> b) -> a -> b
$ \SNat
n -> [sCase|Nat n of
                                     Zero   -> []
                                     Succ p -> n2i n .: down p
                                  |]

-- | Prove that @reverse (down n)@ is the same as @up n@
--
-- >>> runTP upDown
-- Lemma: n2iNonNeg                        Q.E.D.
-- Lemma: revCons                          Q.E.D.
-- Inductive lemma (strong): upDownGen
--   Step: Measure is non-negative         Q.E.D.
--   Step: 1 (2 way case split)
--     Step: 1.1                           Q.E.D.
--     Step: 1.2.1                         Q.E.D.
--     Step: 1.2.2                         Q.E.D.
--     Step: 1.2.3                         Q.E.D.
--     Step: 1.2.4                         Q.E.D.
--     Step: 1.Completeness                Q.E.D.
--   Result:                               Q.E.D.
-- Lemma: upDown                           Q.E.D.
-- [Proven] upDown :: Ɐn ∷ Nat → Bool
upDown :: TP (Proof (Forall "n" Nat -> SBool))
upDown :: TP (Proof (Forall "n" Nat -> SBool))
upDown = do
   Proof (Forall "n" Nat -> SBool)
n2inn <- String
-> TP (Proof (Forall "n" Nat -> SBool))
-> TP (Proof (Forall "n" Nat -> SBool))
forall a. String -> TP (Proof a) -> TP (Proof a)
recall String
"n2iNonNeg" TP (Proof (Forall "n" Nat -> SBool))
n2iNonNeg
   Proof (Forall "x" Integer -> Forall "xs" [Integer] -> SBool)
rc    <- String
-> TP
     (Proof (Forall "x" Integer -> Forall "xs" [Integer] -> SBool))
-> TP
     (Proof (Forall "x" Integer -> Forall "xs" [Integer] -> SBool))
forall a. String -> TP (Proof a) -> TP (Proof a)
recall String
"revCons"   (forall a.
SymVal a =>
TP (Proof (Forall "x" a -> Forall "xs" [a] -> SBool))
revCons @Integer)

   -- We first generalize the theorem, to make it inductive
   Proof (Forall "n" Nat -> Forall "xs" [Integer] -> SBool)
upDownGen <- String
-> (Forall "n" Nat -> Forall "xs" [Integer] -> SBool)
-> (MeasureArgs
      (Forall "n" Nat -> Forall "xs" [Integer] -> SBool) Integer,
    [ProofObj])
-> (Proof (Forall "n" Nat -> Forall "xs" [Integer] -> SBool)
    -> StepArgs
         (Forall "n" Nat -> Forall "xs" [Integer] -> SBool) [Integer])
-> TP (Proof (Forall "n" Nat -> Forall "xs" [Integer] -> SBool))
forall a m t.
(SInductive a, Proposition a, Zero m, SymVal t,
 EqSymbolic (SBV t)) =>
String
-> a
-> (MeasureArgs a m, [ProofObj])
-> (Proof a -> StepArgs a t)
-> TP (Proof a)
forall m t.
(Proposition (Forall "n" Nat -> Forall "xs" [Integer] -> SBool),
 Zero m, SymVal t, EqSymbolic (SBV t)) =>
String
-> (Forall "n" Nat -> Forall "xs" [Integer] -> SBool)
-> (MeasureArgs
      (Forall "n" Nat -> Forall "xs" [Integer] -> SBool) m,
    [ProofObj])
-> (Proof (Forall "n" Nat -> Forall "xs" [Integer] -> SBool)
    -> StepArgs (Forall "n" Nat -> Forall "xs" [Integer] -> SBool) t)
-> TP (Proof (Forall "n" Nat -> Forall "xs" [Integer] -> SBool))
sInduct String
"upDownGen"
           (\(Forall @"n" SNat
n) (Forall @"xs" SList Integer
xs) -> SList Integer -> SList Integer
forall a. SymVal a => SList a -> SList a
reverse (SNat -> SList Integer
down SNat
n) SList Integer -> SList Integer -> SList Integer
forall a. SymVal a => SList a -> SList a -> SList a
++ SList Integer
xs SList Integer -> SList Integer -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SNat -> SList Integer -> SList Integer
upAcc SNat
n SList Integer
xs)
           (\SNat
n SList Integer
_ -> SNat -> SBV Integer
n2i SNat
n, [Proof (Forall "n" Nat -> SBool) -> ProofObj
forall a. Proof a -> ProofObj
proofOf Proof (Forall "n" Nat -> SBool)
n2inn]) ((Proof (Forall "n" Nat -> Forall "xs" [Integer] -> SBool)
  -> StepArgs
       (Forall "n" Nat -> Forall "xs" [Integer] -> SBool) [Integer])
 -> TP (Proof (Forall "n" Nat -> Forall "xs" [Integer] -> SBool)))
-> (Proof (Forall "n" Nat -> Forall "xs" [Integer] -> SBool)
    -> StepArgs
         (Forall "n" Nat -> Forall "xs" [Integer] -> SBool) [Integer])
-> TP (Proof (Forall "n" Nat -> Forall "xs" [Integer] -> SBool))
forall a b. (a -> b) -> a -> b
$
           \Proof (Forall "n" Nat -> Forall "xs" [Integer] -> SBool)
ih SNat
n SList Integer
xs -> [] [SBool]
-> TPProofRaw (SList Integer)
-> (SBool, TPProofRaw (SList Integer))
forall a. [SBool] -> TPProofRaw a -> (SBool, TPProofRaw a)
|- [(SBool, TPProofRaw (SList Integer))] -> TPProofRaw (SList Integer)
forall a. [(SBool, TPProofRaw a)] -> TPProofRaw a
cases [ SNat -> SBool
isZero SNat
n SBool
-> TPProofRaw (SList Integer)
-> (SBool, TPProofRaw (SList Integer))
forall a. SBool -> TPProofRaw a -> (SBool, TPProofRaw a)
==> TPProofRaw (SList Integer)
forall a. Trivial a => a
trivial
                                   , SNat -> SBool
isSucc SNat
n SBool
-> TPProofRaw (SList Integer)
-> (SBool, TPProofRaw (SList Integer))
forall a. SBool -> TPProofRaw a -> (SBool, TPProofRaw a)
==> let p :: SNat
p = SNat -> SNat
getSucc_1 SNat
n
                                               in SList Integer -> SList Integer
forall a. SymVal a => SList a -> SList a
reverse (SNat -> SList Integer
down (SNat -> SNat
sSucc SNat
p)) SList Integer -> SList Integer -> SList Integer
forall a. SymVal a => SList a -> SList a -> SList a
++ SList Integer
xs
                                               SList Integer
-> ChainsTo (SList Integer) -> ChainsTo (SList Integer)
forall a. ChainStep a (ChainsTo a) => a -> ChainsTo a -> ChainsTo a
=: SList Integer -> SList Integer
forall a. SymVal a => SList a -> SList a
reverse (SNat -> SBV Integer
n2i SNat
n SBV Integer -> SList Integer -> SList Integer
forall a. SymVal a => SBV a -> SList a -> SList a
.: SNat -> SList Integer
down SNat
p) SList Integer -> SList Integer -> SList Integer
forall a. SymVal a => SList a -> SList a -> SList a
++ SList Integer
xs
                                               SList Integer
-> Proof (Forall "x" Integer -> Forall "xs" [Integer] -> SBool)
-> Hinted (SList Integer)
forall a b. HintsTo a b => a -> b -> Hinted a
?? Proof (Forall "x" Integer -> Forall "xs" [Integer] -> SBool)
rc
                                               TPProofRaw (SList Integer)
-> ChainsTo (TPProofRaw (SList Integer))
-> ChainsTo (TPProofRaw (SList Integer))
forall a. ChainStep a (ChainsTo a) => a -> ChainsTo a -> ChainsTo a
=: SList Integer -> SList Integer
forall a. SymVal a => SList a -> SList a
reverse (SNat -> SList Integer
down SNat
p) SList Integer -> SList Integer -> SList Integer
forall a. SymVal a => SList a -> SList a -> SList a
++ (SNat -> SBV Integer
n2i SNat
n SBV Integer -> SList Integer -> SList Integer
forall a. SymVal a => SBV a -> SList a -> SList a
.: SList Integer
xs)
                                               SList Integer -> Proof Bool -> Hinted (SList Integer)
forall a b. HintsTo a b => a -> b -> Hinted a
?? Proof (Forall "n" Nat -> Forall "xs" [Integer] -> SBool)
ih Proof (Forall "n" Nat -> Forall "xs" [Integer] -> SBool)
-> IArgs (Forall "n" Nat -> Forall "xs" [Integer] -> SBool)
-> Proof Bool
forall a. Instantiatable a => Proof a -> IArgs a -> Proof Bool
`at` (forall (nm :: Symbol) a. SBV a -> Inst nm a
Inst @"n" SNat
p, forall (nm :: Symbol) a. SBV a -> Inst nm a
Inst @"xs" (SNat -> SBV Integer
n2i SNat
n SBV Integer -> SList Integer -> SList Integer
forall a. SymVal a => SBV a -> SList a -> SList a
.: SList Integer
xs))
                                               TPProofRaw (SList Integer)
-> ChainsTo (TPProofRaw (SList Integer))
-> ChainsTo (TPProofRaw (SList Integer))
forall a. ChainStep a (ChainsTo a) => a -> ChainsTo a -> ChainsTo a
=: SNat -> SList Integer -> SList Integer
upAcc SNat
p (SNat -> SBV Integer
n2i SNat
n SBV Integer -> SList Integer -> SList Integer
forall a. SymVal a => SBV a -> SList a -> SList a
.: SList Integer
xs)
                                               SList Integer
-> ChainsTo (SList Integer) -> ChainsTo (SList Integer)
forall a. ChainStep a (ChainsTo a) => a -> ChainsTo a -> ChainsTo a
=: SNat -> SList Integer -> SList Integer
upAcc SNat
n SList Integer
xs
                                               SList Integer
-> ChainsTo (SList Integer) -> ChainsTo (SList Integer)
forall a. ChainStep a (ChainsTo a) => a -> ChainsTo a -> ChainsTo a
=: ChainsTo (SList Integer)
TPProofRaw (SList Integer)
forall a. TPProofRaw a
qed
                                   ]

   -- The theorem we want to prove follows by instantiating the list at empty, and
   -- the SMT solver can figure it out by itself
   String
-> (Forall "n" Nat -> SBool)
-> [ProofObj]
-> TP (Proof (Forall "n" Nat -> SBool))
forall a.
Proposition a =>
String -> a -> [ProofObj] -> TP (Proof a)
lemma String
"upDown"
         (\(Forall SNat
n) -> SList Integer -> SList Integer
forall a. SymVal a => SList a -> SList a
reverse (SNat -> SList Integer
down SNat
n) SList Integer -> SList Integer -> SBool
forall a. EqSymbolic a => a -> a -> SBool
.== SNat -> SList Integer
up SNat
n)
         [Proof (Forall "n" Nat -> Forall "xs" [Integer] -> SBool)
-> ProofObj
forall a. Proof a -> ProofObj
proofOf Proof (Forall "n" Nat -> Forall "xs" [Integer] -> SBool)
upDownGen]