module Futhark.Optimise.ArrayLayout.LayoutTests (tests) where

import Data.Map.Strict qualified as M
import Futhark.Analysis.AccessPattern
import Futhark.Analysis.PrimExp
import Futhark.FreshNames
import Futhark.IR.GPU (GPU)
import Futhark.IR.GPUTests ()
import Futhark.Optimise.ArrayLayout.Layout
import Language.Futhark.Core
import Test.Tasty
import Test.Tasty.HUnit

tests :: TestTree
tests :: TestTree
tests =
  String -> [TestTree] -> TestTree
testGroup
    String
"Layout"
    [TestTree
commonPermutationEliminatorsTests]

commonPermutationEliminatorsTests :: TestTree
commonPermutationEliminatorsTests :: TestTree
commonPermutationEliminatorsTests =
  String -> [TestTree] -> TestTree
testGroup
    String
"commonPermutationEliminators"
    [TestTree
permutationTests, TestTree
nestTests, TestTree
dimAccessTests, TestTree
constIndexElimTests]

permutationTests :: TestTree
permutationTests :: TestTree
permutationTests =
  String -> [TestTree] -> TestTree
testGroup String
"Permutations" ([TestTree] -> TestTree) -> [TestTree] -> TestTree
forall a b. (a -> b) -> a -> b
$
    do
      -- This isn't the way to test this, in reality we should provide realistic
      -- access patterns that might result in the given permutations.
      -- Luckily we only use the original access for one check atm.
      [ String -> Assertion -> TestTree
testCase ([String] -> String
unwords [[Int] -> String
forall a. Show a => a -> String
show [Int]
perm, String
"->", Bool -> String
forall a. Show a => a -> String
show Bool
res]) (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
          [Int] -> [BodyType] -> Bool
commonPermutationEliminators [Int]
perm [] Bool -> Bool -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= Bool
res
        | ([Int]
perm, Bool
res) <-
            [ ([Int
0], Bool
True),
              ([Int
1, Int
0], Bool
False),
              ([Int
0, Int
1], Bool
True),
              ([Int
0, Int
0], Bool
True),
              ([Int
1, Int
1], Bool
True),
              ([Int
1, Int
2, Int
0], Bool
False),
              ([Int
2, Int
0, Int
1], Bool
False),
              ([Int
0, Int
1, Int
2], Bool
True),
              ([Int
1, Int
0, Int
2], Bool
True),
              ([Int
2, Int
1, Int
0], Bool
True),
              ([Int
2, Int
2, Int
0], Bool
True),
              ([Int
2, Int
1, Int
1], Bool
True),
              ([Int
1, Int
0, Int
1], Bool
True),
              ([Int
0, Int
0, Int
0], Bool
True),
              ([Int
0, Int
1, Int
2, Int
3, Int
4], Bool
True),
              ([Int
1, Int
0, Int
2, Int
3, Int
4], Bool
True),
              ([Int
2, Int
3, Int
0, Int
1, Int
4], Bool
True),
              ([Int
3, Int
4, Int
2, Int
0, Int
1], Bool
True),
              ([Int
2, Int
3, Int
4, Int
0, Int
1], Bool
False),
              ([Int
1, Int
2, Int
3, Int
4, Int
0], Bool
False),
              ([Int
3, Int
4, Int
0, Int
1, Int
2], Bool
False)
            ]
        ]

nestTests :: TestTree
nestTests :: TestTree
nestTests = String -> [TestTree] -> TestTree
testGroup String
"Nests" ([TestTree] -> TestTree) -> [TestTree] -> TestTree
forall a b. (a -> b) -> a -> b
$
  do
    let names :: [VName]
names = Int -> [VName]
generateNames Int
2
    [ String -> Assertion -> TestTree
testCase ([String] -> String
unwords [String
args, String
"->", Bool -> String
forall a. Show a => a -> String
show Bool
res]) (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
        [Int] -> [BodyType] -> Bool
commonPermutationEliminators [Int
1, Int
0] [BodyType]
nest Bool -> Bool -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= Bool
res
      | (String
args, [BodyType]
nest, Bool
res) <-
          [ (String
"[]", [], Bool
False),
            (String
"[CondBodyName]", [VName -> BodyType
CondBodyName] [VName -> BodyType] -> [VName] -> [BodyType]
forall a b. [a -> b] -> [a] -> [b]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName]
names, Bool
False),
            (String
"[SegOpName]", [SegOpName -> BodyType
SegOpName (SegOpName -> BodyType)
-> (VName -> SegOpName) -> VName -> BodyType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SegOpName
SegmentedMap] [VName -> BodyType] -> [VName] -> [BodyType]
forall a b. [a -> b] -> [a] -> [b]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName]
names, Bool
True),
            (String
"[LoopBodyName]", [VName -> BodyType
LoopBodyName] [VName -> BodyType] -> [VName] -> [BodyType]
forall a b. [a -> b] -> [a] -> [b]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName]
names, Bool
False),
            (String
"[SegOpName, CondBodyName]", [SegOpName -> BodyType
SegOpName (SegOpName -> BodyType)
-> (VName -> SegOpName) -> VName -> BodyType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SegOpName
SegmentedMap, VName -> BodyType
CondBodyName] [VName -> BodyType] -> [VName] -> [BodyType]
forall a b. [a -> b] -> [a] -> [b]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName]
names, Bool
True),
            (String
"[CondBodyName, LoopBodyName]", [VName -> BodyType
CondBodyName, VName -> BodyType
LoopBodyName] [VName -> BodyType] -> [VName] -> [BodyType]
forall a b. [a -> b] -> [a] -> [b]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName]
names, Bool
False)
          ]
      ]

