{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE UndecidableInstances #-}

-- |
-- SPDX-License-Identifier: BSD-3-Clause
--
-- Description: Fast yet purely functional implementation of
-- unification, using a map as a lazy substitution, i.e. a
-- manually-maintained "functional shared memory".
--
-- See Dijkstra, Middelkoop, & Swierstra, "Efficient Functional
-- Unification and Substitution", Utrecht University tech report
-- UU-CS-2008-027 (section 5) for the basic idea, and Peyton Jones et
-- al, "Practical type inference for arbitrary-rank types"
-- (pp. 74--75) for a correct implementation of unification via
-- references.
module Swarm.Effect.Unify.Fast where

import Control.Algebra
import Control.Applicative (Alternative)
import Control.Carrier.Accum.Strict (AccumC, runAccum)
import Control.Carrier.Reader (ReaderC, runReader)
import Control.Carrier.State.Strict (StateC, evalState)
import Control.Carrier.Throw.Either (ThrowC, runThrow)
import Control.Category ((>>>))
import Control.Effect.Accum (Accum, add)
import Control.Effect.Reader (Reader, ask, local)
import Control.Effect.State (State, get, gets, modify)
import Control.Effect.Throw (Throw, throwError)
import Control.Monad (zipWithM)
import Control.Monad.Free
import Control.Monad.Trans (MonadIO)
import Data.Function (on)
import Data.Functor.Identity
import Data.Map qualified as M
import Data.Map.Merge.Lazy qualified as M
import Data.Monoid (First (..))
import Data.Set (Set)
import Data.Set qualified as S
import Swarm.Effect.Unify
import Swarm.Effect.Unify.Common
import Swarm.Language.Types hiding (Type)
import Swarm.Util.Effect (withThrow)
import Prelude hiding (lookup)

------------------------------------------------------------
-- Substitutions

-- | Compose two substitutions.  Applying @s1 \@\@ s2@ is the same as
--   applying first @s2@, then @s1@; that is, semantically,
--   composition of substitutions corresponds exactly to function
--   composition when they are considered as functions on terms.
--
--   As one would expect, composition is associative and has 'idS' as
--   its identity.
--
--   Note that we do /not/ apply @s1@ to all the values in @s2@, since
--   the substitution is maintained lazily; we do not need to maintain
--   the invariant that values in the mapping do not contain any of
--   the keys.  This makes composition much faster, at the cost of
--   making application more complex.
(@@) :: (Ord n, Substitutes n a a) => Subst n a -> Subst n a -> Subst n a
(Subst Map n a
s1) @@ :: forall n a.
(Ord n, Substitutes n a a) =>
Subst n a -> Subst n a -> Subst n a
@@ (Subst Map n a
s2) = Map n a -> Subst n a
forall n a. Map n a -> Subst n a
Subst (Map n a
s2 Map n a -> Map n a -> Map n a
forall k a. Ord k => Map k a -> Map k a -> Map k a
`M.union` Map n a
s1)

-- | Class of things supporting substitution.  @Substitutes n b a@ means
--   that we can apply a substitution of type @Subst n b@ to a
--   value of type @a@, replacing all the free names of type @n@
--   inside the @a@ with values of type @b@, resulting in a new value
--   of type @a@.
--
--   We also do a lazy occurs-check during substitution application,
--   so we need the ability to throw a unification error.
class Substitutes n b a where
  subst :: Has (Throw UnificationError) sig m => Subst n b -> a -> m a

-- | We can perform substitution on terms built up as the free monad
--   over a structure functor @f@.
instance Substitutes IntVar UType UType where
  subst :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw UnificationError) sig m =>
Subst IntVar UType -> UType -> m UType
subst Subst IntVar UType
s UType
u = case ReaderC (Set IntVar) (AccumC (First IntVar) Identity) UType
-> (First IntVar, UType)
forall a.
ReaderC (Set IntVar) (AccumC (First IntVar) Identity) a
-> (First IntVar, a)
runSubst (UType
-> ReaderC (Set IntVar) (AccumC (First IntVar) Identity) UType
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Reader (Set IntVar)) sig m,
 Has (Accum (First IntVar)) sig m) =>
UType -> m UType
go UType
u) of
    -- If the substitution completed without encountering a repeated
    -- variable, just return the result.
    (First Maybe IntVar
Nothing, UType
u') -> UType -> m UType
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return UType
u'
    -- Otherwise, throw an error, but re-run substitution starting at
    -- the repeated variable to generate an expanded cyclic equality
    -- constraint of the form x = ... x ... .
    (First (Just IntVar
x), UType
_) ->
      UnificationError -> m UType
forall e (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw e) sig m =>
e -> m a
throwError (UnificationError -> m UType) -> UnificationError -> m UType
forall a b. (a -> b) -> a -> b
$
        IntVar -> UType -> UnificationError
Infinite IntVar
x ((First IntVar, UType) -> UType
forall a b. (a, b) -> b
snd (ReaderC (Set IntVar) (AccumC (First IntVar) Identity) UType
-> (First IntVar, UType)
forall a.
ReaderC (Set IntVar) (AccumC (First IntVar) Identity) a
-> (First IntVar, a)
runSubst (UType
-> ReaderC (Set IntVar) (AccumC (First IntVar) Identity) UType
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Reader (Set IntVar)) sig m,
 Has (Accum (First IntVar)) sig m) =>
UType -> m UType
go (UType
 -> ReaderC (Set IntVar) (AccumC (First IntVar) Identity) UType)
-> UType
-> ReaderC (Set IntVar) (AccumC (First IntVar) Identity) UType
forall a b. (a -> b) -> a -> b
$ IntVar -> UType
forall (f :: * -> *) a. a -> Free f a
Pure IntVar
x)))
   where
    runSubst :: ReaderC (Set IntVar) (AccumC (First IntVar) Identity) a -> (First IntVar, a)
    runSubst :: forall a.
ReaderC (Set IntVar) (AccumC (First IntVar) Identity) a
-> (First IntVar, a)
runSubst = Identity (First IntVar, a) -> (First IntVar, a)
forall a. Identity a -> a
run (Identity (First IntVar, a) -> (First IntVar, a))
-> (ReaderC (Set IntVar) (AccumC (First IntVar) Identity) a
    -> Identity (First IntVar, a))
