{-# LANGUAGE FlexibleInstances #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}
{-# OPTIONS_GHC -fno-warn-unused-imports #-}
{-# OPTIONS_GHC -fno-warn-unused-matches #-}
{-# OPTIONS_GHC -fno-warn-unused-top-binds #-}

module Futhark.Analysis.AlgSimplifyTests
  ( tests,
  )
where

import Control.Monad
import Data.Function ((&))
import Data.List (subsequences)
import Data.Map qualified as M
import Data.Maybe (fromMaybe, mapMaybe)
import Futhark.Analysis.AlgSimplify hiding (add, sub)
import Futhark.Analysis.PrimExp
import Futhark.IR.Syntax.Core
import Test.Tasty
import Test.Tasty.HUnit
import Test.Tasty.QuickCheck

tests :: TestTree
tests :: TestTree
tests =
  TestName -> [TestTree] -> TestTree
testGroup
    TestName
"AlgSimplifyTests"
    [ TestName -> (TestableExp -> Bool) -> TestTree
forall a. Testable a => TestName -> a -> TestTree
testProperty TestName
"simplify is idempotent" ((TestableExp -> Bool) -> TestTree)
-> (TestableExp -> Bool) -> TestTree
forall a b. (a -> b) -> a -> b
$ \(TestableExp Exp
e) -> Exp -> Exp
simplify Exp
e Exp -> Exp -> Bool
forall a. Eq a => a -> a -> Bool
== Exp -> Exp
simplify (Exp -> Exp
simplify Exp
e),
      TestName -> (TestableExp -> Bool) -> TestTree
forall a. Testable a => TestName -> a -> TestTree
testProperty TestName
"simplify doesn't change exp evalutation result" ((TestableExp -> Bool) -> TestTree)
-> (TestableExp -> Bool) -> TestTree
forall a b. (a -> b) -> a -> b
$
        \(TestableExp Exp
e) ->
          (VName -> Maybe PrimValue) -> Exp -> Maybe PrimValue
forall v (m :: * -> *).
(Pretty v, MonadFail m) =>
(v -> m PrimValue) -> PrimExp v -> m PrimValue
evalPrimExp (\VName
_ -> Maybe PrimValue
forall a. Maybe a
Nothing) Exp
e
            Maybe PrimValue -> Maybe PrimValue -> Bool
forall a. Eq a => a -> a -> Bool
== (VName -> Maybe PrimValue) -> Exp -> Maybe PrimValue
forall v (m :: * -> *).
(Pretty v, MonadFail m) =>
(v -> m PrimValue) -> PrimExp v -> m PrimValue
evalPrimExp (\VName
_ -> Maybe PrimValue
forall a. Maybe a
Nothing) (Exp -> Exp
simplify Exp
e)
    ]

eval :: TestableExp -> Int64
eval :: TestableExp -> Int64
eval (TestableExp Exp
e) = Exp -> Int64
evalExp Exp
e

evalExp :: PrimExp VName -> Int64
evalExp :: Exp -> Int64
evalExp (ValueExp (IntValue (Int64Value Int64
i))) = Int64
i
evalExp (BinOpExp (Add IntType
Int64 Overflow
OverflowUndef) Exp
e1 Exp
e2) = Exp -> Int64
evalExp Exp
e1 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
+ Exp -> Int64
evalExp Exp
e2
evalExp (BinOpExp (Sub IntType
Int64 Overflow
OverflowUndef) Exp
e1 Exp
e2) = Exp -> Int64
evalExp Exp
e1 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
- Exp -> Int64
evalExp Exp
e2
evalExp (BinOpExp (Mul IntType
Int64 Overflow
OverflowUndef) Exp
e1 Exp
e2) = Exp -> Int64
evalExp Exp
e1 Int64 -> Int64 -> Int64
forall a. Num a => a -> a -> a
* Exp -> Int64
evalExp Exp
e2
evalExp Exp
_ = Int64
forall a. HasCallStack => a
undefined

add :: PrimExp VName -> PrimExp VName -> PrimExp VName
add :: Exp -> Exp -> Exp
add = BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Add IntType
Int64 Overflow
OverflowUndef)

sub :: PrimExp VName -> PrimExp VName -> PrimExp VName
sub :: Exp -> Exp -> Exp
sub = BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef)

mul :: PrimExp VName -> PrimExp VName -> PrimExp VName
mul :: Exp -> Exp -> Exp
mul = BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Mul IntType
Int64 Overflow
OverflowUndef)

neg :: PrimExp VName -> PrimExp VName
neg :: Exp -> Exp
neg = BinOp -> Exp -> Exp -> Exp
forall v. BinOp -> PrimExp v -> PrimExp v -> PrimExp v
BinOpExp (IntType -> Overflow -> BinOp
Sub IntType
Int64 Overflow
OverflowUndef) (Int64 -> Exp
val Int64
0)

l :: Int -> PrimExp VName
l :: Int -> Exp
l Int
i = VName -> PrimType -> Exp
forall v. v -> PrimType -> PrimExp v
LeafExp (Name -> Int -> VName
VName (TestName -> Name
nameFromString (TestName -> Name) -> TestName -> Name
forall a b. (a -> b) -> a -> b
$ Int -> TestName
forall a. Show a => a -> TestName
show Int
i) Int
i) (IntType -> PrimType
IntType IntType
Int64)

val :: Int64 -> PrimExp VName
val :: Int64 -> Exp
val = PrimValue -> Exp
forall v. PrimValue -> PrimExp v
ValueExp (PrimValue -> Exp) -> (Int64 -> PrimValue) -> Int64 -> Exp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntValue -> PrimValue
IntValue (IntValue -> PrimValue)
-> (Int64 -> IntValue) -> Int64 -> PrimValue
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int64 -> IntValue
Int64Value

generateExp :: Gen (PrimExp VName)
generateExp :: Gen Exp
generateExp = do
  Int
n <- Gen Int
getSize
  if Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1
    then Int64 -> Exp
val (Int64 -> Exp) -> Gen Int64 -> Gen Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen Int64
forall a. Arbitrary a => Gen a
arbitrary
    else
      [Gen Exp] -> Gen Exp
forall a. HasCallStack => [Gen a] -> Gen a
oneof
        [ (Int -> Int) -> Gen Exp -> Gen Exp
forall a. (Int -> Int) -> Gen a -> Gen a
scale (Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) (Gen Exp -> Gen Exp) -> Gen Exp -> Gen Exp
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp -> Exp) -> Gen Exp
generateBinOp Exp -> Exp -> Exp
add,
          (Int -> Int) -> Gen Exp -> Gen Exp
forall a. (Int -> Int) -> Gen a -> Gen a
scale (Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) (Gen Exp -> Gen Exp) -> Gen Exp -> Gen Exp
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp -> Exp) -> Gen Exp
generateBinOp Exp -> Exp -> Exp
sub,
          (Int -> Int) -> Gen Exp -> Gen Exp
