-- | A simple index operation representation.  Every operation corresponds to a
-- constructor.
module Futhark.IR.Mem.IxFun.Alg
  ( IxFun (..),
    iota,
    offsetIndex,
    permute,
    reshape,
    coerce,
    slice,
    flatSlice,
    expand,
    shape,
    index,
    disjoint,
  )
where

import Data.List qualified as L
import Data.Set qualified as S
import Futhark.IR.Pretty ()
import Futhark.IR.Prop
import Futhark.IR.Syntax
  ( DimIndex (..),
    FlatDimIndex (..),
    FlatSlice (..),
    Slice (..),
    flatSliceDims,
    sliceDims,
    unitSlice,
  )
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Prelude hiding (div, mod, span)

type Shape num = [num]

type Indices num = [num]

type Permutation = [Int]

data IxFun num
  = Direct (Shape num)
  | Permute (IxFun num) Permutation
  | Index (IxFun num) (Slice num)
  | FlatIndex (IxFun num) (FlatSlice num)
  | Reshape (IxFun num) (Shape num)
  | Coerce (IxFun num) (Shape num)
  | OffsetIndex (IxFun num) num
  | Expand num num (IxFun num)
  deriving (IxFun num -> IxFun num -> Bool
(IxFun num -> IxFun num -> Bool)
-> (IxFun num -> IxFun num -> Bool) -> Eq (IxFun num)
forall num. Eq num => IxFun num -> IxFun num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall num. Eq num => IxFun num -> IxFun num -> Bool
== :: IxFun num -> IxFun num -> Bool
$c/= :: forall num. Eq num => IxFun num -> IxFun num -> Bool
/= :: IxFun num -> IxFun num -> Bool
Eq, Int -> IxFun num -> ShowS
[IxFun num] -> ShowS
IxFun num -> String
(Int -> IxFun num -> ShowS)
-> (IxFun num -> String)
-> ([IxFun num] -> ShowS)
-> Show (IxFun num)
forall num. Show num => Int -> IxFun num -> ShowS
forall num. Show num => [IxFun num] -> ShowS
forall num. Show num => IxFun num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall num. Show num => Int -> IxFun num -> ShowS
showsPrec :: Int -> IxFun num -> ShowS
$cshow :: forall num. Show num => IxFun num -> String
show :: IxFun num -> String
$cshowList :: forall num. Show num => [IxFun num] -> ShowS
showList :: [IxFun num] -> ShowS
Show)

instance (Pretty num) => Pretty (IxFun num) where
  pretty :: forall ann. IxFun num -> Doc ann
pretty (Direct Shape num
dims) =
    Doc ann
"Direct" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (num -> Doc ann) -> Shape num -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map num -> Doc ann
forall ann. num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Shape num
dims)
  pretty (Permute IxFun num
fun Permutation
perm) = IxFun num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. IxFun num -> Doc ann
pretty IxFun num
fun Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Permutation -> Doc ann
forall ann. Permutation -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Permutation
perm
  pretty (Index IxFun num
fun Slice num
is) = IxFun num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. IxFun num -> Doc ann
pretty IxFun num
fun Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Slice num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Slice num -> Doc ann
pretty Slice num
is
  pretty (FlatIndex IxFun num
fun FlatSlice num
is) = IxFun num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. IxFun num -> Doc ann
pretty IxFun num
fun Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> FlatSlice num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. FlatSlice num -> Doc ann
pretty FlatSlice num
is
  pretty (Reshape IxFun num
fun Shape num
oldshape) =
    IxFun num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. IxFun num -> Doc ann
pretty IxFun num
fun
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"->reshape"
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Shape num -> Doc ann
forall ann. Shape num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Shape num
oldshape)
  pretty (Coerce IxFun num
fun Shape num
oldshape) =
    IxFun num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. IxFun num -> Doc ann
pretty IxFun num
fun
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"->coerce"
      Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Shape num -> Doc ann
forall ann. Shape num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Shape num
oldshape)
  pretty (OffsetIndex IxFun num
fun num
i) =
    IxFun num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. IxFun num -> Doc ann
pretty IxFun num
fun Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"->offset_index" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (num -> Doc ann
forall ann. num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty num
i)
  pretty (Expand num
o num
p IxFun num
fun) =
    Doc ann
"expand(" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> num -> Doc ann
forall ann. num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty num
o Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"," Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> num -> Doc ann
forall ann. num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty num
p Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"," Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> IxFun num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. IxFun num -> Doc ann
pretty IxFun num
fun Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
")"