-> ReaderC (Set IntVar) (AccumC (First IntVar) Identity) a
-> (First IntVar, a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. First IntVar
-> AccumC (First IntVar) Identity a -> Identity (First IntVar, a)
forall w (m :: * -> *) a. w -> AccumC w m a -> m (w, a)
runAccum (Maybe IntVar -> First IntVar
forall a. Maybe a -> First a
First Maybe IntVar
forall a. Maybe a
Nothing) (AccumC (First IntVar) Identity a -> Identity (First IntVar, a))
-> (ReaderC (Set IntVar) (AccumC (First IntVar) Identity) a
    -> AccumC (First IntVar) Identity a)
-> ReaderC (Set IntVar) (AccumC (First IntVar) Identity) a
-> Identity (First IntVar, a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Set IntVar
-> ReaderC (Set IntVar) (AccumC (First IntVar) Identity) a
-> AccumC (First IntVar) Identity a
forall r (m :: * -> *) a. r -> ReaderC r m a -> m a
runReader Set IntVar
forall a. Set a
S.empty

    -- A version of substitution that recurses through the term,
    -- keeping track of unification variables seen along the current
    -- path.  When it encounters a previously-seen variable, it simply
    -- returns it unchanged but notes the first such variable that was
    -- encountered.
    go ::
      (Has (Reader (Set IntVar)) sig m, Has (Accum (First IntVar)) sig m) =>
      UType ->
      m UType
    go :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Reader (Set IntVar)) sig m,
 Has (Accum (First IntVar)) sig m) =>
UType -> m UType
go (Pure IntVar
x) = case IntVar -> Subst IntVar UType -> Maybe UType
forall n a. Ord n => n -> Subst n a -> Maybe a
lookup IntVar
x Subst IntVar UType
s of
      Maybe UType
Nothing -> UType -> m UType
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (UType -> m UType) -> UType -> m UType
forall a b. (a -> b) -> a -> b
$ IntVar -> UType
forall (f :: * -> *) a. a -> Free f a
Pure IntVar
x
      Just UType
t -> do
        Set IntVar
seen <- m (Set IntVar)
forall r (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Reader r) sig m =>
m r
ask
        case IntVar -> Set IntVar -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member IntVar
x Set IntVar
seen of
          Bool
True -> First IntVar -> m ()
forall w (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Accum w) sig m =>
w -> m ()
add (Maybe IntVar -> First IntVar
forall a. Maybe a -> First a
First (IntVar -> Maybe IntVar
forall a. a -> Maybe a
Just IntVar
x)) m () -> m UType -> m UType
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> UType -> m UType
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IntVar -> UType
forall (f :: * -> *) a. a -> Free f a
Pure IntVar
x)
          Bool
False -> (Set IntVar -> Set IntVar) -> m UType -> m UType
forall r (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Reader r) sig m =>
(r -> r) -> m a -> m a
local (IntVar -> Set IntVar -> Set IntVar
forall a. Ord a => a -> Set a -> Set a
S.insert IntVar
x) (m UType -> m UType) -> m UType -> m UType
forall a b. (a -> b) -> a -> b
$ UType -> m UType
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Reader (Set IntVar)) sig m,
 Has (Accum (First IntVar)) sig m) =>
UType -> m UType
go UType
t
    go (Free TypeF UType
t) = TypeF UType -> UType
forall (f :: * -> *) a. f (Free f a) -> Free f a
Free (TypeF UType -> UType) -> m (TypeF UType) -> m UType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeF UType -> m (TypeF UType)
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Reader (Set IntVar)) sig m,
 Has (Accum (First IntVar)) sig m) =>
TypeF UType -> m (TypeF UType)
goF TypeF UType
t

    goF :: (Has (Reader (Set IntVar)) sig m, Has (Accum (First IntVar)) sig m) => TypeF UType -> m (TypeF UType)
    goF :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Reader (Set IntVar)) sig m,
 Has (Accum (First IntVar)) sig m) =>
TypeF UType -> m (TypeF UType)
goF (TyConF TyCon
c [UType]
ts) = TyCon -> [UType] -> TypeF UType
forall t. TyCon -> [t] -> TypeF t
TyConF TyCon
c ([UType] -> TypeF UType) -> m [UType] -> m (TypeF UType)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (UType -> m UType) -> [UType] -> m [UType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM UType -> m UType
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Reader (Set IntVar)) sig m,
 Has (Accum (First IntVar)) sig m) =>
UType -> m UType
go [UType]
ts
    goF t :: TypeF UType
t@(TyVarF {}) = TypeF UType -> m (TypeF UType)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TypeF UType
t
    goF (TyRcdF Map Var UType
m) = Map Var UType -> TypeF UType
forall t. Map Var t -> TypeF t
TyRcdF (Map Var UType -> TypeF UType)
-> m (Map Var UType) -> m (TypeF UType)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (UType -> m UType) -> Map Var UType -> m (Map Var UType)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> Map Var a -> m (Map Var b)
mapM UType -> m UType
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Reader (Set IntVar)) sig m,
 Has (Accum (First IntVar)) sig m) =>
UType -> m UType
go Map Var UType
m
    goF (TyRecF Var
x UType
t) = Var -> UType -> TypeF UType
forall t. Var -> t -> TypeF t
TyRecF Var
x (UType -> TypeF UType) -> m UType -> m (TypeF UType)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UType -> m UType
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Reader (Set IntVar)) sig m,
 Has (Accum (First IntVar)) sig m) =>
UType -> m UType
go UType
t
    goF t :: TypeF UType
t@(TyRecVarF Nat
_) = TypeF UType -> m (TypeF UType)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TypeF UType
t

------------------------------------------------------------
-- Carrier type