forall a. (Int -> Int) -> Gen a -> Gen a
scale (Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) (Gen Exp -> Gen Exp) -> Gen Exp -> Gen Exp
forall a b. (a -> b) -> a -> b
$ (Exp -> Exp -> Exp) -> Gen Exp
generateBinOp Exp -> Exp -> Exp
mul,
          (Int -> Int) -> Gen Exp -> Gen Exp
forall a. (Int -> Int) -> Gen a -> Gen a
scale (Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) Gen Exp
generateNeg,
          Int64 -> Exp
val (Int64 -> Exp) -> Gen Int64 -> Gen Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen Int64
forall a. Arbitrary a => Gen a
arbitrary
        ]

generateBinOp :: (PrimExp VName -> PrimExp VName -> PrimExp VName) -> Gen (PrimExp VName)
generateBinOp :: (Exp -> Exp -> Exp) -> Gen Exp
generateBinOp Exp -> Exp -> Exp
op = do
  Exp
t1 <- Gen Exp
generateExp
  Exp -> Exp -> Exp
op Exp
t1 (Exp -> Exp) -> Gen Exp -> Gen Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen Exp
generateExp

generateNeg :: Gen (PrimExp VName)
generateNeg :: Gen Exp
generateNeg =
  do Exp -> Exp
neg (Exp -> Exp) -> Gen Exp -> Gen Exp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen Exp
generateExp

newtype TestableExp = TestableExp (PrimExp VName)
  deriving (Int -> TestableExp -> ShowS
[TestableExp] -> ShowS
TestableExp -> TestName
(Int -> TestableExp -> ShowS)
-> (TestableExp -> TestName)
-> ([TestableExp] -> ShowS)
-> Show TestableExp
forall a.
(Int -> a -> ShowS) -> (a -> TestName) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TestableExp -> ShowS
showsPrec :: Int -> TestableExp -> ShowS
$cshow :: TestableExp -> TestName
show :: TestableExp -> TestName
$cshowList :: [TestableExp] -> ShowS
showList :: [TestableExp] -> ShowS
Show)

instance Arbitrary TestableExp where
  arbitrary :: Gen TestableExp
arbitrary = Exp -> TestableExp
TestableExp (Exp -> TestableExp) -> Gen Exp -> Gen TestableExp
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen Exp
generateExp