{-# LANGUAGE CPP #-}

module Covenant.Internal.Unification
  ( TypeAppError (..),
    checkApp,
  )
where

import Control.Monad (foldM, unless)
import Data.Ord (comparing)
#if __GLASGOW_HASKELL__==908
import Data.Foldable (foldl')
#endif
import Control.Monad.Except (catchError, throwError)
import Covenant.Index (Index, intCount, intIndex)
import Covenant.Internal.Type
  ( BuiltinFlatT,
    CompT (CompT),
    CompTBody (CompTBody),
    Renamed (Rigid, Unifiable, Wildcard),
    ValT (Abstraction, BuiltinFlat, ThunkT),
  )
import Data.Kind (Type)
import Data.Map (Map)
import Data.Map.Merge.Strict qualified as Merge
import Data.Map.Strict qualified as Map
import Data.Maybe (fromJust, mapMaybe)
import Data.Set (Set)
import Data.Set qualified as Set
import Data.Vector (Vector)
import Data.Vector qualified as Vector
import Data.Vector.NonEmpty qualified as NonEmpty
import Data.Word (Word64)
import Optics.Core (preview)

-- | @since 1.0.0
data TypeAppError
  = -- | The final type after all arguments are applied is @forall a . a@.
    LeakingUnifiable (Index "tyvar")
  | -- | A wildcard (thus, a skolem) escaped its scope.
    LeakingWildcard Word64 Int (Index "tyvar")
  | -- | We were given too many arguments.
    ExcessArgs (CompT Renamed) (Vector (Maybe (ValT Renamed)))
  | -- | We weren't given enough arguments.
    InsufficientArgs (CompT Renamed)
  | -- | The expected type (first field) and actual type (second field) do not
    -- unify.
    DoesNotUnify (ValT Renamed) (ValT Renamed)
  deriving stock
    ( -- | @since 1.0.0
      TypeAppError -> TypeAppError -> Bool
(TypeAppError -> TypeAppError -> Bool)
-> (TypeAppError -> TypeAppError -> Bool) -> Eq TypeAppError
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TypeAppError -> TypeAppError -> Bool
== :: TypeAppError -> TypeAppError -> Bool
$c/= :: TypeAppError -> TypeAppError -> Bool
/= :: TypeAppError -> TypeAppError -> Bool
Eq,
      -- | @since 1.0.0
      Int -> TypeAppError -> ShowS
[TypeAppError] -> ShowS
TypeAppError -> String
(Int -> TypeAppError -> ShowS)
-> (TypeAppError -> String)
-> ([TypeAppError] -> ShowS)
-> Show TypeAppError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TypeAppError -> ShowS
showsPrec :: Int -> TypeAppError -> ShowS
$cshow :: TypeAppError -> String
show :: TypeAppError -> String
$cshowList :: [TypeAppError] -> ShowS
showList :: [TypeAppError] -> ShowS
Show
    )

-- | @since 1.0.0
checkApp :: CompT Renamed -> [Maybe (ValT Renamed)] -> Either TypeAppError (ValT Renamed)
checkApp :: CompT Renamed
-> [Maybe (ValT Renamed)] -> Either TypeAppError (ValT Renamed)
checkApp f :: CompT Renamed
f@(CompT Count "tyvar"
_ (CompTBody NonEmptyVector (ValT Renamed)
xs)) =
  let (ValT Renamed
curr, Vector (ValT Renamed)
rest) = NonEmptyVector (ValT Renamed)
-> (ValT Renamed, Vector (ValT Renamed))
forall a. NonEmptyVector a -> (a, Vector a)
NonEmpty.uncons NonEmptyVector (ValT Renamed)
xs
   in ValT Renamed
-> [ValT Renamed]
-> [Maybe (ValT Renamed)]
-> Either TypeAppError (ValT Renamed)
go ValT Renamed
curr (Vector (ValT Renamed) -> [ValT Renamed]
forall a. Vector a -> [a]
Vector.toList Vector (ValT Renamed)
rest)
  where
    go ::
      ValT Renamed ->
      [ValT Renamed] ->
      [Maybe (ValT Renamed)] ->
      Either TypeAppError (ValT Renamed)
    go :: ValT Renamed
-> [ValT Renamed]
-> [Maybe (ValT Renamed)]
-> Either TypeAppError (ValT Renamed)
go ValT Renamed
currParam [ValT Renamed]
restParams [Maybe (ValT Renamed)]
args = case [ValT Renamed]
restParams of
      [] -> case [Maybe (ValT Renamed)]
args of
        -- If we got here, currParam is the resulting type after all
        -- substitutions have been applied.
        [] -> ValT Renamed -> Either TypeAppError (ValT Renamed)
fixUp ValT Renamed
currParam
        [Maybe (ValT Renamed)]
_ -> TypeAppError -> Either TypeAppError (ValT Renamed)
forall a. TypeAppError -> Either TypeAppError a
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (TypeAppError -> Either TypeAppError (ValT Renamed))
-> ([Maybe (ValT Renamed)] -> TypeAppError)
-> [Maybe (ValT Renamed)]
-> Either TypeAppError (ValT Renamed)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompT Renamed -> Vector (Maybe (ValT Renamed)) -> TypeAppError
ExcessArgs CompT Renamed
f (Vector (Maybe (ValT Renamed)) -> TypeAppError)
-> ([Maybe (ValT Renamed)] -> Vector (Maybe (ValT Renamed)))
-> [Maybe (ValT Renamed)]
-> TypeAppError
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Maybe (ValT Renamed)] -> Vector (Maybe (ValT Renamed))
forall a. [a] -> Vector a
Vector.fromList ([Maybe (ValT Renamed)] -> Either TypeAppError (ValT Renamed))
-> [Maybe (ValT Renamed)] -> Either TypeAppError (ValT Renamed)
forall a b. (a -> b) -> a -> b
$ [Maybe (ValT Renamed)]
args
      [ValT Renamed]
_ -> case [Maybe (ValT Renamed)]
args of
        [] -> TypeAppError -> Either TypeAppError (ValT Renamed)
forall a. TypeAppError -> Either TypeAppError a
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (TypeAppError -> Either TypeAppError (ValT Renamed))
-> (CompT Renamed -> TypeAppError)
-> CompT Renamed
-> Either TypeAppError (ValT Renamed)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompT Renamed -> TypeAppError
InsufficientArgs (CompT Renamed -> Either TypeAppError (ValT Renamed))
-> CompT Renamed -> Either TypeAppError (ValT Renamed)
forall a b. (a -> b) -> a -> b
$ CompT Renamed
f
        (Maybe (ValT Renamed)
currArg : [Maybe (ValT Renamed)]
restArgs) -> do
          [ValT Renamed]
newRestParams <- case Maybe (ValT Renamed)
currArg of
            -- An error argument unifies with anything, as it's effectively
            -- `forall a . a`. Furthermore, it requires no substitutional
            -- changes. Thus, we can just skip it.
            Maybe (ValT Renamed)
Nothing -> [ValT Renamed] -> Either TypeAppError [ValT Renamed]
forall a. a -> Either TypeAppError a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure [ValT Renamed]
restParams
            Just ValT Renamed
currArg' -> do
              Map (Index "tyvar") (ValT Renamed)
subs <- Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
-> (TypeAppError
    -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed)))
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall a.
Either TypeAppError a
-> (TypeAppError -> Either TypeAppError a) -> Either TypeAppError a
forall e (m :: Type -> Type) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError (ValT Renamed
-> ValT Renamed
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
unify ValT Renamed
currParam ValT Renamed
currArg') (ValT Renamed
-> ValT Renamed
-> TypeAppError
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall a.
ValT Renamed
-> ValT Renamed -> TypeAppError -> Either TypeAppError a
promoteUnificationError ValT Renamed
currParam ValT Renamed
currArg')
              [ValT Renamed] -> Either TypeAppError [ValT Renamed]
forall a. a -> Either TypeAppError a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ([ValT Renamed] -> Either TypeAppError [ValT Renamed])
-> (Map (Index "tyvar") (ValT Renamed) -> [ValT Renamed])
-> Map (Index "tyvar") (ValT Renamed)
-> Either TypeAppError [ValT Renamed]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([ValT Renamed] -> Index "tyvar" -> ValT Renamed -> [ValT Renamed])
-> [ValT Renamed]
-> Map (Index "tyvar") (ValT Renamed)
-> [ValT Renamed]
forall a k b. (a -> k -> b -> a) -> a -> Map k b -> a
Map.foldlWithKey' [ValT Renamed] -> Index "tyvar" -> ValT Renamed -> [ValT Renamed]
applySub [ValT Renamed]
restParams (Map (Index "tyvar") (ValT Renamed)
 -> Either TypeAppError [ValT Renamed])
-> Map (Index "tyvar") (ValT Renamed)
-> Either TypeAppError [ValT Renamed]
forall a b. (a -> b) -> a -> b
$ Map (Index "tyvar") (ValT Renamed)
subs
          case [ValT Renamed]
newRestParams of
            [] -> TypeAppError -> Either TypeAppError (ValT Renamed)
forall a. TypeAppError -> Either TypeAppError a
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (TypeAppError -> Either TypeAppError (ValT Renamed))
-> (CompT Renamed -> TypeAppError)
-> CompT Renamed
-> Either TypeAppError (ValT Renamed)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompT Renamed -> TypeAppError
InsufficientArgs (CompT Renamed -> Either TypeAppError (ValT Renamed))
-> CompT Renamed -> Either TypeAppError (ValT Renamed)
forall a b. (a -> b) -> a -> b
$ CompT Renamed
f
            (ValT Renamed
currParam' : [ValT Renamed]
restParams') -> ValT Renamed
-> [ValT Renamed]
-> [Maybe (ValT Renamed)]
-> Either TypeAppError (ValT Renamed)
go ValT Renamed
currParam' [ValT Renamed]
restParams' [Maybe (ValT Renamed)]
restArgs

-- Helpers

applySub ::
  [ValT Renamed] ->
  Index "tyvar" ->
  ValT Renamed ->
  [ValT Renamed]
applySub :: [ValT Renamed] -> Index "tyvar" -> ValT Renamed -> [ValT Renamed]
applySub [ValT Renamed]
acc Index "tyvar"
index ValT Renamed
sub = (ValT Renamed -> ValT Renamed) -> [ValT Renamed] -> [ValT Renamed]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (Index "tyvar" -> ValT Renamed -> ValT Renamed -> ValT Renamed
substitute Index "tyvar"
index ValT Renamed
sub) [ValT Renamed]
acc

substitute ::
  Index "tyvar" ->
  ValT Renamed ->
  ValT Renamed ->
  ValT Renamed
substitute :: Index "tyvar" -> ValT Renamed -> ValT Renamed -> ValT Renamed
substitute Index "tyvar"
index ValT Renamed
toSub = \case
  Abstraction Renamed
t -> case Renamed
t of
    Unifiable Index "tyvar"
ourIndex ->
      if Index "tyvar"
ourIndex Index "tyvar" -> Index "tyvar" -> Bool
forall a. Eq a => a -> a -> Bool
== Index "tyvar"
index
        then ValT Renamed
toSub
        else Renamed -> ValT Renamed
forall a. a -> ValT a
Abstraction Renamed
t
    Renamed
_ -> Renamed -> ValT Renamed
forall a. a -> ValT a
Abstraction Renamed
t
  ThunkT (CompT Count "tyvar"
abstractions (CompTBody NonEmptyVector (ValT Renamed)
xs)) ->
    CompT Renamed -> ValT Renamed
forall a. CompT a -> ValT a
ThunkT (CompT Renamed -> ValT Renamed)
-> (NonEmptyVector (ValT Renamed) -> CompT Renamed)
-> NonEmptyVector (ValT Renamed)
-> ValT Renamed
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Count "tyvar" -> CompTBody Renamed -> CompT Renamed
forall a. Count "tyvar" -> CompTBody a -> CompT a
CompT Count "tyvar"
abstractions (CompTBody Renamed -> CompT Renamed)
-> (NonEmptyVector (ValT Renamed) -> CompTBody Renamed)
-> NonEmptyVector (ValT Renamed)
-> CompT Renamed
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmptyVector (ValT Renamed) -> CompTBody Renamed
forall a. NonEmptyVector (ValT a) -> CompTBody a
CompTBody (NonEmptyVector (ValT Renamed) -> CompTBody Renamed)
-> (NonEmptyVector (ValT Renamed) -> NonEmptyVector (ValT Renamed))
-> NonEmptyVector (ValT Renamed)
-> CompTBody Renamed
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ValT Renamed -> ValT Renamed)
-> NonEmptyVector (ValT Renamed) -> NonEmptyVector (ValT Renamed)
forall a b. (a -> b) -> NonEmptyVector a -> NonEmptyVector b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (Index "tyvar" -> ValT Renamed -> ValT Renamed -> ValT Renamed
substitute Index "tyvar"
index ValT Renamed
toSub) (NonEmptyVector (ValT Renamed) -> ValT Renamed)
-> NonEmptyVector (ValT Renamed) -> ValT Renamed
forall a b. (a -> b) -> a -> b
$ NonEmptyVector (ValT Renamed)
xs
  BuiltinFlat BuiltinFlatT
t -> BuiltinFlatT -> ValT Renamed
forall a. BuiltinFlatT -> ValT a
BuiltinFlat BuiltinFlatT
t

-- Because unification is inherently recursive, if we find an error deep within
-- a type, the message will signify only the _part_ that fails to unify, not the
-- entire type. While potentially useful, this can be quite confusing,
-- especially with generated types. Thus, we use `catchError` with this
-- function, which effectively allows us to rename the types reported in
-- unification errors to whatever types 'wrap' them.
promoteUnificationError ::
  forall (a :: Type).
  ValT Renamed ->
  ValT Renamed ->
  TypeAppError ->
  Either TypeAppError a
promoteUnificationError :: forall a.
ValT Renamed
-> ValT Renamed -> TypeAppError -> Either TypeAppError a
promoteUnificationError ValT Renamed
topLevelExpected ValT Renamed
topLevelActual =
  TypeAppError -> Either TypeAppError a
forall a b. a -> Either a b
Left (TypeAppError -> Either TypeAppError a)
-> (TypeAppError -> TypeAppError)
-> TypeAppError
-> Either TypeAppError a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. \case
    DoesNotUnify ValT Renamed
_ ValT Renamed
_ -> ValT Renamed -> ValT Renamed -> TypeAppError
DoesNotUnify ValT Renamed
topLevelExpected ValT Renamed
topLevelActual
    TypeAppError
err -> TypeAppError
err

fixUp :: ValT Renamed -> Either TypeAppError (ValT Renamed)
fixUp :: ValT Renamed -> Either TypeAppError (ValT Renamed)
fixUp = \case
  -- We have a result that's effectively `forall a . a` but not an error
  Abstraction (Unifiable Index "tyvar"
index) -> TypeAppError -> Either TypeAppError (ValT Renamed)
forall a. TypeAppError -> Either TypeAppError a
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (TypeAppError -> Either TypeAppError (ValT Renamed))
-> (Index "tyvar" -> TypeAppError)
-> Index "tyvar"
-> Either TypeAppError (ValT Renamed)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Index "tyvar" -> TypeAppError
LeakingUnifiable (Index "tyvar" -> Either TypeAppError (ValT Renamed))
-> Index "tyvar" -> Either TypeAppError (ValT Renamed)
forall a b. (a -> b) -> a -> b
$ Index "tyvar"
index
  -- We're doing the equivalent of failing the `ST` trick
  Abstraction (Wildcard Word64
scopeId Int
trueLevel Index "tyvar"
index) -> TypeAppError -> Either TypeAppError (ValT Renamed)
forall a. TypeAppError -> Either TypeAppError a
forall e (m :: Type -> Type) a. MonadError e m => e -> m a
throwError (TypeAppError -> Either TypeAppError (ValT Renamed))
-> (Index "tyvar" -> TypeAppError)
-> Index "tyvar"
-> Either TypeAppError (ValT Renamed)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> Int -> Index "tyvar" -> TypeAppError
LeakingWildcard Word64
scopeId Int
trueLevel (Index "tyvar" -> Either TypeAppError (ValT Renamed))
-> Index "tyvar" -> Either TypeAppError (ValT Renamed)
forall a b. (a -> b) -> a -> b
$ Index "tyvar"
index
  -- We may have a result with fewer unifiables than we started with
  -- This can be a problem, as we might be referring to unifiables that don't
  -- exist anymore
  ThunkT (CompT Count "tyvar"
_ (CompTBody NonEmptyVector (ValT Renamed)
xs)) -> do
    -- Figure out how many variables the thunk has to introduce now
    let remainingUnifiables :: Set (Index "tyvar")
remainingUnifiables = (Set (Index "tyvar") -> ValT Renamed -> Set (Index "tyvar"))
-> Set (Index "tyvar")
-> NonEmptyVector (ValT Renamed)
-> Set (Index "tyvar")
forall a b. (a -> b -> a) -> a -> NonEmptyVector b -> a
NonEmpty.foldl' (\Set (Index "tyvar")
acc ValT Renamed
t -> Set (Index "tyvar")
acc Set (Index "tyvar") -> Set (Index "tyvar") -> Set (Index "tyvar")
forall a. Semigroup a => a -> a -> a
<> ValT Renamed -> Set (Index "tyvar")
collectUnifiables ValT Renamed
t) Set (Index "tyvar")
forall a. Set a
Set.empty NonEmptyVector (ValT Renamed)
xs
    let requiredIntroductions :: Int