-- | Carrier type for unification: we maintain a current substitution,
--   a counter for generating fresh unification variables, and can
--   throw unification errors.
newtype UnificationC m a = UnificationC
  { forall (m :: * -> *) a.
UnificationC m a
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     a
unUnificationC ::
      StateC
        (Set (UType, UType))
        ( StateC
            (Subst IntVar UType)
            ( StateC
                FreshVarCounter
                (ThrowC UnificationError m)
            )
        )
        a
  }
  deriving newtype ((forall a b. (a -> b) -> UnificationC m a -> UnificationC m b)
-> (forall a b. a -> UnificationC m b -> UnificationC m a)
-> Functor (UnificationC m)
forall a b. a -> UnificationC m b -> UnificationC m a
forall a b. (a -> b) -> UnificationC m a -> UnificationC m b
forall (m :: * -> *) a b.
Functor m =>
a -> UnificationC m b -> UnificationC m a
forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> UnificationC m a -> UnificationC m b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall (m :: * -> *) a b.
Functor m =>
(a -> b) -> UnificationC m a -> UnificationC m b
fmap :: forall a b. (a -> b) -> UnificationC m a -> UnificationC m b
$c<$ :: forall (m :: * -> *) a b.
Functor m =>
a -> UnificationC m b -> UnificationC m a
<$ :: forall a b. a -> UnificationC m b -> UnificationC m a
Functor, Functor (UnificationC m)
Functor (UnificationC m) =>
(forall a. a -> UnificationC m a)
-> (forall a b.
    UnificationC m (a -> b) -> UnificationC m a -> UnificationC m b)
-> (forall a b c.
    (a -> b -> c)
    -> UnificationC m a -> UnificationC m b -> UnificationC m c)
-> (forall a b.
    UnificationC m a -> UnificationC m b -> UnificationC m b)
-> (forall a b.
    UnificationC m a -> UnificationC m b -> UnificationC m a)
-> Applicative (UnificationC m)
forall a. a -> UnificationC m a
forall a b.
UnificationC m a -> UnificationC m b -> UnificationC m a
forall a b.
UnificationC m a -> UnificationC m b -> UnificationC m b
forall a b.
UnificationC m (a -> b) -> UnificationC m a -> UnificationC m b
forall a b c.
(a -> b -> c)
-> UnificationC m a -> UnificationC m b -> UnificationC m c
forall (m :: * -> *). Monad m => Functor (UnificationC m)
forall (m :: * -> *) a. Monad m => a -> UnificationC m a
forall (m :: * -> *) a b.
Monad m =>
UnificationC m a -> UnificationC m b -> UnificationC m a
forall (m :: * -> *) a b.
Monad m =>
UnificationC m a -> UnificationC m b -> UnificationC m b
forall (m :: * -> *) a b.
Monad m =>
UnificationC m (a -> b) -> UnificationC m a -> UnificationC m b
forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> UnificationC m a -> UnificationC m b -> UnificationC m c
forall (f :: * -> *).
Functor f =>
(forall a. a -> f a)
-> (forall a b. f (a -> b) -> f a -> f b)
-> (forall a b c. (a -> b -> c) -> f a -> f b -> f c)
-> (forall a b. f a -> f b -> f b)
-> (forall a b. f a -> f b -> f a)
-> Applicative f
$cpure :: forall (m :: * -> *) a. Monad m => a -> UnificationC m a
pure :: forall a. a -> UnificationC m a
$c<*> :: forall (m :: * -> *) a b.
Monad m =>
UnificationC m (a -> b) -> UnificationC m a -> UnificationC m b
<*> :: forall a b.
UnificationC m (a -> b) -> UnificationC m a -> UnificationC m b
$cliftA2 :: forall (m :: * -> *) a b c.
Monad m =>
(a -> b -> c)
-> UnificationC m a -> UnificationC m b -> UnificationC m c
liftA2 :: forall a b c.
(a -> b -> c)
-> UnificationC m a -> UnificationC m b -> UnificationC m c
$c*> :: forall (m :: * -> *) a b.
Monad m =>
UnificationC m a -> UnificationC m b -> UnificationC m b
*> :: forall a b.
UnificationC m a -> UnificationC m b -> UnificationC m b
$c<* :: forall (m :: * -> *) a b.
Monad m =>
UnificationC m a -> UnificationC m b -> UnificationC m a
<* :: forall a b.
UnificationC m a -> UnificationC m b -> UnificationC m a
Applicative, Applicative (UnificationC m)
Applicative (UnificationC m) =>
(forall a. UnificationC m a)
-> (forall a.
    UnificationC m a -> UnificationC m a -> UnificationC m a)
-> (forall a. UnificationC m a -> UnificationC m [a])
-> (forall a. UnificationC m a -> UnificationC m [a])
-> Alternative (UnificationC m)
forall a. UnificationC m a
forall a. UnificationC m a -> UnificationC m [a]
forall a. UnificationC m a -> UnificationC m a -> UnificationC m a
forall (f :: * -> *).
Applicative f =>
(forall a. f a)
-> (forall a. f a -> f a -> f a)
-> (forall a. f a -> f [a])
-> (forall a. f a -> f [a])
-> Alternative f
forall (m :: * -> *).
(Alternative m, Monad m) =>
Applicative (UnificationC m)
forall (m :: * -> *) a.
(Alternative m, Monad m) =>
UnificationC m a
forall (m :: * -> *) a.
(Alternative m, Monad m) =>
UnificationC m a -> UnificationC m [a]
forall (m :: * -> *) a.
(Alternative m, Monad m) =>
UnificationC m a -> UnificationC m a -> UnificationC m a
$cempty :: forall (m :: * -> *) a.
(Alternative m, Monad m) =>
UnificationC m a
empty :: forall a. UnificationC m a
$c<|> :: forall (m :: * -> *) a.
(Alternative m, Monad m) =>
UnificationC m a -> UnificationC m a -> UnificationC m a
<|> :: forall a. UnificationC m a -> UnificationC m a -> UnificationC m a
$csome :: forall (m :: * -> *) a.
(Alternative m, Monad m) =>
UnificationC m a -> UnificationC m [a]
some :: forall a. UnificationC m a -> UnificationC m [a]
$cmany :: forall (m :: * -> *) a.
(Alternative m, Monad m) =>
UnificationC m a -> UnificationC m [a]
many :: forall a. UnificationC m a -> UnificationC m [a]
Alternative, Applicative (UnificationC m)
Applicative (UnificationC m) =>
(forall a b.
 UnificationC m a -> (a -> UnificationC m b) -> UnificationC m b)
-> (forall a b.
    UnificationC m a -> UnificationC m b -> UnificationC m b)
-> (forall a. a -> UnificationC m a)
-> Monad (UnificationC m)
forall a. a -> UnificationC m a
forall a b.
UnificationC m a -> UnificationC m b -> UnificationC m b
forall a b.
UnificationC m a -> (a -> UnificationC m b) -> UnificationC m b
forall (m :: * -> *). Monad m => Applicative (UnificationC m)
forall (m :: * -> *) a. Monad m => a -> UnificationC m a
forall (m :: * -> *) a b.
Monad m =>
UnificationC m a -> UnificationC m b -> UnificationC m b
forall (m :: * -> *) a b.
Monad m =>
UnificationC m a -> (a -> UnificationC m b) -> UnificationC m b
forall (m :: * -> *).
Applicative m =>
(forall a b. m a -> (a -> m b) -> m b)
-> (forall a b. m a -> m b -> m b)
-> (forall a. a -> m a)
-> Monad m
$c>>= :: forall (m :: * -> *) a b.
Monad m =>
UnificationC m a -> (a -> UnificationC m b) -> UnificationC m b
>>= :: forall a b.
UnificationC m a -> (a -> UnificationC m b) -> UnificationC m b
$c>> :: forall (m :: * -> *) a b.
Monad m =>
UnificationC m a -> UnificationC m b -> UnificationC m b
>> :: forall a b.
UnificationC m a -> UnificationC m b -> UnificationC m b
$creturn :: forall (m :: * -> *) a. Monad m => a -> UnificationC m a
return :: forall a. a -> UnificationC m a
Monad, Monad (UnificationC m)
Monad (UnificationC m) =>
(forall a. IO a -> UnificationC m a) -> MonadIO (UnificationC m)
forall a. IO a -> UnificationC m a
forall (m :: * -> *).
Monad m =>
(forall a. IO a -> m a) -> MonadIO m
forall (m :: * -> *). MonadIO m => Monad (UnificationC m)
forall (m :: * -> *) a. MonadIO m => IO a -> UnificationC m a
$cliftIO :: forall (m :: * -> *) a. MonadIO m => IO a -> UnificationC m a
liftIO :: forall a. IO a -> UnificationC m a
MonadIO)