dimAccessTests :: TestTree
dimAccessTests :: TestTree
dimAccessTests = String -> [TestTree] -> TestTree
testGroup String
"DimAccesses" [] -- TODO: Write tests for the part of commonPermutationEliminators that checks the complexity of the DimAccesses.

constIndexElimTests :: TestTree
constIndexElimTests :: TestTree
constIndexElimTests =
  String -> [TestTree] -> TestTree
testGroup
    String
"constIndexElimTests"
    [ String -> Assertion -> TestTree
testCase String
"gpu eliminates indexes with constant in any dim" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
        let primExpTable :: Map VName (Maybe (PrimExp VName))
primExpTable =
              [(VName, Maybe (PrimExp VName))]
-> Map VName (Maybe (PrimExp VName))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                [ (VName
"gtid_4", PrimExp VName -> Maybe (PrimExp VName)
forall a. a -> Maybe a
Just (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
"n_4" (IntType -> PrimType
IntType IntType
Int64))),
                  (VName
"i_5", PrimExp VName -> Maybe (PrimExp VName)
forall a. a -> Maybe a
Just (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
"n_4" (IntType -> PrimType
IntType IntType
Int64)))
                ]
        Map VName (Maybe (PrimExp VName)) -> IndexTable GPU -> LayoutTable
forall {k} (rep :: k).
Layout rep =>
Map VName (Maybe (PrimExp VName)) -> IndexTable rep -> LayoutTable
layoutTableFromIndexTable Map VName (Maybe (PrimExp VName))
primExpTable IndexTable GPU
accessTableGPU LayoutTable -> LayoutTable -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= LayoutTable
forall a. Monoid a => a
mempty,
      String -> Assertion -> TestTree
testCase String
"gpu ignores when not last" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
        let primExpTable :: Map VName (Maybe (PrimExp VName))
primExpTable =
              [(VName, Maybe (PrimExp VName))]
-> Map VName (Maybe (PrimExp VName))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                [ (VName
"gtid_4", PrimExp VName -> Maybe (PrimExp VName)
forall a. a -> Maybe a
Just (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
"gtid_4" (IntType -> PrimType
IntType IntType
Int64))),
                  (VName
"gtid_5", PrimExp VName -> Maybe (PrimExp VName)
forall a. a -> Maybe a
Just (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
"gtid_5" (IntType -> PrimType
IntType IntType
Int64))),
                  (VName
"i_6", PrimExp VName -> Maybe (PrimExp VName)
forall a. a -> Maybe a
Just (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
"i_6" (IntType -> PrimType
IntType IntType
Int64)))
                ]
        Map VName (Maybe (PrimExp VName)) -> IndexTable GPU -> LayoutTable
forall {k} (rep :: k).
Layout rep =>
Map VName (Maybe (PrimExp VName)) -> IndexTable rep -> LayoutTable
layoutTableFromIndexTable Map VName (Maybe (PrimExp VName))
primExpTable IndexTable GPU
accessTableGPUrev
          LayoutTable -> LayoutTable -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [(SegOpName, Map ArrayName (Map VName [Int]))] -> LayoutTable
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
            [ ( VName -> SegOpName
SegmentedMap VName
"mapres_1",
                [(ArrayName, Map VName [Int])] -> Map ArrayName (Map VName [Int])
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                  [ ( (VName
"a_2", [], [Int
0, Int
1, Int
2, Int
3]),
                      [(VName, [Int])] -> Map VName [Int]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
"A_3", [Int
2, Int
3, Int
0, Int
1])]
                    )
                  ]
              )
            ]
    ]
  where
    accessTableGPU :: IndexTable GPU
    accessTableGPU :: IndexTable GPU
accessTableGPU =
      [DimAccess GPU] -> IndexTable GPU
forall {k} (rep :: k). [DimAccess rep] -> IndexTable rep
singleAccess
        [ Int -> VName -> DimAccess GPU
forall {k} (rep :: k). Int -> VName -> DimAccess rep
singleParAccess Int
0 VName
"gtid_4",
          Map VName Dependency -> Maybe VName -> DimAccess GPU
forall {k} (rep :: k).
Map VName Dependency -> Maybe VName -> DimAccess rep
DimAccess Map VName Dependency
forall a. Monoid a => a
mempty Maybe VName
forall a. Maybe a
Nothing,
          Int -> VName -> DimAccess GPU
forall {k} (rep :: k). Int -> VName -> DimAccess rep
singleSeqAccess Int
1 VName
"i_5"
        ]

    accessTableGPUrev :: IndexTable GPU
    accessTableGPUrev :: IndexTable GPU