requiredIntroductions = Set (Index "tyvar") -> Int
forall a. Set a -> Int
Set.size Set (Index "tyvar")
remainingUnifiables
    -- We know that the size of a set can't be negative, but GHC doesn't.
    let asCount :: Count "tyvar"
asCount = Maybe (Count "tyvar") -> Count "tyvar"
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (Count "tyvar") -> Count "tyvar")
-> (Int -> Maybe (Count "tyvar")) -> Int -> Count "tyvar"
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Optic' A_Prism NoIx Int (Count "tyvar")
-> Int -> Maybe (Count "tyvar")
forall k (is :: IxList) s a.
Is k An_AffineFold =>
Optic' k is s a -> s -> Maybe a
preview Optic' A_Prism NoIx Int (Count "tyvar")
forall (ofWhat :: Symbol). Prism' Int (Count ofWhat)
intCount (Int -> Count "tyvar") -> Int -> Count "tyvar"
forall a b. (a -> b) -> a -> b
$ Int
requiredIntroductions
    -- Make enough indexes for us to use in one go
    let indexesToUse :: [Index "tyvar"]
indexesToUse = (Int -> Maybe (Index "tyvar")) -> [Int] -> [Index "tyvar"]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe (Optic' A_Prism NoIx Int (Index "tyvar")
-> Int -> Maybe (Index "tyvar")
forall k (is :: IxList) s a.
Is k An_AffineFold =>
Optic' k is s a -> s -> Maybe a
preview Optic' A_Prism NoIx Int (Index "tyvar")
forall (ofWhat :: Symbol). Prism' Int (Index ofWhat)
intIndex) [Int
0, Int
1 .. Int
requiredIntroductions Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
    -- Construct a mapping between old, possibly non-contiguous, unifiables and
    -- our new ones
    let renames :: [(Index "tyvar", ValT Renamed)]
renames =
          (Index "tyvar" -> Index "tyvar" -> (Index "tyvar", ValT Renamed))
-> [Index "tyvar"]
-> [Index "tyvar"]
-> [(Index "tyvar", ValT Renamed)]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith
            (\Index "tyvar"
i Index "tyvar"
replacement -> (Index "tyvar"
i, Renamed -> ValT Renamed
forall a. a -> ValT a
Abstraction (Renamed -> ValT Renamed)
-> (Index "tyvar" -> Renamed) -> Index "tyvar" -> ValT Renamed
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Index "tyvar" -> Renamed
Unifiable (Index "tyvar" -> ValT Renamed) -> Index "tyvar" -> ValT Renamed
forall a b. (a -> b) -> a -> b
$ Index "tyvar"
replacement))
            (Set (Index "tyvar") -> [Index "tyvar"]
forall a. Set a -> [a]
Set.toList Set (Index "tyvar")
remainingUnifiables)
            [Index "tyvar"]
indexesToUse
    let fixed :: NonEmptyVector (ValT Renamed)
fixed = (ValT Renamed -> ValT Renamed)
-> NonEmptyVector (ValT Renamed) -> NonEmptyVector (ValT Renamed)
forall a b. (a -> b) -> NonEmptyVector a -> NonEmptyVector b
forall (f :: Type -> Type) a b. Functor f => (a -> b) -> f a -> f b
fmap (\ValT Renamed
t -> (ValT Renamed -> (Index "tyvar", ValT Renamed) -> ValT Renamed)
-> ValT Renamed -> [(Index "tyvar", ValT Renamed)] -> ValT Renamed
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: Type -> Type) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' (\ValT Renamed
acc (Index "tyvar"
i, ValT Renamed
r) -> Index "tyvar" -> ValT Renamed -> ValT Renamed -> ValT Renamed
substitute Index "tyvar"
i ValT Renamed
r ValT Renamed
acc) ValT Renamed
t [(Index "tyvar", ValT Renamed)]
renames) NonEmptyVector (ValT Renamed)
xs
    ValT Renamed -> Either TypeAppError (ValT Renamed)
