-- | Facilities for creating, inspecting, and simplifying reshape and
-- coercion operations.
module Futhark.IR.Prop.Reshape
  ( -- * Construction
    shapeCoerce,
    reshapeAll,
    reshapeCoerce,

    -- * Execution
    reshapeOuter,
    reshapeInner,
    newshapeInner,
    applySplice,

    -- * Simplification
    flipReshapeRearrange,
    flipRearrangeReshape,
    simplifyNewShape,

    -- * Shape calculations
    reshapeIndex,
    flattenIndex,
    unflattenIndex,
    sliceSizes,

    -- * Analysis
    ReshapeKind (..),
    reshapeKind,
    newShape,
  )
where

import Control.Monad (guard, mplus)
import Data.Foldable
import Data.Maybe
import Futhark.IR.Prop.Rearrange (isMapTranspose, rearrangeInverse, rearrangeShape)
import Futhark.IR.Syntax
import Futhark.Util (focusNth, mapAccumLM, takeLast)
import Futhark.Util.IntegralExp
import Prelude hiding (product, quot, sum)

-- | Construct a 'NewShape' that completely reshapes the initial shape.
reshapeAll :: (ArrayShape old) => old -> ShapeBase new -> NewShape new
reshapeAll :: forall old new.
ArrayShape old =>
old -> ShapeBase new -> NewShape new
reshapeAll old
old ShapeBase new
new = [DimSplice new] -> ShapeBase new -> NewShape new
forall d. [DimSplice d] -> ShapeBase d -> NewShape d
NewShape [Int -> Int -> ShapeBase new -> DimSplice new
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
0 (old -> Int
forall a. ArrayShape a => a -> Int
shapeRank old
old) ShapeBase new
new] ShapeBase new
new

-- | Construct a 'NewShape' that coerces the shape.
reshapeCoerce :: ShapeBase new -> NewShape new
reshapeCoerce :: forall new. ShapeBase new -> NewShape new
reshapeCoerce ShapeBase new
shape = [DimSplice new] -> ShapeBase new -> NewShape new
forall d. [DimSplice d] -> ShapeBase d -> NewShape d
NewShape ((new -> Int -> DimSplice new) -> [new] -> [Int] -> [DimSplice new]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith new -> Int -> DimSplice new
forall {d}. d -> Int -> DimSplice d
dim (ShapeBase new -> [new]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase new
shape) [Int
0 ..]) ShapeBase new
shape
  where
    dim :: d -> Int -> DimSplice d