iota :: Shape num -> IxFun num
iota :: forall num. Shape num -> IxFun num
iota = Shape num -> IxFun num
forall num. Shape num -> IxFun num
Direct

offsetIndex :: IxFun num -> num -> IxFun num
offsetIndex :: forall num. IxFun num -> num -> IxFun num
offsetIndex = IxFun num -> num -> IxFun num
forall num. IxFun num -> num -> IxFun num
OffsetIndex

permute :: IxFun num -> Permutation -> IxFun num
permute :: forall num. IxFun num -> Permutation -> IxFun num
permute = IxFun num -> Permutation -> IxFun num
forall num. IxFun num -> Permutation -> IxFun num
Permute

slice :: IxFun num -> Slice num -> IxFun num
slice :: forall num. IxFun num -> Slice num -> IxFun num
slice = IxFun num -> Slice num -> IxFun num
forall num. IxFun num -> Slice num -> IxFun num
Index

flatSlice :: IxFun num -> FlatSlice num -> IxFun num
flatSlice :: forall num. IxFun num -> FlatSlice num -> IxFun num
flatSlice = IxFun num -> FlatSlice num -> IxFun num
forall num. IxFun num -> FlatSlice num -> IxFun num
FlatIndex

expand :: num -> num -> IxFun num -> IxFun num
expand :: forall num. num -> num -> IxFun num -> IxFun num
expand = num -> num -> IxFun num -> IxFun num
forall num. num -> num -> IxFun num -> IxFun num
Expand

reshape :: IxFun num -> Shape num -> IxFun num
reshape :: forall num. IxFun num -> Shape num -> IxFun num
reshape = IxFun num -> Shape num -> IxFun num
forall num. IxFun num -> Shape num -> IxFun num
Reshape

coerce :: IxFun num -> Shape num -> IxFun num
coerce :: forall num. IxFun num -> Shape num -> IxFun num
coerce = IxFun num -> Shape num -> IxFun num
forall num. IxFun num -> Shape num -> IxFun num
Reshape

shape ::
  (IntegralExp num) =>
  IxFun num ->
  Shape num
shape :: forall num. IntegralExp num => IxFun num -> Shape num
shape (Direct Shape num
dims) =
  Shape num
dims
shape (Permute IxFun num
ixfun Permutation
perm) =
  Permutation -> Shape num -> Shape num
forall a. Permutation -> [a] -> [a]
rearrangeShape Permutation
perm (Shape num -> Shape num) -> Shape num -> Shape num
forall a b. (a -> b) -> a -> b
$ IxFun num -> Shape num
forall num. IntegralExp num => IxFun num -> Shape num
shape IxFun num
ixfun
shape (Index IxFun num
_ Slice num
how) =
  Slice num -> Shape num
forall d. Slice d -> [d]
sliceDims Slice num
how
shape (FlatIndex IxFun num
ixfun FlatSlice num
how) =
  FlatSlice num -> Shape num
forall d. FlatSlice d -> [d]
flatSliceDims FlatSlice num
how Shape num -> Shape num -> Shape num
forall a. Semigroup a => a -> a -> a
<> Shape num -> Shape num
forall a. HasCallStack => [a] -> [a]
tail (IxFun num -> Shape num
forall num. IntegralExp num => IxFun num -> Shape num
shape IxFun num
ixfun)
shape (Reshape IxFun num
_ Shape num
dims) =
  Shape num
dims
shape (Coerce IxFun num
_ Shape num
dims) =
  Shape num
dims
shape (OffsetIndex IxFun num
ixfun num
_) =
  IxFun num -> Shape num
forall num. IntegralExp num => IxFun num -> Shape num
shape IxFun num
ixfun
shape (Expand num
_ num
_ IxFun num
ixfun) =
  IxFun num -> Shape num
forall num. IntegralExp num => IxFun num -> Shape num
shape IxFun num
ixfun

index ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Indices num ->
  num
index :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index (Direct Shape num
dims) Shape num
is =
  Shape num -> num
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (Shape num -> num) -> Shape num -> num
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> Shape num -> Shape num -> Shape num
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith num -> num -> num
forall a. Num a => a -> a -> a
(*) Shape num
is Shape num
slicesizes
  where
    slicesizes :: Shape num