forall a. a -> Either TypeAppError a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (ValT Renamed -> Either TypeAppError (ValT Renamed))
-> (NonEmptyVector (ValT Renamed) -> ValT Renamed)
-> NonEmptyVector (ValT Renamed)
-> Either TypeAppError (ValT Renamed)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. CompT Renamed -> ValT Renamed
forall a. CompT a -> ValT a
ThunkT (CompT Renamed -> ValT Renamed)
-> (NonEmptyVector (ValT Renamed) -> CompT Renamed)
-> NonEmptyVector (ValT Renamed)
-> ValT Renamed
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Count "tyvar" -> CompTBody Renamed -> CompT Renamed
forall a. Count "tyvar" -> CompTBody a -> CompT a
CompT Count "tyvar"
asCount (CompTBody Renamed -> CompT Renamed)
-> (NonEmptyVector (ValT Renamed) -> CompTBody Renamed)
-> NonEmptyVector (ValT Renamed)
-> CompT Renamed
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmptyVector (ValT Renamed) -> CompTBody Renamed
forall a. NonEmptyVector (ValT a) -> CompTBody a
CompTBody (NonEmptyVector (ValT Renamed)
 -> Either TypeAppError (ValT Renamed))
-> NonEmptyVector (ValT Renamed)
-> Either TypeAppError (ValT Renamed)
forall a b. (a -> b) -> a -> b
$ NonEmptyVector (ValT Renamed)
fixed
  ValT Renamed
