module Futhark.Optimise.ArrayLayout.AnalyseTests (tests) where

import Data.Map.Strict qualified as M
import Futhark.Analysis.AccessPattern
import Futhark.IR.GPU
import Futhark.IR.GPUTests ()
import Futhark.IR.SyntaxTests ()
import Test.Tasty
import Test.Tasty.HUnit

tests :: TestTree
tests :: TestTree
tests = TestName -> [TestTree] -> TestTree
testGroup TestName
"Analyse" [TestTree
analyseStmTests]

analyseStmTests :: TestTree
analyseStmTests :: TestTree
analyseStmTests =
  TestName -> [TestTree] -> TestTree
testGroup
    TestName
"analyseStm"
    [TestTree
analyseIndexTests, TestTree
analyseDimAccessesTests]

analyseIndexTests :: TestTree
analyseIndexTests :: TestTree
analyseIndexTests =
  TestName -> [TestTree] -> TestTree
testGroup
    TestName
"analyseIndex"
    ([TestTree] -> TestTree) -> [TestTree] -> TestTree
forall a b. (a -> b) -> a -> b
$ do
      let arr_name :: VName
arr_name = VName
"xss_5144"
      -- ============================= TestCase0 =============================
      -- Most simple case where we want to manifest an array, hence, we record
      -- the Index in the IndexTable.
      let testCase0 :: TestTree
testCase0 = TestName -> Assertion -> TestTree
testCase TestName
"2D manifest" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
            let ctx :: Context rep
ctx =
                  Context rep
forall a. Monoid a => a
mempty
                    { parents =
                        [ SegOpName (SegmentedMap "defunc_0_map_res_5204"),
                          LoopBodyName "defunc_0_f_res_5208"
                        ],
                      assignments =
                        M.fromList
                          [ ("gtid_5205", VariableInfo mempty 0 mempty ThreadID),
                            ("i_5209", VariableInfo mempty 1 mempty LoopVar)
                          ]
                    }
            let patternNames :: [VName]
patternNames = [VName
"b_5211"]
            let dimFixes :: [DimIndex SubExp]
dimFixes =
                  [ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
"gtid_5205"),
                    SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix (VName -> SubExp
Var VName
"i_5209")
                  ]
            let indexTable :: Map SegOpName (Map ArrayName (Map VName [DimAccess rep]))
indexTable =
                  [(SegOpName, Map ArrayName (Map VName [DimAccess rep]))]