slicesizes = Int -> Shape num -> Shape num
forall a. Int -> [a] -> [a]
drop Int
1 (Shape num -> Shape num) -> Shape num -> Shape num
forall a b. (a -> b) -> a -> b
$ Shape num -> Shape num
forall num. IntegralExp num => [num] -> [num]
sliceSizes Shape num
dims
index (Permute IxFun num
fun Permutation
perm) Shape num
is_new =
  IxFun num -> Shape num -> num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index IxFun num
fun Shape num
is_old
  where
    is_old :: Shape num
is_old = Permutation -> Shape num -> Shape num
forall a. Permutation -> [a] -> [a]
rearrangeShape (Permutation -> Permutation
rearrangeInverse Permutation
perm) Shape num
is_new
index (Index IxFun num
fun (Slice [DimIndex num]
js)) Shape num
is =
  IxFun num -> Shape num -> num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index IxFun num
fun ([DimIndex num] -> Shape num -> Shape num
forall {a}. Num a => [DimIndex a] -> [a] -> [a]
adjust [DimIndex num]
js Shape num
is)
  where
    adjust :: [DimIndex a] -> [a] -> [a]
adjust (DimFix a
j : [DimIndex a]
js') [a]
is' = a
j a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [DimIndex a] -> [a] -> [a]
adjust [DimIndex a]
js' [a]
is'
    adjust (DimSlice a
j a
_ a
s : [DimIndex a]
js') (a
i : [a]
is') = a
j a -> a -> a
forall a. Num a => a -> a -> a
+ a
i a -> a -> a
forall a. Num a => a -> a -> a
* a
s a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [DimIndex a] -> [a] -> [a]
adjust [DimIndex a]
js' [a]
is'
    adjust [DimIndex a]
_ [a]
_ = []
index (FlatIndex IxFun num
fun (FlatSlice num
offset [FlatDimIndex num]
js)) Shape num
is =
  IxFun num -> Shape num -> num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index IxFun num
fun (Shape num -> num) -> Shape num -> num
forall a b. (a -> b) -> a -> b
$ Shape num -> num
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (num
offset num -> Shape num -> Shape num
forall a. a -> [a] -> [a]
: (num -> FlatDimIndex num -> num)
-> Shape num -> [FlatDimIndex num] -> Shape num
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith num -> FlatDimIndex num -> num
forall {a}. Num a => a -> FlatDimIndex a -> a
f Shape num
is [FlatDimIndex num]
js) num -> Shape num -> Shape num
forall a. a -> [a] -> [a]
: Int -> Shape num -> Shape num
forall a. Int -> [a] -> [a]
drop ([FlatDimIndex num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FlatDimIndex num]
js) Shape num
is
  where
    f :: a -> FlatDimIndex a -> a
f a
i (FlatDimIndex a
_ a
s) = a
i a -> a -> a
forall a. Num a => a -> a -> a
* a
s
index (Reshape IxFun num
fun Shape num
newshape) Shape num
is =
  let new_indices :: Shape num
new_indices = Shape num -> Shape num -> Shape num -> Shape num
forall num. IntegralExp num => [num] -> [num] -> [num] -> [num]
reshapeIndex (IxFun num -> Shape num
forall num. IntegralExp num => IxFun num -> Shape num
shape IxFun num
fun) Shape num
newshape Shape num
is
   in IxFun num -> Shape num -> num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index IxFun num
fun Shape num
new_indices
index (Coerce IxFun num
fun Shape num
_) Shape num
is =
  IxFun num -> Shape num -> num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index IxFun num
fun Shape num
is
index (OffsetIndex IxFun num
fun num
i) Shape num
is =
  case IxFun num -> Shape num
forall num. IntegralExp num => IxFun num -> Shape num
shape IxFun num
fun of
    num
d : Shape num
ds ->
      IxFun num -> Shape num -> num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index (IxFun num -> Slice num -> IxFun num
forall num. IxFun num -> Slice num -> IxFun num
Index IxFun num
fun ([DimIndex num] -> Slice num
forall d. [DimIndex d] -> Slice d
Slice (num -> num -> num -> DimIndex num
forall d. d -> d -> d -> DimIndex d
DimSlice num
i (num
d num -> num -> num
forall a. Num a => a -> a -> a
- num
i) num
1 DimIndex num -> [DimIndex num] -> [DimIndex num]
forall a. a -> [a] -> [a]
: (num -> DimIndex num) -> Shape num -> [DimIndex num]
forall a b. (a -> b) -> [a] -> [b]
map (num -> num -> DimIndex num
forall d. Num d => d -> d -> DimIndex d
unitSlice num
0) Shape num
ds))) Shape num
is
    [] -> String -> num
forall a. HasCallStack => String -> a
error String
"index: OffsetIndex: underlying index function has rank zero"
index (Expand num
o num
p IxFun num
ixfun) Shape num
is =
  num
o num -> num -> num
forall a. Num a => a -> a -> a
+ num
p num -> num -> num
forall a. Num a => a -> a -> a
* IxFun num -> Shape num -> num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index IxFun num
ixfun Shape num
is

allPoints :: (IntegralExp num, Enum num) => [num] -> [[num]]
allPoints :: forall num. (IntegralExp num, Enum num) => [num] -> [[num]]
allPoints [num]
dims =
  let total :: num
total = [num] -> num
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [num]
dims
      strides :: [num]
strides = Int -> [num] -> [num]
forall a. Int -> [a] -> [a]
drop Int
1 ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ [num] -> [num]
forall a. [a] -> [a]
L.reverse ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> num -> [num] -> [num]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl num -> num -> num
forall a. Num a => a -> a -> a
(*) num
1 ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ [num] -> [num]
forall a. [a] -> [a]
L.reverse [num]
dims
   in (num -> [num]) -> [num] -> [[num]]
forall a b. (a -> b) -> [a] -> [b]
map ([num] -> num -> [num]
forall {t :: * -> *} {a}.
(Foldable t, IntegralExp a) =>
t a -> a -> [a]
unflatInd [num]
strides) [num
0 .. num
total num -> num -> num
forall a. Num a => a -> a -> a
- num
1]
  where
    unflatInd :: t a -> a -> [a]
unflatInd t a
strides a
x =
      ([a], a) -> [a]
forall a b. (a, b) -> a
fst (([a], a) -> [a]) -> ([a], a) -> [a]
forall a b. (a -> b) -> a -> b
$
        (([a], a) -> a -> ([a], a)) -> ([a], a) -> t a -> ([a], a)
forall b a. (b -> a -> b) -> b -> t a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
          ( \([a]
res, a
acc) a
span ->
              ([a]
res [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
acc a -> a -> a
forall e. IntegralExp e => e -> e -> e
`div` a
span], a
acc a -> a -> a
forall e. IntegralExp e => e -> e -> e
`mod` a
span)
          )
          ([], a
x)
          t a
strides

disjoint :: (IntegralExp num, Ord num, Enum num) => IxFun num -> IxFun num -> Bool
disjoint :: forall num.
(IntegralExp num, Ord num, Enum num) =>
IxFun num -> IxFun num -> Bool
disjoint IxFun num
ixf1 IxFun num
ixf2 =
  let shp1 :: Shape num
shp1 = IxFun num -> Shape num
forall num. IntegralExp num => IxFun num -> Shape num
shape IxFun num
ixf1
      points1 :: Set (Shape num)
points1 = [Shape num] -> Set (Shape num)
forall a. Ord a => [a] -> Set a
S.fromList ([Shape num] -> Set (Shape num)) -> [Shape num] -> Set (Shape num)
forall a b. (a -> b) -> a -> b
$ Shape num -> [Shape num]
forall num. (IntegralExp num, Enum num) => [num] -> [[num]]
allPoints Shape num
shp1
      allIdxs1 :: Set num
allIdxs1 = (Shape num -> num) -> Set (Shape num) -> Set num
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map (IxFun num -> Shape num -> num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index IxFun num
ixf1) Set (Shape num)
points1
      shp2 :: Shape num
shp2 = IxFun num -> Shape num
forall num. IntegralExp num => IxFun num -> Shape num
shape IxFun num
ixf2
      points2 :: Set (Shape num)
points2 = [Shape num] -> Set (Shape num)
forall a. Ord a => [a] -> Set a
S.fromList ([Shape num] -> Set (Shape num)) -> [Shape num] -> Set (Shape num)
forall a b. (a -> b) -> a -> b
$ Shape num -> [Shape num]
forall num. (IntegralExp num, Enum num) => [num] -> [[num]]
allPoints Shape num
shp2
      allIdxs2 :: Set num
allIdxs2 = (Shape num -> num) -> Set (Shape num) -> Set num
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map (IxFun num -> Shape num -> num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index IxFun num
ixf2) Set (Shape num)
points2
   in Set num -> Set num -> Bool
forall a. Ord a => Set a -> Set a -> Bool
S.disjoint Set num
allIdxs1 Set num
allIdxs2