t -> ValT Renamed -> Either TypeAppError (ValT Renamed)
forall a. a -> Either TypeAppError a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure ValT Renamed
t

collectUnifiables :: ValT Renamed -> Set (Index "tyvar")
collectUnifiables :: ValT Renamed -> Set (Index "tyvar")
collectUnifiables = \case
  Abstraction Renamed
t -> case Renamed
t of
    Unifiable Index "tyvar"
index -> Index "tyvar" -> Set (Index "tyvar")
forall a. a -> Set a
Set.singleton Index "tyvar"
index
    Renamed
_ -> Set (Index "tyvar")
forall a. Set a
Set.empty
  BuiltinFlat BuiltinFlatT
_ -> Set (Index "tyvar")
forall a. Set a
Set.empty
  ThunkT (CompT Count "tyvar"
_ (CompTBody NonEmptyVector (ValT Renamed)
xs)) -> (Set (Index "tyvar") -> ValT Renamed -> Set (Index "tyvar"))
-> Set (Index "tyvar")
-> NonEmptyVector (ValT Renamed)
-> Set (Index "tyvar")
forall a b. (a -> b -> a) -> a -> NonEmptyVector b -> a
NonEmpty.foldl' (\Set (Index "tyvar")
acc ValT Renamed
t -> Set (Index "tyvar")
acc Set (Index "tyvar") -> Set (Index "tyvar") -> Set (Index "tyvar")
forall a. Semigroup a => a -> a -> a
<> ValT Renamed -> Set (Index "tyvar")
collectUnifiables ValT Renamed
t) Set (Index "tyvar")
forall a. Set a
Set.empty NonEmptyVector (ValT Renamed)
xs

