module Futhark.AD.DerivativesTests (tests) where
import Data.Map qualified as M
import Data.Text qualified as T
import Futhark.AD.Derivatives
import Futhark.Analysis.PrimExp
import Futhark.IR.Syntax.Core (nameFromText)
import Futhark.Util.Pretty (prettyString)
import Test.Tasty
import Test.Tasty.HUnit
tests :: TestTree
tests :: TestTree
tests =
TestName -> [TestTree] -> TestTree
testGroup
TestName
"Futhark.AD.DerivativesTests"
[ TestName -> [TestTree] -> TestTree
testGroup TestName
"Primitive functions" ([TestTree] -> TestTree) -> [TestTree] -> TestTree
forall a b. (a -> b) -> a -> b
$
((Text, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))
-> TestTree)
-> [(Text, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
-> [TestTree]
forall a b. (a -> b) -> [a] -> [b]
map (Text, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))
-> TestTree
forall {c}. (Text, ([PrimType], PrimType, c)) -> TestTree
primFunTest ([(Text, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
-> [TestTree])
-> [(Text, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
-> [TestTree]
forall a b. (a -> b) -> a -> b
$
((Text, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))
-> Bool)
-> [(Text, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
-> [(Text, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool)
-> ((Text, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))
-> Bool)
-> (Text, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text -> [Text] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Text]
missing_primfuns) (Text -> Bool)
-> ((Text, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))
-> Text)
-> (Text, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))
-> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Text, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))
-> Text
forall a b. (a, b) -> a
fst) ([(Text, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
-> [(Text,
([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))])
-> [(Text, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
-> [(Text, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
forall a b. (a -> b) -> a -> b
$
Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
-> [(Text, ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue))]
forall k a. Map k a -> [(k, a)]
M.toList Map Text ([PrimType], PrimType, [PrimValue] -> Maybe PrimValue)
primFuns,
TestName -> [TestTree] -> TestTree
testGroup TestName
"BinOps" ([TestTree] -> TestTree) -> [TestTree] -> TestTree
forall a b. (a -> b) -> a -> b
$ (BinOp -> TestTree) -> [BinOp] -> [TestTree]
forall a b. (a -> b) -> [a] -> [b]
map BinOp -> TestTree
binOpTest [BinOp]
allBinOps,
TestName -> [TestTree] -> TestTree
testGroup TestName
"UnOps" ([TestTree] -> TestTree) -> [TestTree] -> TestTree
forall a b. (a -> b) -> a -> b
$ (UnOp -> TestTree) -> [UnOp] -> [TestTree]
forall a b. (a -> b) -> [a] -> [b]
map UnOp -> TestTree
unOpTest [UnOp]
allUnOps
]
where
blank :: PrimType -> PrimExp v
blank = PrimValue -> PrimExp v
forall v. PrimValue -> PrimExp v
ValueExp (PrimValue -> PrimExp v)
-> (PrimType -> PrimValue) -> PrimType -> PrimExp v
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PrimType -> PrimValue
blankPrimValue
primFunTest :: (Text, ([PrimType], PrimType, c)) -> TestTree
primFunTest (Text
f, ([PrimType]
ts, PrimType
ret, c
_)) =
TestName -> Assertion -> TestTree
testCase (Text -> TestName
T.unpack Text
f) (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
case Name -> [PrimExp VName] -> Maybe [PrimExp VName]
pdBuiltin (Text -> Name
nameFromText Text
f) ((PrimType -> PrimExp VName) -> [PrimType] -> [PrimExp VName]
forall a b. (a -> b) -> [a] -> [b]
map PrimType -> PrimExp VName
forall {v}. PrimType -> PrimExp v
blank [PrimType]
ts) of
Maybe [PrimExp VName]
Nothing -> TestName -> Assertion
forall a. HasCallStack => TestName -> IO a
assertFailure TestName
"pdBuiltin gives Nothing"
Just [PrimExp VName]
v -> (PrimExp VName -> PrimType) -> [PrimExp VName] -> [PrimType]
forall a b. (a -> b) -> [a] -> [b]
map PrimExp VName -> PrimType
forall v. PrimExp v -> PrimType
primExpType [PrimExp VName]
v [PrimType] -> [PrimType] -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= Int -> PrimType -> [PrimType]
forall a. Int -> a -> [a]
replicate ([PrimType] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [PrimType]
ts) PrimType
ret
missing_primfuns :: [Text]
missing_primfuns =
[ Text
"gamma16",
Text
"gamma32",
Text
"gamma64",
Text
"lgamma16",
Text
"lgamma32",
Text
"lgamma64"
]
binOpTest :: BinOp -> TestTree
binOpTest BinOp
bop =
TestName -> Assertion -> TestTree
testCase (BinOp -> TestName
forall a. Pretty a => a -> TestName
prettyString BinOp
bop) (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
let t :: PrimType
t = BinOp -> PrimType
binOpType BinOp
bop
(PrimExp VName
dx, PrimExp VName
dy) = BinOp
-> PrimExp VName -> PrimExp VName -> (PrimExp VName, PrimExp VName)
pdBinOp BinOp
bop (PrimType -> PrimExp VName
forall {v}. PrimType -> PrimExp v
blank PrimType
t) (PrimType -> PrimExp VName
forall {v}. PrimType -> PrimExp v
blank PrimType
t)
in (PrimExp VName -> PrimType
forall v. PrimExp v -> PrimType
primExpType PrimExp VName
dx, PrimExp VName -> PrimType
forall v. PrimExp v -> PrimType
primExpType PrimExp VName
dy) (PrimType, PrimType) -> (PrimType, PrimType) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= (PrimType
t, PrimType
t)
unOpTest :: UnOp -> TestTree
unOpTest UnOp
bop =
TestName -> Assertion -> TestTree
testCase (UnOp -> TestName
forall a. Pretty a => a -> TestName
prettyString UnOp
bop) (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
let t :: PrimType
t = UnOp -> PrimType
unOpType UnOp
bop
in PrimExp VName -> PrimType
forall v. PrimExp v -> PrimType
primExpType (UnOp -> PrimExp VName -> PrimExp VName
pdUnOp UnOp
bop (PrimExp VName -> PrimExp VName) -> PrimExp VName -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ PrimType -> PrimExp VName
forall {v}. PrimType -> PrimExp v
blank PrimType
t) PrimType -> PrimType -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= PrimType
t