-- This file is part of the 'term-rewriting' library. It is licensed
-- under an MIT license. See the accompanying 'LICENSE' file for details.
--
-- Authors: Bertram Felgenhauer

{-# LANGUAGE FlexibleContexts #-}

module Data.Rewriting.Substitution.Unify (
    unify,
    unifyRef,
) where

import Data.Rewriting.Substitution.Type
import Data.Rewriting.Substitution.Ops (apply)
import qualified Data.Rewriting.Term.Ops as Term
import qualified Data.Rewriting.Term.Type as Term
import Data.Rewriting.Term.Type (Term (..))

import qualified Data.Map as M
import qualified Control.Monad.Union as UM
import qualified Data.Union as U
import Control.Monad.State
import Control.Monad.ST
import Control.Monad
import Control.Applicative
import Control.Arrow
import Data.Array.ST
import Data.Array
import Data.Maybe
import Data.Word

-- The setup is as follows:
--
-- We have a disjoint set forest, in which every node represents some
-- subterm of our unification problem. Each node is annotated by a
-- description of the term which may refer to other nodes. So we actually
-- have a graph, and an efficient implementation for joining nodes in
-- the graph, curtesy of the union find data structure. We also maintain
-- a map of variables encountered so far to their allocated node.

type UnifyM f v a = StateT (M.Map v U.Node) (UM.UnionM (Annot f v)) a

-- Each node can either represent
-- - a variable (in which case this is the only node representing that variable)
-- - an *expanded* function application with arguments represented by nodes,
-- - or a *pending* function application with normal terms as arguments,
--   not yet represented in the disjoint set forest.

data Annot f v = VarA v | FunA f [U.Node] | FunP f [Term f v]

-- Extract function symbol and arity from (non-variable) annotation.
funari :: Annot f v -> (f, Int)
funari :: forall f v. Annot f v -> (f, Int)
funari (FunA f
f [Node]
ns) = (f
f, [Node] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Node]
ns)
funari (FunP f
f [Term f v]
ts) = (f
f, [Term f v] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Term f v]
ts)