-- | Counter for generating fresh unification variables.
newtype FreshVarCounter = FreshVarCounter {FreshVarCounter -> Int
getFreshVarCounter :: Int}
  deriving (FreshVarCounter -> FreshVarCounter -> Bool
(FreshVarCounter -> FreshVarCounter -> Bool)
-> (FreshVarCounter -> FreshVarCounter -> Bool)
-> Eq FreshVarCounter
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: FreshVarCounter -> FreshVarCounter -> Bool
== :: FreshVarCounter -> FreshVarCounter -> Bool
$c/= :: FreshVarCounter -> FreshVarCounter -> Bool
/= :: FreshVarCounter -> FreshVarCounter -> Bool
Eq, Eq FreshVarCounter
Eq FreshVarCounter =>
(FreshVarCounter -> FreshVarCounter -> Ordering)
-> (FreshVarCounter -> FreshVarCounter -> Bool)
-> (FreshVarCounter -> FreshVarCounter -> Bool)
-> (FreshVarCounter -> FreshVarCounter -> Bool)
-> (FreshVarCounter -> FreshVarCounter -> Bool)
-> (FreshVarCounter -> FreshVarCounter -> FreshVarCounter)
-> (FreshVarCounter -> FreshVarCounter -> FreshVarCounter)
-> Ord FreshVarCounter
FreshVarCounter -> FreshVarCounter -> Bool
FreshVarCounter -> FreshVarCounter -> Ordering
FreshVarCounter -> FreshVarCounter -> FreshVarCounter
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: FreshVarCounter -> FreshVarCounter -> Ordering
compare :: FreshVarCounter -> FreshVarCounter -> Ordering
$c< :: FreshVarCounter -> FreshVarCounter -> Bool
< :: FreshVarCounter -> FreshVarCounter -> Bool
$c<= :: FreshVarCounter -> FreshVarCounter -> Bool
<= :: FreshVarCounter -> FreshVarCounter -> Bool
$c> :: FreshVarCounter -> FreshVarCounter -> Bool
> :: FreshVarCounter -> FreshVarCounter -> Bool
$c>= :: FreshVarCounter -> FreshVarCounter -> Bool
>= :: FreshVarCounter -> FreshVarCounter -> Bool
$cmax :: FreshVarCounter -> FreshVarCounter -> FreshVarCounter
max :: FreshVarCounter -> FreshVarCounter -> FreshVarCounter
$cmin :: FreshVarCounter -> FreshVarCounter -> FreshVarCounter
min :: FreshVarCounter -> FreshVarCounter -> FreshVarCounter
Ord, Int -> FreshVarCounter
FreshVarCounter -> Int
FreshVarCounter -> [FreshVarCounter]
FreshVarCounter -> FreshVarCounter
FreshVarCounter -> FreshVarCounter -> [FreshVarCounter]
FreshVarCounter
-> FreshVarCounter -> FreshVarCounter -> [FreshVarCounter]
(FreshVarCounter -> FreshVarCounter)
-> (FreshVarCounter -> FreshVarCounter)
-> (Int -> FreshVarCounter)
-> (FreshVarCounter -> Int)
-> (FreshVarCounter -> [FreshVarCounter])
-> (FreshVarCounter -> FreshVarCounter -> [FreshVarCounter])
-> (FreshVarCounter -> FreshVarCounter -> [FreshVarCounter])
-> (FreshVarCounter
    -> FreshVarCounter -> FreshVarCounter -> [FreshVarCounter])
-> Enum FreshVarCounter
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: FreshVarCounter -> FreshVarCounter
succ :: FreshVarCounter -> FreshVarCounter
$cpred :: FreshVarCounter -> FreshVarCounter
pred :: FreshVarCounter -> FreshVarCounter
$ctoEnum :: Int -> FreshVarCounter
toEnum :: Int -> FreshVarCounter
$cfromEnum :: FreshVarCounter -> Int
fromEnum :: FreshVarCounter -> Int
$cenumFrom :: FreshVarCounter -> [FreshVarCounter]
enumFrom :: FreshVarCounter -> [FreshVarCounter]
$cenumFromThen :: FreshVarCounter -> FreshVarCounter -> [FreshVarCounter]
enumFromThen :: FreshVarCounter -> FreshVarCounter -> [FreshVarCounter]
$cenumFromTo :: FreshVarCounter -> FreshVarCounter -> [FreshVarCounter]
enumFromTo :: FreshVarCounter -> FreshVarCounter -> [FreshVarCounter]
$cenumFromThenTo :: FreshVarCounter
-> FreshVarCounter -> FreshVarCounter -> [FreshVarCounter]
enumFromThenTo :: FreshVarCounter
-> FreshVarCounter -> FreshVarCounter -> [FreshVarCounter]
Enum)

-- | Run a 'Unification' effect via the 'UnificationC' carrier.  Note
--   that we also require an ambient @Reader 'TDCtx'@ effect, so unification
--   will be sure to pick up whatever type aliases happen to be in scope.
runUnification ::
  (Algebra sig m, Has (Reader TDCtx) sig m) =>
  UnificationC m a ->
  m (Either UnificationError a)
runUnification :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
(Algebra sig m, Has (Reader TDCtx) sig m) =>
UnificationC m a -> m (Either UnificationError a)
runUnification =
  UnificationC m a
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     a
forall (m :: * -> *) a.
UnificationC m a
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     a
unUnificationC
    (UnificationC m a
 -> StateC
      (Set (UType, UType))
      (StateC
         (Subst IntVar UType)
         (StateC FreshVarCounter (ThrowC UnificationError m)))
      a)
-> (StateC
      (Set (UType, UType))
      (StateC
         (Subst IntVar UType)
         (StateC FreshVarCounter (ThrowC UnificationError m)))
      a
    -> m (Either UnificationError a))
-> UnificationC m a
-> m (Either UnificationError a)
forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> Set (UType, UType)
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     a
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     a
forall s (m :: * -> *) a. Functor m => s -> StateC s m a -> m a
evalState Set (UType, UType)
forall a. Set a
S.empty
    (StateC
   (Set (UType, UType))
   (StateC
      (Subst IntVar UType)
      (StateC FreshVarCounter (ThrowC UnificationError m)))
   a
 -> StateC
      (Subst IntVar UType)
      (StateC FreshVarCounter (ThrowC UnificationError m))
      a)
-> (StateC
      (Subst IntVar UType)
      (StateC FreshVarCounter (ThrowC UnificationError m))
      a
    -> m (Either UnificationError a))
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     a
-> m (Either UnificationError a)
forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> Subst IntVar UType
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     a
-> StateC FreshVarCounter (ThrowC UnificationError m) a
forall s (m :: * -> *) a. Functor m => s -> StateC s m a -> m a
evalState Subst IntVar UType
forall n a. Subst n a
idS
    (StateC
   (Subst IntVar UType)
   (StateC FreshVarCounter (ThrowC UnificationError m))
   a
 -> StateC FreshVarCounter (ThrowC UnificationError m) a)
-> (StateC FreshVarCounter (ThrowC UnificationError m) a
    -> m (Either UnificationError a))
-> StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m))
     a
-> m (Either UnificationError a)
forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> FreshVarCounter
-> StateC FreshVarCounter (ThrowC UnificationError m) a
-> ThrowC UnificationError m a
forall s (m :: * -> *) a. Functor m => s -> StateC s m a -> m a
evalState (Int -> FreshVarCounter
FreshVarCounter Int
0)
    (StateC FreshVarCounter (ThrowC UnificationError m) a
 -> ThrowC UnificationError m a)
-> (ThrowC UnificationError m a -> m (Either UnificationError a))
-> StateC FreshVarCounter (ThrowC UnificationError m) a
-> m (Either UnificationError a)
forall {k} (cat :: k -> k -> *) (a :: k) (b :: k) (c :: k).
Category cat =>
cat a b -> cat b c -> cat a c
>>> ThrowC UnificationError m a -> m (Either UnificationError a)
forall e (m :: * -> *) a. ThrowC e m a -> m (Either e a)
runThrow

------------------------------------------------------------
-- Unification

-- The idea here (using an explicit substitution as a sort of
-- "functional shared memory", instead of directly using IORefs), is
-- based on Dijkstra et al. Unfortunately, their implementation of
-- unification is subtly wrong; fortunately, a single integration test
-- in the Swarm test suite failed, leading to discovering the bug.
-- The basic issue is that when unifying an equation between two
-- variables @x = y@, we must look up *both* to see whether they are
-- already mapped by the substitution (and if so, replace them by
-- their referent and keep recursing).  Dijkstra et al. only look up
-- @x@ and simply map @x |-> y@ if x is not in the substitution, but
-- this can lead to cycles where e.g. x is mapped to y, and later we
-- unify @y = x@ resulting in both @x |-> y@ and @y |-> x@ in the
-- substitution, which at best leads to a spurious infinite type
-- error, and at worst leads to infinite recursion in the unify function.
--
-- Peyton Jones et al. show how to do it correctly: when unifying x = y and
-- x is not mapped in the substitution, we must also look up y.

-- | Implementation of the 'Unification' effect in terms of the
--   'UnificationC' carrier.
instance
  (Algebra sig m, Has (Reader TDCtx) sig m) =>
  Algebra (Unification :+: sig) (UnificationC m)
  where
  alg :: forall (ctx :: * -> *) (n :: * -> *) a.
Functor ctx =>
Handler ctx n (UnificationC m)
-> (:+:) Unification sig n a -> ctx () -> UnificationC m (ctx a)
alg Handler ctx n (UnificationC m)
hdl (:+:) Unification sig n a
sig ctx ()
ctx = StateC
  (Set (UType, UType))
  (StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m)))
  (ctx a)
-> UnificationC m (ctx a)
forall (m :: * -> *) a.
StateC
  (Set (UType, UType))
  (StateC
     (Subst IntVar UType)
     (StateC FreshVarCounter (ThrowC UnificationError m)))
  a
-> UnificationC m a
UnificationC (StateC
   (Set (UType, UType))
   (StateC
      (Subst IntVar UType)
      (StateC FreshVarCounter (ThrowC UnificationError m)))
   (ctx a)
 -> UnificationC m (ctx a))
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     (ctx a)
-> UnificationC m (ctx a)
forall a b. (a -> b) -> a -> b
$ case (:+:) Unification sig n a
sig of
    L (Unify UType
t1 UType
t2) -> (a -> ctx () -> ctx a
forall a b. a -> ctx b -> ctx a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ ctx ()
ctx) (a -> ctx a)
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     a
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     (ctx a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ThrowC
  UnificationError
  (StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m))))
  UType
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     (Either UnificationError UType)
forall e (m :: * -> *) a. ThrowC e m a -> m (Either e a)
runThrow (UType
-> UType
-> ThrowC
     UnificationError
     (StateC
        (Set (UType, UType))
        (StateC
           (Subst IntVar UType)
           (StateC FreshVarCounter (ThrowC UnificationError m))))
     UType
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Throw UnificationError) sig m, Has (Reader TDCtx) sig m,
 Has (State (Subst IntVar UType)) sig m,
 Has (State (Set (UType, UType))) sig m) =>
UType -> UType -> m UType
unify UType
t1 UType
t2)
    L (ApplyBindings UType
t) -> do
      Subst IntVar UType
s <- forall s (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (State s) sig m =>
m s
get @(Subst IntVar UType)
      (a -> ctx () -> ctx a
forall a b. a -> ctx b -> ctx a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ ctx ()
ctx) (a -> ctx a)
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     a
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     (ctx a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Subst IntVar UType
-> a
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     a
forall n b a (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Substitutes n b a, Has (Throw UnificationError) sig m) =>
Subst n b -> a -> m a
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw UnificationError) sig m =>
Subst IntVar UType -> a -> m a
subst Subst IntVar UType
s a
UType
t
    L Unification n a
FreshIntVar -> do
      IntVar
v <- Int -> IntVar
IntVar (Int -> IntVar)
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     Int
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     IntVar
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (FreshVarCounter -> Int)
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     Int
forall s (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (State s) sig m =>
(s -> a) -> m a
gets FreshVarCounter -> Int
getFreshVarCounter
      forall s (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (State s) sig m =>
(s -> s) -> m ()
modify @FreshVarCounter FreshVarCounter -> FreshVarCounter
forall a. Enum a => a -> a
succ
      ctx a
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     (ctx a)
forall a.
a
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     a
forall (m :: * -> *) a. Monad m => a -> m a
return (ctx a
 -> StateC
      (Set (UType, UType))
      (StateC
         (Subst IntVar UType)
         (StateC FreshVarCounter (ThrowC UnificationError m)))
      (ctx a))
-> ctx a
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     (ctx a)
forall a b. (a -> b) -> a -> b
$ a
IntVar
v a -> ctx () -> ctx a
forall a b. a -> ctx b -> ctx a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ ctx ()
ctx
    L (FreeUVars UType
t) -> do
      Subst IntVar UType
s <- forall s (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (State s) sig m =>
m s
get @(Subst IntVar UType)
      (a -> ctx () -> ctx a
forall a b. a -> ctx b -> ctx a
forall (f :: * -> *) a b. Functor f => a -> f b -> f a
<$ ctx ()
ctx) (a -> ctx a) -> (UType -> a) -> UType -> ctx a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UType -> a
UType -> Set IntVar
fuvs (UType -> ctx a)
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     UType
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     (ctx a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Subst IntVar UType
-> UType
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     UType
forall n b a (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Substitutes n b a, Has (Throw UnificationError) sig m) =>
Subst n b -> a -> m a
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (Throw UnificationError) sig m =>
Subst IntVar UType -> UType -> m UType
subst Subst IntVar UType
s UType
t
    R sig n a
other -> Handler
  ctx
  n
  (StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m))))
-> (:+:)
     (State (Set (UType, UType)))
     (State (Subst IntVar UType)
      :+: (State FreshVarCounter :+: (Throw UnificationError :+: sig)))
     n
     a
-> ctx ()
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     (ctx a)
forall (ctx :: * -> *) (n :: * -> *) a.
Functor ctx =>
Handler
  ctx
  n
  (StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m))))
