module Futhark.Optimise.ArrayLayout.LayoutTests (tests) where
import Data.Map.Strict qualified as M
import Futhark.Analysis.AccessPattern
import Futhark.Analysis.PrimExp
import Futhark.FreshNames
import Futhark.IR.GPU (GPU)
import Futhark.IR.GPUTests ()
import Futhark.Optimise.ArrayLayout.Layout
import Language.Futhark.Core
import Test.Tasty
import Test.Tasty.HUnit
tests :: TestTree
tests :: TestTree
tests =
String -> [TestTree] -> TestTree
testGroup
String
"Layout"
[TestTree
commonPermutationEliminatorsTests]
commonPermutationEliminatorsTests :: TestTree
commonPermutationEliminatorsTests :: TestTree
commonPermutationEliminatorsTests =
String -> [TestTree] -> TestTree
testGroup
String
"commonPermutationEliminators"
[TestTree
permutationTests, TestTree
nestTests, TestTree
dimAccessTests, TestTree
constIndexElimTests]
permutationTests :: TestTree
permutationTests :: TestTree
permutationTests =
String -> [TestTree] -> TestTree
testGroup String
"Permutations" ([TestTree] -> TestTree) -> [TestTree] -> TestTree
forall a b. (a -> b) -> a -> b
$
do
[ String -> Assertion -> TestTree
testCase ([String] -> String
unwords [[Int] -> String
forall a. Show a => a -> String
show [Int]
perm, String
"->", Bool -> String
forall a. Show a => a -> String
show Bool
res]) (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
[Int] -> [BodyType] -> Bool
commonPermutationEliminators [Int]
perm [] Bool -> Bool -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= Bool
res
| ([Int]
perm, Bool
res) <-
[ ([Int
0], Bool
True),
([Int
1, Int
0], Bool
False),
([Int
0, Int
1], Bool
True),
([Int
0, Int
0], Bool
True),
([Int
1, Int
1], Bool
True),
([Int
1, Int
2, Int
0], Bool
False),
([Int
2, Int
0, Int
1], Bool
False),
([Int
0, Int
1, Int
2], Bool
True),
([Int
1, Int
0, Int
2], Bool
True),
([Int
2, Int
1, Int
0], Bool
True),
([Int
2, Int
2, Int
0], Bool
True),
([Int
2, Int
1, Int
1], Bool
True),
([Int
1, Int
0, Int
1], Bool
True),
([Int
0, Int
0, Int
0], Bool
True),
([Int
0, Int
1, Int
2, Int
3, Int
4], Bool
True),
([Int
1, Int
0, Int
2, Int
3, Int
4], Bool
True),
([Int
2, Int
3, Int
0, Int
1, Int
4], Bool
True),
([Int
3, Int
4, Int
2, Int
0, Int
1], Bool
True),
([Int
2, Int
3, Int
4, Int
0, Int
1], Bool
False),
([Int
1, Int
2, Int
3, Int
4, Int
0], Bool
False),
([Int
3, Int
4, Int
0, Int
1, Int
2], Bool
False)
]
]
nestTests :: TestTree
nestTests :: TestTree
nestTests = String -> [TestTree] -> TestTree
testGroup String
"Nests" ([TestTree] -> TestTree) -> [TestTree] -> TestTree
forall a b. (a -> b) -> a -> b
$
do
let names :: [VName]
names = Int -> [VName]
generateNames Int
2
[ String -> Assertion -> TestTree
testCase ([String] -> String
unwords [String
args, String
"->", Bool -> String
forall a. Show a => a -> String
show Bool
res]) (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
[Int] -> [BodyType] -> Bool
commonPermutationEliminators [Int
1, Int
0] [BodyType]
nest Bool -> Bool -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= Bool
res
| (String
args, [BodyType]
nest, Bool
res) <-
[ (String
"[]", [], Bool
False),
(String
"[CondBodyName]", [VName -> BodyType
CondBodyName] [VName -> BodyType] -> [VName] -> [BodyType]
forall a b. [a -> b] -> [a] -> [b]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName]
names, Bool
False),
(String
"[SegOpName]", [SegOpName -> BodyType
SegOpName (SegOpName -> BodyType)
-> (VName -> SegOpName) -> VName -> BodyType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SegOpName
SegmentedMap] [VName -> BodyType] -> [VName] -> [BodyType]
forall a b. [a -> b] -> [a] -> [b]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName]
names, Bool
True),
(String
"[LoopBodyName]", [VName -> BodyType
LoopBodyName] [VName -> BodyType] -> [VName] -> [BodyType]
forall a b. [a -> b] -> [a] -> [b]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName]
names, Bool
False),
(String
"[SegOpName, CondBodyName]", [SegOpName -> BodyType
SegOpName (SegOpName -> BodyType)
-> (VName -> SegOpName) -> VName -> BodyType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. VName -> SegOpName
SegmentedMap, VName -> BodyType
CondBodyName] [VName -> BodyType] -> [VName] -> [BodyType]
forall a b. [a -> b] -> [a] -> [b]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName]
names, Bool
True),
(String
"[CondBodyName, LoopBodyName]", [VName -> BodyType
CondBodyName, VName -> BodyType
LoopBodyName] [VName -> BodyType] -> [VName] -> [BodyType]
forall a b. [a -> b] -> [a] -> [b]
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> [VName]
names, Bool
False)
]
]
dimAccessTests :: TestTree
dimAccessTests :: TestTree
dimAccessTests = String -> [TestTree] -> TestTree
testGroup String
"DimAccesses" []
constIndexElimTests :: TestTree
constIndexElimTests :: TestTree
constIndexElimTests =
String -> [TestTree] -> TestTree
testGroup
String
"constIndexElimTests"
[ String -> Assertion -> TestTree
testCase String
"gpu eliminates indexes with constant in any dim" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
let primExpTable :: Map VName (Maybe (PrimExp VName))
primExpTable =
[(VName, Maybe (PrimExp VName))]
-> Map VName (Maybe (PrimExp VName))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
[ (VName
"gtid_4", PrimExp VName -> Maybe (PrimExp VName)
forall a. a -> Maybe a
Just (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
"n_4" (IntType -> PrimType
IntType IntType
Int64))),
(VName
"i_5", PrimExp VName -> Maybe (PrimExp VName)
forall a. a -> Maybe a
Just (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
"n_4" (IntType -> PrimType
IntType IntType
Int64)))
]
Map VName (Maybe (PrimExp VName)) -> IndexTable GPU -> LayoutTable
forall {k} (rep :: k).
Layout rep =>
Map VName (Maybe (PrimExp VName)) -> IndexTable rep -> LayoutTable
layoutTableFromIndexTable Map VName (Maybe (PrimExp VName))
primExpTable IndexTable GPU
accessTableGPU LayoutTable -> LayoutTable -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= LayoutTable
forall a. Monoid a => a
mempty,
String -> Assertion -> TestTree
testCase String
"gpu ignores when not last" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ do
let primExpTable :: Map VName (Maybe (PrimExp VName))
primExpTable =
[(VName, Maybe (PrimExp VName))]
-> Map VName (Maybe (PrimExp VName))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
[ (VName
"gtid_4", PrimExp VName -> Maybe (PrimExp VName)
forall a. a -> Maybe a
Just (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
"gtid_4" (IntType -> PrimType
IntType IntType
Int64))),
(VName
"gtid_5", PrimExp VName -> Maybe (PrimExp VName)
forall a. a -> Maybe a
Just (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
"gtid_5" (IntType -> PrimType
IntType IntType
Int64))),
(VName
"i_6", PrimExp VName -> Maybe (PrimExp VName)
forall a. a -> Maybe a
Just (VName -> PrimType -> PrimExp VName
forall v. v -> PrimType -> PrimExp v
LeafExp VName
"i_6" (IntType -> PrimType
IntType IntType
Int64)))
]
Map VName (Maybe (PrimExp VName)) -> IndexTable GPU -> LayoutTable
forall {k} (rep :: k).
Layout rep =>
Map VName (Maybe (PrimExp VName)) -> IndexTable rep -> LayoutTable
layoutTableFromIndexTable Map VName (Maybe (PrimExp VName))
primExpTable IndexTable GPU
accessTableGPUrev
LayoutTable -> LayoutTable -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [(SegOpName, Map ArrayName (Map VName [Int]))] -> LayoutTable
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
[ ( VName -> SegOpName
SegmentedMap VName
"mapres_1",
[(ArrayName, Map VName [Int])] -> Map ArrayName (Map VName [Int])
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
[ ( (VName
"a_2", [], [Int
0, Int
1, Int
2, Int
3]),
[(VName, [Int])] -> Map VName [Int]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList [(VName
"A_3", [Int
2, Int
3, Int
0, Int
1])]
)
]
)
]
]
where
accessTableGPU :: IndexTable GPU
accessTableGPU :: IndexTable GPU
accessTableGPU =
[DimAccess GPU] -> IndexTable GPU
forall {k} (rep :: k). [DimAccess rep] -> IndexTable rep
singleAccess
[ Int -> VName -> DimAccess GPU
forall {k} (rep :: k). Int -> VName -> DimAccess rep
singleParAccess Int
0 VName
"gtid_4",
Map VName Dependency -> Maybe VName -> DimAccess GPU
forall {k} (rep :: k).
Map VName Dependency -> Maybe VName -> DimAccess rep
DimAccess Map VName Dependency
forall a. Monoid a => a
mempty Maybe VName
forall a. Maybe a
Nothing,
Int -> VName -> DimAccess GPU
forall {k} (rep :: k). Int -> VName -> DimAccess rep
singleSeqAccess Int
1 VName
"i_5"
]
accessTableGPUrev :: IndexTable GPU
accessTableGPUrev :: IndexTable GPU
accessTableGPUrev =
[DimAccess GPU] -> IndexTable GPU
forall {k} (rep :: k). [DimAccess rep] -> IndexTable rep
singleAccess
[ Int -> VName -> DimAccess GPU
forall {k} (rep :: k). Int -> VName -> DimAccess rep
singleParAccess Int
1 VName
"gtid_4",
Int -> VName -> DimAccess GPU
forall {k} (rep :: k). Int -> VName -> DimAccess rep
singleParAccess Int
2 VName
"gtid_5",
Int -> VName -> DimAccess GPU
forall {k} (rep :: k). Int -> VName -> DimAccess rep
singleSeqAccess Int
0 VName
"i_5",
Int -> VName -> DimAccess GPU
forall {k} (rep :: k). Int -> VName -> DimAccess rep
singleSeqAccess Int
2 VName
"gtid_4"
]
singleAccess :: [DimAccess rep] -> IndexTable rep
singleAccess :: forall {k} (rep :: k). [DimAccess rep] -> IndexTable rep
singleAccess [DimAccess rep]
dims =
[(SegOpName, Map ArrayName (Map VName [DimAccess rep]))]
-> Map SegOpName (Map ArrayName (Map VName [DimAccess rep]))
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
[ ( SegOpName
sgOp,
[(ArrayName, Map VName [DimAccess rep])]
-> Map ArrayName (Map VName [DimAccess rep])
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
[ ( (VName
"A_2", [], [Int
0, Int
1, Int
2, Int
3]),
[(VName, [DimAccess rep])] -> Map VName [DimAccess rep]
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList
[ ( VName
"a_3",
[DimAccess rep]
dims
)
]
)
]
)
]
where
sgOp :: SegOpName
sgOp = SegmentedMap {vnameFromSegOp :: VName
vnameFromSegOp = VName
"mapres_1"}
singleParAccess :: Int -> VName -> DimAccess rep
singleParAccess :: forall {k} (rep :: k). Int -> VName -> DimAccess rep
singleParAccess Int
level VName
name =
Map VName Dependency -> Maybe VName -> DimAccess rep
forall {k} (rep :: k).
Map VName Dependency -> Maybe VName -> DimAccess rep
DimAccess
(VName -> Dependency -> Map VName Dependency
forall k a. k -> a -> Map k a
M.singleton VName
name (Dependency -> Map VName Dependency)
-> Dependency -> Map VName Dependency
forall a b. (a -> b) -> a -> b
$ Int -> VarType -> Dependency
Dependency Int
level VarType
ThreadID)
(VName -> Maybe VName
forall a. a -> Maybe a
Just VName
name)
singleSeqAccess :: Int -> VName -> DimAccess rep
singleSeqAccess :: forall {k} (rep :: k). Int -> VName -> DimAccess rep
singleSeqAccess Int
level VName
name =
Map VName Dependency -> Maybe VName -> DimAccess rep
forall {k} (rep :: k).
Map VName Dependency -> Maybe VName -> DimAccess rep
DimAccess
(VName -> Dependency -> Map VName Dependency
forall k a. k -> a -> Map k a
M.singleton VName
name (Dependency -> Map VName Dependency)
-> Dependency -> Map VName Dependency
forall a b. (a -> b) -> a -> b
$ Int -> VarType -> Dependency
Dependency Int
level VarType
LoopVar)
(VName -> Maybe VName
forall a. a -> Maybe a
Just VName
name)
generateNames :: Int -> [VName]
generateNames :: Int -> [VName]
generateNames Int
count = do
let (VName
name, VNameSource
source) = VNameSource -> VName -> (VName, VNameSource)
newName VNameSource
blankNameSource VName
"i_0"
([VName], VNameSource) -> [VName]
forall a b. (a, b) -> a
fst (([VName], VNameSource) -> [VName])
-> ([VName], VNameSource) -> [VName]
forall a b. (a -> b) -> a -> b
$ (([VName], VNameSource) -> Int -> ([VName], VNameSource))
-> ([VName], VNameSource) -> [Int] -> ([VName], VNameSource)
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl ([VName], VNameSource) -> Int -> ([VName], VNameSource)
forall {p}. ([VName], VNameSource) -> p -> ([VName], VNameSource)
f ([VName
name], VNameSource
source) [Int
1 .. Int
count Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
where
f :: ([VName], VNameSource) -> p -> ([VName], VNameSource)
f ([VName]
names, VNameSource
source) p
_ = do
let (VName
name, VNameSource
source') = VNameSource -> VName -> (VName, VNameSource)
newName VNameSource
source ([VName] -> VName
forall a. HasCallStack => [a] -> a
last [VName]
names)
([VName]
names [VName] -> [VName] -> [VName]
forall a. [a] -> [a] -> [a]
++ [VName
name], VNameSource
source')