module Futhark.IR.Mem.IxFun.Alg
( IxFun (..),
iota,
offsetIndex,
permute,
reshape,
coerce,
slice,
flatSlice,
expand,
shape,
index,
disjoint,
)
where
import Data.List qualified as L
import Data.Set qualified as S
import Futhark.IR.Pretty ()
import Futhark.IR.Prop
import Futhark.IR.Syntax
( DimIndex (..),
FlatDimIndex (..),
FlatSlice (..),
Slice (..),
flatSliceDims,
sliceDims,
unitSlice,
)
import Futhark.Util.IntegralExp
import Futhark.Util.Pretty
import Prelude hiding (div, mod, span)
type Shape num = [num]
type Indices num = [num]
type Permutation = [Int]
data IxFun num
= Direct (Shape num)
| Permute (IxFun num) Permutation
| Index (IxFun num) (Slice num)
| FlatIndex (IxFun num) (FlatSlice num)
| Reshape (IxFun num) (Shape num)
| Coerce (IxFun num) (Shape num)
| OffsetIndex (IxFun num) num
| Expand num num (IxFun num)
deriving (IxFun num -> IxFun num -> Bool
(IxFun num -> IxFun num -> Bool)
-> (IxFun num -> IxFun num -> Bool) -> Eq (IxFun num)
forall num. Eq num => IxFun num -> IxFun num -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall num. Eq num => IxFun num -> IxFun num -> Bool
== :: IxFun num -> IxFun num -> Bool
$c/= :: forall num. Eq num => IxFun num -> IxFun num -> Bool
/= :: IxFun num -> IxFun num -> Bool
Eq, Int -> IxFun num -> ShowS
[IxFun num] -> ShowS
IxFun num -> String
(Int -> IxFun num -> ShowS)
-> (IxFun num -> String)
-> ([IxFun num] -> ShowS)
-> Show (IxFun num)
forall num. Show num => Int -> IxFun num -> ShowS
forall num. Show num => [IxFun num] -> ShowS
forall num. Show num => IxFun num -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall num. Show num => Int -> IxFun num -> ShowS
showsPrec :: Int -> IxFun num -> ShowS
$cshow :: forall num. Show num => IxFun num -> String
show :: IxFun num -> String
$cshowList :: forall num. Show num => [IxFun num] -> ShowS
showList :: [IxFun num] -> ShowS
Show)
instance (Pretty num) => Pretty (IxFun num) where
pretty :: forall ann. IxFun num -> Doc ann
pretty (Direct Shape num
dims) =
Doc ann
"Direct" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens ([Doc ann] -> Doc ann
forall a. [Doc a] -> Doc a
commasep ([Doc ann] -> Doc ann) -> [Doc ann] -> Doc ann
forall a b. (a -> b) -> a -> b
$ (num -> Doc ann) -> Shape num -> [Doc ann]
forall a b. (a -> b) -> [a] -> [b]
map num -> Doc ann
forall ann. num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Shape num
dims)
pretty (Permute IxFun num
fun Permutation
perm) = IxFun num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. IxFun num -> Doc ann
pretty IxFun num
fun Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Permutation -> Doc ann
forall ann. Permutation -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Permutation
perm
pretty (Index IxFun num
fun Slice num
is) = IxFun num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. IxFun num -> Doc ann
pretty IxFun num
fun Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Slice num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. Slice num -> Doc ann
pretty Slice num
is
pretty (FlatIndex IxFun num
fun FlatSlice num
is) = IxFun num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. IxFun num -> Doc ann
pretty IxFun num
fun Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> FlatSlice num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. FlatSlice num -> Doc ann
pretty FlatSlice num
is
pretty (Reshape IxFun num
fun Shape num
oldshape) =
IxFun num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. IxFun num -> Doc ann
pretty IxFun num
fun
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"->reshape"
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Shape num -> Doc ann
forall ann. Shape num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Shape num
oldshape)
pretty (Coerce IxFun num
fun Shape num
oldshape) =
IxFun num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. IxFun num -> Doc ann
pretty IxFun num
fun
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"->coerce"
Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (Shape num -> Doc ann
forall ann. Shape num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty Shape num
oldshape)
pretty (OffsetIndex IxFun num
fun num
i) =
IxFun num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. IxFun num -> Doc ann
pretty IxFun num
fun Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"->offset_index" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann
parens (num -> Doc ann
forall ann. num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty num
i)
pretty (Expand num
o num
p IxFun num
fun) =
Doc ann
"expand(" Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> num -> Doc ann
forall ann. num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty num
o Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"," Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> num -> Doc ann
forall ann. num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
pretty num
p Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
"," Doc ann -> Doc ann -> Doc ann
forall ann. Doc ann -> Doc ann -> Doc ann
<+> IxFun num -> Doc ann
forall a ann. Pretty a => a -> Doc ann
forall ann. IxFun num -> Doc ann
pretty IxFun num
fun Doc ann -> Doc ann -> Doc ann
forall a. Semigroup a => a -> a -> a
<> Doc ann
")"
iota :: Shape num -> IxFun num
iota :: forall num. Shape num -> IxFun num
iota = Shape num -> IxFun num
forall num. Shape num -> IxFun num
Direct
offsetIndex :: IxFun num -> num -> IxFun num
offsetIndex :: forall num. IxFun num -> num -> IxFun num
offsetIndex = IxFun num -> num -> IxFun num
forall num. IxFun num -> num -> IxFun num
OffsetIndex
permute :: IxFun num -> Permutation -> IxFun num
permute :: forall num. IxFun num -> Permutation -> IxFun num
permute = IxFun num -> Permutation -> IxFun num
forall num. IxFun num -> Permutation -> IxFun num
Permute
slice :: IxFun num -> Slice num -> IxFun num
slice :: forall num. IxFun num -> Slice num -> IxFun num
slice = IxFun num -> Slice num -> IxFun num
forall num. IxFun num -> Slice num -> IxFun num
Index
flatSlice :: IxFun num -> FlatSlice num -> IxFun num
flatSlice :: forall num. IxFun num -> FlatSlice num -> IxFun num
flatSlice = IxFun num -> FlatSlice num -> IxFun num
forall num. IxFun num -> FlatSlice num -> IxFun num
FlatIndex
expand :: num -> num -> IxFun num -> IxFun num
expand :: forall num. num -> num -> IxFun num -> IxFun num
expand = num -> num -> IxFun num -> IxFun num
forall num. num -> num -> IxFun num -> IxFun num
Expand
reshape :: IxFun num -> Shape num -> IxFun num
reshape :: forall num. IxFun num -> Shape num -> IxFun num
reshape = IxFun num -> Shape num -> IxFun num
forall num. IxFun num -> Shape num -> IxFun num
Reshape
coerce :: IxFun num -> Shape num -> IxFun num
coerce :: forall num. IxFun num -> Shape num -> IxFun num
coerce = IxFun num -> Shape num -> IxFun num
forall num. IxFun num -> Shape num -> IxFun num
Reshape
shape ::
(IntegralExp num) =>
IxFun num ->
Shape num
shape :: forall num. IntegralExp num => IxFun num -> Shape num
shape (Direct Shape num
dims) =
Shape num
dims
shape (Permute IxFun num
ixfun Permutation
perm) =
Permutation -> Shape num -> Shape num
forall a. Permutation -> [a] -> [a]
rearrangeShape Permutation
perm (Shape num -> Shape num) -> Shape num -> Shape num
forall a b. (a -> b) -> a -> b
$ IxFun num -> Shape num
forall num. IntegralExp num => IxFun num -> Shape num
shape IxFun num
ixfun
shape (Index IxFun num
_ Slice num
how) =
Slice num -> Shape num
forall d. Slice d -> [d]
sliceDims Slice num
how
shape (FlatIndex IxFun num
ixfun FlatSlice num
how) =
FlatSlice num -> Shape num
forall d. FlatSlice d -> [d]
flatSliceDims FlatSlice num
how Shape num -> Shape num -> Shape num
forall a. Semigroup a => a -> a -> a
<> Shape num -> Shape num
forall a. HasCallStack => [a] -> [a]
tail (IxFun num -> Shape num
forall num. IntegralExp num => IxFun num -> Shape num
shape IxFun num
ixfun)
shape (Reshape IxFun num
_ Shape num
dims) =
Shape num
dims
shape (Coerce IxFun num
_ Shape num
dims) =
Shape num
dims
shape (OffsetIndex IxFun num
ixfun num
_) =
IxFun num -> Shape num
forall num. IntegralExp num => IxFun num -> Shape num
shape IxFun num
ixfun
shape (Expand num
_ num
_ IxFun num
ixfun) =
IxFun num -> Shape num
forall num. IntegralExp num => IxFun num -> Shape num
shape IxFun num
ixfun
index ::
(Eq num, IntegralExp num) =>
IxFun num ->
Indices num ->
num
index :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index (Direct Shape num
dims) Shape num
is =
Shape num -> num
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (Shape num -> num) -> Shape num -> num
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> Shape num -> Shape num -> Shape num
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith num -> num -> num
forall a. Num a => a -> a -> a
(*) Shape num
is Shape num
slicesizes
where
slicesizes :: Shape num
slicesizes = Int -> Shape num -> Shape num
forall a. Int -> [a] -> [a]
drop Int
1 (Shape num -> Shape num) -> Shape num -> Shape num
forall a b. (a -> b) -> a -> b
$ Shape num -> Shape num
forall num. IntegralExp num => [num] -> [num]
sliceSizes Shape num
dims
index (Permute IxFun num
fun Permutation
perm) Shape num
is_new =
IxFun num -> Shape num -> num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index IxFun num
fun Shape num
is_old
where
is_old :: Shape num
is_old = Permutation -> Shape num -> Shape num
forall a. Permutation -> [a] -> [a]
rearrangeShape (Permutation -> Permutation
rearrangeInverse Permutation
perm) Shape num
is_new
index (Index IxFun num
fun (Slice [DimIndex num]
js)) Shape num
is =
IxFun num -> Shape num -> num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index IxFun num
fun ([DimIndex num] -> Shape num -> Shape num
forall {a}. Num a => [DimIndex a] -> [a] -> [a]
adjust [DimIndex num]
js Shape num
is)
where
adjust :: [DimIndex a] -> [a] -> [a]
adjust (DimFix a
j : [DimIndex a]
js') [a]
is' = a
j a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [DimIndex a] -> [a] -> [a]
adjust [DimIndex a]
js' [a]
is'
adjust (DimSlice a
j a
_ a
s : [DimIndex a]
js') (a
i : [a]
is') = a
j a -> a -> a
forall a. Num a => a -> a -> a
+ a
i a -> a -> a
forall a. Num a => a -> a -> a
* a
s a -> [a] -> [a]
forall a. a -> [a] -> [a]
: [DimIndex a] -> [a] -> [a]
adjust [DimIndex a]
js' [a]
is'
adjust [DimIndex a]
_ [a]
_ = []
index (FlatIndex IxFun num
fun (FlatSlice num
offset [FlatDimIndex num]
js)) Shape num
is =
IxFun num -> Shape num -> num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index IxFun num
fun (Shape num -> num) -> Shape num -> num
forall a b. (a -> b) -> a -> b
$ Shape num -> num
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
sum (num
offset num -> Shape num -> Shape num
forall a. a -> [a] -> [a]
: (num -> FlatDimIndex num -> num)
-> Shape num -> [FlatDimIndex num] -> Shape num
forall a b c. (a -> b -> c) -> [a] -> [b] -> [c]
zipWith num -> FlatDimIndex num -> num
forall {a}. Num a => a -> FlatDimIndex a -> a
f Shape num
is [FlatDimIndex num]
js) num -> Shape num -> Shape num
forall a. a -> [a] -> [a]
: Int -> Shape num -> Shape num
forall a. Int -> [a] -> [a]
drop ([FlatDimIndex num] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [FlatDimIndex num]
js) Shape num
is
where
f :: a -> FlatDimIndex a -> a
f a
i (FlatDimIndex a
_ a
s) = a
i a -> a -> a
forall a. Num a => a -> a -> a
* a
s
index (Reshape IxFun num
fun Shape num
newshape) Shape num
is =
let new_indices :: Shape num
new_indices = Shape num -> Shape num -> Shape num -> Shape num
forall num. IntegralExp num => [num] -> [num] -> [num] -> [num]
reshapeIndex (IxFun num -> Shape num
forall num. IntegralExp num => IxFun num -> Shape num
shape IxFun num
fun) Shape num
newshape Shape num
is
in IxFun num -> Shape num -> num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index IxFun num
fun Shape num
new_indices
index (Coerce IxFun num
fun Shape num
_) Shape num
is =
IxFun num -> Shape num -> num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index IxFun num
fun Shape num
is
index (OffsetIndex IxFun num
fun num
i) Shape num
is =
case IxFun num -> Shape num
forall num. IntegralExp num => IxFun num -> Shape num
shape IxFun num
fun of
num
d : Shape num
ds ->
IxFun num -> Shape num -> num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index (IxFun num -> Slice num -> IxFun num
forall num. IxFun num -> Slice num -> IxFun num
Index IxFun num
fun ([DimIndex num] -> Slice num
forall d. [DimIndex d] -> Slice d
Slice (num -> num -> num -> DimIndex num
forall d. d -> d -> d -> DimIndex d
DimSlice num
i (num
d num -> num -> num
forall a. Num a => a -> a -> a
- num
i) num
1 DimIndex num -> [DimIndex num] -> [DimIndex num]
forall a. a -> [a] -> [a]
: (num -> DimIndex num) -> Shape num -> [DimIndex num]
forall a b. (a -> b) -> [a] -> [b]
map (num -> num -> DimIndex num
forall d. Num d => d -> d -> DimIndex d
unitSlice num
0) Shape num
ds))) Shape num
is
[] -> String -> num
forall a. HasCallStack => String -> a
error String
"index: OffsetIndex: underlying index function has rank zero"
index (Expand num
o num
p IxFun num
ixfun) Shape num
is =
num
o num -> num -> num
forall a. Num a => a -> a -> a
+ num
p num -> num -> num
forall a. Num a => a -> a -> a
* IxFun num -> Shape num -> num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index IxFun num
ixfun Shape num
is
allPoints :: (IntegralExp num, Enum num) => [num] -> [[num]]
allPoints :: forall num. (IntegralExp num, Enum num) => [num] -> [[num]]
allPoints [num]
dims =
let total :: num
total = [num] -> num
forall a. Num a => [a] -> a
forall (t :: * -> *) a. (Foldable t, Num a) => t a -> a
product [num]
dims
strides :: [num]
strides = Int -> [num] -> [num]
forall a. Int -> [a] -> [a]
drop Int
1 ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ [num] -> [num]
forall a. [a] -> [a]
L.reverse ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ (num -> num -> num) -> num -> [num] -> [num]
forall b a. (b -> a -> b) -> b -> [a] -> [b]
scanl num -> num -> num
forall a. Num a => a -> a -> a
(*) num
1 ([num] -> [num]) -> [num] -> [num]
forall a b. (a -> b) -> a -> b
$ [num] -> [num]
forall a. [a] -> [a]
L.reverse [num]
dims
in (num -> [num]) -> [num] -> [[num]]
forall a b. (a -> b) -> [a] -> [b]
map ([num] -> num -> [num]
forall {t :: * -> *} {a}.
(Foldable t, IntegralExp a) =>
t a -> a -> [a]
unflatInd [num]
strides) [num
0 .. num
total num -> num -> num
forall a. Num a => a -> a -> a
- num
1]
where
unflatInd :: t a -> a -> [a]
unflatInd t a
strides a
x =
([a], a) -> [a]
forall a b. (a, b) -> a
fst (([a], a) -> [a]) -> ([a], a) -> [a]
forall a b. (a -> b) -> a -> b
$
(([a], a) -> a -> ([a], a)) -> ([a], a) -> t a -> ([a], a)
forall b a. (b -> a -> b) -> b -> t a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl
( \([a]
res, a
acc) a
span ->
([a]
res [a] -> [a] -> [a]
forall a. [a] -> [a] -> [a]
++ [a
acc a -> a -> a
forall e. IntegralExp e => e -> e -> e
`div` a
span], a
acc a -> a -> a
forall e. IntegralExp e => e -> e -> e
`mod` a
span)
)
([], a
x)
t a
strides
disjoint :: (IntegralExp num, Ord num, Enum num) => IxFun num -> IxFun num -> Bool
disjoint :: forall num.
(IntegralExp num, Ord num, Enum num) =>
IxFun num -> IxFun num -> Bool
disjoint IxFun num
ixf1 IxFun num
ixf2 =
let shp1 :: Shape num
shp1 = IxFun num -> Shape num
forall num. IntegralExp num => IxFun num -> Shape num
shape IxFun num
ixf1
points1 :: Set (Shape num)
points1 = [Shape num] -> Set (Shape num)
forall a. Ord a => [a] -> Set a
S.fromList ([Shape num] -> Set (Shape num)) -> [Shape num] -> Set (Shape num)
forall a b. (a -> b) -> a -> b
$ Shape num -> [Shape num]
forall num. (IntegralExp num, Enum num) => [num] -> [[num]]
allPoints Shape num
shp1
allIdxs1 :: Set num
allIdxs1 = (Shape num -> num) -> Set (Shape num) -> Set num
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map (IxFun num -> Shape num -> num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index IxFun num
ixf1) Set (Shape num)
points1
shp2 :: Shape num
shp2 = IxFun num -> Shape num
forall num. IntegralExp num => IxFun num -> Shape num
shape IxFun num
ixf2
points2 :: Set (Shape num)
points2 = [Shape num] -> Set (Shape num)
forall a. Ord a => [a] -> Set a
S.fromList ([Shape num] -> Set (Shape num)) -> [Shape num] -> Set (Shape num)
forall a b. (a -> b) -> a -> b
$ Shape num -> [Shape num]
forall num. (IntegralExp num, Enum num) => [num] -> [[num]]
allPoints Shape num
shp2
allIdxs2 :: Set num
allIdxs2 = (Shape num -> num) -> Set (Shape num) -> Set num
forall b a. Ord b => (a -> b) -> Set a -> Set b
S.map (IxFun num -> Shape num -> num
forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Indices num -> num
index IxFun num
ixf2) Set (Shape num)
points2
in Set num -> Set num -> Bool
forall a. Ord a => Set a -> Set a -> Bool
S.disjoint Set num
allIdxs1 Set num
allIdxs2