-- Solve a system of equations between terms that are represented by nodes.
solve :: (Eq f, Ord v) => [(U.Node, U.Node)] -> UnifyM f v Bool
solve :: forall f v. (Eq f, Ord v) => [(Node, Node)] -> UnifyM f v Bool
solve [] = Bool -> StateT (Map v Node) (UnionM (Annot f v)) Bool
forall a. a -> StateT (Map v Node) (UnionM (Annot f v)) a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
solve ((Node
t, Node
u) : [(Node, Node)]
xs) = do
    (Node
t, Annot f v
t') <- Node -> StateT (Map v Node) (UnionM (Annot f v)) (Node, Annot f v)
forall l (m :: * -> *). MonadUnion l m => Node -> m (Node, l)
UM.lookup Node
t
    (Node
u, Annot f v
u') <- Node -> StateT (Map v Node) (UnionM (Annot f v)) (Node, Annot f v)
forall l (m :: * -> *). MonadUnion l m => Node -> m (Node, l)
UM.lookup Node
u
    -- if t == u then the nodes are already equivalent.
    if Node
t Node -> Node -> Bool
forall a. Eq a => a -> a -> Bool
== Node
u then [(Node, Node)] -> StateT (Map v Node) (UnionM (Annot f v)) Bool
forall f v. (Eq f, Ord v) => [(Node, Node)] -> UnifyM f v Bool
solve [(Node, Node)]
xs else case (Annot f v
t', Annot f v
u') of
        (VarA v
_, Annot f v
_) -> do
            -- assign term to variable
            (Annot f v -> Annot f v -> (Annot f v, ()))
-> Node
-> Node
-> StateT (Map v Node) (UnionM (Annot f v)) (Maybe ())
forall a.
(Annot f v -> Annot f v -> (Annot f v, a))
-> Node
-> Node
-> StateT (Map v Node) (UnionM (Annot f v)) (Maybe a)
forall l (m :: * -> *) a.
MonadUnion l m =>
(l -> l -> (l, a)) -> Node -> Node -> m (Maybe a)
UM.merge (\Annot f v
_ Annot f v
_ -> (Annot f v
u', ())) Node
t Node
u
            [(Node, Node)] -> StateT (Map v Node) (UnionM (Annot f v)) Bool
forall f v. (Eq f, Ord v) => [(Node, Node)] -> UnifyM f v Bool
solve [(Node, Node)]
xs
        (Annot f v
_, VarA v
_) -> do
            -- assign term to variable
            (Annot f v -> Annot f v -> (Annot f v, ()))
-> Node
-> Node
-> StateT (Map v Node) (UnionM (Annot f v)) (Maybe ())
forall a.
(Annot f v -> Annot f v -> (Annot f v, a))
-> Node
-> Node
-> StateT (Map v Node) (UnionM (Annot f v)) (Maybe a)
forall l (m :: * -> *) a.
MonadUnion l m =>
(l -> l -> (l, a)) -> Node -> Node -> m (Maybe a)
UM.merge (\Annot f v
_ Annot f v
_ -> (Annot f v
t', ())) Node
t Node
u
            [(Node, Node)] -> StateT (Map v Node) (UnionM (Annot f v)) Bool
forall f v. (Eq f, Ord v) => [(Node, Node)] -> UnifyM f v Bool
solve [(Node, Node)]
xs
        (Annot f v, Annot f v)
_ | Annot f v -> (f, Int)
forall f v. Annot f v -> (f, Int)
funari Annot f v
t' (f, Int) -> (f, Int) -> Bool
forall a. Eq a => a -> a -> Bool
== Annot f v -> (f, Int)
forall f v. Annot f v -> (f, Int)
funari Annot f v
u' ->
            -- matching function applications: expand ...
            -- note: avoid `do` notation because `FunA _ ts` is a "failable"
            -- pattern and `UnionM` doesn't have a `MonadFail` instance;
            -- cf. https://wiki.haskell.org/MonadFail_Proposal
            Node -> Annot f v -> UnifyM f v (Annot f v)
forall v f. Ord v => Node -> Annot f v -> UnifyM f v (Annot f v)
expand Node
t Annot f v
t' UnifyM f v (Annot f v)
-> (Annot f v -> StateT (Map v Node) (UnionM (Annot f v)) Bool)
-> StateT (Map v Node) (UnionM (Annot f v)) Bool
forall a b.
StateT (Map v Node) (UnionM (Annot f v)) a
-> (a -> StateT (Map v Node) (UnionM (Annot f v)) b)
-> StateT (Map v Node) (UnionM (Annot f v)) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(FunA f
_ [Node]
ts) ->
            Node -> Annot f v -> UnifyM f v (Annot f v)
forall v f. Ord v => Node -> Annot f v -> UnifyM f v (Annot f v)
expand Node
u Annot f v
u' UnifyM f v (Annot f v)
-> (Annot f v -> StateT (Map v Node) (UnionM (Annot f v)) Bool)
-> StateT (Map v Node) (UnionM (Annot f v)) Bool
forall a b.
StateT (Map v Node) (UnionM (Annot f v)) a
-> (a -> StateT (Map v Node) (UnionM (Annot f v)) b)
-> StateT (Map v Node) (UnionM (Annot f v)) b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \(FunA f
_ [Node]
us) ->
            (Annot f v -> Annot f v -> (Annot f v, ()))
-> Node
-> Node
-> StateT (Map v Node) (UnionM (Annot f v)) (Maybe ())
forall a.
(Annot f v -> Annot f v -> (Annot f v, a))
-> Node
-> Node
-> StateT (Map v Node) (UnionM (Annot f v)) (Maybe a)
forall l (m :: * -> *) a.
MonadUnion l m =>
(l -> l -> (l, a)) -> Node -> Node -> m (Maybe a)
UM.merge (\Annot f v
t Annot f v
_ -> (Annot f v
t, ())) Node
t Node
u StateT (Map v Node) (UnionM (Annot f v)) (Maybe ())
-> StateT (Map v Node) (UnionM (Annot f v)) Bool
-> StateT (Map v Node) (UnionM (Annot f v)) Bool
forall a b.
StateT (Map v Node) (UnionM (Annot f v)) a
-> StateT (Map v Node) (UnionM (Annot f v)) b
-> StateT (Map v Node) (UnionM (Annot f v)) b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>>
            -- ... and equate the argument lists.
            [(Node, Node)] -> StateT (Map v Node) (UnionM (Annot f v)) Bool
forall f v. (Eq f, Ord v) => [(Node, Node)] -> UnifyM f v Bool
solve ([Node] -> [Node] -> [(Node, Node)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Node]
ts [Node]
us [(Node, Node)] -> [(Node, Node)] -> [(Node, Node)]
forall a. [a] -> [a] -> [a]
++ [(Node, Node)]
xs)
        (Annot f v, Annot f v)
_ -> do
            -- mismatch, fail.
            Bool -> StateT (Map v Node) (UnionM (Annot f v)) Bool
forall a. a -> StateT (Map v Node) (UnionM (Annot f v)) a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

-- Expand a node: If the node is currently a pending function application,
-- turn it into an expanded one.
-- The second argument must equal the current annotation of the node.
expand :: (Ord v) => U.Node -> Annot f v -> UnifyM f v (Annot f v)
expand :: forall v f. Ord v => Node -> Annot f v -> UnifyM f v (Annot f v)
expand Node
n (FunP f
f [Term f v]
ts) = do
    Annot f v
ann <- f -> [Node] -> Annot f v
forall f v. f -> [Node] -> Annot f v
FunA f
f ([Node] -> Annot f v)
-> StateT (Map v Node) (UnionM (Annot f v)) [Node]
-> UnifyM f v (Annot f v)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Term f v -> StateT (Map v Node) (UnionM (Annot f v)) Node)
-> [Term f v] -> StateT (Map v Node) (UnionM (Annot f v)) [Node]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Term f v -> StateT (Map v Node) (UnionM (Annot f v)) Node
forall v f. Ord v => Term f v -> UnifyM f v Node
mkNode [Term f v]
ts
    Node -> Annot f v -> StateT (Map v Node) (UnionM (Annot f v)) ()
forall l (m :: * -> *). MonadUnion l m => Node -> l -> m ()
UM.annotate Node
n Annot f v
ann
    Annot f v -> UnifyM f v (Annot f v)
forall a. a -> StateT (Map v Node) (UnionM (Annot f v)) a
forall (m :: * -> *) a. Monad m => a -> m a
return Annot f v
ann
expand Node
n Annot f v
ann = Annot f v -> UnifyM f v (Annot f v)
forall a. a -> StateT (Map v Node) (UnionM (Annot f v)) a
forall (m :: * -> *) a. Monad m => a -> m a
return Annot f v
ann

-- Create a new node representing a given term.
-- Variable nodes are shared whenever possible.
-- Function applications will be pending initially.
mkNode :: (Ord v) => Term f v -> UnifyM f v U.Node
mkNode :: forall v f. Ord v => Term f v -> UnifyM f v Node
mkNode (Var v
v) = do
    Maybe Node
n <- (Map v Node -> Maybe Node)
-> StateT (Map v Node) (UnionM (Annot f v)) (Maybe Node)
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets (v -> Map v Node -> Maybe Node
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup v
v)
    case Maybe Node
n of
        Just Node
n -> Node -> UnifyM f v Node
forall a. a -> StateT (Map v Node) (UnionM (Annot f v)) a
forall (m :: * -> *) a. Monad m => a -> m a
return Node
n
        Maybe Node
Nothing -> do
            Node
n <- Annot f v -> UnifyM f v Node
forall l (m :: * -> *). MonadUnion l m => l -> m Node
UM.new (v -> Annot f v
forall f v. v -> Annot f v
VarA v
v)
            (Map v Node -> Map v Node)
-> StateT (Map v Node) (UnionM (Annot f v)) ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (v -> Node -> Map v Node -> Map v Node
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert v
v Node
n)
            Node -> UnifyM f v Node
forall a. a -> StateT (Map v Node) (UnionM (Annot f v)) a
forall (m :: * -> *) a. Monad m => a -> m a
return Node
n
mkNode (Fun f
f [Term f v]
ts) = Annot f v -> UnifyM f v Node
forall l (m :: * -> *). MonadUnion l m => l -> m Node
UM.new (f -> [Term f v] -> Annot f v
forall f v. f -> [Term f v] -> Annot f v
FunP f
f [Term f v]
ts)

-- | Unify two terms. If unification succeeds, return a most general unifier
-- of the given terms. We have the following property:
--
-- > unify t u == Just s   ==>   apply s t == apply s u
--
-- /O(n log(n))/, where /n/ is the apparent size of the arguments. Note that
-- the apparent size of the result may be exponential due to shared subterms.
unify :: (Eq f, Ord v) => Term f v -> Term f v -> Maybe (Subst f v)
unify :: forall f v.
(Eq f, Ord v) =>
Term f v -> Term f v -> Maybe (Subst f v)
unify Term f v
t Term f v
u = do
    let -- solve unification problem
        act :: StateT (Map v Node) (UnionM (Annot f v)) (Node, Bool)
act = do
            Node
t' <- Term f v -> UnifyM f v Node
forall v f. Ord v => Term f v -> UnifyM f v Node
mkNode Term f v
t
            Node
u' <- Term f v -> UnifyM f v Node
forall v f. Ord v => Term f v -> UnifyM f v Node
mkNode Term f v
u
            Bool
success <- [(Node, Node)] -> UnifyM f v Bool
forall f v. (Eq f, Ord v) => [(Node, Node)] -> UnifyM f v Bool
solve [(Node
t', Node
u')]
            (Node, Bool)
-> StateT (Map v Node) (UnionM (Annot f v)) (Node, Bool)
forall a. a -> StateT (Map v Node) (UnionM (Annot f v)) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Node
t', Bool
success)
        (Union (Annot f v)
union, ((Node
root, Bool
success), Map v Node
vmap)) = UnionM (Annot f v) ((Node, Bool), Map v Node)
-> (Union (Annot f v), ((Node, Bool), Map v Node))
forall l a. UnionM l a -> (Union l, a)
UM.run' (UnionM (Annot f v) ((Node, Bool), Map v Node)
 -> (Union (Annot f v), ((Node, Bool), Map v Node)))
-> UnionM (Annot f v) ((Node, Bool), Map v Node)
-> (Union (Annot f v), ((Node, Bool), Map v Node))
forall a b. (a -> b) -> a -> b
$ StateT (Map v Node) (UnionM (Annot f v)) (Node, Bool)
-> Map v Node -> UnionM (Annot f v) ((Node, Bool), Map v Node)
forall s (m :: * -> *) a. StateT s m a -> s -> m (a, s)
runStateT StateT (Map v Node) (UnionM (Annot f v)) (Node, Bool)
act Map v Node
forall k a. Map k a
M.empty
        -- find the successors in the resulting graph
        succs :: Node -> [Node]
succs Node
n = case (Node, Annot f v) -> Annot f v
forall a b. (a, b) -> b
snd (Union (Annot f v) -> Node -> (Node, Annot f v)
forall l. Union l -> Node -> (Node, l)
U.lookup Union (Annot f v)
union Node
n) of
            VarA v
v -> []
            FunA f
f [Node]
ns -> [Node]
ns
            FunP f
f [Term f v]
ts -> do v
v <- Term f v -> [v]
forall f v. Term f v -> [v]
Term.vars (Term f v -> [v]) -> [Term f v] -> [v]
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< [Term f v]
ts; Maybe Node -> [Node]
forall a. Maybe a -> [a]
maybeToList (v -> Map v Node -> Maybe Node
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup v
v Map v Node
vmap)
    Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Bool
success Bool -> Bool -> Bool
&& Int -> (Node -> [Node]) -> Node -> Bool
acyclic (Union (Annot f v) -> Int
forall l. Union l -> Int
U.size Union (Annot f v)
union) Node -> [Node]
succs Node
root
    let -- build resulting substitution
        subst :: Subst f v
subst = Map v (Term f v) -> Subst f v
forall v f v'. Map v (Term f v') -> GSubst v f v'
fromMap (Map v (Term f v) -> Subst f v) -> Map v (Term f v) -> Subst f v
forall a b. (a -> b) -> a -> b
$ (Node -> Term f v) -> Map v Node -> Map v (Term f v)
forall a b. (a -> b) -> Map v a -> Map v b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Node -> Term f v
lookupNode Map v Node
vmap
        -- 'terms' maps representatives to their reconstructed terms
        terms :: Array Int (Term f v)
terms = (Annot f v -> Term f v)
-> Array Int (Annot f v) -> Array Int (Term f v)
forall a b. (a -> b) -> Array Int a -> Array Int b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Annot f v -> Term f v
mkTerm (Union (Annot f v) -> Array Int (Annot f v)
forall a. Union a -> Array Int a
UM.label Union (Annot f v)
union)
        -- look up a node in 'terms'
        lookupNode :: Node -> Term f v
lookupNode = (Array Int (Term f v)
terms Array Int (Term f v) -> Int -> Term f v
forall i e. Ix i => Array i e -> i -> e
!) (Int -> Term f v) -> (Node -> Int) -> Node -> Term f v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Node -> Int
U.fromNode (Node -> Int) -> (Node -> Node) -> Node -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Node, Annot f v) -> Node
forall a b. (a, b) -> a
fst ((Node, Annot f v) -> Node)
-> (Node -> (Node, Annot f v)) -> Node -> Node
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Union (Annot f v) -> Node -> (Node, Annot f v)
forall l. Union l -> Node -> (Node, l)
U.lookup Union (Annot f v)
union
        -- translate annotation back to term
        mkTerm :: Annot f v -> Term f v
mkTerm (VarA v
v) = v -> Term f v
forall f v. v -> Term f v
Var v
v
        mkTerm (FunA f
f [Node]
ns) = f -> [Term f v] -> Term f v
forall f v. f -> [Term f v] -> Term f v
Fun f
f ((Node -> Term f v) -> [Node] -> [Term f v]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Node -> Term f v
lookupNode [Node]
ns)
        mkTerm (FunP f
f [Term f v]
ts) = Subst f v
subst Subst f v -> Term f v -> Term f v
forall v f. Ord v => Subst f v -> Term f v -> Term f v
`apply` f -> [Term f v] -> Term f v
forall f v. f -> [Term f v] -> Term f v
Fun f
f [Term f v]
ts
    Subst f v -> Maybe (Subst f v)
forall a. a -> Maybe a
forall (m :: * -> *) a. Monad m => a -> m a
return Subst f v
subst

-- Check whether the subgraph reachable from the given root is acyclic.
-- This is done by a depth first search, where nodes are initially colored
-- white (0), then grey (1) while their children are being visited and
-- finally black (2) after the children have been processed completely.
--
-- The subgraph is cyclic iff we encounter a grey node at some point.
--
-- O(n) plus the cost of 'succs'; 'succs' is called at most once per node.
acyclic :: Int -> (U.Node -> [U.Node]) -> U.Node -> Bool
acyclic :: Int -> (Node -> [Node]) -> Node -> Bool
acyclic Int
size Node -> [Node]
succs Node
root = (forall s. ST s Bool) -> Bool
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Bool) -> Bool) -> (forall s. ST s Bool) -> Bool
forall a b. (a -> b) -> a -> b
$ do
    let t :: ST s (STUArray s Int Word8)
        t :: forall s. ST s (STUArray s Int Word8)
t = ST s (STUArray s Int Word8)
forall a. HasCallStack => a
undefined
    STUArray s Int Word8
color <- (Int, Int) -> Word8 -> ST s (STUArray s Int Word8)
forall i. Ix i => (i, i) -> Word8 -> ST s (STUArray s i Word8)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0, Int
sizeInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1) Word8
0 ST s (STUArray s Int Word8)
-> ST s (STUArray s Int Word8) -> ST s (STUArray s Int Word8)
forall a. a -> a -> a
`asTypeOf` ST s (STUArray s Int Word8)
forall s. ST s (STUArray s Int Word8)
t
    let dfs :: Node -> m Bool
dfs Node
n = do
            Word8
c <- STUArray s Int Word8 -> Int -> m Word8
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray STUArray s Int Word8
color (Node -> Int
U.fromNode Node
n)
            case Word8
c of
                Word8
0 -> do
                    STUArray s Int Word8 -> Int -> Word8 -> m ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Int Word8
color (Node -> Int
U.fromNode Node
n) Word8
1
                    (m Bool -> [m Bool] -> m Bool) -> [m Bool] -> m Bool -> m Bool
forall a b c. (a -> b -> c) -> b -> a -> c
flip ((m Bool -> m Bool -> m Bool) -> m Bool -> [m Bool] -> m Bool
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr m Bool -> m Bool -> m Bool
forall (m :: * -> *). Monad m => m Bool -> m Bool -> m Bool
andM) ((Node -> m Bool) -> [Node] -> [m Bool]
forall a b. (a -> b) -> [a] -> [b]
map Node -> m Bool
dfs (Node -> [Node]
succs Node
n)) (m Bool -> m Bool) -> m Bool -> m Bool
forall a b. (a -> b) -> a -> b
$ do
                        STUArray s Int Word8 -> Int -> Word8 -> m ()
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> e -> m ()
writeArray STUArray s Int Word8
color (Node -> Int
U.fromNode Node
n) Word8
2
                        Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
                Word8
1 -> Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False
                Word8
2 -> Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
    Node -> ST s Bool
forall {m :: * -> *}. MArray (STUArray s) Word8 m => Node -> m Bool
dfs Node
root

-- monadic, logical and with short-cut evaluation
andM :: Monad m => m Bool -> m Bool -> m Bool
andM :: forall (m :: * -> *). Monad m => m Bool -> m Bool -> m Bool
andM m Bool
a m Bool
b = do
    Bool
a' <- m Bool
a
    if Bool
a' then m Bool
b else Bool -> m Bool
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
False

------------------------------------------------------------------------------
-- Reference implementation

-- | Unify two terms. This is a simple implementation for testing purposes,
-- and may be removed in future versions of this library.
unifyRef :: (Eq f, Ord v) => Term f v -> Term f v -> Maybe (Subst f v)
unifyRef :: forall f v.
(Eq f, Ord v) =>
Term f v -> Term f v -> Maybe (Subst f v)
unifyRef Term f v
t Term f v
u = Map v (Term f v) -> GSubst v f v
forall v f v'. Map v (Term f v') -> GSubst v f v'
fromMap (Map v (Term f v) -> GSubst v f v)
-> Maybe (Map v (Term f v)) -> Maybe (GSubst v f v)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [(Term f v, Term f v)]
-> Map v (Term f v) -> Maybe (Map v (Term f v))
forall {k} {f}.
(Ord k, Eq f) =>
[(Term f k, Term f k)]
-> Map k (Term f k) -> Maybe (Map k (Term f k))
go [(Term f v
t, Term f v
u)] Map v (Term f v)
forall k a. Map k a
M.empty where
   go :: [(Term f k, Term f k)]
