{-# OPTIONS_GHC -fno-warn-orphans #-}

module Futhark.IR.Prop.ReshapeTests
  ( tests,
  )
where

import Data.List qualified as L
import Futhark.IR.Prop.Constants
import Futhark.IR.Prop.Reshape
import Futhark.IR.Syntax
import Futhark.IR.SyntaxTests ()
import Test.Tasty
import Test.Tasty.HUnit

intShape :: [Int] -> Shape
intShape :: [Int] -> Shape
intShape = [SubExp] -> Shape
forall d. [d] -> ShapeBase d
Shape ([SubExp] -> Shape) -> ([Int] -> [SubExp]) -> [Int] -> Shape
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Int -> SubExp) -> [Int] -> [SubExp]
forall a b. (a -> b) -> [a] -> [b]
map (IntType -> Integer -> SubExp
intConst IntType
Int32 (Integer -> SubExp) -> (Int -> Integer) -> Int -> SubExp
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Integer
forall a. Integral a => a -> Integer
toInteger)

reshapeOuterTests :: [TestTree]
reshapeOuterTests :: [TestTree]
reshapeOuterTests =
  [ String -> Assertion -> TestTree
testCase ([String] -> String
unwords [String
"reshapeOuter", [Int] -> String
forall a. Show a => a -> String
show [Int]
sc, Int -> String
forall a. Show a => a -> String
show Int
n, [Int] -> String
forall a. Show a => a -> String
show [Int]
shape, String
"==", [Int] -> String
forall a. Show a => a -> String
show [Int]
sc_res]) (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
      Shape -> Int -> Shape -> Shape
reshapeOuter ([Int] -> Shape
intShape [Int]
sc) Int
n ([Int] -> Shape
intShape [Int]
shape) Shape -> Shape -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [Int] -> Shape
intShape [Int]
sc_res
    | ([Int]
sc, Int
n, [Int]
shape, [Int]
sc_res) <-
        [ ([Int
1], Int
1, [Int
4, Int
3], [Int
1, Int
3]),
          ([Int
1], Int
2, [Int
4, Int
3], [Int
1]),
          ([Int
2, Int
2], Int
1, [Int
4, Int
3], [Int
2, Int
2, Int
3]),
          ([Int
2, Int
2], Int
2, [Int
4, Int
3], [Int
2, Int
2])
        ]
  ]

reshapeInnerTests :: [TestTree]
reshapeInnerTests :: [TestTree]
reshapeInnerTests =
  [ String -> Assertion -> TestTree
testCase ([String] -> String
unwords [String
"reshapeInner", [Int] -> String
forall a. Show a => a -> String
show [Int]
sc, Int -> String
forall a. Show a => a -> String
show Int
n, [Int] -> String
forall a. Show a => a -> String
show [Int]
shape, String
"==", [Int] -> String
forall a. Show a => a -> String
show [Int]
sc_res]) (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
      Shape -> Int -> Shape -> Shape
reshapeInner ([Int] -> Shape
intShape [Int]
sc) Int
n ([Int] -> Shape
intShape [Int]
shape) Shape -> Shape -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [Int] -> Shape
intShape [Int]
sc_res
    | ([Int]
sc, Int
n, [Int]
shape, [Int]
sc_res) <-
        [ ([Int
1], Int
1, [Int
4, Int
3], [Int
4, Int
1]),
          ([Int
1], Int
0, [Int
4, Int
3], [Int
1]),
          ([Int
2, Int
2], Int
1, [Int
4, Int
3], [Int
4, Int
2, Int
2]),
          ([Int
2, Int
2], Int
0, [Int
4, Int
3], [Int
2, Int
2])
        ]
  ]

dimFlatten :: Int -> Int -> d -> DimSplice d
dimFlatten :: forall d. Int -> Int -> d -> DimSplice d
dimFlatten Int
i Int
k d
w = Int -> Int -> ShapeBase d -> DimSplice d
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
i Int
k ([d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape [d
w])

dimUnflatten :: Int -> [d] -> DimSplice d
dimUnflatten :: forall d. Int -> [d] -> DimSplice d
dimUnflatten Int
i [d]
ws = Int -> Int -> ShapeBase d -> DimSplice d
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
i Int
1 ([d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape [d]
ws)

dimCoerce :: Int -> d -> DimSplice d
dimCoerce :: forall d. Int -> d -> DimSplice d
dimCoerce Int
i d
w = Int -> Int -> ShapeBase d -> DimSplice d
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
i Int
1 ([d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape [d
w])

dimSplice :: Int -> Int -> [d] -> DimSplice d
dimSplice :: forall d. Int -> Int -> [d] -> DimSplice d
dimSplice Int
i Int
n [d]
s = Int -> Int -> ShapeBase d -> DimSplice d
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
i Int
n (ShapeBase d -> DimSplice d) -> ShapeBase d -> DimSplice d
forall a b. (a -> b) -> a -> b
$ [d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape [d]
s

flipReshapeRearrangeTests :: [TestTree]
flipReshapeRearrangeTests :: [TestTree]
flipReshapeRearrangeTests =
  [ String -> Assertion -> TestTree
testCase
      ( [String] -> String
unwords
          [ String
"flipReshapeRearrange",
            [String] -> String
forall a. Show a => a -> String
show [String]
v0_shape,
            [String] -> String
forall a. Show a => a -> String
show [String]
v1_shape,
            [Int] -> String
forall a. Show a => a -> String
show [Int]
perm
          ]
      )
      (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ [String] -> [String] -> [Int] -> Maybe [Int]
forall d. Eq d => [d] -> [d] -> [Int] -> Maybe [Int]
flipReshapeRearrange [String]
v0_shape [String]
v1_shape [Int]
perm Maybe [Int] -> Maybe [Int] -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= Maybe [Int]
res
    | ([String]
v0_shape :: [String], [String]
v1_shape, [Int]
perm, Maybe [Int]
res) <-
        [ ( [String
"A", String
"B", String
"C"],
            [String
"A", String
"BC"],
            [Int
1, Int
0],
            [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int
1, Int
2, Int
0]
          ),
          ( [String
"A", String
"B", String
"C", String
"D"],
            [String
"A", String
"BCD"],
            [Int
1, Int
0],
            [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int
1, Int
2, Int
3, Int
0]
          ),
          ( [String
"A"],
            [String
"B", String
"C"],
            [Int
1, Int
0],
            Maybe [Int]
forall a. Maybe a
Nothing
          ),
          ( [String
"A", String
"B", String
"C"],
            [String
"AB", String
"C"],
            [Int
1, Int
0],
            [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int
2, Int
0, Int
1]
          ),
          ( [String
"A", String
"B", String
"C", String
"D"],
            [String
"ABC", String
"D"],
            [Int
1, Int
0],
            [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int
3, Int
0, Int
1, Int
2]
          )
        ]
  ]

flipRearrangeReshapeTests :: [TestTree]
flipRearrangeReshapeTests :: [TestTree]
flipRearrangeReshapeTests =
  [ String -> Assertion -> TestTree
testCase
      ( [String] -> String
unwords
          [ String
"flipRearrangeReshape",
            [Int] -> String
forall a. Show a => a -> String
show [Int]
perm,
            NewShape String -> String
forall a. Pretty a => a -> String
prettyStringOneLine NewShape String
newshape
          ]
      )
      (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$ [Int] -> NewShape String -> Maybe (NewShape String, [Int])
forall d. [Int] -> NewShape d -> Maybe (NewShape d, [Int])
flipRearrangeReshape [Int]
perm NewShape String
newshape Maybe (NewShape String, [Int])
-> Maybe (NewShape String, [Int]) -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= Maybe (NewShape String, [Int])
res
    | ([Int]
perm, NewShape String
newshape :: NewShape String, Maybe (NewShape String, [Int])
res) <-
        [ ( [Int
1, Int
0],
            [DimSplice String] -> ShapeBase String -> NewShape String
forall d. [DimSplice d] -> ShapeBase d -> NewShape d
NewShape
              [Int -> [String] -> DimSplice String
forall d. Int -> [d] -> DimSplice d
dimUnflatten Int
1 [String
"B", String
"C"]]
              ([String] -> ShapeBase String
forall d. [d] -> ShapeBase d
Shape [String
"A", String
"B", String
"C"]),
            (NewShape String, [Int]) -> Maybe (NewShape String, [Int])
forall a. a -> Maybe a
Just
              ( [DimSplice String] -> ShapeBase String -> NewShape String
forall d. [DimSplice d] -> ShapeBase d -> NewShape d
NewShape
                  [Int -> [String] -> DimSplice String
forall d. Int -> [d] -> DimSplice d
dimUnflatten Int
0 [String
"B", String
"C"]]
                  ([String] -> ShapeBase String
forall d. [d] -> ShapeBase d
Shape [String
"B", String
"C", String
"A"]),
                [Int
2, Int
0, Int
1]
              )
          ),
          ( [Int
1, Int
0],
            [DimSplice String] -> ShapeBase String -> NewShape String
forall d. [DimSplice d] -> ShapeBase d -> NewShape d
NewShape
              [Int -> Int -> String -> DimSplice String
forall d. Int -> Int -> d -> DimSplice d
dimFlatten Int
0 Int
2 String
"AB"]
              ([String] -> ShapeBase String
forall d. [d] -> ShapeBase d
Shape [String
"AB"]),
            Maybe (NewShape String, [Int])
forall a. Maybe a
Nothing
          )
        ]
  ]

simplifyTests :: TestTree
simplifyTests :: TestTree
simplifyTests =
  String -> [TestTree] -> TestTree
testGroup
    String
"simplifyNewShape"
    [ String -> Assertion -> TestTree
testCase String
"Inverse flatten and unflatten - simple case" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
        [String] -> [DimSplice String] -> Maybe [DimSplice String]
lhs
          [String
"A", String
"B"]
          [Int -> Int -> String -> DimSplice String
forall d. Int -> Int -> d -> DimSplice d
dimFlatten Int
0 Int
2 String
"AB", Int -> [String] -> DimSplice String
forall d. Int -> [d] -> DimSplice d
dimUnflatten Int
0 [String
"A", String
"B"]]
          Maybe [DimSplice String] -> Maybe [DimSplice String] -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [DimSplice String] -> Maybe [DimSplice String]
forall a. a -> Maybe a
Just [],
      String -> Assertion -> TestTree
testCase String
"Non-inverse flatten and unflatten - simple case" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
        [String] -> [DimSplice String] -> Maybe [DimSplice String]
lhs
          [String
"A", String
"B"]
          [Int -> Int -> String -> DimSplice String
forall d. Int -> Int -> d -> DimSplice d
dimFlatten Int
0 Int
2 String
"AB", Int -> [String] -> DimSplice String
forall d. Int -> [d] -> DimSplice d
dimUnflatten Int
0 [String
"C", String
"D"]]
          Maybe [DimSplice String] -> Maybe [DimSplice String] -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [DimSplice String] -> Maybe [DimSplice String]
forall a. a -> Maybe a
Just [Int -> Int -> [String] -> DimSplice String
forall d. Int -> Int -> [d] -> DimSplice d
dimSplice Int
0 Int
2 [String
"C", String
"D"]],
      String -> Assertion -> TestTree
testCase String
"Inverse flatten and unflatten - separated by coercion" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
        [String] -> [DimSplice String] -> Maybe [DimSplice String]
lhs
          [String
"A", String
"B"]
          [ Int -> Int -> String -> DimSplice String
forall d. Int -> Int -> d -> DimSplice d
dimFlatten Int
0 Int
2 String
"AB",
            Int -> String -> DimSplice String
forall d. Int -> d -> DimSplice d
dimCoerce Int
0 String
"CD",
            Int -> [String] -> DimSplice String
forall d. Int -> [d] -> DimSplice d
dimUnflatten Int
0 [String
"C", String
"D"]
          ]
          Maybe [DimSplice String] -> Maybe [DimSplice String] -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [DimSplice String] -> Maybe [DimSplice String]
forall a. a -> Maybe a
Just [Int -> Int -> [String] -> DimSplice String
forall d. Int -> Int -> [d] -> DimSplice d
dimSplice Int
0 Int
2 [String
"C", String
"D"]],
      String -> Assertion -> TestTree
testCase String
"Two unflattens - simple case" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
        [String] -> [DimSplice String] -> Maybe [DimSplice String]
lhs
          [String
"ABC"]
          [Int -> [String] -> DimSplice String
forall d. Int -> [d] -> DimSplice d
dimUnflatten Int
0 [String
"A", String
"BC"], Int -> [String] -> DimSplice String
forall d. Int -> [d] -> DimSplice d
dimUnflatten Int
1 [String
"B", String
"C"]]
          Maybe [DimSplice String] -> Maybe [DimSplice String] -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [DimSplice String] -> Maybe [DimSplice String]
forall a. a -> Maybe a
Just [Int -> [String] -> DimSplice String
forall d. Int -> [d] -> DimSplice d
dimUnflatten Int
0 [String
"A", String
"B", String
"C"]],
      String -> Assertion -> TestTree
testCase String
"Two unflattens with unchanged prefix" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
        [String] -> [DimSplice String] -> Maybe [DimSplice String]
lhs
          [String
"A", String
"B", String
"C", String
"D", String
"E"]
          [ Int -> Int -> ShapeBase String -> DimSplice String
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
3 Int
2 (ShapeBase String -> DimSplice String)
-> ShapeBase String -> DimSplice String
forall a b. (a -> b) -> a -> b
$ [String] -> ShapeBase String
forall d. [d] -> ShapeBase d
Shape [String
"DE"],
            Int -> Int -> ShapeBase String -> DimSplice String
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
2 Int
2 (ShapeBase String -> DimSplice String)
-> ShapeBase String -> DimSplice String
forall a b. (a -> b) -> a -> b
$ [String] -> ShapeBase String
forall d. [d] -> ShapeBase d
Shape [String
"CDE"]
          ]
          Maybe [DimSplice String] -> Maybe [DimSplice String] -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [DimSplice String] -> Maybe [DimSplice String]
forall a. a -> Maybe a
Just [Int -> Int -> String -> DimSplice String
forall d. Int -> Int -> d -> DimSplice d
dimFlatten Int
2 Int
3 String
"CDE"],
      String -> Assertion -> TestTree
testCase String
"Identity coerce" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
        [String] -> [DimSplice String] -> Maybe [DimSplice String]
lhs
          [String
"A", String
"B", String
"C"]
          [Int -> String -> DimSplice String
forall d. Int -> d -> DimSplice d
dimCoerce Int
1 String
"B", Int -> String -> DimSplice String
forall d. Int -> d -> DimSplice d
dimCoerce Int
2 String
"C"]
          Maybe [DimSplice String] -> Maybe [DimSplice String] -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [DimSplice String] -> Maybe [DimSplice String]
forall a. a -> Maybe a
Just [],
      String -> Assertion -> TestTree
testCase String
"Identity coerce (multiple dimensions)" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
        [String] -> [DimSplice String] -> Maybe [DimSplice String]
lhs
          [String
"A", String
"B", String
"C"]
          [Int -> Int -> ShapeBase String -> DimSplice String
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
1 Int
2 ([String] -> ShapeBase String
forall d. [d] -> ShapeBase d
Shape [String
"B", String
"C"])]
          Maybe [DimSplice String] -> Maybe [DimSplice String] -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [DimSplice String] -> Maybe [DimSplice String]
forall a. a -> Maybe a
Just [],
      String -> Assertion -> TestTree
testCase String
"Identity coerce (with non-identity stuff afterwards)" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
        [String] -> [DimSplice String] -> Maybe [DimSplice String]
lhs
          [String
"B", String
"CD"]
          [Int -> String -> DimSplice String
forall d. Int -> d -> DimSplice d
dimCoerce Int
0 String
"B", Int -> [String] -> DimSplice String
forall d. Int -> [d] -> DimSplice d
dimUnflatten Int
1 [String
"C", String
"D"]]
          Maybe [DimSplice String] -> Maybe [DimSplice String] -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [DimSplice String] -> Maybe [DimSplice String]
forall a. a -> Maybe a
Just [Int -> [String] -> DimSplice String
forall d. Int -> [d] -> DimSplice d
dimUnflatten Int
1 [String
"C", String
"D"]],
      String -> Assertion -> TestTree
testCase String
"Get rid of a coerce before an unflatten" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
        [String] -> [DimSplice String] -> Maybe [DimSplice String]
lhs
          [String
"CD"]
          [Int -> String -> DimSplice String
forall d. Int -> d -> DimSplice d
dimCoerce Int
0 String
"AB", Int -> [String] -> DimSplice String
forall d. Int -> [d] -> DimSplice d
dimUnflatten Int
0 [String
"A", String
"B"]]
          Maybe [DimSplice String] -> Maybe [DimSplice String] -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [DimSplice String] -> Maybe [DimSplice String]
forall a. a -> Maybe a
Just [Int -> [String] -> DimSplice String
forall d. Int -> [d] -> DimSplice d
dimUnflatten Int
0 [String
"A", String
"B"]],
      String -> Assertion -> TestTree
testCase String
"Get rid of a coerce after a flatten" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
        [String] -> [DimSplice String] -> Maybe [DimSplice String]
lhs
          [String
"A", String
"B", String
"C"]
          [Int -> Int -> String -> DimSplice String
forall d. Int -> Int -> d -> DimSplice d
dimFlatten Int
0 Int
2 String
"ABC", Int -> String -> DimSplice String
forall d. Int -> d -> DimSplice d
dimCoerce Int
0 String
"K"]
          Maybe [DimSplice String] -> Maybe [DimSplice String] -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [DimSplice String] -> Maybe [DimSplice String]
forall a. a -> Maybe a
Just [Int -> Int -> String -> DimSplice String
forall d. Int -> Int -> d -> DimSplice d
dimFlatten Int
0 Int
2 String
"K"],
      String -> Assertion -> TestTree
testCase String
"Flatten and unflatten (invariant suffix)" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
        [String] -> [DimSplice String] -> Maybe [DimSplice String]
lhs
          [String
"A", String
"B", String
"C"]
          [Int -> Int -> String -> DimSplice String
forall d. Int -> Int -> d -> DimSplice d
dimFlatten Int
0 Int
3 String
"ABC", Int -> [String] -> DimSplice String
forall d. Int -> [d] -> DimSplice d
dimUnflatten Int
0 [String
"D", String
"E", String
"C"]]
          Maybe [DimSplice String] -> Maybe [DimSplice String] -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [DimSplice String] -> Maybe [DimSplice String]
forall a. a -> Maybe a
Just [Int -> Int -> [String] -> DimSplice String
forall d. Int -> Int -> [d] -> DimSplice d
dimSplice Int
0 Int
2 [String
"D", String
"E"]],
      String -> Assertion -> TestTree
testCase String
"Flatten and unflatten (invariant prefix)" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
        [String] -> [DimSplice String] -> Maybe [DimSplice String]
lhs
          [String
"A", String
"B", String
"C"]
          [Int -> Int -> String -> DimSplice String
forall d. Int -> Int -> d -> DimSplice d
dimFlatten Int
0 Int
3 String
"ABC", Int -> [String] -> DimSplice String
forall d. Int -> [d] -> DimSplice d
dimUnflatten Int
0 [String
"A", String
"D", String
"E"]]
          Maybe [DimSplice String] -> Maybe [DimSplice String] -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [DimSplice String] -> Maybe [DimSplice String]
forall a. a -> Maybe a
Just [Int -> Int -> [String] -> DimSplice String
forall d. Int -> Int -> [d] -> DimSplice d
dimSplice Int
1 Int
2 [String
"D", String
"E"]],
      String -> Assertion -> TestTree
testCase String
"Invariant part of splice" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
        [String] -> [DimSplice String] -> Maybe [DimSplice String]
lhs
          [String
"A", String
"B", String
"C", String
"D"]
          [Int -> Int -> ShapeBase String -> DimSplice String
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
1 Int
3 (ShapeBase String -> DimSplice String)
-> ShapeBase String -> DimSplice String
forall a b. (a -> b) -> a -> b
$ [String] -> ShapeBase String
forall d. [d] -> ShapeBase d
Shape [String
"BC", String
"D"]]
          Maybe [DimSplice String] -> Maybe [DimSplice String] -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [DimSplice String] -> Maybe [DimSplice String]
forall a. a -> Maybe a
Just [Int -> Int -> ShapeBase String -> DimSplice String
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
1 Int
2 (ShapeBase String -> DimSplice String)
-> ShapeBase String -> DimSplice String
forall a b. (a -> b) -> a -> b
$ [String] -> ShapeBase String
forall d. [d] -> ShapeBase d
Shape [String
"BC"]],
      String -> Assertion -> TestTree
testCase String
"Necessary coercion" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
        [String] -> [DimSplice String] -> Maybe [DimSplice String]
lhs
          [String
"A", String
"B"]
          [Int -> String -> DimSplice String
forall d. Int -> d -> DimSplice d
dimCoerce Int
0 String
"C", Int -> String -> DimSplice String
forall d. Int -> d -> DimSplice d
dimCoerce Int
1 String
"D"]
          Maybe [DimSplice String] -> Maybe [DimSplice String] -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= Maybe [DimSplice String]
forall a. Maybe a
Nothing,
      String -> Assertion -> TestTree
testCase String
"Another necessary coercion" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
        [String] -> [DimSplice String] -> Maybe [DimSplice String]
lhs
          [String
"A", String
"B", String
"C"]
          [Int -> String -> DimSplice String
forall d. Int -> d -> DimSplice d
dimCoerce Int
0 String
"A'", Int -> String -> DimSplice String
forall d. Int -> d -> DimSplice d
dimCoerce Int
1 String
"A'", Int -> String -> DimSplice String
forall d. Int -> d -> DimSplice d
dimCoerce Int
2 String
"A'"]
          Maybe [DimSplice String] -> Maybe [DimSplice String] -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= Maybe [DimSplice String]
forall a. Maybe a
Nothing,
      String -> Assertion -> TestTree
testCase String
"Long with redundancies" (Assertion -> TestTree) -> Assertion -> TestTree
forall a b. (a -> b) -> a -> b
$
        [String] -> [DimSplice String] -> Maybe [DimSplice String]
lhs
          [String
"A", String
"B", String
"C", String
"D"]
          [ Int -> Int -> ShapeBase String -> DimSplice String
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
1 Int
3 (ShapeBase String -> DimSplice String)
-> ShapeBase String -> DimSplice String
forall a b. (a -> b) -> a -> b
$ [String] -> ShapeBase String
forall d. [d] -> ShapeBase d
Shape [String
"BC", String
"D"],
            Int -> String -> DimSplice String
forall d. Int -> d -> DimSplice d
dimCoerce Int
1 String
"BC",
            Int -> String -> DimSplice String
forall d. Int -> d -> DimSplice d
dimCoerce Int
2 String
"D",
            Int -> Int -> String -> DimSplice String
forall d. Int -> Int -> d -> DimSplice d
dimFlatten Int
1 Int
2 String
"BCD",
            Int -> Int -> String -> DimSplice String
forall d. Int -> Int -> d -> DimSplice d
dimFlatten Int
0 Int
2 String
"ABCD"
          ]
          Maybe [DimSplice String] -> Maybe [DimSplice String] -> Assertion
forall a. (Eq a, Show a, HasCallStack) => a -> a -> Assertion
@?= [DimSplice String] -> Maybe [DimSplice String]
forall a. a -> Maybe a
Just [Int -> Int -> String -> DimSplice String
forall d. Int -> Int -> d -> DimSplice d
dimFlatten Int
0 Int
4 String
"ABCD"]
    ]
  where
    lhs :: [String] -> [DimSplice String] -> Maybe [DimSplice String]
lhs [String]
orig_shape [DimSplice String]
ss =
      let ShapeBase String
res_shape :: ShapeBase String =
            (ShapeBase String -> DimSplice String -> ShapeBase String)
-> ShapeBase String -> [DimSplice String] -> ShapeBase String
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
L.foldl' ShapeBase String -> DimSplice String -> ShapeBase String
forall d. ShapeBase d -> DimSplice d -> ShapeBase d
applySplice ([String] -> ShapeBase String
forall d. [d] -> ShapeBase d
Shape [String]
orig_shape) [DimSplice String]
ss
       in NewShape String -> [DimSplice String]
forall d. NewShape d -> [DimSplice d]
dimSplices
            (NewShape String -> [DimSplice String])
-> Maybe (NewShape String) -> Maybe [DimSplice String]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ShapeBase String -> NewShape String -> Maybe (NewShape String)
forall d. Eq d => ShapeBase d -> NewShape d -> Maybe (NewShape d)
simplifyNewShape ([String] -> ShapeBase String
forall d. [d] -> ShapeBase d
Shape [String]
orig_shape) ([DimSplice String] -> ShapeBase String -> NewShape String
forall d. [DimSplice d] -> ShapeBase d -> NewShape d
NewShape [DimSplice String]
ss ShapeBase String
res_shape)

tests :: TestTree
tests :: TestTree
tests =
  String -> [TestTree] -> TestTree
testGroup String
"ReshapeTests" ([TestTree] -> TestTree)
-> ([[TestTree]] -> [TestTree]) -> [[TestTree]] -> TestTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[TestTree]] -> [TestTree]
forall a. Monoid a => [a] -> a
mconcat ([[TestTree]] -> TestTree) -> [[TestTree]] -> TestTree
forall a b. (a -> b) -> a -> b
$
    [ [TestTree]
reshapeOuterTests,
      [TestTree]
reshapeInnerTests,
      [TestTree]
flipReshapeRearrangeTests,
      [TestTree]
flipRearrangeReshapeTests,
      [TestTree
simplifyTests]
    ]