-> (:+:)
     (State (Set (UType, UType)))
     (State (Subst IntVar UType)
      :+: (State FreshVarCounter :+: (Throw UnificationError :+: sig)))
     n
     a
-> ctx ()
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     (ctx a)
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) (ctx :: * -> *)
       (n :: * -> *) a.
(Algebra sig m, Functor ctx) =>
Handler ctx n m -> sig n a -> ctx () -> m (ctx a)
alg (UnificationC m (ctx x)
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     (ctx x)
forall (m :: * -> *) a.
UnificationC m a
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     a
unUnificationC (UnificationC m (ctx x)
 -> StateC
      (Set (UType, UType))
      (StateC
         (Subst IntVar UType)
         (StateC FreshVarCounter (ThrowC UnificationError m)))
      (ctx x))
-> (ctx (n x) -> UnificationC m (ctx x))
-> ctx (n x)
-> StateC
     (Set (UType, UType))
     (StateC
        (Subst IntVar UType)
        (StateC FreshVarCounter (ThrowC UnificationError m)))
     (ctx x)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ctx (n x) -> UnificationC m (ctx x)
Handler ctx n (UnificationC m)
hdl) ((:+:)
  (State (Subst IntVar UType))
  (State FreshVarCounter :+: (Throw UnificationError :+: sig))
  n
  a
-> (:+:)
     (State (Set (UType, UType)))
     (State (Subst IntVar UType)
      :+: (State FreshVarCounter :+: (Throw UnificationError :+: sig)))
     n
     a
forall (f :: (* -> *) -> * -> *) (g :: (* -> *) -> * -> *)
       (m :: * -> *) k.
g m k -> (:+:) f g m k
R ((:+:) (State FreshVarCounter) (Throw UnificationError :+: sig) n a
-> (:+:)
     (State (Subst IntVar UType))
     (State FreshVarCounter :+: (Throw UnificationError :+: sig))
     n
     a
forall (f :: (* -> *) -> * -> *) (g :: (* -> *) -> * -> *)
       (m :: * -> *) k.
g m k -> (:+:) f g m k
R ((:+:) (Throw UnificationError) sig n a
-> (:+:)
     (State FreshVarCounter) (Throw UnificationError :+: sig) n a
forall (f :: (* -> *) -> * -> *) (g :: (* -> *) -> * -> *)
       (m :: * -> *) k.
g m k -> (:+:) f g m k
R (sig n a -> (:+:) (Throw UnificationError) sig n a
forall (f :: (* -> *) -> * -> *) (g :: (* -> *) -> * -> *)
       (m :: * -> *) k.
g m k -> (:+:) f g m k
R sig n a
other)))) ctx ()
ctx

-- | Unify two types, returning a unified type equal to both.  Note
--   that for efficiency we /don't/ do an occurs check here, but
--   instead lazily during substitution.
--
--   We keep track of a set of pairs of types we have seen; if we ever
--   see a pair a second time we simply assume they are equal without
--   recursing further.  This constitutes a finite (coinductive)
--   algorithm for doing unification on recursive types.
--
--   For example, suppose we wanted to unify @rec s. Unit + Unit + s@
--   and @rec t. Unit + t@.  These types are actually equal since
--   their infinite unfoldings are both @Unit + Unit + Unit + ...@ In
--   practice we would proceed through the following recursive calls
--   to unify:
--
--   @
--     (rec s. Unit + Unit + s)                 =:= (rec t. Unit + t)
--         { unfold the LHS }
--     (Unit + Unit + (rec s. Unit + Unit + s)) =:= (rec t. Unit + t)
--         { unfold the RHS }
--     (Unit + Unit + (rec s. Unit + Unit + s)) =:= (Unit + (rec t. Unit + t)
--         { unifyF matches the + and makes two calls to unify }
--     Unit =:= Unit   { trivial}
--     (Unit + (rec s. Unit + Unit + s))        =:= (rec t. Unit + t)
--         { unfold the RHS }
--     (Unit + (rec s. Unit + Unit + s))        =:= (Unit + (rec t. Unit + t))
--         { unifyF on + }
--     (rec s. Unit + Unit + s)                 =:= (rec t. Unit + t)
--         { back to the starting pair, return success }
--   @
unify ::
  ( Has (Throw UnificationError) sig m
  , Has (Reader TDCtx) sig m
  , Has (State (Subst IntVar UType)) sig m
  , Has (State (Set (UType, UType))) sig m
  ) =>
  UType ->
  UType ->
  m UType
unify :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Throw UnificationError) sig m, Has (Reader TDCtx) sig m,
 Has (State (Subst IntVar UType)) sig m,
 Has (State (Set (UType, UType))) sig m) =>
UType -> UType -> m UType
unify UType
ty1 UType
ty2 = do
  Set (UType, UType)
seen <- forall s (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (State s) sig m =>
m s
get @(Set (UType, UType))
  case (UType, UType) -> Set (UType, UType) -> Bool
forall a. Ord a => a -> Set a -> Bool
S.member (UType
ty1, UType
ty2) Set (UType, UType)
seen of
    Bool
True -> UType -> m UType
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return UType
ty1
    Bool
False -> do
      (Set (UType, UType) -> Set (UType, UType)) -> m ()
forall s (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (State s) sig m =>
(s -> s) -> m ()
modify ((UType, UType) -> Set (UType, UType) -> Set (UType, UType)
forall a. Ord a => a -> Set a -> Set a
S.insert (UType
ty1, UType
ty2))
      case (UType
ty1, UType
ty2) of
        (Pure IntVar
x, Pure IntVar
y) | IntVar
x IntVar -> IntVar -> Bool
forall a. Eq a => a -> a -> Bool
== IntVar
y -> UType -> m UType
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IntVar -> UType
forall (f :: * -> *) a. a -> Free f a
Pure IntVar
x)
        (Pure IntVar
x, UType
y) -> do
          Maybe UType
mxv <- IntVar -> m (Maybe UType)
forall n a (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Ord n, Has (State (Subst n a)) sig m) =>
n -> m (Maybe a)
lookupS IntVar
x
          case Maybe UType
mxv of
            Maybe UType
Nothing -> IntVar -> UType -> m UType
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Throw UnificationError) sig m, Has (Reader TDCtx) sig m,
 Has (State (Subst IntVar UType)) sig m,
 Has (State (Set (UType, UType))) sig m) =>
IntVar -> UType -> m UType
unifyVar IntVar
x UType
y
            Just UType
xv -> UType -> UType -> m UType
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Throw UnificationError) sig m, Has (Reader TDCtx) sig m,
 Has (State (Subst IntVar UType)) sig m,
 Has (State (Set (UType, UType))) sig m) =>
