module Futhark.Optimise.ArrayLayout
  ( optimiseArrayLayoutGPU,
    optimiseArrayLayoutMC,
  )
where

import Control.Monad.State.Strict
import Futhark.Analysis.AccessPattern (Analyse, analyseDimAccesses)
import Futhark.Analysis.PrimExp.Table (primExpTable)
import Futhark.Builder
import Futhark.IR.GPU (GPU)
import Futhark.IR.MC (MC)
import Futhark.Optimise.ArrayLayout.Layout (layoutTableFromIndexTable)
import Futhark.Optimise.ArrayLayout.Transform (Transform, transformStms)
import Futhark.Pass

optimiseArrayLayout :: (Analyse rep, Transform rep, BuilderOps rep) => String -> Pass rep rep
optimiseArrayLayout :: forall rep.
(Analyse rep, Transform rep, BuilderOps rep) =>
String -> Pass rep rep
optimiseArrayLayout String
s =
  String -> String -> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
forall fromrep torep.
String
-> String
-> (Prog fromrep -> PassM (Prog torep))
-> Pass fromrep torep
Pass
    (String
"optimise array layout " String -> String -> String
forall a. Semigroup a => a -> a -> a
<> String
s)
    String
"Transform array layout for locality optimisations."
    ((Prog rep -> PassM (Prog rep)) -> Pass rep rep)
-> (Prog rep -> PassM (Prog rep)) -> Pass rep rep
forall a b. (a -> b) -> a -> b
$ \Prog rep
prog -> do
      -- Analyse the program
      let index_table :: IndexTable rep
index_table = Prog rep -> IndexTable rep
forall rep. Analyse rep => Prog rep -> IndexTable rep
analyseDimAccesses Prog rep
prog
      -- Compute primExps for all variables
      let table :: PrimExpTable
table = Prog rep -> PrimExpTable
forall rep.
(PrimExpAnalysis rep, RepTypes rep) =>
Prog rep -> PrimExpTable
primExpTable Prog rep
prog
      -- Compute permutations to acheive coalescence for all arrays
      let permutation_table :: LayoutTable
permutation_table = PrimExpTable -> IndexTable rep -> LayoutTable
forall {k} (rep :: k).
Layout rep =>
PrimExpTable -> IndexTable rep -> LayoutTable
layoutTableFromIndexTable PrimExpTable
table IndexTable rep
index_table
      -- Insert permutations in the AST
      (Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
forall rep.
(Scope rep -> Stms rep -> PassM (Stms rep))
-> Prog rep -> PassM (Prog rep)
intraproceduralTransformation (LayoutTable -> Scope rep -> Stms rep -> PassM (Stms rep)
forall {f :: * -> *} {rep}.
(MonadFreshNames f, Transform rep, BuilderOps rep) =>
LayoutTable -> Scope rep -> Stms rep -> f (Stms rep)
onStms LayoutTable
permutation_table) Prog rep
prog
  where
    onStms :: LayoutTable -> Scope rep -> Stms rep -> f (Stms rep)
onStms LayoutTable
layout_table Scope rep
scope Stms rep
stms = do
      let m :: TransformM rep (Stms rep)
m = LayoutTable -> ExpMap rep -> Stms rep -> TransformM rep (Stms rep)
forall rep.
(Transform rep, BuilderOps rep) =>
LayoutTable -> ExpMap rep -> Stms rep -> TransformM rep (Stms rep)
transformStms LayoutTable
layout_table ExpMap rep
forall a. Monoid a => a
mempty Stms rep
stms
      ((Stms rep, Stms rep) -> Stms rep)
-> f (Stms rep, Stms rep) -> f (Stms rep)
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Stms rep, Stms rep) -> Stms rep
forall a b. (a, b) -> a
fst (f (Stms rep, Stms rep) -> f (Stms rep))
-> f (Stms rep, Stms rep) -> f (Stms rep)
forall a b. (a -> b) -> a -> b
$ (VNameSource -> ((Stms rep, Stms rep), VNameSource))
-> f (Stms rep, Stms rep)
forall (m :: * -> *) a.
MonadFreshNames m =>
(VNameSource -> (a, VNameSource)) -> m a
modifyNameSource ((VNameSource -> ((Stms rep, Stms rep), VNameSource))
 -> f (Stms rep, Stms rep))
-> (VNameSource -> ((Stms rep, Stms rep), VNameSource))
-> f (Stms rep, Stms rep)
forall a b. (a -> b) -> a -> b
$ State VNameSource (Stms rep, Stms rep)
-> VNameSource -> ((Stms rep, Stms rep), VNameSource)
forall s a. State s a -> s -> (a, s)
runState (State VNameSource (Stms rep, Stms rep)
 -> VNameSource -> ((Stms rep, Stms rep), VNameSource))
-> State VNameSource (Stms rep, Stms rep)
-> VNameSource
-> ((Stms rep, Stms rep), VNameSource)
forall a b. (a -> b) -> a -> b
$ TransformM rep (Stms rep)
-> Scope rep -> State VNameSource (Stms rep, Stms rep)
forall (m :: * -> *) rep a.
MonadFreshNames m =>
BuilderT rep m a -> Scope rep -> m (a, Stms rep)
runBuilderT TransformM rep (Stms rep)
m Scope rep
scope

-- | The optimisation performed on the GPU representation.
optimiseArrayLayoutGPU :: Pass GPU GPU
optimiseArrayLayoutGPU :: Pass GPU GPU
optimiseArrayLayoutGPU = String -> Pass GPU GPU
forall rep.
(Analyse rep, Transform rep, BuilderOps rep) =>
String -> Pass rep rep
optimiseArrayLayout String
"gpu"

-- | The optimisation performed on the MC representation.
optimiseArrayLayoutMC :: Pass MC MC
optimiseArrayLayoutMC :: Pass MC MC
optimiseArrayLayoutMC = String -> Pass MC MC
forall rep.
(Analyse rep, Transform rep, BuilderOps rep) =>
String -> Pass rep rep
optimiseArrayLayout String
"mc"