-> Map k (Term f k) -> Maybe (Map k (Term f k))
go [] Map k (Term f k)
subst = Map k (Term f k) -> Maybe (Map k (Term f k))
forall a. a -> Maybe a
Just Map k (Term f k)
subst
   go ((Term f k
t, Term f k
u) : [(Term f k, Term f k)]
xs) Map k (Term f k)
subst = case (Term f k
t, Term f k
u) of
      (Var k
v, Term f k
t) -> k
-> Term f k
-> [(Term f k, Term f k)]
-> Map k (Term f k)
-> Maybe (Map k (Term f k))
add k
v Term f k
t [(Term f k, Term f k)]
xs Map k (Term f k)
subst
      (Term f k
t, Var k
v) -> k
-> Term f k
-> [(Term f k, Term f k)]
-> Map k (Term f k)
-> Maybe (Map k (Term f k))
add k
v Term f k
t [(Term f k, Term f k)]
xs Map k (Term f k)
subst
      (Fun f
f [Term f k]
ts, Fun f
f' [Term f k]
ts')
          | f
f f -> f -> Bool
forall a. Eq a => a -> a -> Bool
/= f
f' Bool -> Bool -> Bool
|| [Term f k] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Term f k]
ts Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [Term f k] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Term f k]
ts' -> Maybe (Map k (Term f k))
forall a. Maybe a
Nothing
          | Bool