unify ::
  ValT Renamed ->
  ValT Renamed ->
  Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
unify :: ValT Renamed
-> ValT Renamed
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
unify ValT Renamed
expected ValT Renamed
actual =
  Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
-> (TypeAppError
    -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed)))
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall a.
Either TypeAppError a
-> (TypeAppError -> Either TypeAppError a) -> Either TypeAppError a
forall e (m :: Type -> Type) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError
    ( case ValT Renamed
expected of
        Abstraction Renamed
t1 -> case Renamed
t1 of
          -- Unifiables unify with everything, and require a substitutional rewrite.
          Unifiable Index "tyvar"
index1 -> Map (Index "tyvar") (ValT Renamed)
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall a. a -> Either TypeAppError a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (Map (Index "tyvar") (ValT Renamed)
 -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed)))
-> (ValT Renamed -> Map (Index "tyvar") (ValT Renamed))
-> ValT Renamed
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Index "tyvar" -> ValT Renamed -> Map (Index "tyvar") (ValT Renamed)
forall k a. k -> a -> Map k a
Map.singleton Index "tyvar"
index1 (ValT Renamed
 -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed)))
-> ValT Renamed
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall a b. (a -> b) -> a -> b
$ ValT Renamed
actual
          Rigid Int
level1 Index "tyvar"
index1 -> Int
-> Index "tyvar"
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
expectRigid Int
level1 Index "tyvar"
index1
          Wildcard Word64
