{-# OPTIONS_GHC -fno-warn-orphans #-}

module Futhark.IR.Syntax.CoreTests (tests) where

import Control.Applicative
import Data.Loc (Loc (..), Pos (..))
import Futhark.IR.Pretty (prettyString)
import Futhark.IR.Syntax.Core
import Language.Futhark.CoreTests ()
import Language.Futhark.PrimitiveTests ()
import Test.QuickCheck
import Test.Tasty
import Test.Tasty.HUnit
import Prelude

instance Arbitrary NoUniqueness where
  arbitrary :: Gen NoUniqueness
arbitrary = NoUniqueness -> Gen NoUniqueness
forall a. a -> Gen a
forall (f :: * -> *) a. Applicative f => a -> f a
pure NoUniqueness
NoUniqueness

instance (Arbitrary shape, Arbitrary u) => Arbitrary (TypeBase shape u) where
  arbitrary :: Gen (TypeBase shape u)
arbitrary =
    [Gen (TypeBase shape u)] -> Gen (TypeBase shape u)
forall a. HasCallStack => [Gen a] -> Gen a
oneof
      [ PrimType -> TypeBase shape u
forall shape u. PrimType -> TypeBase shape u
Prim (PrimType -> TypeBase shape u)
-> Gen PrimType -> Gen (TypeBase shape u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen PrimType
forall a. Arbitrary a => Gen a
arbitrary,
        PrimType -> shape -> u -> TypeBase shape u
forall shape u. PrimType -> shape -> u -> TypeBase shape u
Array (PrimType -> shape -> u -> TypeBase shape u)
-> Gen PrimType -> Gen (shape -> u -> TypeBase shape u)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen PrimType
forall a. Arbitrary a => Gen a
arbitrary Gen (shape -> u -> TypeBase shape u)
-> Gen shape -> Gen (u -> TypeBase shape u)
forall a b. Gen (a -> b) -> Gen a -> Gen b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Gen shape
forall a. Arbitrary a => Gen a
arbitrary Gen (u -> TypeBase shape u) -> Gen u -> Gen (TypeBase shape u)
forall a b. Gen (a -> b) -> Gen a -> Gen b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Gen u
forall a. Arbitrary a => Gen a
arbitrary
      ]

instance Arbitrary Ident where
  arbitrary :: Gen Ident
arbitrary = VName -> Type -> Ident
Ident (VName -> Type -> Ident) -> Gen VName -> Gen (Type -> Ident)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen VName
forall a. Arbitrary a => Gen a
arbitrary Gen (Type -> Ident) -> Gen Type -> Gen Ident
forall a b. Gen (a -> b) -> Gen a -> Gen b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Gen Type
forall a. Arbitrary a => Gen a
arbitrary

instance Arbitrary Rank where
  arbitrary :: Gen Rank
arbitrary = Int -> Rank
Rank (Int -> Rank) -> Gen Int -> Gen Rank
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> Gen Int
forall a. HasCallStack => [a] -> Gen a
elements [Int
1 .. Int
9]

instance Arbitrary Shape where
  arbitrary :: Gen (ShapeBase SubExp)
arbitrary = [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> ShapeBase SubExp)
-> ([Int32] -> [SubExp]) -> [Int32] -> ShapeBase SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int32 -> SubExp) -> [Int32] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map Int32 -> SubExp
intconst ([Int32] -> ShapeBase SubExp)
-> Gen [Int32] -> Gen (ShapeBase SubExp)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen Int32 -> Gen [Int32]
forall a. Gen a -> Gen [a]
listOf1 ([Int32] -> Gen Int32
forall a. HasCallStack => [a] -> Gen a
elements [Int32
1 .. Int32
9])
    where
      intconst :: Int32 -> SubExp
intconst = PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> (Int32 -> PrimValue) -> Int32 -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntValue -> PrimValue
IntValue (IntValue -> PrimValue)
-> (Int32 -> IntValue) -> Int32 -> PrimValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int32 -> IntValue
Int32Value

subShapeTests :: [TestTree]
subShapeTests :: [TestTree]
subShapeTests =
  [ [Ext SubExp] -> ExtShape
shape [Int -> Ext SubExp
free Int
1, Int -> Ext SubExp
free Int
2] ExtShape -> ExtShape -> TestTree
`isSubShapeOf` [Ext SubExp] -> ExtShape
shape [Int -> Ext SubExp
free Int
1, Int -> Ext SubExp
free Int
2],
    [Ext SubExp] -> ExtShape
shape [Int -> Ext SubExp
free Int
1, Int -> Ext SubExp
free Int
3] ExtShape -> ExtShape -> TestTree
`isNotSubShapeOf` [Ext SubExp] -> ExtShape
shape [Int -> Ext SubExp
free Int
1, Int -> Ext SubExp
free Int
2],
    [Ext SubExp] -> ExtShape
shape [Int -> Ext SubExp
free Int
1] ExtShape -> ExtShape -> TestTree
`isNotSubShapeOf` [Ext SubExp] -> ExtShape
shape [Int -> Ext SubExp
free Int
1, Int -> Ext SubExp
free Int
2],
    [Ext SubExp] -> ExtShape
shape [Int -> Ext SubExp
free Int
1, Int -> Ext SubExp
free Int
2] ExtShape -> ExtShape -> TestTree
`isSubShapeOf` [Ext SubExp] -> ExtShape
shape [Int -> Ext SubExp
free Int
1, Int -> Ext SubExp
forall a. Int -> Ext a
Ext Int
3],
    [Ext SubExp] -> ExtShape
shape [Int -> Ext SubExp
forall a. Int -> Ext a
Ext Int
1, Int -> Ext SubExp
forall a. Int -> Ext a
Ext Int
2] ExtShape -> ExtShape -> TestTree
`isNotSubShapeOf` [Ext SubExp] -> ExtShape
shape [Int -> Ext SubExp
forall a. Int -> Ext a
Ext Int
1, Int -> Ext SubExp
forall a. Int -> Ext a
Ext Int
1],
    [Ext SubExp] -> ExtShape
shape [Int -> Ext SubExp
forall a. Int -> Ext a
Ext Int
1, Int -> Ext SubExp
forall a. Int -> Ext a
Ext Int
1] ExtShape -> ExtShape -> TestTree
`isSubShapeOf` [Ext SubExp] -> ExtShape
shape [Int -> Ext SubExp
forall a. Int -> Ext a
Ext Int
1, Int -> Ext SubExp
forall a. Int -> Ext a
Ext Int
2]
  ]
  where
    shape :: [ExtSize] -> ExtShape
    shape :: [Ext SubExp] -> ExtShape
shape = [Ext SubExp] -> ExtShape
forall d. [d] -> ShapeBase d
Shape

    free :: Int -> ExtSize
    free :: Int -> Ext SubExp
free = SubExp -> Ext SubExp
forall a. a -> Ext a
Free (SubExp -> Ext SubExp) -> (Int -> SubExp) -> Int -> Ext SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimValue -> SubExp
Constant (PrimValue -> SubExp) -> (Int -> PrimValue) -> Int -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntValue -> PrimValue
IntValue (IntValue -> PrimValue) -> (Int -> IntValue) -> Int -> PrimValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int32 -> IntValue
Int32Value (Int32 -> IntValue) -> (Int -> Int32) -> Int -> IntValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Int32
forall a b. (Integral a, Num b) => a -> b
fromIntegral

    isSubShapeOf :: ExtShape -> ExtShape -> TestTree
isSubShapeOf ExtShape
shape1 ExtShape
shape2 =
      ExtShape -> ExtShape -> Bool -> TestTree
subShapeTest ExtShape
shape1 ExtShape
shape2 Bool
True
    isNotSubShapeOf :: ExtShape -> ExtShape -> TestTree
isNotSubShapeOf ExtShape
shape1 ExtShape
shape2 =
      ExtShape -> ExtShape -> Bool -> TestTree
subShapeTest ExtShape
shape1 ExtShape
shape2 Bool
False

    subShapeTest :: ExtShape -> ExtShape -> Bool -> TestTree
    subShapeTest :: ExtShape -> ExtShape -> Bool -> TestTree
subShapeTest ExtShape
shape1 ExtShape
shape2 Bool
expected =
      [Char] -> Assertion -> TestTree
testCase
        ( [Char]
"subshapeOf "
            [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ ExtShape -> [Char]
forall a. Pretty a => a -> [Char]
prettyString ExtShape
shape1
            [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" "
            [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ ExtShape -> [Char]
forall a. Pretty a => a -> [Char]
prettyString ExtShape
shape2
            [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ [Char]
" == "
            [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Bool -> [Char]
forall a. Show a => a -> [Char]
show Bool
expected
        )
        (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ ExtShape
shape1 ExtShape -> ExtShape -> Bool
forall a. ArrayShape a => a -> a -> Bool
`subShapeOf` ExtShape
shape2 Bool -> Bool -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= Bool
expected

provenanceTests :: [TestTree]
provenanceTests :: [TestTree]
provenanceTests =
  [ [Char] -> [TestTree] -> TestTree
testGroup
      [Char]
"<>"
      [ [Char] -> Assertion -> TestTree
testCase [Char]
"simple" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
          ([Loc] -> Loc -> Provenance
Provenance [] Loc
line1 Provenance -> Provenance -> Provenance
forall a. Semigroup a => a -> a -> a
<> [Loc] -> Loc -> Provenance
Provenance [] Loc
line0) Provenance -> Provenance -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [Loc] -> Loc -> Provenance
Provenance [] Loc
lines01,
        [Char] -> Assertion -> TestTree
testCase [Char]
"mempty left" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
          ([Loc] -> Loc -> Provenance
Provenance [] Loc
forall a. Monoid a => a
mempty Provenance -> Provenance -> Provenance
forall a. Semigroup a => a -> a -> a
<> [Loc] -> Loc -> Provenance
Provenance [] Loc
line0) Provenance -> Provenance -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [Loc] -> Loc -> Provenance
Provenance [] Loc
line0,
        [Char] -> Assertion -> TestTree
testCase [Char]
"mempty right" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
          ([Loc] -> Loc -> Provenance
Provenance [] Loc
line1 Provenance -> Provenance -> Provenance
forall a. Semigroup a => a -> a -> a
<> [Loc] -> Loc -> Provenance
Provenance [] Loc
forall a. Monoid a => a
mempty) Provenance -> Provenance -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [Loc] -> Loc -> Provenance
Provenance [] Loc
line1
      ],
    [Char] -> [TestTree] -> TestTree
testGroup
      [Char]
"stackProvenance"
      [ [Char] -> Assertion -> TestTree
testCase [Char]
"encloses" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
          ([Loc] -> Loc -> Provenance
Provenance [] Loc
line0 Provenance -> Provenance -> Provenance
`stackProvenance` [Loc] -> Loc -> Provenance
Provenance [] Loc
line0_sub)
            Provenance -> Provenance -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [Loc] -> Loc -> Provenance
Provenance [] Loc
line0_sub
      ]
  ]
  where
    line0 :: Loc
line0 = Pos -> Pos -> Loc
Loc ([Char] -> Int -> Int -> Int -> Pos
Pos [Char]
"" Int
0 Int
1 Int
0) ([Char] -> Int -> Int -> Int -> Pos
Pos [Char]
"" Int
0 Int
10 Int
10)
    line0_sub :: Loc
line0_sub = Pos -> Pos -> Loc
Loc ([Char] -> Int -> Int -> Int -> Pos
Pos [Char]
"" Int
0 Int
2 Int
1) ([Char] -> Int -> Int -> Int -> Pos
Pos [Char]
"" Int
0 Int
9 Int
9)
    line1 :: Loc
line1 = Pos -> Pos -> Loc
Loc ([Char] -> Int -> Int -> Int -> Pos
Pos [Char]
"" Int
1 Int
1 Int
0) ([Char] -> Int -> Int -> Int -> Pos
Pos [Char]
"" Int
1 Int
10 Int
20)
    lines01 :: Loc
lines01 = Pos -> Pos -> Loc
Loc ([Char] -> Int -> Int -> Int -> Pos
Pos [Char]
"" Int
0 Int
1 Int
0) ([Char] -> Int -> Int -> Int -> Pos
Pos [Char]
"" Int
1 Int
10 Int
20)

tests :: TestTree
tests :: TestTree
tests =
  [Char] -> [TestTree] -> TestTree
testGroup
    [Char]
"Internal CoreTests"
    [ [Char] -> [TestTree] -> TestTree
testGroup [Char]
"subShape" [TestTree]
subShapeTests,
      [Char] -> [TestTree] -> TestTree
testGroup [Char]
"Provenance" [TestTree]
provenanceTests
    ]