{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE ViewPatterns #-}

-- |
-- Module: Covenant.Type
-- Copyright: (C) MLabs 2025
-- License: Apache 2.0
-- Maintainer: koz@mlabs.city, sean@mlabs.city
--
-- Covenant's type system and various ways to construct types.
--
-- @since 1.0.0
module Covenant.Type
  ( -- * Type abstractions
    AbstractTy (..),
    Renamed (..),

    -- * Computation types
    CompT (Comp0, Comp1, Comp2, Comp3, CompN),
    CompTBody (ReturnT, (:--:>), ArgsAndResult),
    arity,

    -- * Value types
    ValT (..),
    BuiltinFlatT (..),
    byteStringT,
    integerT,
    stringT,
    tyvar,
    boolT,
    g1T,
    g2T,
    mlResultT,
    unitT,

    -- * Renaming

    -- ** Types
    RenameError (..),
    RenameM,

    -- ** Introduction
    renameValT,
    renameCompT,

    -- ** Elimination
    runRenameM,

    -- * Type application
    TypeAppError (..),
    checkApp,
  )
where

import Control.Monad (guard)
import Covenant.DeBruijn (DeBruijn)
import Covenant.Index
  ( Count,
    Index,
    count0,
    count1,
    count2,
    count3,
    intCount,
  )
import Covenant.Internal.Rename
  ( RenameError
      ( InvalidAbstractionReference,
        IrrelevantAbstraction,
        UndeterminedAbstraction
      ),
    RenameM,
    renameCompT,
    renameValT,
    runRenameM,
  )
import Covenant.Internal.Type
  ( AbstractTy (BoundAt),
    BuiltinFlatT
      ( BLS12_381_G1_ElementT,
        BLS12_381_G2_ElementT,
        BLS12_381_MlResultT,
        BoolT,
        ByteStringT,
        IntegerT,
        StringT,
        UnitT
      ),
    CompT (CompT),
    CompTBody (CompTBody),
    Renamed (Rigid, Unifiable, Wildcard),
    ValT (Abstraction, BuiltinFlat, ThunkT),
  )
import Covenant.Internal.Unification
  ( TypeAppError
      ( DoesNotUnify,
        ExcessArgs,
        InsufficientArgs,
        LeakingUnifiable,
        LeakingWildcard
      ),
    checkApp,
  )
import Data.Coerce (coerce)
import Data.Kind (Type)
import Data.Vector (Vector)
import Data.Vector qualified as Vector
import Data.Vector.NonEmpty (NonEmptyVector)
import Data.Vector.NonEmpty qualified as NonEmpty
import Optics.Core (preview)

-- | The body of a computation type that doesn't take any arguments and produces
-- the a result of the given value type. Use this just as you would a
-- data constructor.
--
-- = Example
--
-- * @'ReturnT' 'integerT'@ is @!Integer@
--
-- @since 1.0.0
pattern ReturnT :: forall (a :: Type). ValT a -> CompTBody a
pattern $bReturnT :: forall a. ValT a -> CompTBody a
$mReturnT :: forall {r} {a}. CompTBody a -> (ValT a -> r) -> ((# #) -> r) -> r
ReturnT x <- CompTBody (returnHelper -> Just x)
  where
    ReturnT ValT a
x = NonEmptyVector (ValT a) -> CompTBody a
forall a. NonEmptyVector (ValT a) -> CompTBody a
CompTBody (ValT a -> NonEmptyVector (ValT a)
forall a. a -> NonEmptyVector a
NonEmpty.singleton ValT a
x)

-- | Given a type of argument, and the body of another computation type,
-- construct a copy of the body, adding an extra argument of the argument type.
-- Use this just as you would a data constructor.
--
-- = Note
--
-- Together with 'ReturnT', these two patterns provide an exhaustive pattern
-- match.
--
-- = Example
--
-- * @'integerT' :--:> ReturnT 'byteStringT'@ is @Integer -> !ByteString@
--
-- @since 1.0.0
pattern (:--:>) ::
  forall (a :: Type).
  ValT a ->
  CompTBody a ->
  CompTBody a
pattern x $b:--:> :: forall a. ValT a -> CompTBody a -> CompTBody a
$m:--:> :: forall {r} {a}.
CompTBody a -> (ValT a -> CompTBody a -> r) -> ((# #) -> r) -> r
:--:> xs <- CompTBody (arrowHelper -> Just (x, xs))
  where
    ValT a
x :--:> CompTBody a
xs = NonEmptyVector (ValT a) -> CompTBody a
forall a. NonEmptyVector (ValT a) -> CompTBody a
CompTBody (ValT a -> NonEmptyVector (ValT a) -> NonEmptyVector (ValT a)
forall a. a -> NonEmptyVector a -> NonEmptyVector a
NonEmpty.cons ValT a
x (CompTBody a -> NonEmptyVector (ValT a)
forall a b. Coercible a b => a -> b
coerce CompTBody a
xs))

infixr 1 :--:>

-- | A view of a computation type as a 'Vector' of its argument types, together
-- with its result type. Can be used as a data constructor, and is an exhaustive
-- match.
--
-- = Example
--
-- * @'ArgsAndResult' ('Vector.fromList' ['integerT', 'integerT']) 'integerT'@
--   is @Integer -> Integer -> !Integer@
--
-- @since 1.0.0
pattern ArgsAndResult ::
  forall (a :: Type).
  Vector (ValT a) ->
  ValT a ->
  CompTBody a
pattern $bArgsAndResult :: forall a. Vector (ValT a) -> ValT a -> CompTBody a
$mArgsAndResult :: forall {r} {a}.
CompTBody a
-> (Vector (ValT a) -> ValT a -> r) -> ((# #) -> r) -> r
ArgsAndResult args result <- (argsAndResultHelper -> (args, result))
  where
    ArgsAndResult Vector (ValT a)
args ValT a
result = NonEmptyVector (ValT a) -> CompTBody a
forall a. NonEmptyVector (ValT a) -> CompTBody a
CompTBody (Vector (ValT a) -> ValT a -> NonEmptyVector (ValT a)
forall a. Vector a -> a -> NonEmptyVector a
NonEmpty.snocV Vector (ValT a)
args ValT a
result)

{-# COMPLETE ArgsAndResult #-}

{-# COMPLETE ReturnT, (:--:>) #-}

-- | Determine the arity of a computation type: that is, how many arguments a
-- function of this type must be given.
--
-- @since 1.0.0
arity :: forall (a :: Type). CompT a -> Int
arity :: forall a. CompT a -> Int
arity (CompT Count "tyvar"
_ (CompTBody NonEmptyVector (ValT a)
xs)) = NonEmptyVector (ValT a) -> Int
forall a. NonEmptyVector a -> Int
NonEmpty.length NonEmptyVector (ValT a)
xs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1

-- | A computation type that does not bind any type variables. Use this like a
-- data constructor.
--
-- @since 1.0.0
pattern Comp0 ::
  forall (a :: Type).
  CompTBody a ->
  CompT a
pattern $bComp0 :: forall a. CompTBody a -> CompT a
$mComp0 :: forall {r} {a}. CompT a -> (CompTBody a -> r) -> ((# #) -> r) -> r
Comp0 xs <- (countHelper 0 -> Just xs)
  where
    Comp0 CompTBody a
xs = Count "tyvar" -> CompTBody a -> CompT a
forall a. Count "tyvar" -> CompTBody a -> CompT a
CompT Count "tyvar"
forall (ofWhat :: Symbol). Count ofWhat
count0 CompTBody a
xs

-- | A computation type that binds one type variable (that
-- is, something whose type is @forall a . ... -> ...)@. Use this like a data
-- constructor.
--
-- @since 1.0.0
pattern Comp1 ::
  forall (a :: Type).
  CompTBody a ->
  CompT a
pattern $bComp1 :: forall a. CompTBody a -> CompT a
$mComp1 :: forall {r} {a}. CompT a -> (CompTBody a -> r) -> ((# #) -> r) -> r
Comp1 xs <- (countHelper 1 -> Just xs)
  where
    Comp1 CompTBody a
xs = Count "tyvar" -> CompTBody a -> CompT a
forall a. Count "tyvar" -> CompTBody a -> CompT a
CompT Count "tyvar"
forall (ofWhat :: Symbol). Count ofWhat
count1 CompTBody a
xs

-- | A computation type that binds two type variables (that
-- is, something whose type is @forall a b . ... -> ...)@. Use this like a data
-- constructor.
--
-- @since 1.0.0
pattern Comp2 ::
  forall (a :: Type).
  CompTBody a ->
  CompT a
pattern $bComp2 :: forall a. CompTBody a -> CompT a
$mComp2 :: forall {r} {a}. CompT a -> (CompTBody a -> r) -> ((# #) -> r) -> r
Comp2 xs <- (countHelper 2 -> Just xs)
  where
    Comp2 CompTBody a
xs = Count "tyvar" -> CompTBody a -> CompT a
forall a. Count "tyvar" -> CompTBody a -> CompT a
CompT Count "tyvar"
forall (ofWhat :: Symbol). Count ofWhat
count2 CompTBody a
xs

-- | A computation type that binds three type variables
-- (that is, something whose type is @forall a b c . ... -> ...)@. Use this like
-- a data constructor.
--
-- @since 1.0.0
pattern Comp3 ::
  forall (a :: Type).
  CompTBody a ->
  CompT a
pattern $bComp3 :: forall a. CompTBody a -> CompT a
$mComp3 :: forall {r} {a}. CompT a -> (CompTBody a -> r) -> ((# #) -> r) -> r
Comp3 xs <- (countHelper 3 -> Just xs)
  where
    Comp3 CompTBody a
xs = Count "tyvar" -> CompTBody a -> CompT a
forall a. Count "tyvar" -> CompTBody a -> CompT a
CompT Count "tyvar"
forall (ofWhat :: Symbol). Count ofWhat
count3 CompTBody a
xs

-- | A general way to construct and deconstruct computations which bind an
-- arbitrary number of type variables. Use this like a data constructor. Unlike
-- the other @Comp@ patterns, 'CompN' is exhaustive if matched on.
--
-- @since 1.0.0
pattern CompN ::
  Count "tyvar" ->
  CompTBody AbstractTy ->
  CompT AbstractTy
pattern $bCompN :: Count "tyvar" -> CompTBody AbstractTy -> CompT AbstractTy
$mCompN :: forall {r}.
CompT AbstractTy
-> (Count "tyvar" -> CompTBody AbstractTy -> r)
-> ((# #) -> r)
-> r
CompN count xs <- CompT count xs
  where
    CompN Count "tyvar"
count CompTBody AbstractTy
xs = Count "tyvar" -> CompTBody AbstractTy -> CompT AbstractTy
forall a. Count "tyvar" -> CompTBody a -> CompT a
CompT Count "tyvar"
count CompTBody AbstractTy
xs

{-# COMPLETE CompN #-}

-- | Helper for defining type variables.
--
-- @since 1.0.0
tyvar :: DeBruijn -> Index "tyvar" -> ValT AbstractTy
tyvar :: DeBruijn -> Index "tyvar" -> ValT AbstractTy
tyvar DeBruijn
db = AbstractTy -> ValT AbstractTy
forall a. a -> ValT a
Abstraction (AbstractTy -> ValT AbstractTy)
-> (Index "tyvar" -> AbstractTy)
-> Index "tyvar"
-> ValT AbstractTy
forall b c a. (b -> c) -> (a -> b) -> a -> c
. DeBruijn -> Index "tyvar" -> AbstractTy
BoundAt DeBruijn
db

-- | Helper for defining the value type of builtin bytestrings.
--
-- @since 1.0.0
byteStringT :: forall (a :: Type). ValT a
byteStringT :: forall a. ValT a
byteStringT = BuiltinFlatT -> ValT a
forall a. BuiltinFlatT -> ValT a
BuiltinFlat BuiltinFlatT
ByteStringT

-- | Helper for defining the value type of builtin integers.
--
-- @since 1.0.0
integerT :: forall (a :: Type). ValT a
integerT :: forall a. ValT a
integerT = BuiltinFlatT -> ValT a
forall a. BuiltinFlatT -> ValT a
BuiltinFlat BuiltinFlatT
IntegerT

-- | Helper for defining the value type of builtin strings.
--
-- @since 1.0.0
stringT :: forall (a :: Type). ValT a
stringT :: forall a. ValT a
stringT = BuiltinFlatT -> ValT a
forall a. BuiltinFlatT -> ValT a
BuiltinFlat BuiltinFlatT
StringT

-- | Helper for defining the value type of builtin booleans.
--
-- @since 1.0.0
boolT :: forall (a :: Type). ValT a
boolT :: forall a. ValT a
boolT = BuiltinFlatT -> ValT a
forall a. BuiltinFlatT -> ValT a
BuiltinFlat BuiltinFlatT
BoolT

-- | Helper for defining the value type of BLS12-381 G1 curve points.
--
-- @since 1.0.0
g1T :: forall (a :: Type). ValT a
g1T :: forall a. ValT a
g1T = BuiltinFlatT -> ValT a
forall a. BuiltinFlatT -> ValT a
BuiltinFlat BuiltinFlatT
BLS12_381_G1_ElementT

-- | Helper for defining the value type of BLS12-381 G2 curve points.
--
-- @since 1.0.0
g2T :: forall (a :: Type). ValT a
g2T :: forall a. ValT a
g2T = BuiltinFlatT -> ValT a
forall a. BuiltinFlatT -> ValT a
BuiltinFlat BuiltinFlatT
BLS12_381_G2_ElementT

-- | Helper for defining the value type of BLS12-381 multiplication results.
--
-- @since 1.0.0
mlResultT :: forall (a :: Type). ValT a
mlResultT :: forall a. ValT a
mlResultT = BuiltinFlatT -> ValT a
forall a. BuiltinFlatT -> ValT a
BuiltinFlat BuiltinFlatT
BLS12_381_MlResultT

-- | Helper for defining the value type of the builtin unit type.
--
-- @since 1.0.0
unitT :: forall (a :: Type). ValT a
unitT :: forall a. ValT a
unitT = BuiltinFlatT -> ValT a
forall a. BuiltinFlatT -> ValT a
BuiltinFlat BuiltinFlatT
UnitT

-- Helpers

returnHelper ::
  forall (a :: Type).
  NonEmptyVector (ValT a) ->
  Maybe (ValT a)
returnHelper :: forall a. NonEmptyVector (ValT a) -> Maybe (ValT a)
returnHelper NonEmptyVector (ValT a)
xs = case NonEmptyVector (ValT a) -> (ValT a, Vector (ValT a))
forall a. NonEmptyVector a -> (a, Vector a)
NonEmpty.uncons NonEmptyVector (ValT a)
xs of
  (ValT a
y, Vector (ValT a)
ys) ->
    if Vector (ValT a) -> Int
forall a. Vector a -> Int
Vector.length Vector (ValT a)
ys Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0
      then ValT a -> Maybe (ValT a)
forall a. a -> Maybe a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ValT a
y
      else Maybe (ValT a)
forall a. Maybe a
Nothing

arrowHelper ::
  forall (a :: Type).
  NonEmptyVector (ValT a) ->
  Maybe (ValT a, CompTBody a)
arrowHelper :: forall a. NonEmptyVector (ValT a) -> Maybe (ValT a, CompTBody a)
arrowHelper NonEmptyVector (ValT a)
xs = case NonEmptyVector (ValT a) -> (ValT a, Vector (ValT a))
forall a. NonEmptyVector a -> (a, Vector a)
NonEmpty.uncons NonEmptyVector (ValT a)
xs of
  (ValT a
y, Vector (ValT a)
ys) -> (ValT a
y,) (CompTBody a -> (ValT a, CompTBody a))
-> (NonEmptyVector (ValT a) -> CompTBody a)
-> NonEmptyVector (ValT a)
-> (ValT a, CompTBody a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmptyVector (ValT a) -> CompTBody a
forall a. NonEmptyVector (ValT a) -> CompTBody a
CompTBody (NonEmptyVector (ValT a) -> (ValT a, CompTBody a))
-> Maybe (NonEmptyVector (ValT a)) -> Maybe (ValT a, CompTBody a)
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
<$> Vector (ValT a) -> Maybe (NonEmptyVector (ValT a))
forall a. Vector a -> Maybe (NonEmptyVector a)
NonEmpty.fromVector Vector (ValT a)
ys

argsAndResultHelper ::
  forall (a :: Type).
  CompTBody a ->
  (Vector (ValT a), ValT a)
argsAndResultHelper :: forall a. CompTBody a -> (Vector (ValT a), ValT a)
argsAndResultHelper (CompTBody NonEmptyVector (ValT a)
xs) = NonEmptyVector (ValT a) -> (Vector (ValT a), ValT a)
forall a. NonEmptyVector a -> (Vector a, a)
NonEmpty.unsnoc NonEmptyVector (ValT a)
xs

countHelper ::
  forall (a :: Type).
  Int ->
  CompT a ->
  Maybe (CompTBody a)
countHelper :: forall a. Int -> CompT a -> Maybe (CompTBody a)
countHelper Int
expected (CompT Count "tyvar"
actual CompTBody a
xs) = do
  Count "tyvar"
expectedCount <- Optic' A_Prism NoIx Int (Count "tyvar")
-> Int -> Maybe (Count "tyvar")
forall k (is :: IxList) s a.
Is k An_AffineFold =>
Optic' k is s a -> s -> Maybe a
preview Optic' A_Prism NoIx Int (Count "tyvar")
forall (ofWhat :: Symbol). Prism' Int (Count ofWhat)
intCount Int
expected
  Bool -> Maybe ()
forall (f :: Type -> Type). Alternative f => Bool -> f ()
guard (Count "tyvar"
expectedCount Count "tyvar" -> Count "tyvar" -> Bool
forall a. Eq a => a -> a -> Bool
== Count "tyvar"
actual)
  CompTBody a -> Maybe (CompTBody a)
forall a. a -> Maybe a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure CompTBody a
xs