accessTableGPUrev =
      [DimAccess GPU] -> IndexTable GPU
forall {k} (rep :: k). [DimAccess rep] -> IndexTable rep
singleAccess
        [ Int -> VName -> DimAccess GPU
forall {k} (rep :: k). Int -> VName -> DimAccess rep
singleParAccess Int
1 VName
"gtid_4",
          Int -> VName -> DimAccess GPU
forall {k} (rep :: k). Int -> VName -> DimAccess rep
singleParAccess Int
2 VName
"gtid_5",
          Int -> VName -> DimAccess GPU
forall {k} (rep :: k). Int -> VName -> DimAccess rep
singleSeqAccess Int
0 VName
"i_5",
          Int -> VName -> DimAccess GPU
forall {k} (rep :: k). Int -> VName -> DimAccess rep
singleSeqAccess Int
2 VName
"gtid_4"
        ]

singleAccess :: [DimAccess rep] -> IndexTable rep
singleAccess :: forall {k} (rep :: k). [DimAccess rep] -> IndexTable rep
singleAccess [DimAccess rep]
dims =
  [(SegOpName, Map ArrayName (Map VName [DimAccess rep]))]
-> Map SegOpName (Map ArrayName (Map VName [DimAccess rep]))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
    [ ( SegOpName
sgOp,
        [(ArrayName, Map VName [DimAccess rep])]
-> Map ArrayName (Map VName [DimAccess rep])
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
          [ ( (VName
"A_2", [], [Int
0, Int
1, Int
2, Int
3]),
              [(VName, [DimAccess rep])] -> Map VName [DimAccess rep]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                [ ( VName
"a_3",
                    [DimAccess rep]
dims
                  )
                ]
            )
          ]
      )
    ]
  where
    sgOp :: SegOpName
sgOp = SegmentedMap {vnameFromSegOp :: VName
vnameFromSegOp = VName
"mapres_1"}

singleParAccess :: Int -> VName -> DimAccess rep
singleParAccess :: forall {k} (rep :: k). Int -> VName -> DimAccess rep
singleParAccess Int
level VName
name =
  Map VName Dependency -> Maybe VName -> DimAccess rep
forall {k} (rep :: k).
Map VName Dependency -> Maybe VName -> DimAccess rep
DimAccess
    (VName -> Dependency -> Map VName Dependency
forall k a. k -> a -> Map k a
M.singleton VName
name (Dependency -> Map VName Dependency)
-> Dependency -> Map VName Dependency
forall a b. (a -> b) -> a -> b
$ Int -> VarType -> Dependency
Dependency Int
level VarType
ThreadID)
    (VName -> Maybe VName
forall a. a -> Maybe a
Just VName
name)

singleSeqAccess :: Int -> VName -> DimAccess rep
singleSeqAccess :: forall {k} (rep :: k). Int -> VName -> DimAccess rep
singleSeqAccess Int
level VName
name =
  Map VName Dependency -> Maybe VName -> DimAccess rep
forall {k} (rep :: k).
Map VName Dependency -> Maybe VName -> DimAccess rep
DimAccess
    (VName -> Dependency -> Map VName Dependency
forall k a. k -> a -> Map k a
M.singleton VName
name (Dependency -> Map VName Dependency)
-> Dependency -> Map VName Dependency
forall a b. (a -> b) -> a -> b
$ Int -> VarType -> Dependency
Dependency Int
level VarType
LoopVar)
    (VName -> Maybe VName
forall a. a -> Maybe a
Just VName
name)

generateNames :: Int -> [VName]
generateNames :: Int -> [VName]
generateNames Int
count = do
  let (VName
name, VNameSource
source) = VNameSource -> VName -> (VName, VNameSource)
newName VNameSource
blankNameSource VName
"i_0"
  ([VName], VNameSource) -> [VName]
forall a b. (a, b) -> a
fst (([VName], VNameSource) -> [VName])
-> ([VName], VNameSource) -> [VName]
forall a b. (a -> b) -> a -> b
$ (([VName], VNameSource) -> Int -> ([VName], VNameSource))
-> ([VName], VNameSource) -> [Int] -> ([VName], VNameSource)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([VName], VNameSource) -> Int -> ([VName], VNameSource)
forall {p}. ([VName], VNameSource) -> p -> ([VName], VNameSource)
f ([VName
name], VNameSource
source) [Int
1 .. Int
count Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
  where
    f :: ([VName], VNameSource) -> p -> ([VName], VNameSource)
f ([VName]
names, VNameSource
source) p
_ = do
      let (VName
name, VNameSource
source') = VNameSource -> VName -> (VName, VNameSource)
newName VNameSource
source ([VName] -> VName
forall a. HasCallStack => [a] -> a
last [VName]
names)
      ([VName]
names [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
name], VNameSource
source')