module Language.Futhark.TypeChecker.TypesTests (tests) where

import Data.Bifunctor
import Data.List (isInfixOf)
import Data.Map qualified as M
import Data.Text qualified as T
import Futhark.FreshNames
import Futhark.Util.Pretty (docText, prettyTextOneLine)
import Language.Futhark
import Language.Futhark.Semantic
import Language.Futhark.SyntaxTests ()
import Language.Futhark.TypeChecker (initialEnv)
import Language.Futhark.TypeChecker.Monad
import Language.Futhark.TypeChecker.Names (resolveTypeExp)
import Language.Futhark.TypeChecker.Terms
import Language.Futhark.TypeChecker.Types
import Test.Tasty
import Test.Tasty.HUnit

evalTest :: TypeExp (ExpBase NoInfo Name) Name -> Either String ([VName], ResRetType) -> TestTree
evalTest :: TypeExp (ExpBase NoInfo Name) Name
-> Either TestName ([VName], ResRetType) -> TestTree
evalTest TypeExp (ExpBase NoInfo Name) Name
te Either TestName ([VName], ResRetType)
expected =
  TestName -> Assertion -> TestTree
testCase (TypeExp (ExpBase NoInfo Name) Name -> TestName
forall a. Pretty a => a -> TestName
prettyString TypeExp (ExpBase NoInfo Name) Name
te) (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
    case ((((TypeExp (ExpBase Info VName) VName, [VName], ResRetType,
   Liftedness),
  VNameSource)
 -> ([VName], ResRetType))
-> Either
     TypeError
     ((TypeExp (ExpBase Info VName) VName, [VName], ResRetType,
       Liftedness),
      VNameSource)
-> Either TypeError ([VName], ResRetType)
forall a b. (a -> b) -> Either TypeError a -> Either TypeError b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((TypeExp (ExpBase Info VName) VName, [VName], ResRetType,
 Liftedness)
-> ([VName], ResRetType)
forall {a} {a} {b} {d}. (a, a, b, d) -> (a, b)
extract ((TypeExp (ExpBase Info VName) VName, [VName], ResRetType,
  Liftedness)
 -> ([VName], ResRetType))
-> (((TypeExp (ExpBase Info VName) VName, [VName], ResRetType,
      Liftedness),
     VNameSource)
    -> (TypeExp (ExpBase Info VName) VName, [VName], ResRetType,
        Liftedness))
-> ((TypeExp (ExpBase Info VName) VName, [VName], ResRetType,
     Liftedness),
    VNameSource)
-> ([VName], ResRetType)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((TypeExp (ExpBase Info VName) VName, [VName], ResRetType,
  Liftedness),
 VNameSource)
-> (TypeExp (ExpBase Info VName) VName, [VName], ResRetType,
    Liftedness)
forall a b. (a, b) -> a
fst) (TypeM
  (TypeExp (ExpBase Info VName) VName, [VName], ResRetType,
   Liftedness)
-> Either
     TypeError
     ((TypeExp (ExpBase Info VName) VName, [VName], ResRetType,
       Liftedness),
      VNameSource)
forall {a}. TypeM a -> Either TypeError (a, VNameSource)
run ((ExpBase NoInfo VName -> TypeM (ExpBase Info VName))
-> TypeExp (ExpBase NoInfo VName) VName
-> TypeM
     (TypeExp (ExpBase Info VName) VName, [VName], ResRetType,
      Liftedness)
forall (m :: * -> *) df.
(MonadTypeChecker m, Pretty df) =>
(df -> m (ExpBase Info VName))
-> TypeExp df VName
-> m (TypeExp (ExpBase Info VName) VName, [VName], ResRetType,
      Liftedness)
checkTypeExp ExpBase NoInfo VName -> TypeM (ExpBase Info VName)
checkSizeExp (TypeExp (ExpBase NoInfo VName) VName
 -> TypeM
      (TypeExp (ExpBase Info VName) VName, [VName], ResRetType,
       Liftedness))
-> TypeM (TypeExp (ExpBase NoInfo VName) VName)
-> TypeM
     (TypeExp (ExpBase Info VName) VName, [VName], ResRetType,
      Liftedness)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< TypeExp (ExpBase NoInfo Name) Name
-> TypeM (TypeExp (ExpBase NoInfo VName) VName)
resolveTypeExp TypeExp (ExpBase NoInfo Name) Name
te)), Either TestName ([VName], ResRetType)
expected) of
      (Left TypeError
got_e, Left TestName
expected_e) ->
        let got_e_s :: TestName
got_e_s = Text -> TestName
T.unpack (Text -> TestName) -> Text -> TestName
forall a b. (a -> b) -> a -> b
$ Doc AnsiStyle -> Text
forall a. Doc a -> Text
docText (Doc AnsiStyle -> Text) -> Doc AnsiStyle -> Text
forall a b. (a -> b) -> a -> b
$ TypeError -> Doc AnsiStyle
prettyTypeError TypeError
got_e
         in (TestName
expected_e TestName -> TestName -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isInfixOf` TestName
got_e_s) Bool -> TestName -> Assertion
forall t.
(AssertionPredicable t, HasCallStack) =>
t -> TestName -> Assertion
@? TestName
got_e_s
      (Left TypeError
got_e, Right ([VName], ResRetType)
_) ->
        let got_e_s :: TestName
got_e_s = Text -> TestName
T.unpack (Text -> TestName) -> Text -> TestName
forall a b. (a -> b) -> a -> b
$ Doc AnsiStyle -> Text
forall a. Doc a -> Text
docText (Doc AnsiStyle -> Text) -> Doc AnsiStyle -> Text
forall a b. (a -> b) -> a -> b
$ TypeError -> Doc AnsiStyle
prettyTypeError TypeError
got_e
         in TestName -> Assertion
forall a. HasCallStack => TestName -> IO a
assertFailure (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$ TestName
"Failed: " TestName -> TestName -> TestName
forall a. Semigroup a => a -> a -> a
<> TestName
got_e_s
      (Right ([VName], ResRetType)
actual_t, Right ([VName], ResRetType)
expected_t) ->
        ([VName], ResRetType)
actual_t ([VName], ResRetType) -> ([VName], ResRetType) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= ([VName], ResRetType)
expected_t
      (Right ([VName], ResRetType)
actual_t, Left TestName
_) ->
        TestName -> Assertion
forall a. HasCallStack => TestName -> IO a
assertFailure (TestName -> Assertion) -> TestName -> Assertion
forall a b. (a -> b) -> a -> b
$ TestName
"Expected error, got: " TestName -> TestName -> TestName
forall a. Semigroup a => a -> a -> a
<> ([VName], ResRetType) -> TestName
forall a. Show a => a -> TestName
show ([VName], ResRetType)
actual_t
  where
    extract :: (a, a, b, d) -> (a, b)
extract (a
_, a
svars, b
t, d
_) = (a
svars, b
t)
    run :: TypeM a -> Either TypeError (a, VNameSource)
run = (Warnings, Either TypeError (a, VNameSource))
-> Either TypeError (a, VNameSource)
forall a b. (a, b) -> b
snd ((Warnings, Either TypeError (a, VNameSource))
 -> Either TypeError (a, VNameSource))
-> (TypeM a -> (Warnings, Either TypeError (a, VNameSource)))
-> TypeM a
-> Either TypeError (a, VNameSource)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Env
-> ImportTable
-> ImportName
-> VNameSource
-> TypeM a
-> (Warnings, Either TypeError (a, VNameSource))
forall a.
Env
-> ImportTable
-> ImportName
-> VNameSource
-> TypeM a
-> (Warnings, Either TypeError (a, VNameSource))
runTypeM Env
env ImportTable
forall a. Monoid a => a
mempty (TestName -> ImportName
mkInitialImport TestName
"") (Int -> VNameSource
newNameSource Int
100)
    -- We hack up an environment with some predefined type
    -- abbreviations for testing.  This is all pretty sensitive to the
    -- specific unique names, so we have to be careful!
    env :: Env
env =
      Env
initialEnv
        { envTypeTable =
            M.fromList
              [ ( "square_1000",
                  TypeAbbr
                    Unlifted
                    [TypeParamDim "n_1001" mempty]
                    "[n_1001][n_1001]i32"
                ),
                ( "fun_1100",
                  TypeAbbr
                    Lifted
                    [ TypeParamType Lifted "a_1101" mempty,
                      TypeParamType Lifted "b_1102" mempty
                    ]
                    "a_1101 -> b_1102"
                ),
                ( "pair_1200",
                  TypeAbbr
                    SizeLifted
                    []
                    "?[n_1201][m_1202].([n_1201]i64, [m_1202]i64)"
                )
              ]
              <> envTypeTable initialEnv,
          envNameMap =
            M.fromList
              [ ((Type, "square"), "square_1000"),
                ((Type, "fun"), "fun_1100"),
                ((Type, "pair"), "pair_1200")
              ]
              <> envNameMap initialEnv
        }

evalTests :: TestTree
evalTests :: TestTree
evalTests =
  TestName -> [TestTree] -> TestTree
testGroup
    TestName
"Type expression elaboration"
    [ TestName -> [TestTree] -> TestTree
testGroup TestName
"Positive tests" (((TypeExp (ExpBase NoInfo Name) Name, ([VName], ResRetType))
 -> TestTree)
-> [(TypeExp (ExpBase NoInfo Name) Name, ([VName], ResRetType))]
-> [TestTree]
forall a b. (a -> b) -> [a] -> [b]
map (TypeExp (ExpBase NoInfo Name) Name, ([VName], ResRetType))
-> TestTree
mkPos [(TypeExp (ExpBase NoInfo Name) Name, ([VName], ResRetType))]
pos),
      TestName -> [TestTree] -> TestTree
testGroup TestName
"Negative tests" (((TypeExp (ExpBase NoInfo Name) Name, TestName) -> TestTree)
-> [(TypeExp (ExpBase NoInfo Name) Name, TestName)] -> [TestTree]
forall a b. (a -> b) -> [a] -> [b]
map (TypeExp (ExpBase NoInfo Name) Name, TestName) -> TestTree
mkNeg [(TypeExp (ExpBase NoInfo Name) Name, TestName)]
neg)
    ]
  where
    mkPos :: (TypeExp (ExpBase NoInfo Name) Name, ([VName], ResRetType))
-> TestTree
mkPos (TypeExp (ExpBase NoInfo Name) Name
x, ([VName], ResRetType)
y) = TypeExp (ExpBase NoInfo Name) Name
-> Either TestName ([VName], ResRetType) -> TestTree
evalTest TypeExp (ExpBase NoInfo Name) Name
x (([VName], ResRetType) -> Either TestName ([VName], ResRetType)
forall a b. b -> Either a b
Right ([VName], ResRetType)
y)
    mkNeg :: (TypeExp (ExpBase NoInfo Name) Name, TestName) -> TestTree
mkNeg (TypeExp (ExpBase NoInfo Name) Name
x, TestName
y) = TypeExp (ExpBase NoInfo Name) Name
-> Either TestName ([VName], ResRetType) -> TestTree
evalTest TypeExp (ExpBase NoInfo Name) Name
x (TestName -> Either TestName ([VName], ResRetType)
forall a b. a -> Either a b
Left TestName
y)
    pos :: [(TypeExp (ExpBase NoInfo Name) Name, ([VName], ResRetType))]
pos =
      [ ( TypeExp (ExpBase NoInfo Name) Name
"[]i32",
          ([], ResRetType
"?[d_100].[d_100]i32")
        ),
        ( TypeExp (ExpBase NoInfo Name) Name
"[][]i32",
          ([], ResRetType
"?[d_100][d_101].[d_100][d_101]i32")
        ),
        ( TypeExp (ExpBase NoInfo Name) Name
"bool -> []i32",
          ([], ResRetType
"bool -> ?[d_100].[d_100]i32")
        ),
        ( TypeExp (ExpBase NoInfo Name) Name
"bool -> []f32 -> []i32",
          ([VName
"d_100"], ResRetType
"bool -> [d_100]f32 -> ?[d_101].[d_101]i32")
        ),
        ( TypeExp (ExpBase NoInfo Name) Name
"([]i32,[]i32)",
          ([], ResRetType
"?[d_100][d_101].([d_100]i32, [d_101]i32)")
        ),
        ( TypeExp (ExpBase NoInfo Name) Name
"{a:[]i32,b:[]i32}",
          ([], ResRetType
"?[d_100][d_101].{a:[d_100]i32, b:[d_101]i32}")
        ),
        ( TypeExp (ExpBase NoInfo Name) Name
"?[n].[n][n]bool",
          ([], ResRetType
"?[n_100].[n_100][n_100]bool")
        ),
        ( TypeExp (ExpBase NoInfo Name) Name
"([]i32 -> []i32) -> bool -> []i32",
          ([VName
"d_100"], ResRetType
"([d_100]i32 -> ?[d_101].[d_101]i32) -> bool -> ?[d_102].[d_102]i32")
        ),
        ( TypeExp (ExpBase NoInfo Name) Name
"((k: i64) -> [k]i32 -> [k]i32) -> []i32 -> bool",
          ([VName
"d_101"], ResRetType
"((k_100: i64) -> [k_100]i32 -> [k_100]i32) -> [d_101]i32 -> bool")
        ),
        ( TypeExp (ExpBase NoInfo Name) Name
"square [10]",
          ([], ResRetType
"[10][10]i32")
        ),
        ( TypeExp (ExpBase NoInfo Name) Name
"square []",
          ([], ResRetType
"?[d_100].[d_100][d_100]i32")
        ),
        ( TypeExp (ExpBase NoInfo Name) Name
"bool -> square []",
          ([], ResRetType
"bool -> ?[d_100].[d_100][d_100]i32")
        ),
        ( TypeExp (ExpBase NoInfo Name) Name
"(k: i64) -> square [k]",
          ([], ResRetType
"(k_100: i64) -> [k_100][k_100]i32")
        ),
        ( TypeExp (ExpBase NoInfo Name) Name
"fun i32 bool",
          ([], ResRetType
"i32 -> bool")
        ),
        ( TypeExp (ExpBase NoInfo Name) Name
"fun ([]i32) bool",
          ([], ResRetType
"?[d_100].[d_100]i32 -> bool")
        ),
        ( TypeExp (ExpBase NoInfo Name) Name
"fun bool ([]i32)",
          ([], ResRetType
"?[d_100].bool -> [d_100]i32")
        ),
        ( TypeExp (ExpBase NoInfo Name) Name
"bool -> fun ([]i32) bool",
          ([], ResRetType
"bool -> ?[d_100].[d_100]i32 -> bool")
        ),
        ( TypeExp (ExpBase NoInfo Name) Name
"bool -> fun bool ([]i32)",
          ([], ResRetType
"bool -> ?[d_100].bool -> [d_100]i32")
        ),
        ( TypeExp (ExpBase NoInfo Name) Name
"pair",
          ([], ResRetType
"?[n_100][m_101].([n_100]i64, [m_101]i64)")
        ),
        ( TypeExp (ExpBase NoInfo Name) Name
"(pair,pair)",
          ([], ResRetType
"?[n_100][m_101][n_102][m_103].(([n_100]i64, [m_101]i64), ([n_102]i64, [m_103]i64))")
        )
      ]
    neg :: [(TypeExp (ExpBase NoInfo Name) Name, TestName)]
neg =
      [ (TypeExp (ExpBase NoInfo Name) Name
"?[n].bool", TestName
"Existential size \"n\""),
        (TypeExp (ExpBase NoInfo Name) Name
"?[n].bool -> [n]bool", TestName
"Existential size \"n\""),
        (TypeExp (ExpBase NoInfo Name) Name
"?[n].[n]bool -> [n]bool", TestName
"Existential size \"n\""),
        (TypeExp (ExpBase NoInfo Name) Name
"?[n].[n]bool -> bool", TestName
"Existential size \"n\"")
      ]

substTest :: M.Map VName (Subst StructRetType) -> StructRetType -> StructRetType -> TestTree
substTest :: Map VName (Subst StructRetType)
-> StructRetType -> StructRetType -> TestTree
substTest Map VName (Subst StructRetType)
m StructRetType
t StructRetType
expected =
  TestName -> Assertion -> TestTree
testCase (TestName
pretty_m TestName -> TestName -> TestName
forall a. Semigroup a => a -> a -> a
<> TestName
": " TestName -> TestName -> TestName
forall a. Semigroup a => a -> a -> a
<> Text -> TestName
T.unpack (StructRetType -> Text
forall a. Pretty a => a -> Text
prettyTextOneLine StructRetType
t)) (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
    TypeSubs -> StructRetType -> StructRetType
forall a. Substitutable a => TypeSubs -> a -> a
applySubst (VName
-> Map VName (Subst StructRetType) -> Maybe (Subst StructRetType)
forall k a. Ord k => k -> Map k a -> Maybe a
`M.lookup` Map VName (Subst StructRetType)
m) StructRetType
t StructRetType -> StructRetType -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= StructRetType
expected
  where
    pretty_m :: TestName
pretty_m = Text -> TestName
T.unpack (Text -> TestName) -> Text -> TestName
forall a b. (a -> b) -> a -> b
$ [(Name, Subst StructRetType)] -> Text
forall a. Pretty a => a -> Text
prettyText ([(Name, Subst StructRetType)] -> Text)
-> [(Name, Subst StructRetType)] -> Text
forall a b. (a -> b) -> a -> b
$ ((VName, Subst StructRetType) -> (Name, Subst StructRetType))
-> [(VName, Subst StructRetType)] -> [(Name, Subst StructRetType)]
forall a b. (a -> b) -> [a] -> [b]
map ((VName -> Name)
-> (VName, Subst StructRetType) -> (Name, Subst StructRetType)
forall a b c. (a -> b) -> (a, c) -> (b, c)
forall (p :: * -> * -> *) a b c.
Bifunctor p =>
(a -> b) -> p a c -> p b c
first VName -> Name
forall v. IsName v => v -> Name
toName) ([(VName, Subst StructRetType)] -> [(Name, Subst StructRetType)])
-> [(VName, Subst StructRetType)] -> [(Name, Subst StructRetType)]
forall a b. (a -> b) -> a -> b
$ Map VName (Subst StructRetType) -> [(VName, Subst StructRetType)]
forall k a. Map k a -> [(k, a)]
M.toList Map VName (Subst StructRetType)
m

-- Some of these tests may be a bit fragile, in that they depend on
-- internal renumbering, which can be arbitrary.
substTests :: TestTree
substTests :: TestTree
substTests =
  TestName -> [TestTree] -> TestTree
testGroup
    TestName
"Type substitution"
    [ Map VName (Subst StructRetType)
-> StructRetType -> StructRetType -> TestTree
substTest Map VName (Subst StructRetType)
m0 StructRetType
"t_0" StructRetType
"i64",
      Map VName (Subst StructRetType)
-> StructRetType -> StructRetType -> TestTree
substTest Map VName (Subst StructRetType)
m0 StructRetType
"[1]t_0" StructRetType
"[1]i64",
      Map VName (Subst StructRetType)
-> StructRetType -> StructRetType -> TestTree
substTest Map VName (Subst StructRetType)
m0 StructRetType
"?[n_10].[n_10]t_0" StructRetType
"?[n_10].[n_10]i64",
      --
      Map VName (Subst StructRetType)
-> StructRetType -> StructRetType -> TestTree
substTest Map VName (Subst StructRetType)
m1 StructRetType
"t_0" StructRetType
"?[n_1].[n_1]bool",
      Map VName (Subst StructRetType)
-> StructRetType -> StructRetType -> TestTree
substTest Map VName (Subst StructRetType)
m1 StructRetType
"f32 -> t_0" StructRetType
"f32 -> ?[n_1].[n_1]bool",
      Map VName (Subst StructRetType)
-> StructRetType -> StructRetType -> TestTree
substTest Map VName (Subst StructRetType)
m1 StructRetType
"f32 -> f64 -> t_0" StructRetType
"f32 -> f64 -> ?[n_1].[n_1]bool",
      Map VName (Subst StructRetType)
-> StructRetType -> StructRetType -> TestTree
substTest Map VName (Subst StructRetType)
m1 StructRetType
"f32 -> t_0 -> bool" StructRetType
"?[n_1].f32 -> [n_1]bool -> bool",
      Map VName (Subst StructRetType)
-> StructRetType -> StructRetType -> TestTree
substTest Map VName (Subst StructRetType)
m1 StructRetType
"f32 -> t_0 -> t_0" StructRetType
"?[n_1].f32 -> [n_1]bool -> ?[n_2].[n_2]bool"
    ]
  where
    m0 :: Map VName (Subst StructRetType)
m0 =
      [(VName, Subst StructRetType)] -> Map VName (Subst StructRetType)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
"t_0", [TypeParam] -> StructRetType -> Subst StructRetType
forall t. [TypeParam] -> t -> Subst t
Subst [] StructRetType
"i64")]

    m1 :: Map VName (Subst StructRetType)
m1 =
      [(VName, Subst StructRetType)] -> Map VName (Subst StructRetType)
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
"t_0", [TypeParam] -> StructRetType -> Subst StructRetType
forall t. [TypeParam] -> t -> Subst t
Subst [] StructRetType
"?[n_1].[n_1]bool")]

tests :: TestTree
tests :: TestTree
tests = TestName -> [TestTree] -> TestTree
testGroup TestName
"Basic type operations" [TestTree
evalTests, TestTree
substTests]