-- |
-- Module      : Data.Fin
-- Description : Bounded natural numbers
--
-- This file re-exports definitions from [fin](https://hackage.haskell.org/package/fin)'s
-- [Data.Fin](https://hackage.haskell.org/package/fin-0.3.2/docs/Data-Fin.html), while adding a few more
-- that are relevant to this context. Like [Data.Fin](https://hackage.haskell.org/package/fin-0.3.2/docs/Data-Fin.html),
-- it is meant to be used qualified.
--
-- @
-- import 'Fin' ('Fin' (..))
-- import qualified 'Fin' as 'Fin'
-- @
{-# LANGUAGE PackageImports #-}
module Data.Fin(
  Nat(..), SNat(..),
  Fin(..),
  toNat, fromNat, toInteger,
  mirror,
  absurd,
  universe,
  f0,f1,f2,f3,
  invert,
  shiftN,
  shift1,
  weakenFin,
  weakenFinRight,
  weaken1Fin,
  weaken1FinRight,
  strengthen1Fin,
  strengthenRecFin
 ) where

import Data.Nat
import Data.SNat
import "fin" Data.Fin hiding (cata)
import Data.Proxy (Proxy (..))
-- for efficient rescoping
import Unsafe.Coerce (unsafeCoerce)

-------------------------------------------------------------------------------
-- toInt
-------------------------------------------------------------------------------

-- | The `toInteger` instance for Fin has an unnecessary
-- type class constraint (NatI n) for Fin. So we
-- also include this class for simple conversion.
instance ToInt (Fin n) where
  toInt :: Fin n -> Int
  toInt :: Fin n -> Int
toInt Fin n
FZ = Int
0
  toInt (FS Fin n1
x) = Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Fin n1 -> Int
forall a. ToInt a => a -> Int
toInt Fin n1
x

-- >>> [minBound .. maxBound] :: [Fin N3]
-- [0,1,2]

-- | List all numbers up to some size
-- >>> universe :: [Fin N3]
-- [0,1,2]

-- | Convert an "index" Fin to a "level" Fin and vice versa.
invert :: forall n. (SNatI n) => Fin n -> Fin n
invert :: forall (n :: Nat). SNatI n => Fin n -> Fin n
invert Fin n
f = case forall (n :: Nat). SNatI n => SNat n
snat @n of
  SNat n
SZ -> case Fin n
f of {}
  SNat n
SS -> Fin n
forall a. Bounded a => a
maxBound Fin n -> Fin n -> Fin n
forall a. Num a => a -> a -> a
- Fin n
f

-------------------------------------------------------------------------------
-- * Shifting
-------------------------------------------------------------------------------

-- We use the term "Weakening" to mean: Adding a new binding to the front of
-- the typing context without changing existing indices.
-- In contrast, "Shifting" means: Adjusting the indices of free variables
-- within a term to reflect a new binding added to the end of the context.
--
-- Shifting functions add some specified amount to the given
-- `Fin` value, also incrementing its type.
--
-- Shifting is implemented in the Data.Fin libary using the `weakenRight`
-- function, which changes the value of a Fin and its type.
-- >>> :t weakenRight
-- weakenRight :: SNatI n => Proxy n -> Fin m -> Fin (Plus n m)
--
-- >>> weakenRight (Proxy :: Proxy N1) (f1 :: Fin N2) :: Fin N3
-- 2
--
-- In this module, we call the same operation `shiftN` and give
-- it a slightly more convenient interface.
-- >>> shiftN s1 (f1 :: Fin N2)
-- 2
--
-- | Increment by a fixed amount (on the left).
shiftN :: forall n m. SNat n -> Fin m -> Fin (n + m)
shiftN :: forall (n :: Nat) (m :: Nat). SNat n -> Fin m -> Fin (n + m)
shiftN SNat n
p Fin m
f = SNat n -> (SNatI n => Fin (n + m)) -> Fin (n + m)
forall (n :: Nat) r. SNat n -> (SNatI n => r) -> r
withSNat SNat n
p ((SNatI n => Fin (n + m)) -> Fin (n + m))
-> (SNatI n => Fin (n + m)) -> Fin (n + m)
forall a b. (a -> b) -> a -> b
$ Proxy n -> Fin m -> Fin (Plus n m)
forall (n :: Nat) (m :: Nat).
SNatI n =>
Proxy n -> Fin m -> Fin (Plus n m)
weakenRight (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n) Fin m
f

-- | Increment by one.
shift1 :: Fin m -> Fin (S m)
shift1 :: forall (m :: Nat). Fin m -> Fin ('S m)
shift1 = SNat N1 -> Fin m -> Fin (N1 + m)
forall (n :: Nat) (m :: Nat). SNat n -> Fin m -> Fin (n + m)
shiftN SNat N1
s1

-- We could also include a dual function, which increments on the right
-- but we haven't needed that operation anywhere.

-------------------------------------------------------------------------------
-- * Weakening
-------------------------------------------------------------------------------

-- | Weaken the bound of a 'Fin' by an arbitrary amount, without
-- changing its index.

-- | Weakenening changes the bound of a nat-indexed type without changing
-- its value.
-- These operations can either be defined for the n-ary case (as in Fin below)
-- or be defined in terms of a single-step operation.
-- However, as both of these operations are identity functions,
-- it is justified to use unsafeCoerce.
--
-- The corresponding function in the Data.Fin library is `weakenLeft`.
--
-- @
-- -- >>> :t weakenLeft
-- weakenLeft :: SNatI n => Proxy m -> Fin n -> Fin (Plus n m)
-- @
--
-- This function does not change the value, it only changes its type.
--
-- @
-- -- >>> weakenLeft (Proxy :: Proxy N1) (f1 :: Fin N2) :: Fin N3
-- 1
-- @
--
-- We could use the following definition:
--
-- @
-- weakenFin m f = withSNat m $ weakenLeft (Proxy :: Proxy m) f
-- @
--
-- But, by using an 'unsafeCoerce' implementation, we can avoid the
-- @'SNatI' n@ constraint in the type of this operation.
--
-- @
-- -- >>> weakenFin (Proxy :: Proxy N1) (f1 :: Fin N2) :: Fin N3
-- 1
-- @
weakenFin :: proxy m -> Fin n -> Fin (m + n)
weakenFin :: forall (proxy :: Nat -> *) (m :: Nat) (n :: Nat).
proxy m -> Fin n -> Fin (m + n)
weakenFin proxy m
_ Fin n
f = Fin n -> Fin (Plus m n)
forall a b. a -> b
unsafeCoerce Fin n
f

-- | Weaken the bound of a 'Fin' by 1.
weaken1Fin :: Fin n -> Fin (S n)
weaken1Fin :: forall (m :: Nat). Fin m -> Fin ('S m)
weaken1Fin = SNat N1 -> Fin n -> Fin (N1 + n)
forall (proxy :: Nat -> *) (m :: Nat) (n :: Nat).
proxy m -> Fin n -> Fin (m + n)
weakenFin SNat N1
s1

-- | Weaken the bound of of a 'Fin' by an arbitrary amount on the right.
-- This is also an identity function
-- >>> weakenFinRight (s1 :: SNat N1) (f1 :: Fin N2) :: Fin N3
-- 1
weakenFinRight :: proxy m -> Fin n -> Fin (n + m)
weakenFinRight :: forall (proxy :: Nat -> *) (m :: Nat) (n :: Nat).
proxy m -> Fin n -> Fin (n + m)
weakenFinRight proxy m
m Fin n
f = Fin n -> Fin (Plus n m)
forall a b. a -> b
unsafeCoerce Fin n
f

-- | Weaken the bound of a 'Fin' by 1.
weaken1FinRight :: Fin n -> Fin (n + N1)
weaken1FinRight :: forall (n :: Nat). Fin n -> Fin (n + N1)
weaken1FinRight = SNat N1 -> Fin n -> Fin (n + N1)
forall (proxy :: Nat -> *) (m :: Nat) (n :: Nat).
proxy m -> Fin n -> Fin (n + m)
weakenFinRight SNat N1
s1

-------------------------------------------------------------------------------
-- * Aliases
-------------------------------------------------------------------------------

-- Convenient names for fin values. These have polymorphic types so they
-- will work in any scope. (These are also called fin0, fin1, fin2, etc
-- in Data.Fin)

-- | 0.
f0 :: Fin (S n)
f0 :: forall (n :: Nat). Fin ('S n)
f0 = Fin ('S n)
forall (n :: Nat). Fin ('S n)
FZ

-- | 1.
f1 :: Fin (S (S n))
f1 :: forall (n :: Nat). Fin ('S ('S n))
f1 = Fin ('S n) -> Fin ('S ('S n))
forall (m :: Nat). Fin m -> Fin ('S m)
FS Fin ('S n)
forall (n :: Nat). Fin ('S n)
f0

-- | 2.
f2 :: Fin (S (S (S n)))
f2 :: forall (n :: Nat). Fin ('S ('S ('S n)))
f2 = Fin ('S ('S n)) -> Fin ('S ('S ('S n)))
forall (m :: Nat). Fin m -> Fin ('S m)
FS Fin ('S ('S n))
forall (n :: Nat). Fin ('S ('S n))
f1

-- | 3.
f3 :: Fin (S (S (S (S n))))
f3 :: forall (n :: Nat). Fin ('S ('S ('S ('S n))))
f3 = Fin ('S ('S ('S n))) -> Fin ('S ('S ('S ('S n))))
forall (m :: Nat). Fin m -> Fin ('S m)
FS Fin ('S ('S ('S n)))
forall (n :: Nat). Fin ('S ('S ('S n)))
f2

-- >>> f2
-- 2

-------------------------------------------------------------------------------
-- * Strengthening
-------------------------------------------------------------------------------

-- | With strengthening, we make sure that variable f0 is not used,
-- and we decrement all other indices by 1. This allows us to
-- also decrement the scope by one.
--- >>> strengthen1Fin (f0 :: Fin (S N3)) :: Maybe (Fin N3)
-- Nothing
-- >>> strengthen1Fin (f1 :: Fin (S N3)) :: Maybe (Fin N3)
-- Just 0
-- >>> strengthen1Fin (f2 :: Fin (S N3)) :: Maybe (Fin N3)
-- Just 1
strengthen1Fin :: forall n. SNatI n => Fin (S n) -> Maybe (Fin n)
strengthen1Fin :: forall (n :: Nat). SNatI n => Fin ('S n) -> Maybe (Fin n)
strengthen1Fin = SNat 'Z
-> SNat N1 -> Any n -> Fin ('Z + (N1 + n)) -> Maybe (Fin ('Z + n))
forall (k :: Nat) (m :: Nat) (proxy :: Nat -> *) (n :: Nat).
SNat k
-> SNat m -> proxy n -> Fin (k + (m + n)) -> Maybe (Fin (k + n))
strengthenRecFin SNat 'Z
s0 SNat N1
s1 Any n
forall a. HasCallStack => a
undefined

-- | We implement strengthening with the following operation that
-- generalizes the induction hypothesis, so that we can strengthen
-- in the middle of the scope. The scope of the Fin should have the form
-- @k + (m + n)@
--
-- Indices in the middle part of the scope @m@ are "strengthened" away.
--
--- >>> strengthenRecFin s1 s1 s2 (f1 :: Fin (N1 + N1 + N2)) :: Maybe (Fin (N1 + N2))
-- Nothing
--
-- Variables that are in the first part of the scope @k@ (the ones that have
-- most recently entered the context) do not change when strengthening.
--
--- >>> strengthenRecFin s1 s1 s2 (f0 :: Fin (N1 + N1 + N2))
-- Just 0
--
-- Variables in the last part of the scope @n@ are decremented by strengthening
--
-- >>> strengthenRecFin s1 s1 s2 (f2 :: Fin (N1 + N1 + N2)) :: Maybe (Fin N3)
-- Just 1
--
-- >>> strengthenRecFin s1 s1 s2 (f3 :: Fin (N1 + N1 + N2)) :: Maybe (Fin N3)
-- Just 2
--
strengthenRecFin ::
   SNat k -> SNat m -> proxy n -> Fin (k + (m + n)) -> Maybe (Fin (k + n))
strengthenRecFin :: forall (k :: Nat) (m :: Nat) (proxy :: Nat -> *) (n :: Nat).
SNat k
-> SNat m -> proxy n -> Fin (k + (m + n)) -> Maybe (Fin (k + n))
strengthenRecFin SNat k
SZ SNat m
SZ proxy n
n Fin (k + (m + n))
x = Fin n -> Maybe (Fin n)
forall a. a -> Maybe a
Just Fin n
Fin (k + (m + n))
x  -- Base case: k = 0, m = 0
strengthenRecFin SNat k
SZ (SNat m -> SNat_ m
forall (n :: Nat). SNat n -> SNat_ n
snat_ -> SS_ SNat n1
m) proxy n
n Fin (k + (m + n))
FZ = Maybe (Fin n)
Maybe (Fin (k + n))
forall a. Maybe a
Nothing
  -- Case: k = 0, m > 0, and x is in the `m` range
strengthenRecFin SNat k
SZ (SNat m -> SNat_ m
forall (n :: Nat). SNat n -> SNat_ n
snat_ -> SS_ SNat n1
m) proxy n
n (FS Fin n1
x) =
    SNat 'Z
-> SNat n1
-> proxy n
-> Fin ('Z + (n1 + n))
-> Maybe (Fin ('Z + n))
forall (k :: Nat) (m :: Nat) (proxy :: Nat -> *) (n :: Nat).
SNat k
-> SNat m -> proxy n -> Fin (k + (m + n)) -> Maybe (Fin (k + n))
strengthenRecFin SNat 'Z
SZ SNat n1
m proxy n
n Fin n1
Fin ('Z + (n1 + n))
x
strengthenRecFin (SNat k -> SNat_ k
forall (n :: Nat). SNat n -> SNat_ n
snat_ -> SS_ SNat n1
k) SNat m
m proxy n
n Fin (k + (m + n))
FZ = Fin ('S (Plus n1 n)) -> Maybe (Fin ('S (Plus n1 n)))
forall a. a -> Maybe a
Just Fin ('S (Plus n1 n))
forall (n :: Nat). Fin ('S n)
FZ
  -- Case: x < k, leave it alone
strengthenRecFin (SNat k -> SNat_ k
forall (n :: Nat). SNat n -> SNat_ n
snat_ -> SS_ SNat n1
k) SNat m
m proxy n
n (FS Fin n1
x) =
    Fin (Plus n1 n) -> Fin ('S (Plus n1 n))
forall (m :: Nat). Fin m -> Fin ('S m)
FS (Fin (Plus n1 n) -> Fin ('S (Plus n1 n)))
-> Maybe (Fin (Plus n1 n)) -> Maybe (Fin ('S (Plus n1 n)))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> SNat n1
-> SNat m -> proxy n -> Fin (n1 + (m + n)) -> Maybe (Fin (n1 + n))
forall (k :: Nat) (m :: Nat) (proxy :: Nat -> *) (n :: Nat).
SNat k
-> SNat m -> proxy n -> Fin (k + (m + n)) -> Maybe (Fin (k + n))
strengthenRecFin SNat n1
k SNat m
m proxy n
n Fin n1
Fin (n1 + (m + n))
x