dim d
x Int
i = Int -> Int -> ShapeBase d -> DimSplice d
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
i Int
1 (ShapeBase d -> DimSplice d) -> ShapeBase d -> DimSplice d
forall a b. (a -> b) -> a -> b
$ [d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape [d
x]

-- | Construct a 'Reshape' that is a 'ReshapeCoerce'.
shapeCoerce :: [SubExp] -> VName -> Exp rep
shapeCoerce :: forall rep. [SubExp] -> VName -> Exp rep
shapeCoerce [SubExp]
newdims VName
arr =
  BasicOp -> Exp rep
forall rep. BasicOp -> Exp rep
BasicOp (BasicOp -> Exp rep) -> BasicOp -> Exp rep
forall a b. (a -> b) -> a -> b
$ VName -> NewShape SubExp -> BasicOp
Reshape VName
arr (ShapeBase SubExp -> NewShape SubExp
forall new. ShapeBase new -> NewShape new
reshapeCoerce ([SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape [SubExp]
newdims))

-- | @reshapeOuter newshape n oldshape@ returns a 'Reshape' expression
-- that replaces the outer @n@ dimensions of @oldshape@ with @newshape@.
reshapeOuter :: Shape -> Int -> Shape -> Shape
reshapeOuter :: ShapeBase SubExp -> Int -> ShapeBase SubExp -> ShapeBase SubExp
reshapeOuter ShapeBase SubExp
newshape Int
n ShapeBase SubExp
oldshape =
  ShapeBase SubExp
newshape ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
drop Int
n (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
oldshape))

-- | @reshapeInner newshape n oldshape@ produces a shape that replaces the inner
-- @m-n@ dimensions (where @m@ is the rank of @oldshape@) of @src@ with
-- @newshape@.
reshapeInner :: Shape -> Int -> Shape -> Shape
reshapeInner :: ShapeBase SubExp -> Int -> ShapeBase SubExp -> ShapeBase SubExp
reshapeInner ShapeBase SubExp
newshape Int
n ShapeBase SubExp
oldshape =
  [SubExp] -> ShapeBase SubExp
forall d. [d] -> ShapeBase d
Shape (Int -> [SubExp] -> [SubExp]
forall a. Int -> [a] -> [a]
take Int
n (ShapeBase SubExp -> [SubExp]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase SubExp
oldshape)) ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
newshape

-- | @newshapeInner outershape newshape@ bumps all the dimensions in @newshape@
-- by the rank of @outershape@, essentially making them operate on the inner
-- dimensions of a larger array, and also updates the shape of @newshape@ to
-- have @outershape@ outermost.
newshapeInner :: Shape -> NewShape SubExp -> NewShape SubExp
newshapeInner :: ShapeBase SubExp -> NewShape SubExp -> NewShape SubExp
newshapeInner ShapeBase SubExp
outershape (NewShape [DimSplice SubExp]
ss ShapeBase SubExp
oldshape) =
  [DimSplice SubExp] -> ShapeBase SubExp -> NewShape SubExp
forall d. [DimSplice d] -> ShapeBase d -> NewShape d
NewShape ((DimSplice SubExp -> DimSplice SubExp)
-> [DimSplice SubExp] -> [DimSplice SubExp]
forall a b. (a -> b) -> [a] -> [b]
map DimSplice SubExp -> DimSplice SubExp
forall {d}. DimSplice d -> DimSplice d
f [DimSplice SubExp]
ss) (ShapeBase SubExp
outershape ShapeBase SubExp -> ShapeBase SubExp -> ShapeBase SubExp
forall a. Semigroup a => a -> a -> a
<> ShapeBase SubExp
oldshape)
  where
    r :: Int
r = ShapeBase SubExp -> Int
forall a. ArrayShape a => a -> Int
shapeRank ShapeBase SubExp
outershape
    f :: DimSplice d -> DimSplice d
f (DimSplice Int
i Int
k ShapeBase d
shape) = Int -> Int -> ShapeBase d -> DimSplice d
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice (Int
r Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) Int
k ShapeBase d
shape

-- | @reshapeIndex to_dims from_dims is@ transforms the index list
-- @is@ (which is into an array of shape @from_dims@) into an index
-- list @is'@, which is into an array of shape @to_dims@.  @is@ must
-- have the same length as @from_dims@, and @is'@ will have the same
-- length as @to_dims@.
reshapeIndex ::
  (IntegralExp num) =>
  [num] ->
  [num] ->
  [num] ->
  [num]
reshapeIndex :: forall num. IntegralExp num => [num] -> [num] -> [num] -> [num]
reshapeIndex [num]
to_dims [num]
from_dims [num]
is =
  [num] -> num -> [num]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex [num]
to_dims (num -> [num]) -> num -> [num]
forall a b. (a -> b) -> a -> b
$ [num] -> [num] -> num
forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [num]
from_dims [num]
is

-- | @unflattenIndex dims i@ computes a list of indices into an array
-- with dimension @dims@ given the flat index @i@.  The resulting list
-- will have the same size as @dims@.
unflattenIndex ::
  (IntegralExp num) =>
  [num] ->
  num ->
  [num]
unflattenIndex :: forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndex = [num] -> num -> [num]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndexFromSlices ([num] -> num -> [num])
-> ([num] -> [num]) -> [num] -> num -> [num]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> [num] -> [num]
forall a. Int -> [a] -> [a]
drop Int
1 ([num] -> [num]) -> ([num] -> [num]) -> [num] -> [num]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [num] -> [num]
forall num. IntegralExp num => [num] -> [num]
sliceSizes

unflattenIndexFromSlices ::
  (IntegralExp num) =>
  [num] ->
  num ->
  [num]
unflattenIndexFromSlices :: forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndexFromSlices [] num
_ = []
unflattenIndexFromSlices (num
size : [num]
slices) num
i =
  (num
i num -> num -> num
forall e. IntegralExp e => e -> e -> e
`quot` num
size) num -> [num] -> [num]
forall a. a -> [a] -> [a]
: [num] -> num -> [num]
forall num. IntegralExp num => [num] -> num -> [num]
unflattenIndexFromSlices [num]
slices (num
i num -> num -> num
forall a. Num a => a -> a -> a
- (num
i num -> num -> num
forall e. IntegralExp e => e -> e -> e
`quot` num
size) num -> num -> num
forall a. Num a => a -> a -> a
* num
size)

-- | @flattenIndex dims is@ computes the flat index of @is@ into an
-- array with dimensions @dims@.  The length of @dims@ and @is@ must
-- be the same.
flattenIndex ::
  (IntegralExp num) =>
  [num] ->
  [num] ->
  num
flattenIndex :: forall num. IntegralExp num => [num] -> [num] -> num
flattenIndex [num]
dims [num]
is
  | [num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [num]
is Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= [num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [num]
slicesizes = [Char] -> num
forall a. HasCallStack => [Char] -> a
error [Char]
"flattenIndex: length mismatch"
  | Bool
otherwise = [num] -> num
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum ([num] -> num) -> [num] -> num
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> [num] -> [num] -> [num]
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith num -> num -> num
forall a. Num a => a -> a -> a
(*) [num]
is [num]
slicesizes
  where
    slicesizes :: [num]
slicesizes = Int -> [num] -> [num]
forall a. Int -> [a] -> [a]
drop Int
1 ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ [num] -> [num]
forall num. IntegralExp num => [num] -> [num]
sliceSizes [num]
dims

-- | Given a length @n@ list of dimensions @dims@, @sizeSizes dims@
-- will compute a length @n+1@ list of the size of each possible array
-- slice.  The first element of this list will be the product of
-- @dims@, and the last element will be 1.
sliceSizes ::
  (IntegralExp num) =>
  [num] ->
  [num]
sliceSizes :: forall num. IntegralExp num => [num] -> [num]
sliceSizes [] = [num
1]
sliceSizes (num
n : [num]
ns) =
  [num] -> num
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product (num
n num -> [num] -> [num]
forall a. a -> [a] -> [a]
: [num]
ns) num -> [num] -> [num]
forall a. a -> [a] -> [a]
: [num] -> [num]
forall num. IntegralExp num => [num] -> [num]
sliceSizes [num]
ns

{- HLINT ignore sliceSizes -}

-- | Interchange a reshape and rearrange. Essentially, rewrite composition
--
-- @
-- let v1 = reshape(v0, v1_shape)
-- let v2 = rearrange(v1, perm)
-- @
--
-- into
--
-- @
-- let v1' = rearrange(v0, perm')
-- let v2' = reshape(v1', v1_shape')
-- @
--
-- The function is given the shape of @v0@, @v1@, and the @perm@, and returns
-- @perm'@. This is a meaningful operation when @v2@ is itself reshaped, as the
-- reshape-reshape can be fused. This can significantly simplify long chains of
-- reshapes and rearranges.
flipReshapeRearrange ::
  (Eq d) =>
  [d] ->
  [d] ->
  [Int] ->
  Maybe [Int]
flipReshapeRearrange :: forall d. Eq d => [d] -> [d] -> [Int] -> Maybe [Int]
flipReshapeRearrange [d]
v0_shape [d]
v1_shape [Int]
perm = do
  (Int
num_map_dims, Int
num_a_dims, Int
num_b_dims) <- [Int] -> Maybe (Int, Int, Int)
isMapTranspose [Int]
perm
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Int
num_a_dims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Int
num_b_dims Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1
  let map_dims :: [d]
map_dims = Int -> [d] -> [d]
forall a. Int -> [a] -> [a]
take Int
num_map_dims [d]
v0_shape
      num_b_dims_expanded :: Int
num_b_dims_expanded = [d] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [d]
v0_shape Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_map_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_a_dims
      num_a_dims_expanded :: Int
num_a_dims_expanded = [d] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [d]
v0_shape Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_map_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
num_b_dims
      caseA :: Maybe [Int]
caseA = do
        Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Int -> [d] -> [d]
forall a. Int -> [a] -> [a]
take Int
num_a_dims [d]
v0_shape [d] -> [d] -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> [d] -> [d]
forall a. Int -> [a] -> [a]
take Int
num_b_dims [d]
v1_shape
        let perm' :: [Int]
perm' =
              [Int
0 .. Int
num_map_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
                [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
num_map_dims) ([Int
1 .. Int
num_b_dims_expanded] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
0])
        [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
perm'
      caseB :: Maybe [Int]
caseB = do
        Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ Int -> [d] -> [d]
forall a. Int -> [a] -> [a]
takeLast Int
num_b_dims [d]
v0_shape [d] -> [d] -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> [d] -> [d]
forall a. Int -> [a] -> [a]
takeLast Int
num_b_dims [d]
v1_shape
        let perm' :: [Int]
perm' =
              [Int
0 .. Int
num_map_dims Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1]
                [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map
                  (Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
num_map_dims)
                  (Int
num_a_dims_expanded Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int
0 .. Int
num_a_dims_expanded Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1])
        [Int] -> Maybe [Int]
forall a. a -> Maybe a
Just [Int]
perm'

  Bool -> Maybe ()
forall (f :: * -> *). Alternative f => Bool -> f ()
guard (Bool -> Maybe ()) -> Bool -> Maybe ()
forall a b. (a -> b) -> a -> b
$ [d]
map_dims [d] -> [d] -> Bool
forall a. Eq a => a -> a -> Bool
== Int -> [d] -> [d]
forall a. Int -> [a] -> [a]
take Int
num_map_dims [d]
v1_shape

  Maybe [Int]
caseA Maybe [Int] -> Maybe [Int] -> Maybe [Int]
forall a. Maybe a -> Maybe a -> Maybe a
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` Maybe [Int]
caseB

-- | Interchange a reshape and rearrange. Essentially, rewrite composition
--
-- @
-- let v1 = rearrange(v0, perm)
-- let v2 = reshape(v1, v1_shape)
-- @
--
-- into
--
-- @
-- let v1' = reshape(v0, v1_shape')
-- let v2' = rearrange(v1', perm')
-- @
--
-- The function is given @perm@ and @v1_shape@, and returns @perm'@ and
-- @v1_shape'@. This is a meaningful operation when @v2@ is itself rearranged
-- (or @v0@ the result of a reshape), as this enables fusion.
flipRearrangeReshape :: [Int] -> NewShape d -> Maybe (NewShape d, [Int])
flipRearrangeReshape :: forall d. [Int] -> NewShape d -> Maybe (NewShape d, [Int])
flipRearrangeReshape [Int]
orig_perm (NewShape [DimSplice d]
ss ShapeBase d
shape) = do
  ([Int]
perm', [DimSplice d]
ss') <- ([Int] -> DimSplice d -> Maybe ([Int], DimSplice d))
-> [Int] -> [DimSplice d] -> Maybe ([Int], [DimSplice d])
forall (m :: * -> *) (t :: * -> *) acc x y.
(Monad m, Traversable t) =>
(acc -> x -> m (acc, y)) -> acc -> t x -> m (acc, t y)
mapAccumLM [Int] -> DimSplice d -> Maybe ([Int], DimSplice d)
forall {d}. [Int] -> DimSplice d -> Maybe ([Int], DimSplice d)
f [Int]
orig_perm [DimSplice d]
ss
  let shape' :: ShapeBase d
shape' = [d] -> ShapeBase d
forall d. [d] -> ShapeBase d
Shape ([d] -> ShapeBase d) -> [d] -> ShapeBase d
forall a b. (a -> b) -> a -> b
$ [Int] -> [d] -> [d]
forall a. [Int] -> [a] -> [a]
rearrangeShape ([Int] -> [Int]
rearrangeInverse [Int]
perm') (ShapeBase d -> [d]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase d
shape)
  (NewShape d, [Int]) -> Maybe (NewShape d, [Int])
forall a. a -> Maybe a
Just ([DimSplice d] -> ShapeBase d -> NewShape d
forall d. [DimSplice d] -> ShapeBase d -> NewShape d
NewShape [DimSplice d]
ss' ShapeBase d
shape', [Int]
perm')
  where
    f :: [Int] -> DimSplice d -> Maybe ([Int], DimSplice d)
f [Int]
perm (DimSplice Int
i Int
1 ShapeBase d
s) = do
      ([Int]
perm_bef, Int
j, [Int]
perm_aft) <- Int -> [Int] -> Maybe ([Int], Int, [Int])
forall int a. Integral int => int -> [a] -> Maybe ([a], a, [a])
focusNth Int
i [Int]
perm
      let adj :: Int -> Int
adj Int
l = if Int
l Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
j then Int
l Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ShapeBase d -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeBase d
s Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 else Int
l
      ([Int], DimSplice d) -> Maybe ([Int], DimSplice d)
forall a. a -> Maybe a
Just
        ( (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int
adj [Int]
perm_bef [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ [Int
j .. Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ShapeBase d -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeBase d
s Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1] [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
++ (Int -> Int) -> [Int] -> [Int]
forall a b. (a -> b) -> [a] -> [b]
map Int -> Int
adj [Int]
perm_aft,
          Int -> Int -> ShapeBase d -> DimSplice d
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
j Int
1 ShapeBase d
s
        )
    f [Int]
_ DimSplice d
_ = Maybe ([Int], DimSplice d)
forall a. Maybe a
Nothing

-- | Which kind of reshape is this?
data ReshapeKind
  = -- | New shape is dynamically same as original.
    ReshapeCoerce
  | -- | Any kind of reshaping.
    ReshapeArbitrary
  deriving (ReshapeKind -> ReshapeKind -> Bool
(ReshapeKind -> ReshapeKind -> Bool)
-> (ReshapeKind -> ReshapeKind -> Bool) -> Eq ReshapeKind
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: ReshapeKind -> ReshapeKind -> Bool
== :: ReshapeKind -> ReshapeKind -> Bool
$c/= :: ReshapeKind -> ReshapeKind -> Bool
/= :: ReshapeKind -> ReshapeKind -> Bool
Eq, Eq ReshapeKind
Eq ReshapeKind =>
(ReshapeKind -> ReshapeKind -> Ordering)
-> (ReshapeKind -> ReshapeKind -> Bool)
-> (ReshapeKind -> ReshapeKind -> Bool)
-> (ReshapeKind -> ReshapeKind -> Bool)
-> (ReshapeKind -> ReshapeKind -> Bool)
-> (ReshapeKind -> ReshapeKind -> ReshapeKind)
-> (ReshapeKind -> ReshapeKind -> ReshapeKind)
-> Ord ReshapeKind
ReshapeKind -> ReshapeKind -> Bool
ReshapeKind -> ReshapeKind -> Ordering
ReshapeKind -> ReshapeKind -> ReshapeKind
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: ReshapeKind -> ReshapeKind -> Ordering
compare :: ReshapeKind -> ReshapeKind -> Ordering
$c< :: ReshapeKind -> ReshapeKind -> Bool
< :: ReshapeKind -> ReshapeKind -> Bool
$c<= :: ReshapeKind -> ReshapeKind -> Bool
<= :: ReshapeKind -> ReshapeKind -> Bool
$c> :: ReshapeKind -> ReshapeKind -> Bool
> :: ReshapeKind -> ReshapeKind -> Bool
$c>= :: ReshapeKind -> ReshapeKind -> Bool
>= :: ReshapeKind -> ReshapeKind -> Bool
$cmax :: ReshapeKind -> ReshapeKind -> ReshapeKind
max :: ReshapeKind -> ReshapeKind -> ReshapeKind
$cmin :: ReshapeKind -> ReshapeKind -> ReshapeKind
min :: ReshapeKind -> ReshapeKind -> ReshapeKind
Ord, Int -> ReshapeKind -> ShowS
[ReshapeKind] -> ShowS
ReshapeKind -> [Char]
(Int -> ReshapeKind -> ShowS)
-> (ReshapeKind -> [Char])
-> ([ReshapeKind] -> ShowS)
-> Show ReshapeKind
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ReshapeKind -> ShowS
showsPrec :: Int -> ReshapeKind -> ShowS
$cshow :: ReshapeKind -> [Char]
show :: ReshapeKind -> [Char]
$cshowList :: [ReshapeKind] -> ShowS
showList :: [ReshapeKind] -> ShowS
Show)

-- | Determine whether this might be a coercion.
reshapeKind :: NewShape SubExp -> ReshapeKind
reshapeKind :: NewShape SubExp -> ReshapeKind
reshapeKind (NewShape [DimSplice SubExp]
ss ShapeBase SubExp
_)
  | (DimSplice SubExp -> Bool) -> [DimSplice SubExp] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all DimSplice SubExp -> Bool
forall {d}. DimSplice d -> Bool
unit [DimSplice SubExp]
ss = ReshapeKind
ReshapeCoerce
  | Bool
otherwise = ReshapeKind
ReshapeArbitrary
  where
    unit :: DimSplice d -> Bool
unit (DimSplice Int
_ Int
1 (Shape [d
_])) = Bool
True
    unit DimSplice d
_ = Bool
False

-- | Apply the splice to a shape.
applySplice :: ShapeBase d -> DimSplice d -> ShapeBase d
applySplice :: forall d. ShapeBase d -> DimSplice d -> ShapeBase d
applySplice ShapeBase d
shape_bef (DimSplice Int
i Int
k ShapeBase d
shape) =
  Int -> ShapeBase d -> ShapeBase d
forall d. Int -> ShapeBase d -> ShapeBase d
takeDims Int
i ShapeBase d
shape_bef ShapeBase d -> ShapeBase d -> ShapeBase d
forall a. Semigroup a => a -> a -> a
<> ShapeBase d
shape ShapeBase d -> ShapeBase d -> ShapeBase d
forall a. Semigroup a => a -> a -> a
<> Int -> ShapeBase d -> ShapeBase d
forall d. Int -> ShapeBase d -> ShapeBase d
stripDims (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k) ShapeBase d
shape_bef

-- | @dimSpan i n s@ gets @n@ dimensions starting from @i@ from @s@.
dimSpan :: Int -> Int -> ShapeBase d -> ShapeBase d
dimSpan :: forall d. Int -> Int -> ShapeBase d -> ShapeBase d
dimSpan Int
i Int
n = Int -> ShapeBase d -> ShapeBase d
forall d. Int -> ShapeBase d -> ShapeBase d
takeDims Int
n (ShapeBase d -> ShapeBase d)
-> (ShapeBase d -> ShapeBase d) -> ShapeBase d -> ShapeBase d
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ShapeBase d -> ShapeBase d
forall d. Int -> ShapeBase d -> ShapeBase d
dropDims Int
i

next ::
  (Eq d) =>
  ShapeBase d ->
  DimSplice d ->
  DimSplice d ->
  [DimSplice d] ->
  Maybe [DimSplice d]
next :: forall d.
Eq d =>
ShapeBase d
-> DimSplice d
-> DimSplice d
-> [DimSplice d]
-> Maybe [DimSplice d]
next ShapeBase d
shape DimSplice d
x DimSplice d
y [DimSplice d]
ss =
  (DimSplice d
x :) ([DimSplice d] -> [DimSplice d])
-> Maybe [DimSplice d] -> Maybe [DimSplice d]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (ShapeBase d, DimSplice d) -> [DimSplice d] -> Maybe [DimSplice d]
forall d.
Eq d =>
(ShapeBase d, DimSplice d) -> [DimSplice d] -> Maybe [DimSplice d]
move (ShapeBase d -> DimSplice d -> ShapeBase d
forall d. ShapeBase d -> DimSplice d -> ShapeBase d
applySplice ShapeBase d
shape DimSplice d
x, DimSplice d
y) [DimSplice d]
ss

move ::
  (Eq d) =>
  (ShapeBase d, DimSplice d) ->
  [DimSplice d] ->
  Maybe [DimSplice d]
--
-- A coercion that does not do anything.
move :: forall d.
Eq d =>
(ShapeBase d, DimSplice d) -> [DimSplice d] -> Maybe [DimSplice d]
move (ShapeBase d
shape_bef, DimSplice Int
i1 Int
n1 ShapeBase d
shape) [DimSplice d]
ss
  | Int -> Int -> ShapeBase d -> ShapeBase d
forall d. Int -> Int -> ShapeBase d -> ShapeBase d
dimSpan Int
i1 Int
n1 ShapeBase d
shape_bef ShapeBase d -> ShapeBase d -> Bool
forall a. Eq a => a -> a -> Bool
== ShapeBase d
shape =
      [DimSplice d] -> Maybe [DimSplice d]
forall a. a -> Maybe a
Just [DimSplice d]
ss
--
-- See if we can find some redundancy.
move (ShapeBase d
shape, DimSplice Int
i1 Int
n1 ShapeBase d
s1) [DimSplice d]
ss
  -- Check for redundant prefix.
  | [(d, d)]
match <-
      ((d, d) -> Bool) -> [(d, d)] -> [(d, d)]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile ((d -> d -> Bool) -> (d, d) -> Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry d -> d -> Bool
forall a. Eq a => a -> a -> Bool
(==)) ([(d, d)] -> [(d, d)]) -> [(d, d)] -> [(d, d)]
forall a b. (a -> b) -> a -> b
$
        [d] -> [d] -> [(d, d)]
forall a b. [a] -> [b] -> [(a, b)]
zip (ShapeBase d -> [d]
forall d. ShapeBase d -> [d]
shapeDims (Int -> Int -> ShapeBase d -> ShapeBase d
forall d. Int -> Int -> ShapeBase d -> ShapeBase d
dimSpan Int
i1 Int
n1 ShapeBase d
shape)) (ShapeBase d -> [d]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase d
s1),
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(d, d)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(d, d)]
match,
    [(d, d)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(d, d)]
match Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n1 =
      let k :: Int
k = [(d, d)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(d, d)]
match
       in [DimSplice d] -> Maybe [DimSplice d]
forall a. a -> Maybe a
Just ([DimSplice d] -> Maybe [DimSplice d])
-> [DimSplice d] -> Maybe [DimSplice d]
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ShapeBase d -> DimSplice d
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice (Int
i1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k) (Int
n1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k) (Int -> ShapeBase d -> ShapeBase d
forall d. Int -> ShapeBase d -> ShapeBase d
dropDims Int
k ShapeBase d
s1) DimSplice d -> [DimSplice d] -> [DimSplice d]
forall a. a -> [a] -> [a]
: [DimSplice d]
ss
  -- Check for redundant suffix.
  | [(d, d)]
match <-
      ((d, d) -> Bool) -> [(d, d)] -> [(d, d)]
forall a. (a -> Bool) -> [a] -> [a]
takeWhile ((d -> d -> Bool) -> (d, d) -> Bool
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry d -> d -> Bool
forall a. Eq a => a -> a -> Bool
(==)) ([(d, d)] -> [(d, d)]) -> [(d, d)] -> [(d, d)]
forall a b. (a -> b) -> a -> b
$
        [d] -> [d] -> [(d, d)]
forall a b. [a] -> [b] -> [(a, b)]
zip
          ([d] -> [d]
forall a. [a] -> [a]
reverse (ShapeBase d -> [d]
forall d. ShapeBase d -> [d]
shapeDims (Int -> Int -> ShapeBase d -> ShapeBase d
forall d. Int -> Int -> ShapeBase d -> ShapeBase d
dimSpan Int
i1 Int
n1 ShapeBase d
shape)))
          ([d] -> [d]
forall a. [a] -> [a]
reverse (ShapeBase d -> [d]
forall d. ShapeBase d -> [d]
shapeDims ShapeBase d
s1)),
    Bool -> Bool
not (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ [(d, d)] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [(d, d)]
match,
    [(d, d)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(d, d)]
match Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
n1 =
      let k :: Int
k = [(d, d)] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [(d, d)]
match
       in [DimSplice d] -> Maybe [DimSplice d]
forall a. a -> Maybe a
Just ([DimSplice d] -> Maybe [DimSplice d])
-> [DimSplice d] -> Maybe [DimSplice d]
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ShapeBase d -> DimSplice d
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
i1 (Int
n1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k) (Int -> ShapeBase d -> ShapeBase d
forall d. Int -> ShapeBase d -> ShapeBase d
takeDims (ShapeBase d -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeBase d
s1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
k) ShapeBase d
s1) DimSplice d -> [DimSplice d] -> [DimSplice d]
forall a. a -> [a] -> [a]
: [DimSplice d]
ss
--
-- Base case.
move (ShapeBase d, DimSplice d)
_ [] = Maybe [DimSplice d]
forall a. Maybe a
Nothing
--
-- A coercion can be fused with anything.
move (ShapeBase d
_, DimSplice Int
i1 Int
1 (Shape [d
_])) (DimSplice Int
i2 Int
n2 ShapeBase d
s2 : [DimSplice d]
ss)
  | Int
i1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i2 =
      [DimSplice d] -> Maybe [DimSplice d]
forall a. a -> Maybe a
Just ([DimSplice d] -> Maybe [DimSplice d])
-> [DimSplice d] -> Maybe [DimSplice d]
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ShapeBase d -> DimSplice d
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
i2 Int
n2 ShapeBase d
s2 DimSplice d -> [DimSplice d] -> [DimSplice d]
forall a. a -> [a] -> [a]
: [DimSplice d]
ss
--
-- A flatten with an inverse unflatten turns into nothing.
move (ShapeBase d
shape_bef, DimSplice Int
i1 Int
n1 ShapeBase d
_s1) (DimSplice Int
i2 Int
_n2 ShapeBase d
s2 : [DimSplice d]
ss)
  | Int
i1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i2,
    Int -> Int -> ShapeBase d -> ShapeBase d
forall d. Int -> Int -> ShapeBase d -> ShapeBase d
dimSpan Int
i1 Int
n1 ShapeBase d
shape_bef ShapeBase d -> ShapeBase d -> Bool
forall a. Eq a => a -> a -> Bool
== ShapeBase d
s2 =
      [DimSplice d] -> Maybe [DimSplice d]
forall a. a -> Maybe a
Just [DimSplice d]
ss
--
-- An unflatten where one of the dimensions is then further unflattened.
move (ShapeBase d
_, DimSplice Int
i1 Int
n1 ShapeBase d
s1) (DimSplice Int
i2 Int
n2 ShapeBase d
s2 : [DimSplice d]
ss)
  | Int
i2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
i1,
    Int
i2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
i1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ShapeBase d -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeBase d
s1,
    Int
n1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1,
    Int
n2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 =
      [DimSplice d] -> Maybe [DimSplice d]
forall a. a -> Maybe a
Just ([DimSplice d] -> Maybe [DimSplice d])
-> [DimSplice d] -> Maybe [DimSplice d]
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ShapeBase d -> DimSplice d
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
i1 Int
1 (ShapeBase d
s1_bef ShapeBase d -> ShapeBase d -> ShapeBase d
forall a. Semigroup a => a -> a -> a
<> ShapeBase d
s2 ShapeBase d -> ShapeBase d -> ShapeBase d
forall a. Semigroup a => a -> a -> a
<> ShapeBase d
s1_aft) DimSplice d -> [DimSplice d] -> [DimSplice d]
forall a. a -> [a] -> [a]
: [DimSplice d]
ss
  where
    s1_bef :: ShapeBase d
s1_bef = Int -> ShapeBase d -> ShapeBase d
forall d. Int -> ShapeBase d -> ShapeBase d
takeDims (Int
i2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i1) ShapeBase d
s1
    s1_aft :: ShapeBase d
s1_aft = Int -> ShapeBase d -> ShapeBase d
forall d. Int -> ShapeBase d -> ShapeBase d
dropDims (Int
i2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
i1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ShapeBase d
s1

--
-- Flatten followed by a flattening of overlapping dimensions.
move (ShapeBase d
_, DimSplice Int
i1 Int
n1 ShapeBase d
s1) (DimSplice Int
i2 Int
n2 ShapeBase d
s2 : [DimSplice d]
ss)
  | ShapeBase d -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeBase d
s1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1,
    ShapeBase d -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeBase d
s2 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1,
    Int
i1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1,
    Int
n2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1 =
      [DimSplice d] -> Maybe [DimSplice d]
forall a. a -> Maybe a
Just ([DimSplice d] -> Maybe [DimSplice d])
-> [DimSplice d] -> Maybe [DimSplice d]
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ShapeBase d -> DimSplice d
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
i2 (Int
n2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- ShapeBase d -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeBase d
s1) ShapeBase d
s2 DimSplice d -> [DimSplice d] -> [DimSplice d]
forall a. a -> [a] -> [a]
: [DimSplice d]
ss
--
-- Flatten into an unflatten.
move (ShapeBase d
_, DimSplice Int
i1 Int
n1 (Shape [d
_])) (DimSplice Int
i2 Int
1 ShapeBase d
s2 : [DimSplice d]
ss)
  | Int
i1 Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i2 =
      [DimSplice d] -> Maybe [DimSplice d]
forall a. a -> Maybe a
Just ([DimSplice d] -> Maybe [DimSplice d])
-> [DimSplice d] -> Maybe [DimSplice d]
forall a b. (a -> b) -> a -> b
$ Int -> Int -> ShapeBase d -> DimSplice d
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
i1 Int
n1 ShapeBase d
s2 DimSplice d -> [DimSplice d] -> [DimSplice d]
forall a. a -> [a] -> [a]
: [DimSplice d]
ss
--
-- These cases are for updating dimensions as we move across intervening
-- operations.
move (ShapeBase d
shape, DimSplice Int
i1 Int
n1 ShapeBase d
s1) (DimSplice Int
i2 Int
n2 ShapeBase d
s2 : [DimSplice d]
ss)
  | Int
i1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
i2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n2 =
      ShapeBase d
-> DimSplice d
-> DimSplice d
-> [DimSplice d]
-> Maybe [DimSplice d]
forall d.
Eq d =>
ShapeBase d
-> DimSplice d
-> DimSplice d
-> [DimSplice d]
-> Maybe [DimSplice d]
next ShapeBase d
shape (Int -> Int -> ShapeBase d -> DimSplice d
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
i2 Int
n2 ShapeBase d
s2) (Int -> Int -> ShapeBase d -> DimSplice d
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice (Int
i1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ShapeBase d -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeBase d
s2) Int
n1 ShapeBase d
s1) [DimSplice d]
ss
  | Int
i2 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
i1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n1 =
      ShapeBase d
-> DimSplice d
-> DimSplice d
-> [DimSplice d]
-> Maybe [DimSplice d]
forall d.
Eq d =>
ShapeBase d
-> DimSplice d
-> DimSplice d
-> [DimSplice d]
-> Maybe [DimSplice d]
next ShapeBase d
shape (Int -> Int -> ShapeBase d -> DimSplice d
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice (Int
i2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
n1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ ShapeBase d -> Int
forall a. ShapeBase a -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length ShapeBase d
s1) Int
n2 ShapeBase d
s2) (Int -> Int -> ShapeBase d -> DimSplice d
forall d. Int -> Int -> ShapeBase d -> DimSplice d
DimSplice Int
i1 Int
n2 ShapeBase d
s1) [DimSplice d]
ss
  | Bool
otherwise = Maybe [DimSplice d]
forall a. Maybe a
Nothing

-- This is a quadratic-time function that looks for a DimSplice that can be
-- combined with a move DimSlice (and then does so). Since these lists are
-- usually small, this should not be a problem. It is called to convergence by
-- 'improve'.
improveOne :: (Eq d) => ShapeBase d -> [DimSplice d] -> Maybe [DimSplice d]
improveOne :: forall d.
Eq d =>
ShapeBase d -> [DimSplice d] -> Maybe [DimSplice d]
improveOne ShapeBase d
_ [] = Maybe [DimSplice d]
forall a. Maybe a
Nothing
improveOne ShapeBase d
shape (DimSplice d
s : [DimSplice d]
ss) =
  (ShapeBase d, DimSplice d) -> [DimSplice d] -> Maybe [DimSplice d]
forall d.
Eq d =>
(ShapeBase d, DimSplice d) -> [DimSplice d] -> Maybe [DimSplice d]
move (ShapeBase d
shape, DimSplice d
s) [DimSplice d]
ss Maybe [DimSplice d] -> Maybe [DimSplice d] -> Maybe [DimSplice d]
forall a. Maybe a -> Maybe a -> Maybe a
forall (m :: * -> *) a. MonadPlus m => m a -> m a -> m a
`mplus` ((DimSplice d
s :) ([DimSplice d] -> [DimSplice d])
-> Maybe [DimSplice d] -> Maybe [DimSplice d]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ShapeBase d -> [DimSplice d] -> Maybe [DimSplice d]
forall d.
Eq d =>
ShapeBase d -> [DimSplice d] -> Maybe [DimSplice d]
improveOne (ShapeBase d -> DimSplice d -> ShapeBase d
forall d. ShapeBase d -> DimSplice d -> ShapeBase d
applySplice ShapeBase d
shape DimSplice d
s) [DimSplice d]
ss)

-- | Try to simplify the given 'NewShape'. Returns 'Nothing' if no improvement
-- is possible.
simplifyNewShape :: (Eq d) => ShapeBase d -> NewShape d -> Maybe (NewShape d)
simplifyNewShape :: forall d. Eq d => ShapeBase d -> NewShape d -> Maybe (NewShape d)
simplifyNewShape ShapeBase d
shape_bef (NewShape [DimSplice d]
ss ShapeBase d
shape) =
  [DimSplice d] -> ShapeBase d -> NewShape d
forall d. [DimSplice d] -> ShapeBase d -> NewShape d
NewShape ([DimSplice d] -> ShapeBase d -> NewShape d)
-> Maybe [DimSplice d] -> Maybe (ShapeBase d -> NewShape d)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ([DimSplice d] -> [DimSplice d]
improve ([DimSplice d] -> [DimSplice d])
-> Maybe [DimSplice d] -> Maybe [DimSplice d]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ShapeBase d -> [DimSplice d] -> Maybe [DimSplice d]
forall d.
Eq d =>
ShapeBase d -> [DimSplice d] -> Maybe [DimSplice d]
improveOne ShapeBase d
shape_bef [DimSplice d]
ss) Maybe (ShapeBase d -> NewShape d)
-> Maybe (ShapeBase d) -> Maybe (NewShape d)
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ShapeBase d -> Maybe (ShapeBase d)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ShapeBase d
shape
  where
    improve :: [DimSplice d] -> [DimSplice d]
improve [DimSplice d]
ss' = [DimSplice d]
-> ([DimSplice d] -> [DimSplice d])
-> Maybe [DimSplice d]
-> [DimSplice d]
forall b a. b -> (a -> b) -> Maybe a -> b
maybe [DimSplice d]
ss' [DimSplice d] -> [DimSplice d]
improve (Maybe [DimSplice d] -> [DimSplice d])
-> Maybe [DimSplice d] -> [DimSplice d]
forall a b. (a -> b) -> a -> b
$ ShapeBase d -> [DimSplice d] -> Maybe [DimSplice d]
forall d.
Eq d =>
ShapeBase d -> [DimSplice d] -> Maybe [DimSplice d]
improveOne ShapeBase d
shape_bef [DimSplice d]
ss'

{-# NOINLINE flipReshapeRearrange #-}

{-# NOINLINE flipRearrangeReshape #-}

{-# NOINLINE reshapeKind #-}

{-# NOINLINE simplifyNewShape #-}

{-# NOINLINE newshapeInner #-}