module Futhark.IR.Mem.IntervalTests
  ( tests,
  )
where

import Futhark.Analysis.AlgSimplify
import Futhark.Analysis.PrimExp.Convert
import Futhark.IR.Mem.Interval
import Futhark.IR.Syntax
import Futhark.IR.Syntax.Core ()
import Test.Tasty
import Test.Tasty.HUnit

-- Actual tests.
tests :: TestTree
tests :: TestTree
tests =
  String -> [TestTree] -> TestTree
testGroup
    String
"IntervalTests"
    [TestTree]
testDistributeOffset

name :: String -> Int -> VName
name :: String -> Int -> VName
name String
s = Name -> Int -> VName
VName (String -> Name
nameFromString String
s)

testDistributeOffset :: [TestTree]
testDistributeOffset :: [TestTree]
testDistributeOffset =
  [ String -> Assertion -> TestTree
testCase String
"Stride is (nb-b)" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
      let n :: TPrimExp t VName
n = PrimExp VName -> TPrimExp t VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp t VName)
-> PrimExp VName -> TPrimExp t VName
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp (String -> Int -> VName
name String
"n" Int
1) (PrimType -> PrimExp VName) -> PrimType -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64
          b :: TPrimExp t VName
b = PrimExp VName -> TPrimExp t VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp t VName)
-> PrimExp VName -> TPrimExp t VName
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp (String -> Int -> VName
name String
"b" Int
2) (PrimType -> PrimExp VName) -> PrimType -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64
      [Interval]
res <-
        SofP -> [Interval] -> IO [Interval]
forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset
          [Bool -> [PrimExp VName] -> Prod
Prod Bool
False [TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
n TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
b TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
b :: TPrimExp Int64 VName)]]
          [ TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
1 (TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
n TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
b TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
b),
            TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
b TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
b,
            TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
b TPrimExp Int64 VName
1
          ]
      [Interval]
res [Interval] -> [Interval] -> Bool
forall a. Eq a => a -> a -> Bool
== [TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
1 TPrimExp Int64 VName
1 (TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
n TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
b TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
- TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
b), TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
b TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
b, TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
b TPrimExp Int64 VName
1] Bool -> String -> Assertion
forall t.
(AssertionPredicable t, HasCallStack) =>
t -> String -> Assertion
@? String
"Failed",
    String -> Assertion -> TestTree
testCase String
"Stride is 1024r" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
      let r :: TPrimExp t VName
r = PrimExp VName -> TPrimExp t VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp t VName)
-> PrimExp VName -> TPrimExp t VName
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp (String -> Int -> VName
name String
"r" Int
1) (PrimType -> PrimExp VName) -> PrimType -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64
      [Interval]
res <-
        SofP -> [Interval] -> IO [Interval]
forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset
          [Bool -> [PrimExp VName] -> Prod
Prod Bool
False [TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName
1024 :: TPrimExp Int64 VName), TPrimExp Any VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Any VName
forall {k} {t :: k}. TPrimExp t VName
r]]
          [ TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
1 (TPrimExp Int64 VName
1024 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
r),
            TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
32 TPrimExp Int64 VName
32,
            TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
32 TPrimExp Int64 VName
1
          ]
      [Interval]
res [Interval] -> [Interval] -> Bool
forall a. Eq a => a -> a -> Bool
== [TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
1 TPrimExp Int64 VName
1 (TPrimExp Int64 VName
1024 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
r), TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
32 TPrimExp Int64 VName
32, TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
32 TPrimExp Int64 VName
1] Bool -> String -> Assertion
forall t.
(AssertionPredicable t, HasCallStack) =>
t -> String -> Assertion
@? String
"Failed. Got " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> [Interval] -> String
forall a. Show a => a -> String
show [Interval]
res,
    String -> Assertion -> TestTree
testCase String
"Stride is 32, offsets are multples of 32" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
      let n :: TPrimExp t VName
n = PrimExp VName -> TPrimExp t VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp t VName)
-> PrimExp VName -> TPrimExp t VName
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp (String -> Int -> VName
name String
"n" Int
0) (PrimType -> PrimExp VName) -> PrimType -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64
      let g1 :: TPrimExp t VName
g1 = PrimExp VName -> TPrimExp t VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp t VName)
-> PrimExp VName -> TPrimExp t VName
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp (String -> Int -> VName
name String
"g" Int
1) (PrimType -> PrimExp VName) -> PrimType -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64
      let g2 :: TPrimExp t VName
g2 = PrimExp VName -> TPrimExp t VName
forall {k} (t :: k) v. PrimExp v -> TPrimExp t v
TPrimExp (PrimExp VName -> TPrimExp t VName)
-> PrimExp VName -> TPrimExp t VName
forall a b. (a -> b) -> a -> b
$ VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp (String -> Int -> VName
name String
"g" Int
2) (PrimType -> PrimExp VName) -> PrimType -> PrimExp VName
forall a b. (a -> b) -> a -> b
$ IntType -> PrimType
IntType IntType
Int64
      [Interval]
res <-
        SofP -> [Interval] -> IO [Interval]
forall (m :: * -> *).
MonadFail m =>
SofP -> [Interval] -> m [Interval]
distributeOffset
          [ Bool -> [PrimExp VName] -> Prod
Prod Bool
False [TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName
1024 :: TPrimExp Int64 VName)],
            Bool -> [PrimExp VName] -> Prod
Prod Bool
False [TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName
1024 :: TPrimExp Int64 VName), TPrimExp Any VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Any VName
forall {k} {t :: k}. TPrimExp t VName
g1],
            Bool -> [PrimExp VName] -> Prod
Prod Bool
False [TPrimExp Int64 VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped (TPrimExp Int64 VName
32 :: TPrimExp Int64 VName), TPrimExp Any VName -> PrimExp VName
forall {k} (t :: k) v. TPrimExp t v -> PrimExp v
untyped TPrimExp Any VName
forall {k} {t :: k}. TPrimExp t VName
g2]
          ]
          [ TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
1 (TPrimExp Int64 VName
1024 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
n),
            TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
1 TPrimExp Int64 VName
32,
            TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
32 TPrimExp Int64 VName
1
          ]
      [Interval]
res
        [Interval] -> [Interval] -> Bool
forall a. Eq a => a -> a -> Bool
== [ TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
1 (TPrimExp Int64 VName
1024 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
n),
             TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval (TPrimExp Int64 VName
32 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
32 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
* TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
g1 TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName
forall a. Num a => a -> a -> a
+ TPrimExp Int64 VName
forall {k} {t :: k}. TPrimExp t VName
g2) TPrimExp Int64 VName
1 TPrimExp Int64 VName
32,
             TPrimExp Int64 VName
-> TPrimExp Int64 VName -> TPrimExp Int64 VName -> Interval
Interval TPrimExp Int64 VName
0 TPrimExp Int64 VName
32 TPrimExp Int64 VName
1
           ]
        Bool -> String -> Assertion
forall t.
(AssertionPredicable t, HasCallStack) =>
t -> String -> Assertion
@? String
"Failed. Got "
          String -> String -> String
forall a. Semigroup a => a -> a -> a
<> [Interval] -> String
forall a. Show a => a -> String
show [Interval]
res
  ]