-- | Perform index function operations in both algebraic and LMAD
-- representations.
module Futhark.IR.Mem.IxFunWrapper
  ( IxFun,
    iota,
    permute,
    reshape,
    coerce,
    slice,
    flatSlice,
    expand,
  )
where

import Control.Monad (join)
import Futhark.IR.Mem.IxFun.Alg qualified as IA
import Futhark.IR.Mem.LMAD qualified as I
import Futhark.IR.Syntax (FlatSlice, Slice)
import Futhark.Util.IntegralExp

type Shape num = [num]

type Permutation = [Int]

type IxFun num = (Maybe (I.LMAD num), IA.IxFun num)

iota ::
  (IntegralExp num) =>
  Shape num ->
  IxFun num
iota :: forall num. IntegralExp num => Shape num -> IxFun num
iota Shape num
x = (LMAD num -> Maybe (LMAD num)
forall a. a -> Maybe a
Just (LMAD num -> Maybe (LMAD num)) -> LMAD num -> Maybe (LMAD num)
forall a b. (a -> b) -> a -> b
$ num -> Shape num -> LMAD num
forall num. IntegralExp num => num -> [num] -> LMAD num
I.iota num
0 Shape num
x, Shape num -> IxFun num
forall num. Shape num -> IxFun num
IA.iota Shape num
x)

permute ::
  IxFun num ->
  Permutation ->
  IxFun num
permute :: forall num. IxFun num -> Permutation -> IxFun num
permute (Maybe (LMAD num)
l, IxFun num
a) Permutation
x = (LMAD num -> Permutation -> LMAD num
forall num. LMAD num -> Permutation -> LMAD num
I.permute (LMAD num -> Permutation -> LMAD num)
-> Maybe (LMAD num) -> Maybe (Permutation -> LMAD num)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (LMAD num)
l Maybe (Permutation -> LMAD num)
-> Maybe Permutation -> Maybe (LMAD num)
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Permutation -> Maybe Permutation
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Permutation
x, IxFun num -> Permutation -> IxFun num
forall num. IxFun num -> Permutation -> IxFun num
IA.permute IxFun num
a Permutation
x)

reshape ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Shape num ->
  IxFun num
reshape :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Shape num -> IxFun num
reshape (Maybe (LMAD num)
l, IxFun num
a) Shape num
x = (Maybe (Maybe (LMAD num)) -> Maybe (LMAD num)
forall (m :: * -> *) a. Monad m => m (m a) -> m a
join (LMAD num -> Shape num -> Maybe (LMAD num)
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Shape num -> Maybe (LMAD num)
I.reshape (LMAD num -> Shape num -> Maybe (LMAD num))
-> Maybe (LMAD num) -> Maybe (Shape num -> Maybe (LMAD num))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (LMAD num)
l Maybe (Shape num -> Maybe (LMAD num))
-> Maybe (Shape num) -> Maybe (Maybe (LMAD num))
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Shape num -> Maybe (Shape num)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Shape num
x), IxFun num -> Shape num -> IxFun num
forall num. IxFun num -> Shape num -> IxFun num
IA.reshape IxFun num
a Shape num
x)

coerce ::
  IxFun num ->
  Shape num ->
  IxFun num
coerce :: forall num. IxFun num -> Shape num -> IxFun num
coerce (Maybe (LMAD num)
l, IxFun num
a) Shape num
x = (LMAD num -> Shape num -> LMAD num
forall num. LMAD num -> Shape num -> LMAD num
I.coerce (LMAD num -> Shape num -> LMAD num)
-> Maybe (LMAD num) -> Maybe (Shape num -> LMAD num)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (LMAD num)
l Maybe (Shape num -> LMAD num)
-> Maybe (Shape num) -> Maybe (LMAD num)
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Shape num -> Maybe (Shape num)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Shape num
x, IxFun num -> Shape num -> IxFun num
forall num. IxFun num -> Shape num -> IxFun num
IA.coerce IxFun num
a Shape num
x)

slice ::
  (Eq num, IntegralExp num) =>
  IxFun num ->
  Slice num ->
  IxFun num
slice :: forall num.
(Eq num, IntegralExp num) =>
IxFun num -> Slice num -> IxFun num
slice (Maybe (LMAD num)
l, IxFun num
a) Slice num
x = (LMAD num -> Slice num -> LMAD num
forall num.
(Eq num, IntegralExp num) =>
LMAD num -> Slice num -> LMAD num
I.slice (LMAD num -> Slice num -> LMAD num)
-> Maybe (LMAD num) -> Maybe (Slice num -> LMAD num)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (LMAD num)
l Maybe (Slice num -> LMAD num)
-> Maybe (Slice num) -> Maybe (LMAD num)
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Slice num -> Maybe (Slice num)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Slice num
x, IxFun num -> Slice num -> IxFun num
forall num. IxFun num -> Slice num -> IxFun num
IA.slice IxFun num
a Slice num
x)

flatSlice ::
  (IntegralExp num) =>
  IxFun num ->
  FlatSlice num ->
  IxFun num
flatSlice :: forall num.
IntegralExp num =>
IxFun num -> FlatSlice num -> IxFun num
flatSlice (Maybe (LMAD num)
l, IxFun num
a) FlatSlice num
x = (LMAD num -> FlatSlice num -> LMAD num
forall num.
IntegralExp num =>
LMAD num -> FlatSlice num -> LMAD num
I.flatSlice (LMAD num -> FlatSlice num -> LMAD num)
-> Maybe (LMAD num) -> Maybe (FlatSlice num -> LMAD num)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Maybe (LMAD num)
l Maybe (FlatSlice num -> LMAD num)
-> Maybe (FlatSlice num) -> Maybe (LMAD num)
forall a b. Maybe (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> FlatSlice num -> Maybe (FlatSlice num)
forall a. a -> Maybe a
forall (f :: * -> *) a. Applicative f => a -> f a
pure FlatSlice num
x, IxFun num -> FlatSlice num -> IxFun num
forall num. IxFun num -> FlatSlice num -> IxFun num
IA.flatSlice IxFun num
a FlatSlice num
x)

expand ::
  (IntegralExp num) =>
  num ->
  num ->
  IxFun num ->
  IxFun num
expand :: forall num. IntegralExp num => num -> num -> IxFun num -> IxFun num
expand num
o num
p (Maybe (LMAD num)
lf, IxFun num
af) = (LMAD num -> Maybe (LMAD num)
forall a. a -> Maybe a
Just (LMAD num -> Maybe (LMAD num))
-> (LMAD num -> LMAD num) -> LMAD num -> Maybe (LMAD num)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. num -> num -> LMAD num -> LMAD num
forall num. IntegralExp num => num -> num -> LMAD num -> LMAD num
I.expand num
o num
p (LMAD num -> Maybe (LMAD num))
-> Maybe (LMAD num) -> Maybe (LMAD num)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> m a -> m b
=<< Maybe (LMAD num)
lf, num -> num -> IxFun num -> IxFun num
forall num. num -> num -> IxFun num -> IxFun num
IA.expand num
o num
p IxFun num
af)