scopeId1 Int
_ Index "tyvar"
index1 -> Word64
-> Index "tyvar"
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
expectWildcard Word64
scopeId1 Index "tyvar"
index1
        ThunkT CompT Renamed
t1 -> CompT Renamed
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
expectThunk CompT Renamed
t1
        BuiltinFlat BuiltinFlatT
t1 -> BuiltinFlatT
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
expectFlatBuiltin BuiltinFlatT
t1
    )
    (ValT Renamed
-> ValT Renamed
-> TypeAppError
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall a.
ValT Renamed
-> ValT Renamed -> TypeAppError -> Either TypeAppError a
promoteUnificationError ValT Renamed
expected ValT Renamed
actual)
  where
    unificationError :: forall (a :: Type). Either TypeAppError a
    unificationError :: forall a. Either TypeAppError a
unificationError = TypeAppError -> Either TypeAppError a
forall a b. a -> Either a b
Left (TypeAppError -> Either TypeAppError a)
-> (ValT Renamed -> TypeAppError)
-> ValT Renamed
-> Either TypeAppError a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ValT Renamed -> ValT Renamed -> TypeAppError
DoesNotUnify ValT Renamed
expected (ValT Renamed -> Either TypeAppError a)
-> ValT Renamed -> Either TypeAppError a
forall a b. (a -> b) -> a -> b
$ ValT Renamed
actual
    noSubUnify :: forall (k :: Type) (a :: Type). Either TypeAppError (Map k a)
    noSubUnify :: forall k a. Either TypeAppError (Map k a)
noSubUnify = Map k a -> Either TypeAppError (Map k a)
forall a. a -> Either TypeAppError a
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure Map k a
forall k a. Map k a
Map.empty
    expectRigid ::
      Int -> Index "tyvar" -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
    -- Rigids behave identically to concrete types: they can unify with
    -- themselves, or any other abstraction, but nothing else. No substitutional
    -- rewrites are needed.
    expectRigid :: Int
-> Index "tyvar"
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
expectRigid Int
level1 Index "tyvar"
index1 = case ValT Renamed
actual of
      Abstraction (Rigid Int
level2 Index "tyvar"
index2) ->
        if Int
level1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
level2 Bool -> Bool -> Bool
&& Index "tyvar"
index1 Index "tyvar" -> Index "tyvar" -> Bool
forall a. Eq a => a -> a -> Bool
== Index "tyvar"
index2
          then Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall k a. Either TypeAppError (Map k a)
noSubUnify
          else Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall a. Either TypeAppError a
unificationError
      Abstraction Renamed
_ -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall k a. Either TypeAppError (Map k a)
noSubUnify
      ValT Renamed
_ -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall a. Either TypeAppError a
unificationError
    expectWildcard ::
      Word64 -> Index "tyvar" -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
    -- Wildcards can unify with unifiables, as well as themselves, but nothing
    -- else. No substitutional rewrites are needed.
    expectWildcard :: Word64