-> Map SegOpName (Map ArrayName (Map VName [DimAccess rep]))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                    [ ( VName -> SegOpName
SegmentedMap VName
"defunc_0_map_res_5204",
                        [(ArrayName, Map VName [DimAccess rep])]
-> Map ArrayName (Map VName [DimAccess rep])
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                          [ ( (VName
arr_name, [], [Int
0 .. Int
1]),
                              [(VName, [DimAccess rep])] -> Map VName [DimAccess rep]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                                [ ( VName
"b_5211",
                                    [ Map VName Dependency -> Maybe VName -> DimAccess rep
forall {k} (rep :: k).
Map VName Dependency -> Maybe VName -> DimAccess rep
DimAccess ([(VName, Dependency)] -> Map VName Dependency
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
"gtid_5205", Int -> VarType -> Dependency
Dependency Int
0 VarType
ThreadID)]) (VName -> Maybe VName
forall a. a -> Maybe a
Just VName
"gtid_5205"),
                                      Map VName Dependency -> Maybe VName -> DimAccess rep
forall {k} (rep :: k).
Map VName Dependency -> Maybe VName -> DimAccess rep
DimAccess ([(VName, Dependency)] -> Map VName Dependency
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
"i_5209", Int -> VarType -> Dependency
Dependency Int
1 VarType
LoopVar)]) (VName -> Maybe VName
forall a. a -> Maybe a
Just VName
"i_5209")
                                    ]
                                  )
                                ]
                            )
                          ]
                      )
                    ]
            let (Context rep
_, IndexTable rep
indexTable') = Context rep
-> [VName]
-> VName
-> [DimIndex SubExp]
-> (Context rep, IndexTable rep)
forall {k} (rep :: k).
Context rep
-> [VName]
-> VName
-> [DimIndex SubExp]
-> (Context rep, IndexTable rep)
analyseIndex Context rep
forall {k} {rep :: k}. Context rep
ctx [VName]
patternNames VName
arr_name [DimIndex SubExp]
dimFixes
            IndexTable Any
forall {k} {rep :: k}. IndexTable rep
indexTable' IndexTable Any -> IndexTable Any -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= IndexTable Any
forall {k} {rep :: k}. IndexTable rep
indexTable

      -- ============================= TestCase2 =============================
      -- We don't want to manifest an array with only one dimension, so we don't
      -- record anything in the IndexTable.
      let testCase1 :: TestTree
testCase1 = TestName -> Assertion -> TestTree
testCase TestName
"1D manifest" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
            let ctx :: Context rep
ctx =
                  Context rep
forall a. Monoid a => a
mempty
                    { parents =
                        [ SegOpName (SegmentedMap "defunc_0_map_res_5204"),
                          LoopBodyName "defunc_0_f_res_5208"
                        ]
                    }
            let patternNames :: [VName]
patternNames = [VName
"b_5211"]
            let dimFixes :: [DimIndex SubExp]
dimFixes = [SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
"i_5209"]

            let (Context rep
_, IndexTable rep
indexTable') = Context rep
-> [VName]
-> VName
-> [DimIndex SubExp]
-> (Context rep, IndexTable rep)
forall {k} (rep :: k).
Context rep
-> [VName]
-> VName
-> [DimIndex SubExp]
-> (Context rep, IndexTable rep)
analyseIndex Context rep
forall {k} {rep :: k}. Context rep
ctx [VName]
patternNames VName
arr_name [DimIndex SubExp]
dimFixes
            IndexTable Any
forall {k} {rep :: k}. IndexTable rep
indexTable' IndexTable Any -> IndexTable Any -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= IndexTable Any
forall a. Monoid a => a
mempty

      -- ============================= TestCase1 =============================
      -- We don't want to record anything to the IndexTable when the array is
      -- not accessed inside a SegMap
      -- TODO: Create a similar one for MC with loops
      let testCase2 :: TestTree
testCase2 = TestName -> Assertion -> TestTree
testCase TestName
"Not inside SegMap" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
            let ctx :: Context Any
ctx = Context Any
forall a. Monoid a => a
mempty
            let patternNames :: [VName]
patternNames = [VName
"b_5211"]
            let dimFixes :: [DimIndex SubExp]
dimFixes =
                  [ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
"gtid_5205",
                    SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
"i_5209"
                  ]
            let (Context Any
_, IndexTable Any
indexTable') = Context Any
-> [VName]
-> VName
-> [DimIndex SubExp]
-> (Context Any, IndexTable Any)
forall {k} (rep :: k).
Context rep
-> [VName]
-> VName
-> [DimIndex SubExp]
-> (Context rep, IndexTable rep)
analyseIndex Context Any
ctx [VName]
patternNames VName
arr_name [DimIndex SubExp]
dimFixes
            IndexTable Any
indexTable' IndexTable Any -> IndexTable Any -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= IndexTable Any
forall a. Monoid a => a
mempty

      -- ============================= TestCase3 =============================
      -- If an array is allocated inside a loop or SegMap, we want to record that
      -- information in the ArrayName of the IndexTable.
      let testCase3 :: TestTree
testCase3 = TestName -> Assertion -> TestTree
testCase TestName
"Allocated inside SegMap" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
            let parents' :: [BodyType]
parents' =
                  [ SegOpName -> BodyType
SegOpName (VName -> SegOpName
SegmentedMap VName
"defunc_0_map_res_5204"),
                    VName -> BodyType
LoopBodyName VName
"defunc_0_f_res_5208"
                  ]
            let ctx :: Context rep
ctx =
                  Context rep
forall a. Monoid a => a
mempty
                    { parents = parents',
                      assignments =
                        M.fromList
                          [ ("gtid_5205", VariableInfo mempty 0 mempty ThreadID),
                            ("i_5209", VariableInfo mempty 1 mempty LoopVar),
                            (arr_name, VariableInfo mempty 0 parents' Variable)
                          ]
                    }
            let patternNames :: [VName]
patternNames = [VName
"b_5211"]
            let dimFixes :: [DimIndex SubExp]
dimFixes =
                  [ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
"gtid_5205",
                    SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
"i_5209"
                  ]
            let indexTable :: Map SegOpName (Map ArrayName (Map VName [DimAccess rep]))
indexTable =
                  [(SegOpName, Map ArrayName (Map VName [DimAccess rep]))]
-> Map SegOpName (Map ArrayName (Map VName [DimAccess rep]))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                    [ ( VName -> SegOpName
SegmentedMap VName
"defunc_0_map_res_5204",
                        [(ArrayName, Map VName [DimAccess rep])]
-> Map ArrayName (Map VName [DimAccess rep])
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                          [ ( (VName
arr_name, [BodyType]
parents', [Int
0 .. Int
1]),
                              [(VName, [DimAccess rep])] -> Map VName [DimAccess rep]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                                [ ( VName
"b_5211",
                                    [ Map VName Dependency -> Maybe VName -> DimAccess rep
forall {k} (rep :: k).
Map VName Dependency -> Maybe VName -> DimAccess rep
DimAccess ([(VName, Dependency)] -> Map VName Dependency
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
"gtid_5205", Int -> VarType -> Dependency
Dependency Int
0 VarType
ThreadID)]) (VName -> Maybe VName
forall a. a -> Maybe a
Just VName
"gtid_5205"),
                                      Map VName Dependency -> Maybe VName -> DimAccess rep
forall {k} (rep :: k).
Map VName Dependency -> Maybe VName -> DimAccess rep
DimAccess ([(VName, Dependency)] -> Map VName Dependency
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
"i_5209", Int -> VarType -> Dependency
Dependency Int
1 VarType
LoopVar)]) (VName -> Maybe VName
forall a. a -> Maybe a
Just VName
"i_5209")
                                    ]
                                  )
                                ]
                            )
                          ]
                      )
                    ]
            let (Context rep
_, IndexTable rep
indexTable') = Context rep
-> [VName]
-> VName
-> [DimIndex SubExp]
-> (Context rep, IndexTable rep)
forall {k} (rep :: k).
Context rep
-> [VName]
-> VName
-> [DimIndex SubExp]
-> (Context rep, IndexTable rep)
analyseIndex Context rep
forall {k} {rep :: k}. Context rep
ctx [VName]
patternNames VName
arr_name [DimIndex SubExp]
dimFixes
            IndexTable Any
forall {k} {rep :: k}. IndexTable rep
indexTable' IndexTable Any -> IndexTable Any -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= IndexTable Any
forall {k} {rep :: k}. IndexTable rep
indexTable

      -- ============================= TestCase4 =============================
      -- If the vars in the index are temporaries, we want to reduce them to
      -- to the thread IDs and or loop counters they are functions of.
      let testCase4 :: TestTree
testCase4 = TestName -> Assertion -> TestTree
testCase TestName
"Reduce dependencies" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
            let ctx :: Context rep
ctx =
                  Context rep
forall a. Monoid a => a
mempty
                    { parents =
                        [ SegOpName (SegmentedMap "defunc_0_map_res_5204"),
                          LoopBodyName "defunc_0_f_res_5208"
                        ],
                      assignments =
                        M.fromList
                          [ ("gtid_5205", VariableInfo mempty 0 mempty ThreadID),
                            ("i_5209", VariableInfo mempty 1 mempty LoopVar),
                            ("tmp0_5210", VariableInfo (namesFromList ["gtid_5205"]) 2 mempty Variable),
                            ("tmp1_5211", VariableInfo (namesFromList ["i_5209"]) 3 mempty Variable),
                            ("k_5212", VariableInfo mempty 1 mempty ConstType)
                          ]
                    }
            let patternNames :: [VName]
patternNames = [VName
"b_5211"]
            let dimFixes :: [DimIndex SubExp]
dimFixes =
                  [ SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
"tmp0_5210",
                    SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
"tmp1_5211",
                    SubExp -> DimIndex SubExp
forall d. d -> DimIndex d
DimFix SubExp
"k_5212"
                  ]
            let indexTable :: Map SegOpName (Map ArrayName (Map VName [DimAccess rep]))
indexTable =
                  [(SegOpName, Map ArrayName (Map VName [DimAccess rep]))]
-> Map SegOpName (Map ArrayName (Map VName [DimAccess rep]))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                    [ ( VName -> SegOpName
SegmentedMap VName
"defunc_0_map_res_5204",
                        [(ArrayName, Map VName [DimAccess rep])]
-> Map ArrayName (Map VName [DimAccess rep])
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                          [ ( (VName
arr_name, [], [Int
0 .. Int
2]),
                              [(VName, [DimAccess rep])] -> Map VName [DimAccess rep]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                                [ ( VName
"b_5211",
                                    [ Map VName Dependency -> Maybe VName -> DimAccess rep
forall {k} (rep :: k).
Map VName Dependency -> Maybe VName -> DimAccess rep
DimAccess ([(VName, Dependency)] -> Map VName Dependency
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
"gtid_5205", Int -> VarType -> Dependency
Dependency Int
0 VarType
ThreadID)]) (VName -> Maybe VName
forall a. a -> Maybe a
Just VName
"tmp0_5210"),
                                      Map VName Dependency -> Maybe VName -> DimAccess rep
forall {k} (rep :: k).
Map VName Dependency -> Maybe VName -> DimAccess rep
DimAccess ([(VName, Dependency)] -> Map VName Dependency
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
"i_5209", Int -> VarType -> Dependency
Dependency Int
1 VarType
LoopVar)]) (VName -> Maybe VName
forall a. a -> Maybe a
Just VName
"tmp1_5211"),
                                      Map VName Dependency -> Maybe VName -> DimAccess rep
forall {k} (rep :: k).
Map VName Dependency -> Maybe VName -> DimAccess rep
DimAccess Map VName Dependency
forall a. Monoid a => a
mempty (VName -> Maybe VName
forall a. a -> Maybe a
Just VName
"k_5212")
                                    ]
                                  )
                                ]
                            )
                          ]
                      )
                    ]
            let (Context rep
_, IndexTable rep
indexTable') = Context rep
-> [VName]
-> VName
-> [DimIndex SubExp]
-> (Context rep, IndexTable rep)
forall {k} (rep :: k).
Context rep
-> [VName]
-> VName
-> [DimIndex SubExp]
-> (Context rep, IndexTable rep)
analyseIndex Context rep
forall {k} {rep :: k}. Context rep
ctx [VName]
patternNames VName
arr_name [DimIndex SubExp]
dimFixes
            IndexTable Any
forall {k} {rep :: k}. IndexTable rep
indexTable' IndexTable Any -> IndexTable Any -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= IndexTable Any
forall {k} {rep :: k}. IndexTable rep
indexTable

      [TestTree
testCase0, TestTree
testCase1, TestTree
testCase2, TestTree
testCase3, TestTree
testCase4]

analyseDimAccessesTests :: TestTree
analyseDimAccessesTests :: TestTree
analyseDimAccessesTests = TestName -> [TestTree] -> TestTree
testGroup
  TestName
"analyseDimAccesses"
  ([TestTree] -> TestTree) -> [TestTree] -> TestTree
forall a b. (a -> b) -> a -> b
$ do
    let testCase0 :: TestTree
testCase0 = TestName -> Assertion -> TestTree
testCase TestName
"Fold" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
          let indexTable :: Map SegOpName (Map ArrayName (Map VName [DimAccess rep]))
indexTable =
                [(SegOpName, Map ArrayName (Map VName [DimAccess rep]))]
-> Map SegOpName (Map ArrayName (Map VName [DimAccess rep]))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                  [ ( VName -> SegOpName
SegmentedMap VName
"defunc_0_map_res_5204",
                      [(ArrayName, Map VName [DimAccess rep])]
-> Map ArrayName (Map VName [DimAccess rep])
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                        [ ( (VName
"xss_5144", [], [Int
0, Int
1]),
                            [(VName, [DimAccess rep])] -> Map VName [DimAccess rep]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
                              [ ( VName
"b_5211",
                                  [ Map VName Dependency -> Maybe VName -> DimAccess rep
forall {k} (rep :: k).
Map VName Dependency -> Maybe VName -> DimAccess rep
DimAccess ([(VName, Dependency)] -> Map VName Dependency
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
"gtid_5205", Int -> VarType -> Dependency
Dependency Int
0 VarType
ThreadID)]) (VName -> Maybe VName
forall a. a -> Maybe a
Just VName
"gtid_5205"),
                                    Map VName Dependency -> Maybe VName -> DimAccess rep
forall {k} (rep :: k).
Map VName Dependency -> Maybe VName -> DimAccess rep
DimAccess ([(VName, Dependency)] -> Map VName Dependency
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
"i_5209", Int -> VarType -> Dependency
Dependency Int
1 VarType
LoopVar)]) (VName -> Maybe VName
forall a. a -> Maybe a
Just VName
"i_5209")
                                  ]
                                )
                              ]
                          )
                        ]
                    )
                  ]
          let indexTable' :: IndexTable GPU