otherwise -> [(Term f k, Term f k)]
-> Map k (Term f k) -> Maybe (Map k (Term f k))
go ([Term f k] -> [Term f k] -> [(Term f k, Term f k)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Term f k]
ts [Term f k]
ts' [(Term f k, Term f k)]
-> [(Term f k, Term f k)] -> [(Term f k, Term f k)]
forall a. [a] -> [a] -> [a]
++ [(Term f k, Term f k)]
xs) Map k (Term f k)
subst
   add :: k
-> Term f k
-> [(Term f k, Term f k)]
-> Map k (Term f k)
-> Maybe (Map k (Term f k))
add k
v Term f k
t [(Term f k, Term f k)]
xs Map k (Term f k)
subst
       | k -> Term f k
forall f v. v -> Term f v
Var k
v Term f k -> Term f k -> Bool
forall a. Eq a => a -> a -> Bool
== Term f k
t = [(Term f k, Term f k)]
-> Map k (Term f k) -> Maybe (Map k (Term f k))
go [(Term f k, Term f k)]
xs Map k (Term f k)
subst
       | k -> Term f k -> Bool
forall {a} {f}. Eq a => a -> Term f a -> Bool
occurs k
v Term f k
t = Maybe (Map k (Term f k))
forall a. Maybe a
Nothing
       | Bool