-> Index "tyvar"
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
expectWildcard Word64
scopeId1 Index "tyvar"
index1 = case ValT Renamed
actual of
      Abstraction (Unifiable Index "tyvar"
_) -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall k a. Either TypeAppError (Map k a)
noSubUnify
      Abstraction (Wildcard Word64
scopeId2 Int
_ Index "tyvar"
index2) ->
        if Word64
scopeId1 Word64 -> Word64 -> Bool
forall a. Eq a => a -> a -> Bool
/= Word64
scopeId2 Bool -> Bool -> Bool
|| Index "tyvar"
index1 Index "tyvar" -> Index "tyvar" -> Bool
forall a. Eq a => a -> a -> Bool
== Index "tyvar"
index2
          then Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall k a. Either TypeAppError (Map k a)
noSubUnify
          else Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall a. Either TypeAppError a
unificationError
      ValT Renamed
_ -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall a. Either TypeAppError a
unificationError
    expectThunk :: CompT Renamed -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
    -- Thunks unify unconditionally with wildcards or unifiables. They unify
    -- conditionally with other thunks, provided that we can unify each argument
    -- with its counterpart in the same position, as well as their result types,
    -- without conflicts.
    expectThunk :: CompT Renamed
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
expectThunk (CompT Count "tyvar"
_ (CompTBody NonEmptyVector (ValT Renamed)
t1)) = case ValT Renamed
actual of
      Abstraction (Rigid Int
_ Index "tyvar"
_) -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall a. Either TypeAppError a
unificationError
      Abstraction Renamed
_ -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall k a. Either TypeAppError (Map k a)
noSubUnify
      ThunkT (CompT Count "tyvar"
_ (CompTBody NonEmptyVector (ValT Renamed)
t2)) -> do
        Bool -> Either TypeAppError () -> Either TypeAppError ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless ((NonEmptyVector (ValT Renamed) -> Int)
-> NonEmptyVector (ValT Renamed)
-> NonEmptyVector (ValT Renamed)
-> Ordering
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing NonEmptyVector (ValT Renamed) -> Int
forall a. NonEmptyVector a -> Int
NonEmpty.length NonEmptyVector (ValT Renamed)
t1 NonEmptyVector (ValT Renamed)
t2 Ordering -> Ordering -> Bool
forall a. Eq a => a -> a -> Bool
== Ordering
EQ) Either TypeAppError ()
forall a. Either TypeAppError a
unificationError
        Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
-> (TypeAppError
    -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed)))
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall a.
Either TypeAppError a
-> (TypeAppError -> Either TypeAppError a) -> Either TypeAppError a
forall e (m :: Type -> Type) a.
MonadError e m =>
m a -> (e -> m a) -> m a
catchError
          ((Map (Index "tyvar") (ValT Renamed)
 -> (ValT Renamed, ValT Renamed)
 -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed)))
-> Map (Index "tyvar") (ValT Renamed)
-> NonEmptyVector (ValT Renamed, ValT Renamed)
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall (t :: Type -> Type) (m :: Type -> Type) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldM (\Map (Index "tyvar") (ValT Renamed)
acc (ValT Renamed
l, ValT Renamed
r) -> ValT Renamed
-> ValT Renamed
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
unify ValT Renamed
l ValT Renamed
r Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
-> (Map (Index "tyvar") (ValT Renamed)
    -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed)))
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall a b.
Either TypeAppError a
-> (a -> Either TypeAppError b) -> Either TypeAppError b
forall (m :: Type -> Type) a b. Monad m => m a -> (a -> m b) -> m b
>>= Map (Index "tyvar") (ValT Renamed)
-> Map (Index "tyvar") (ValT Renamed)
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
reconcile Map (Index "tyvar") (ValT Renamed)
acc) Map (Index "tyvar") (ValT Renamed)
forall k a. Map k a
Map.empty (NonEmptyVector (ValT Renamed, ValT Renamed)
 -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed)))
-> (NonEmptyVector (ValT Renamed)
    -> NonEmptyVector (ValT Renamed, ValT Renamed))
-> NonEmptyVector (ValT Renamed)
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmptyVector (ValT Renamed)
-> NonEmptyVector (ValT Renamed)
-> NonEmptyVector (ValT Renamed, ValT Renamed)
forall a b.
NonEmptyVector a -> NonEmptyVector b -> NonEmptyVector (a, b)
NonEmpty.zip NonEmptyVector (ValT Renamed)
t1 (NonEmptyVector (ValT Renamed)
 -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed)))
-> NonEmptyVector (ValT Renamed)
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall a b. (a -> b) -> a -> b
$ NonEmptyVector (ValT Renamed)
t2)
          (ValT Renamed
-> ValT Renamed
-> TypeAppError
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall a.
ValT Renamed
-> ValT Renamed -> TypeAppError -> Either TypeAppError a
promoteUnificationError ValT Renamed
expected ValT Renamed
actual)
      ValT Renamed