UType -> UType -> m UType
unify UType
xv UType
y
        (UType
x, Pure IntVar
y) -> UType -> UType -> m UType
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Throw UnificationError) sig m, Has (Reader TDCtx) sig m,
 Has (State (Subst IntVar UType)) sig m,
 Has (State (Set (UType, UType))) sig m) =>
UType -> UType -> m UType
unify (IntVar -> UType
forall (f :: * -> *) a. a -> Free f a
Pure IntVar
y) UType
x
        (UTyRec Var
x UType
ty, UType
_) -> UType -> UType -> m UType
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Throw UnificationError) sig m, Has (Reader TDCtx) sig m,
 Has (State (Subst IntVar UType)) sig m,
 Has (State (Set (UType, UType))) sig m) =>
UType -> UType -> m UType
unify (Var -> UType -> UType
forall t. SubstRec t => Var -> t -> t
unfoldRec Var
x UType
ty) UType
ty2
        (UType
_, UTyRec Var
x UType
ty) -> UType -> UType -> m UType
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Throw UnificationError) sig m, Has (Reader TDCtx) sig m,
 Has (State (Subst IntVar UType)) sig m,
 Has (State (Set (UType, UType))) sig m) =>
UType -> UType -> m UType
unify UType
ty1 (Var -> UType -> UType
forall t. SubstRec t => Var -> t -> t
unfoldRec Var
x UType
ty)
        (UTyUser TDVar
x1 [UType]
tys, UType
_) -> do
          UType
ty1' <-
            (ExpandTydefErr -> UnificationError)
-> ThrowC ExpandTydefErr m UType -> m UType
forall e2 (sig :: (* -> *) -> * -> *) (m :: * -> *) e1 a.
Has (Throw e2) sig m =>
(e1 -> e2) -> ThrowC e1 m a -> m a
withThrow
              (\(UnexpandedUserType TDVar
_) -> UType -> UnificationError
UndefinedUserType (TDVar -> [UType] -> UType
UTyUser TDVar
x1 [UType]
tys))
              (TDVar -> [UType] -> ThrowC ExpandTydefErr m UType
forall (sig :: (* -> *) -> * -> *) (m :: * -> *) t.
(Has (Reader TDCtx) sig m, Has (Throw ExpandTydefErr) sig m,
 Typical t) =>
TDVar -> [t] -> m t
expandTydef TDVar
x1 [UType]
tys)
          UType -> UType -> m UType
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Throw UnificationError) sig m, Has (Reader TDCtx) sig m,
 Has (State (Subst IntVar UType)) sig m,
 Has (State (Set (UType, UType))) sig m) =>
UType -> UType -> m UType
unify UType
ty1' UType
ty2
        (UType
_, UTyUser {}) -> UType -> UType -> m UType
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Throw UnificationError) sig m, Has (Reader TDCtx) sig m,
 Has (State (Subst IntVar UType)) sig m,
 Has (State (Set (UType, UType))) sig m) =>
UType -> UType -> m UType
unify UType
ty2 UType
ty1
        (Free TypeF UType
t1, Free TypeF UType
t2) -> TypeF UType -> UType
forall (f :: * -> *) a. f (Free f a) -> Free f a
Free (TypeF UType -> UType) -> m (TypeF UType) -> m UType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> TypeF UType -> TypeF UType -> m (TypeF UType)
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Throw UnificationError) sig m, Has (Reader TDCtx) sig m,
 Has (State (Subst IntVar UType)) sig m,
 Has (State (Set (UType, UType))) sig m) =>
TypeF UType -> TypeF UType -> m (TypeF UType)
unifyF TypeF UType
t1 TypeF UType
t2

-- | Unify a unification variable which /is not/ bound by the current
--   substitution with another term.  If the other term is also a
--   variable, we must look it up as well to see if it is bound.
unifyVar ::
  ( Has (Throw UnificationError) sig m
  , Has (Reader TDCtx) sig m
  , Has (State (Subst IntVar UType)) sig m
  , Has (State (Set (UType, UType))) sig m
  ) =>
  IntVar ->
  UType ->
  m UType
unifyVar :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Throw UnificationError) sig m, Has (Reader TDCtx) sig m,
 Has (State (Subst IntVar UType)) sig m,
 Has (State (Set (UType, UType))) sig m) =>
IntVar -> UType -> m UType
unifyVar IntVar
x (Pure IntVar
y) = do
  Maybe UType
myv <- IntVar -> m (Maybe UType)
forall n a (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Ord n, Has (State (Subst n a)) sig m) =>
n -> m (Maybe a)
lookupS IntVar
y
  case Maybe UType
myv of
    -- x = y but the variable y is not bound: just add (x |-> y) to
    -- the current Subst
    --
    -- Note, as an optimization we just call e.g. insert x (Pure y)
    -- instead of building a singleton Subst with @(|->)@ and then
    -- composing, since composition doesn't need to apply the newly
    -- created binding to all the other values bound in the Subst.
    Maybe UType
Nothing -> forall s (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (State s) sig m =>
(s -> s) -> m ()
modify @(Subst IntVar UType) (IntVar -> UType -> Subst IntVar UType -> Subst IntVar UType
forall n a. Ord n => n -> a -> Subst n a -> Subst n a
insert IntVar
x (IntVar -> UType
forall (f :: * -> *) a. a -> Free f a
Pure IntVar
y)) m () -> m UType -> m UType
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> UType -> m UType
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (IntVar -> UType
forall (f :: * -> *) a. a -> Free f a
Pure IntVar
y)
    -- x = y  and y is bound to v: recurse on x = v.
    Just UType
yv -> UType -> UType -> m UType
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Throw UnificationError) sig m, Has (Reader TDCtx) sig m,
 Has (State (Subst IntVar UType)) sig m,
 Has (State (Set (UType, UType))) sig m) =>
UType -> UType -> m UType
unify (IntVar -> UType
forall (f :: * -> *) a. a -> Free f a
Pure IntVar
x) UType
yv

-- x = t for a non-variable t: just add (x |-> t) to the Subst.
unifyVar IntVar
x UType
t = (Subst IntVar UType -> Subst IntVar UType) -> m ()
forall s (sig :: (* -> *) -> * -> *) (m :: * -> *).
Has (State s) sig m =>
(s -> s) -> m ()
modify (IntVar -> UType -> Subst IntVar UType -> Subst IntVar UType
forall n a. Ord n => n -> a -> Subst n a -> Subst n a
insert IntVar
x UType
t) m () -> m UType -> m UType
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> UType -> m UType
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure UType
t

-- | Perform unification on two non-variable terms: check that they
--   have the same top-level constructor and recurse on their
--   contents.
unifyF ::
  ( Has (Throw UnificationError) sig m
  , Has (Reader TDCtx) sig m
  , Has (State (Subst IntVar UType)) sig m
  , Has (State (Set (UType, UType))) sig m
  ) =>
  TypeF UType ->
  TypeF UType ->
  m (TypeF UType)