indexTable' = (forall rep. Analyse rep => Prog rep -> IndexTable rep
analyseDimAccesses @GPU) Prog GPU
prog0
          IndexTable GPU
indexTable' IndexTable GPU -> IndexTable GPU -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= IndexTable GPU
forall {k} {rep :: k}. IndexTable rep
indexTable

    [TestTree
testCase0]
  where
    prog0 :: Prog GPU
    prog0 :: Prog GPU
prog0 =
      Prog GPU
"\
      \entry(\"main\",\
      \      {xss: [][]i64},\
      \      {[]i64})\
      \  entry_main (n_5142 : i64,\
      \              m_5143 : i64,\
      \              xss_5144 : [n_5142][m_5143]i64)\
      \  : {[n_5142]i64#([2], [0])} = {\
      \  let {segmap_group_size_5202 : i64} =\
      \    get_size(segmap_group_size_5190, thread_block_size)\
      \  let {segmap_usable_groups_5203 : i64} =\
      \    sdiv_up64(n_5142, segmap_group_size_5202)\
      \  let {defunc_0_map_res_5204 : [n_5142]i64} =\
      \    segmap(thread; ; grid=segmap_usable_groups_5203; blocksize=segmap_group_size_5202)\
      \    (gtid_5205 < n_5142) (~phys_tid_5206) : {i64} {\
      \      let {defunc_0_f_res_5208 : i64} =\
      \        loop {acc_5210 : i64} = {0i64}\
      \        for i_5209:i64 < m_5143 do {\
      \          let {b_5211 : i64} =\
      \            xss_5144[gtid_5205, i_5209]\
      \          let {defunc_0_f_res_5212 : i64} =\
      \            add64(acc_5210, b_5211)\
      \          in {defunc_0_f_res_5212}\
      \        }\
      \      return {returns defunc_0_f_res_5208}\
      \    }\
      \  in {defunc_0_map_res_5204}\
      \}"