_ -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall a. Either TypeAppError a
unificationError
    expectFlatBuiltin :: BuiltinFlatT -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
    -- 'Flat' builtins are always concrete. They can unify with themselves,
    -- unifiables or wildcards, but nothing else. No substitutional rewrites are
    -- needed.
    expectFlatBuiltin :: BuiltinFlatT
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
expectFlatBuiltin BuiltinFlatT
t1 = case ValT Renamed
actual of
      Abstraction (Rigid Int
_ Index "tyvar"
_) -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall a. Either TypeAppError a
unificationError
      Abstraction Renamed
_ -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall k a. Either TypeAppError (Map k a)
noSubUnify
      BuiltinFlat BuiltinFlatT
t2 ->
        if BuiltinFlatT
t1 BuiltinFlatT -> BuiltinFlatT -> Bool
forall a. Eq a => a -> a -> Bool
== BuiltinFlatT
t2
          then Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall k a. Either TypeAppError (Map k a)
noSubUnify
          else Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall a. Either TypeAppError a
unificationError
      ValT Renamed
_ -> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall a. Either TypeAppError a
unificationError
    reconcile ::
      Map (Index "tyvar") (ValT Renamed) ->
      Map (Index "tyvar") (ValT Renamed) ->
      Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
    -- Note (Koz, 14/04/2025): This utter soup means the following:
    --
    -- - If the old map and the new map don't have any overlapping assignments,
    --   just union them.
    -- - Otherwise, for any assignment to a unifiable that is present in both
    --   maps, ensure they assign to the same thing; if they do, it's fine,
    --   otherwise we have a problem.
    reconcile :: Map (Index "tyvar") (ValT Renamed)
-> Map (Index "tyvar") (ValT Renamed)
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
reconcile =
      WhenMissing
  (Either TypeAppError) (Index "tyvar") (ValT Renamed) (ValT Renamed)
-> WhenMissing
     (Either TypeAppError) (Index "tyvar") (ValT Renamed) (ValT Renamed)
-> WhenMatched
     (Either TypeAppError)
     (Index "tyvar")
     (ValT Renamed)
     (ValT Renamed)
     (ValT Renamed)
-> Map (Index "tyvar") (ValT Renamed)
-> Map (Index "tyvar") (ValT Renamed)
-> Either TypeAppError (Map (Index "tyvar") (ValT Renamed))
forall (f :: Type -> Type) k a c b.
(Applicative f, Ord k) =>
WhenMissing f k a c
-> WhenMissing f k b c
-> WhenMatched f k a b c
-> Map k a
-> Map k b
-> f (Map k c)
Merge.mergeA
        WhenMissing
  (Either TypeAppError) (Index "tyvar") (ValT Renamed) (ValT Renamed)
forall (f :: Type -> Type) k x.
Applicative f =>
WhenMissing f k x x
Merge.preserveMissing
        WhenMissing
  (Either TypeAppError) (Index "tyvar") (ValT Renamed) (ValT Renamed)
forall (f :: Type -> Type) k x.
Applicative f =>
WhenMissing f k x x
Merge.preserveMissing
        ((Index "tyvar"
 -> ValT Renamed
 -> ValT Renamed
 -> Either TypeAppError (ValT Renamed))
-> WhenMatched
     (Either TypeAppError)
     (Index "tyvar")
     (ValT Renamed)
     (ValT Renamed)
     (ValT Renamed)
forall (f :: Type -> Type) k x y z.
Applicative f =>
(k -> x -> y -> f z) -> WhenMatched f k x y z
Merge.zipWithAMatched ((Index "tyvar"
  -> ValT Renamed
  -> ValT Renamed
  -> Either TypeAppError (ValT Renamed))
 -> WhenMatched
      (Either TypeAppError)
      (Index "tyvar")
      (ValT Renamed)
      (ValT Renamed)
      (ValT Renamed))
-> (Index "tyvar"
    -> ValT Renamed
    -> ValT Renamed
    -> Either TypeAppError (ValT Renamed))
-> WhenMatched
     (Either TypeAppError)
     (Index "tyvar")
     (ValT Renamed)
     (ValT Renamed)
     (ValT Renamed)
forall a b. (a -> b) -> a -> b
$ \Index "tyvar"
_ ValT Renamed
l ValT Renamed
r -> ValT Renamed
l ValT Renamed
-> Either TypeAppError () -> Either TypeAppError (ValT Renamed)
forall a b. a -> Either TypeAppError b -> Either TypeAppError a
forall (f :: Type -> Type) a b. Functor f => a -> f b -> f a
<$ Bool -> Either TypeAppError () -> Either TypeAppError ()
forall (f :: Type -> Type). Applicative f => Bool -> f () -> f ()
unless (ValT Renamed
l ValT Renamed -> ValT Renamed -> Bool
forall a. Eq a => a -> a -> Bool
== ValT Renamed
r) Either TypeAppError ()
forall a. Either TypeAppError a
unificationError)