{-|
  Copyright  :  (C) 2021-2024 QBayLogic B.V.
  License    :  BSD2 (see the file LICENSE)
  Maintainer :  QBayLogic B.V. <devops@qbaylogic.com>
-}

{-# LANGUAGE CPP #-}
{-# LANGUAGE OverloadedStrings #-}

module Clash.Core.EqSolver where

import Data.List.Extra (zipEqual)
import Data.Maybe (catMaybes, mapMaybe)

import Clash.Core.Name (Name(nameUniq))
import Clash.Core.Term
import Clash.Core.TyCon
import Clash.Core.Type
import Clash.Core.Var
import Clash.Core.VarEnv (VarSet, elemVarSet, emptyVarSet, mkVarSet)
import Clash.Unique (fromGhcUnique)
import Clash.Core.DataCon (dcUniq)
import GHC.Builtin.Names (unsafeReflDataConKey, eqPrimTyConKey, typeNatAddTyFamNameKey)

-- | Data type that indicates what kind of solution (if any) was found
data TypeEqSolution
  = Solution (TyVar, Type)
  -- ^ Solution was found. Variable equals some integer.
  | AbsurdSolution
  -- ^ A solution was found, but it involved negative naturals.
  | NoSolution
  -- ^ Given type wasn't an equation, or it was unsolvable.
    deriving (Int -> TypeEqSolution -> ShowS
[TypeEqSolution] -> ShowS
TypeEqSolution -> String
(Int -> TypeEqSolution -> ShowS)
-> (TypeEqSolution -> String)
-> ([TypeEqSolution] -> ShowS)
-> Show TypeEqSolution
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TypeEqSolution -> ShowS
showsPrec :: Int -> TypeEqSolution -> ShowS
$cshow :: TypeEqSolution -> String
show :: TypeEqSolution -> String
$cshowList :: [TypeEqSolution] -> ShowS
showList :: [TypeEqSolution] -> ShowS
Show, TypeEqSolution -> TypeEqSolution -> Bool
(TypeEqSolution -> TypeEqSolution -> Bool)
-> (TypeEqSolution -> TypeEqSolution -> Bool) -> Eq TypeEqSolution
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: TypeEqSolution -> TypeEqSolution -> Bool
== :: TypeEqSolution -> TypeEqSolution -> Bool
$c/= :: TypeEqSolution -> TypeEqSolution -> Bool
/= :: TypeEqSolution -> TypeEqSolution -> Bool
Eq)

catSolutions :: [TypeEqSolution] -> [(TyVar, Type)]
catSolutions :: [TypeEqSolution] -> [(TyVar, Type)]
catSolutions = (TypeEqSolution -> Maybe (TyVar, Type))
-> [TypeEqSolution] -> [(TyVar, Type)]
forall a b. (a -> Maybe b) -> [a] -> [b]
mapMaybe TypeEqSolution -> Maybe (TyVar, Type)
getSol
 where
  getSol :: TypeEqSolution -> Maybe (TyVar, Type)
getSol (Solution (TyVar, Type)
s) = (TyVar, Type) -> Maybe (TyVar, Type)
forall a. a -> Maybe a
Just (TyVar, Type)
s
  getSol TypeEqSolution
_ = Maybe (TyVar, Type)
forall a. Maybe a
Nothing

-- | Solve given equations and return all non-absurd solutions
solveNonAbsurds :: TyConMap -> VarSet -> [(Type, Type)] -> [(TyVar, Type)]
solveNonAbsurds :: TyConMap -> VarSet -> [(Type, Type)] -> [(TyVar, Type)]
solveNonAbsurds TyConMap
_tcm VarSet
_ [] = []
solveNonAbsurds TyConMap
tcm VarSet
solveSet ((Type, Type)
eq:[(Type, Type)]
eqs) =
  [(TyVar, Type)]
solved [(TyVar, Type)] -> [(TyVar, Type)] -> [(TyVar, Type)]
forall a. [a] -> [a] -> [a]
++ TyConMap -> VarSet -> [(Type, Type)] -> [(TyVar, Type)]
solveNonAbsurds TyConMap
tcm VarSet
solveSet [(Type, Type)]
eqs
 where
  solvers :: [(Type, Type) -> [TypeEqSolution]]
solvers = [TypeEqSolution -> [TypeEqSolution]
forall a. a -> [a]
forall (f :: Type -> Type) a. Applicative f => a -> f a
pure (TypeEqSolution -> [TypeEqSolution])
-> ((Type, Type) -> TypeEqSolution)
-> (Type, Type)
-> [TypeEqSolution]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VarSet -> (Type, Type) -> TypeEqSolution
solveAdd VarSet
solveSet, TyConMap -> VarSet -> (Type, Type) -> [TypeEqSolution]
solveEq TyConMap
tcm VarSet
solveSet]
  solved :: [(TyVar, Type)]
solved = [TypeEqSolution] -> [(TyVar, Type)]
catSolutions ([[TypeEqSolution]] -> [TypeEqSolution]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat [(Type, Type) -> [TypeEqSolution]
s (Type, Type)
eq | (Type, Type) -> [TypeEqSolution]
s <- [(Type, Type) -> [TypeEqSolution]]
solvers])

-- | Solve simple equalities such as:
--
--   * a ~ 3
--   * 3 ~ a
--   * SomeType a b ~ SomeType 3 5
--   * SomeType 3 5 ~ SomeType a b
--   * SomeType a 5 ~ SomeType 3 b
--
solveEq :: TyConMap -> VarSet -> (Type, Type) -> [TypeEqSolution]
solveEq :: TyConMap -> VarSet -> (Type, Type) -> [TypeEqSolution]
solveEq TyConMap
tcm VarSet
solveSet (TyConMap -> Type -> Type
coreView TyConMap
tcm -> Type
left, TyConMap -> Type -> Type
coreView TyConMap
tcm -> Type
right) =
  case (Type
left, Type
right) of
    (VarTy TyVar
tyVar, ConstTy {}) | TyVar -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
elemVarSet TyVar
tyVar VarSet
solveSet ->
      -- a ~ 3
      [(TyVar, Type) -> TypeEqSolution
Solution (TyVar
tyVar, Type
right)]
    (ConstTy {}, VarTy TyVar
tyVar) | TyVar -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
elemVarSet TyVar
tyVar VarSet
solveSet ->
      -- 3 ~ a
      [(TyVar, Type) -> TypeEqSolution
Solution (TyVar
tyVar, Type
left)]
    (ConstTy {}, ConstTy {}) ->
      -- Int /= Char
      if Type
left Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
/= Type
right then [TypeEqSolution
AbsurdSolution] else []
    (LitTy {}, LitTy {}) ->
      -- 3 /= 5
      if Type
left Type -> Type -> Bool
forall a. Eq a => a -> a -> Bool
/= Type
right then [TypeEqSolution
AbsurdSolution] else []
    (Type, Type)
_ ->
      -- The call to 'coreView' at the start of 'solveEq' should have reduced
      -- all solvable type families. If we encounter one here that means the
      -- type family is stuck (and that we shouldn't compare it to anything!).
      if (Type -> Bool) -> [Type] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any (TyConMap -> Type -> Bool
isTypeFamilyApplication TyConMap
tcm) [Type
left, Type
right] then
        []
      else
        case (Type -> TypeView
tyView Type
left, Type -> TypeView
tyView Type
right) of
          (TyConApp TyConName
leftNm [Type]
leftTys, TyConApp TyConName
rightNm [Type]
rightTys) ->
            -- SomeType a b ~ SomeType 3 5 (or other way around)
            if TyConName
leftNm TyConName -> TyConName -> Bool
forall a. Eq a => a -> a -> Bool
== TyConName
rightNm then
              [[TypeEqSolution]] -> [TypeEqSolution]
forall (t :: Type -> Type) a. Foldable t => t [a] -> [a]
concat (((Type, Type) -> [TypeEqSolution])
-> [(Type, Type)] -> [[TypeEqSolution]]
forall a b. (a -> b) -> [a] -> [b]
map (TyConMap -> VarSet -> (Type, Type) -> [TypeEqSolution]
solveEq TyConMap
tcm VarSet
solveSet) ([Type] -> [Type] -> [(Type, Type)]
forall a b. HasCallStack => [a] -> [b] -> [(a, b)]
zipEqual [Type]
leftTys [Type]
rightTys))
            else
              [TypeEqSolution
AbsurdSolution]
          (TypeView, TypeView)
_ ->
            []

-- | Solve equations supported by @normalizeAdd@. See documentation of
-- @TypeEqSolution@ to understand the return value.
solveAdd
  :: VarSet
  -> (Type, Type)
  -> TypeEqSolution
solveAdd :: VarSet -> (Type, Type) -> TypeEqSolution
solveAdd VarSet
solveSet (Type, Type)
ab =
  case (Type, Type) -> Maybe (Integer, Integer, Type)
normalizeAdd (Type, Type)
ab of
    Just (Integer
n, Integer
m, VarTy TyVar
tyVar) | TyVar -> VarSet -> Bool
forall a. Var a -> VarSet -> Bool
elemVarSet TyVar
tyVar VarSet
solveSet ->
      if Integer
n Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
0 Bool -> Bool -> Bool
&& Integer
m Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
0 Bool -> Bool -> Bool
&& Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
m Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
>= Integer
0 then
        (TyVar, Type) -> TypeEqSolution
Solution (TyVar
tyVar, (LitTy -> Type
LitTy (Integer -> LitTy
NumTy (Integer
n Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
m))))
      else
        TypeEqSolution
AbsurdSolution
    Maybe (Integer, Integer, Type)
_ ->
      TypeEqSolution
NoSolution

-- | Given the left and right side of an equation, normalize it such that
-- equations of the following forms:
--
--     * 5     ~ n + 2
--     * 5     ~ 2 + n
--     * n + 2 ~ 5
--     * 2 + n ~ 5
--
-- are returned as (5, 2, n)
normalizeAdd
  :: (Type, Type)
  -> Maybe (Integer, Integer, Type)
normalizeAdd :: (Type, Type) -> Maybe (Integer, Integer, Type)
normalizeAdd (Type
a, Type
b) = do
  (Integer
n, Type
rhs) <- Type -> Type -> Maybe (Integer, Type)
lhsLit Type
a Type
b
  case Type -> TypeView
tyView Type
rhs of
    TyConApp TyConName
tc [Type
left, Type
right]
      | TyConName -> Unique
forall a. Name a -> Unique
nameUniq TyConName
tc Unique -> Unique -> Bool
forall a. Eq a => a -> a -> Bool
== Unique -> Unique
fromGhcUnique Unique
typeNatAddTyFamNameKey -> do
      (Integer
m, Type
o) <- Type -> Type -> Maybe (Integer, Type)
lhsLit Type
left Type
right
      (Integer, Integer, Type) -> Maybe (Integer, Integer, Type)
forall a. a -> Maybe a
forall (m :: Type -> Type) a. Monad m => a -> m a
return (Integer
n, Integer
m, Type
o)
    TypeView
_ ->
      Maybe (Integer, Integer, Type)
forall a. Maybe a
Nothing
 where
  lhsLit :: Type -> Type -> Maybe (Integer, Type)
lhsLit Type
x                 (LitTy (NumTy Integer
n)) = (Integer, Type) -> Maybe (Integer, Type)
forall a. a -> Maybe a
Just (Integer
n, Type
x)
  lhsLit (LitTy (NumTy Integer
n)) Type
y                 = (Integer, Type) -> Maybe (Integer, Type)
forall a. a -> Maybe a
Just (Integer
n, Type
y)
  lhsLit Type
_                 Type
_                 = Maybe (Integer, Type)
forall a. Maybe a
Nothing

-- | Tests for nonsencical patterns due to types being "absurd". See
-- @isAbsurdEq@ for more info.
isAbsurdPat
  :: TyConMap
  -> Pat
  -> Bool
isAbsurdPat :: TyConMap -> Pat -> Bool
isAbsurdPat TyConMap
_tcm (DataPat DataCon
dc [TyVar]
_ [Id]
_)
  -- unsafeCoerce is not absurd in the way intended by /isAbsurdPat/
  | DataCon -> Unique
dcUniq DataCon
dc Unique -> Unique -> Bool
forall a. Eq a => a -> a -> Bool
== Unique -> Unique
fromGhcUnique Unique
unsafeReflDataConKey
  = Bool
False
isAbsurdPat TyConMap
tcm Pat
pat =
  ((Type, Type) -> Bool) -> [(Type, Type)] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any (TyConMap -> VarSet -> (Type, Type) -> Bool
isAbsurdEq TyConMap
tcm VarSet
exts) (TyConMap -> Pat -> [(Type, Type)]
patEqs TyConMap
tcm Pat
pat)
 where
  exts :: VarSet
exts = case Pat
pat of
    DataPat DataCon
_dc [TyVar]
extNms [Id]
_ids -> [TyVar] -> VarSet
forall a. [Var a] -> VarSet
mkVarSet [TyVar]
extNms
    Pat
_ -> VarSet
emptyVarSet

-- | Determines if an "equation" obtained through @patEqs@ or @typeEq@ is
-- absurd. That is, it tests if two types that are definitely not equal are
-- asserted to be equal OR if the computation of the types yield some absurd
-- (intermediate) result such as -1.
isAbsurdEq
  :: TyConMap
  -> VarSet -- ^ existential tvs
  -> (Type, Type)
  -> Bool
isAbsurdEq :: TyConMap -> VarSet -> (Type, Type) -> Bool
isAbsurdEq TyConMap
tcm VarSet
exts ((Type
left0, Type
right0)) =
  case (TyConMap -> Type -> Type
coreView TyConMap
tcm Type
left0, TyConMap -> Type -> Type
coreView TyConMap
tcm Type
right0) of
    (VarSet -> (Type, Type) -> TypeEqSolution
solveAdd VarSet
exts -> TypeEqSolution
AbsurdSolution) -> Bool
True
    (Type, Type)
lr -> (TypeEqSolution -> Bool) -> [TypeEqSolution] -> Bool
forall (t :: Type -> Type) a.
Foldable t =>
(a -> Bool) -> t a -> Bool
any (TypeEqSolution -> TypeEqSolution -> Bool
forall a. Eq a => a -> a -> Bool
==TypeEqSolution
AbsurdSolution) (TyConMap -> VarSet -> (Type, Type) -> [TypeEqSolution]
solveEq TyConMap
tcm VarSet
exts (Type, Type)
lr)

-- | Get constraint equations
patEqs
  :: TyConMap
  -> Pat
  -> [(Type, Type)]
patEqs :: TyConMap -> Pat -> [(Type, Type)]
patEqs TyConMap
tcm Pat
pat =
 [Maybe (Type, Type)] -> [(Type, Type)]
forall a. [Maybe a] -> [a]
catMaybes ((Id -> Maybe (Type, Type)) -> [Id] -> [Maybe (Type, Type)]
forall a b. (a -> b) -> [a] -> [b]
map (TyConMap -> Type -> Maybe (Type, Type)
typeEq TyConMap
tcm (Type -> Maybe (Type, Type))
-> (Id -> Type) -> Id -> Maybe (Type, Type)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Id -> Type
forall a. Var a -> Type
varType) (([TyVar], [Id]) -> [Id]
forall a b. (a, b) -> b
snd (Pat -> ([TyVar], [Id])
patIds Pat
pat)))

-- | If type is an equation, return LHS and RHS.
typeEq
  :: TyConMap
  -> Type
  -> Maybe (Type, Type)
typeEq :: TyConMap -> Type -> Maybe (Type, Type)
typeEq TyConMap
tcm Type
ty =
 case Type -> TypeView
tyView (TyConMap -> Type -> Type
coreView TyConMap
tcm Type
ty) of
  TyConApp TyConName
tc [Type
_, Type
_, Type
left, Type
right]
    | TyConName -> Unique
forall a. Name a -> Unique
nameUniq TyConName
tc Unique -> Unique -> Bool
forall a. Eq a => a -> a -> Bool
== Unique -> Unique
fromGhcUnique Unique
eqPrimTyConKey ->
    (Type, Type) -> Maybe (Type, Type)
forall a. a -> Maybe a
Just (TyConMap -> Type -> Type
coreView TyConMap
tcm Type
left, TyConMap -> Type -> Type
coreView TyConMap
tcm Type
right)
  TypeView
_ ->
    Maybe (Type, Type)
forall a. Maybe a
Nothing