otherwise =
           let app :: Term f k -> Term f k
app = Subst f k -> Term f k -> Term f k
forall v f. Ord v => Subst f v -> Term f v -> Term f v
apply (Map k (Term f k) -> Subst f k
forall v f v'. Map v (Term f v') -> GSubst v f v'
fromMap (k -> Term f k -> Map k (Term f k)
forall k a. k -> a -> Map k a
M.singleton k
v Term f k
t))
           in  [(Term f k, Term f k)]
-> Map k (Term f k) -> Maybe (Map k (Term f k))
go (((Term f k, Term f k) -> (Term f k, Term f k))
-> [(Term f k, Term f k)] -> [(Term f k, Term f k)]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Term f k -> Term f k
app (Term f k -> Term f k)
-> (Term f k -> Term f k)
-> (Term f k, Term f k)
-> (Term f k, Term f k)
forall b c b' c'. (b -> c) -> (b' -> c') -> (b, b') -> (c, c')
forall (a :: * -> * -> *) b c b' c'.
Arrow a =>
a b c -> a b' c' -> a (b, b') (c, c')
*** Term f k -> Term f k
app) [(Term f k, Term f k)]
xs) (k -> Term f k -> Map k (Term f k) -> Map k (Term f k)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert k
v Term f k
t ((Term f k -> Term f k) -> Map k (Term f k) -> Map k (Term f k)
forall a b. (a -> b) -> Map k a -> Map k b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Term f k -> Term f k
app Map k (Term f k)
subst))
   occurs :: a -> Term f a -> Bool
occurs a
v Term f a
t = a
v a -> [a] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` Term f a -> [a]
forall f v. Term f v -> [v]
Term.vars Term f a
t