unifyF :: forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Throw UnificationError) sig m, Has (Reader TDCtx) sig m,
 Has (State (Subst IntVar UType)) sig m,
 Has (State (Set (UType, UType))) sig m) =>
TypeF UType -> TypeF UType -> m (TypeF UType)
unifyF TypeF UType
t1 TypeF UType
t2 = case (TypeF UType
t1, TypeF UType
t2) of
  -- Recursive types are always expanded in 'unify', these first four cases
  -- should never happen.
  (TyRecF {}, TypeF UType
_) -> UnificationError -> m (TypeF UType)
forall e (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw e) sig m =>
e -> m a
throwError (UnificationError -> m (TypeF UType))
-> UnificationError -> m (TypeF UType)
forall a b. (a -> b) -> a -> b
$ TypeF UType -> UnificationError
UnexpandedRecTy TypeF UType
t1
  (TypeF UType
_, TyRecF {}) -> UnificationError -> m (TypeF UType)
forall e (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw e) sig m =>
e -> m a
throwError (UnificationError -> m (TypeF UType))
-> UnificationError -> m (TypeF UType)
forall a b. (a -> b) -> a -> b
$ TypeF UType -> UnificationError
UnexpandedRecTy TypeF UType
t2
  (TyRecVarF {}, TypeF UType
_) -> UnificationError -> m (TypeF UType)
forall e (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw e) sig m =>
e -> m a
throwError (UnificationError -> m (TypeF UType))
-> UnificationError -> m (TypeF UType)
forall a b. (a -> b) -> a -> b
$ TypeF UType -> UnificationError
UnexpandedRecTy TypeF UType
t1
  (TypeF UType
_, TyRecVarF {}) -> UnificationError -> m (TypeF UType)
forall e (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw e) sig m =>
e -> m a
throwError (UnificationError -> m (TypeF UType))
-> UnificationError -> m (TypeF UType)
forall a b. (a -> b) -> a -> b
$ TypeF UType -> UnificationError
UnexpandedRecTy TypeF UType
t2
  (TyConF TyCon
c1 [UType]
ts1, TyConF TyCon
c2 [UType]
ts2) -> case TyCon
c1 TyCon -> TyCon -> Bool
forall a. Eq a => a -> a -> Bool
== TyCon
c2 of
    Bool
True -> TyCon -> [UType] -> TypeF UType
forall t. TyCon -> [t] -> TypeF t
TyConF TyCon
c1 ([UType] -> TypeF UType) -> m [UType] -> m (TypeF UType)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (UType -> UType -> m UType) -> [UType] -> [UType] -> m [UType]
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m [c]
zipWithM UType -> UType -> m UType
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Throw UnificationError) sig m, Has (Reader TDCtx) sig m,
 Has (State (Subst IntVar UType)) sig m,
 Has (State (Set (UType, UType))) sig m) =>
UType -> UType -> m UType
unify [UType]
ts1 [UType]
ts2
    Bool
False -> m (TypeF UType)
unifyErr
  (TyConF {}, TypeF UType
_) -> m (TypeF UType)
unifyErr
  -- Note that *type variables* are not the same as *unification variables*.
  -- Type variables must match exactly.
  (TyVarF Var
_ Var
v1, TyVarF Var
_ Var
v2) -> case Var
v1 Var -> Var -> Bool
forall a. Eq a => a -> a -> Bool
== Var
v2 of
    Bool
True -> TypeF UType -> m (TypeF UType)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure TypeF UType
t1
    Bool
False -> m (TypeF UType)
unifyErr
  (TyVarF {}, TypeF UType
_) -> m (TypeF UType)
unifyErr
  (TyRcdF Map Var UType
m1, TyRcdF Map Var UType
m2) ->
    case (Set Var -> Set Var -> Bool
forall a. Eq a => a -> a -> Bool
(==) (Set Var -> Set Var -> Bool)
-> (Map Var UType -> Set Var)
-> Map Var UType
-> Map Var UType
-> Bool
forall b c a. (b -> b -> c) -> (a -> b) -> a -> a -> c
`on` Map Var UType -> Set Var
forall k a. Map k a -> Set k
M.keysSet) Map Var UType
m1 Map Var UType
m2 of
      Bool
False -> m (TypeF UType)
unifyErr
      Bool
_ -> (Map Var UType -> TypeF UType)
-> m (Map Var UType) -> m (TypeF UType)
forall a b. (a -> b) -> m a -> m b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Map Var UType -> TypeF UType
forall t. Map Var t -> TypeF t
TyRcdF (m (Map Var UType) -> m (TypeF UType))
-> (Map Var (m UType) -> m (Map Var UType))
-> Map Var (m UType)
-> m (TypeF UType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Map Var (m UType) -> m (Map Var UType)
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => Map Var (m a) -> m (Map Var a)
sequence (Map Var (m UType) -> m (TypeF UType))
-> Map Var (m UType) -> m (TypeF UType)
forall a b. (a -> b) -> a -> b
$ SimpleWhenMissing Var UType (m UType)
-> SimpleWhenMissing Var UType (m UType)
-> SimpleWhenMatched Var UType UType (m UType)
-> Map Var UType
-> Map Var UType
-> Map Var (m UType)
forall k a c b.
Ord k =>
SimpleWhenMissing k a c
-> SimpleWhenMissing k b c
-> SimpleWhenMatched k a b c
-> Map k a
-> Map k b
-> Map k c
M.merge SimpleWhenMissing Var UType (m UType)
forall (f :: * -> *) k x y. Applicative f => WhenMissing f k x y
M.dropMissing SimpleWhenMissing Var UType (m UType)
forall (f :: * -> *) k x y. Applicative f => WhenMissing f k x y
M.dropMissing ((Var -> UType -> UType -> m UType)
-> SimpleWhenMatched Var UType UType (m UType)
forall (f :: * -> *) k x y z.
Applicative f =>
(k -> x -> y -> z) -> WhenMatched f k x y z
M.zipWithMatched ((UType -> UType -> m UType) -> Var -> UType -> UType -> m UType
forall a b. a -> b -> a
const UType -> UType -> m UType
forall (sig :: (* -> *) -> * -> *) (m :: * -> *).
(Has (Throw UnificationError) sig m, Has (Reader TDCtx) sig m,
 Has (State (Subst IntVar UType)) sig m,
 Has (State (Set (UType, UType))) sig m) =>
UType -> UType -> m UType
unify)) Map Var UType
m1 Map Var UType
m2
  (TyRcdF {}, TypeF UType
_) -> m (TypeF UType)
unifyErr
 where
  unifyErr :: m (TypeF UType)
unifyErr = UnificationError -> m (TypeF UType)
forall e (sig :: (* -> *) -> * -> *) (m :: * -> *) a.
Has (Throw e) sig m =>
e -> m a
throwError (UnificationError -> m (TypeF UType))
-> UnificationError -> m (TypeF UType)
forall a b. (a -> b) -> a -> b
$ TypeF UType -> TypeF UType -> UnificationError
UnifyErr TypeF UType
t1 TypeF UType
t2