{-|
  Copyright  :  (C) 2012-2016, University of Twente,
                (C) 2021-2026, QBayLogic B.V.
  License    :  BSD2 (see the file LICENSE)
  Maintainer :  QBayLogic B.V. <devops@qbaylogic.com>

  Transformation process for normalization
-}

{-# LANGUAGE CPP #-}

module Clash.Normalize.Strategy where

import Clash.Normalize.Transformations
import Clash.Normalize.Types
import Clash.Rewrite.Combinators
import Clash.Rewrite.Types
import Clash.Rewrite.Util

-- [Note: bottomup traversal reduceConst]
--
-- 2-May-2019: There is a bug in the evaluator where all data constructors are
-- considered lazy, even though their declaration says they have strict fields.
-- This causes some reductions to fail because the term under the constructor is
-- not in WHNF, which is what some of the evaluation rules for certain primitive
-- operations expect. Using a bottom-up traversal works around this bug by
-- ensuring that the values under the constructor are in WHNF.
--
-- Using a bottomup traversal ensures that constants are reduced to NF, even if
-- constructors are lazy, thus ensuring more sensible/smaller generated HDL.

-- | Normalisation transformation
normalization :: NormRewrite
normalization :: NormRewrite
normalization =
  NormRewrite
rmDeadcode NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
multPrim NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
constantPropagation NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
rmUnusedExpr NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-!-> NormRewrite
anf NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-!-> NormRewrite
rmDeadcode NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
  NormRewrite
bindConst NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
letTL
  NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
evalConst
  NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-!-> NormRewrite
cse NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-!-> NormRewrite
cleanup NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
  NormRewrite
elimCaseBigNum NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->  -- see [Note] late elimCaseBigNum
  NormRewrite
xOptim NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
rmDeadcode NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
  NormRewrite
cleanup NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
bindSimIO NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
recLetRec NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
splitArgs
  where
    multPrim :: NormRewrite
multPrim   = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"setupMultiResultPrim" HasCallStack => NormRewrite
NormRewrite
setupMultiResultPrim)
    anf :: NormRewrite
anf        = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"nonRepANF" HasCallStack => NormRewrite
NormRewrite
nonRepANF) NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-> String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"ANF" HasCallStack => NormRewrite
NormRewrite
makeANF NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"caseCon" HasCallStack => NormRewrite
NormRewrite
caseCon)
    letTL :: NormRewrite
letTL      = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownSucR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"topLet" HasCallStack => NormRewrite
NormRewrite
topLet)
    recLetRec :: NormRewrite
recLetRec  = String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"recToLetRec" HasCallStack => NormRewrite
NormRewrite
recToLetRec
    rmUnusedExpr :: NormRewrite
rmUnusedExpr = NormRewrite -> NormRewrite
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
bottomupR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"removeUnusedExpr" HasCallStack => NormRewrite
NormRewrite
removeUnusedExpr)
    rmDeadcode :: NormRewrite
rmDeadcode = NormRewrite -> NormRewrite
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
bottomupR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"deadcode" HasCallStack => NormRewrite
NormRewrite
deadCode)
    bindConst :: NormRewrite
bindConst  = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"bindConstantVar" HasCallStack => NormRewrite
NormRewrite
bindConstantVar)
    -- See [Note] bottomup traversal reduceConst:
    evalConst :: NormRewrite
evalConst  = NormRewrite -> NormRewrite
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
bottomupR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"reduceConst" HasCallStack => NormRewrite
NormRewrite
reduceConst)
    cse :: NormRewrite
cse        = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"CSE" HasCallStack => NormRewrite
NormRewrite
simpleCSE)
    elimCaseBigNum :: NormRewrite
elimCaseBigNum = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"elimCaseBigNum" HasCallStack => NormRewrite
NormRewrite
elimCaseBigNumInternals)
    xOptim :: NormRewrite
xOptim     = NormRewrite -> NormRewrite
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
bottomupR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"xOptimize" HasCallStack => NormRewrite
NormRewrite
xOptimize)
    cleanup :: NormRewrite
cleanup    = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"etaExpandSyn" HasCallStack => NormRewrite
NormRewrite
etaExpandSyn) NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
                 NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownSucR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"inlineCleanup" HasCallStack => NormRewrite
NormRewrite
inlineCleanup) NormRewrite -> NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m -> Rewrite m
!->
                 NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
innerMost ([(String, NormRewrite)] -> NormRewrite
forall extra. [(String, Rewrite extra)] -> Rewrite extra
applyMany [(String
"caseCon"        , HasCallStack => NormRewrite
NormRewrite
caseCon)
                                      ,(String
"bindConstantVar", HasCallStack => NormRewrite
NormRewrite
bindConstantVar)
                                      ,(String
"letFlat"        , HasCallStack => NormRewrite
NormRewrite
flattenLet)])
                 NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
rmDeadcode NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
letTL
    splitArgs :: NormRewrite
splitArgs  = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"separateArguments" HasCallStack => NormRewrite
NormRewrite
separateArguments) NormRewrite -> NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m -> Rewrite m
!->
                 NormRewrite -> NormRewrite
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
bottomupR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"caseCon" HasCallStack => NormRewrite
NormRewrite
caseCon)
    bindSimIO :: NormRewrite
bindSimIO  = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"bindSimIO" HasCallStack => NormRewrite
NormRewrite
inlineSimIO)


constantPropagation :: NormRewrite
constantPropagation :: NormRewrite
constantPropagation =
  NormRewrite
inlineAndPropagate NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
  NormRewrite
caseFlattening NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
  NormRewrite
etaTL NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
  NormRewrite
dec NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
  NormRewrite
spec NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
  NormRewrite
dec NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>->
  NormRewrite
conSpec
  where
    etaTL :: NormRewrite
etaTL              = String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"etaTL" HasCallStack => NormRewrite
NormRewrite
etaExpansionTL NormRewrite -> NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m -> Rewrite m
!-> NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"applicationPropagation" HasCallStack => NormRewrite
NormRewrite
appProp)
    inlineAndPropagate :: NormRewrite
inlineAndPropagate = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
repeatR (NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR ([(String, NormRewrite)] -> NormRewrite
forall extra. [(String, Rewrite extra)] -> Rewrite extra
applyMany [(String, NormRewrite)]
transPropagateAndInline) NormRewrite -> NormRewrite -> NormRewrite
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
>-> NormRewrite
inlineNR)
    spec :: NormRewrite
spec               = NormRewrite -> NormRewrite
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
bottomupR ([(String, NormRewrite)] -> NormRewrite
forall extra. [(String, Rewrite extra)] -> Rewrite extra
applyMany [(String, NormRewrite)]
specTransformations)
    caseFlattening :: NormRewrite
caseFlattening     = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
repeatR (NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"caseFlat" HasCallStack => NormRewrite
NormRewrite
caseFlat))
    dec :: NormRewrite
dec                = NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
repeatR (NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m
topdownR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"DEC" HasCallStack => NormRewrite
NormRewrite
disjointExpressionConsolidation))
    conSpec :: NormRewrite
conSpec            = NormRewrite -> NormRewrite
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
bottomupR  ((String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"appPropCS" HasCallStack => NormRewrite
NormRewrite
appProp NormRewrite -> NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m -> Rewrite m
!->
                                     NormRewrite -> NormRewrite
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
bottomupR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"constantSpec" HasCallStack => NormRewrite
NormRewrite
constantSpec)) NormRewrite -> NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m -> Rewrite m
>-!
                                     String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"constantSpec" HasCallStack => NormRewrite
NormRewrite
constantSpec)

    transPropagateAndInline :: [(String,NormRewrite)]
    transPropagateAndInline :: [(String, NormRewrite)]
transPropagateAndInline =
      [ (String
"applicationPropagation", HasCallStack => NormRewrite
NormRewrite
appProp              )
      , (String
"bindConstantVar"       , HasCallStack => NormRewrite
NormRewrite
bindConstantVar      )
      , (String
"caseLet"               , HasCallStack => NormRewrite
NormRewrite
caseLet              )
      , (String
"caseCase"              , HasCallStack => NormRewrite
NormRewrite
caseCase             )
      , (String
"caseCon"               , HasCallStack => NormRewrite
NormRewrite
caseCon              )
      , (String
"elimExistentials"      , HasCallStack => NormRewrite
NormRewrite
elimExistentials     )
      , (String
"caseEliminateNonReachable"  , HasCallStack => NormRewrite
NormRewrite
caseEliminateNonReachable )
      , (String
"removeUnusedExpr"      , HasCallStack => NormRewrite
NormRewrite
removeUnusedExpr     )
      -- These transformations can safely be applied in a top-down traversal as
      -- they themselves check whether the to-be-inlined binder is recursive or not.
      , (String
"inlineWorkFree"  , HasCallStack => NormRewrite
NormRewrite
inlineWorkFree)
      , (String
"inlineSmall"     , HasCallStack => NormRewrite
NormRewrite
inlineSmall)
      , (String
"bindOrLiftNonRep", HasCallStack => NormRewrite
NormRewrite
inlineOrLiftNonRep) -- See: [Note] bindNonRep before liftNonRep
                                                 -- See: [Note] bottom-up traversal for liftNonRep
      , (String
"reduceNonRepPrim", HasCallStack => NormRewrite
NormRewrite
reduceNonRepPrim)


      , (String
"caseCast"        , HasCallStack => NormRewrite
NormRewrite
caseCast)
      , (String
"letCast"         , HasCallStack => NormRewrite
NormRewrite
letCast)
      , (String
"splitCastWork"   , HasCallStack => NormRewrite
NormRewrite
splitCastWork)
      , (String
"argCastSpec"     , HasCallStack => NormRewrite
NormRewrite
argCastSpec)
      , (String
"inlineCast"      , HasCallStack => NormRewrite
NormRewrite
inlineCast)
      , (String
"elimCastCast"    , HasCallStack => NormRewrite
NormRewrite
elimCastCast)
      ]

    -- InlineNonRep cannot be applied in a top-down traversal, as the non-representable
    -- binder might be recursive. The idea is, is that if the recursive
    -- non-representable binder is inlined once, we can get rid of the recursive
    -- aspect using the case-of-known-constructor
    --
    -- Note that we first do a dead code removal pass, which makes sure that
    -- unused let-bindings get cleaned up. Only if no dead code is removed
    -- 'inlineNonRep' is executed. We do this for two reasons:
    --
    --   1. 'deadCode' is an expensive operation and is therefore left out of
    --      the hot loop 'transPropagateAndInline'.
    --
    --   2. In various situations 'transPropagateAndInline' can do more work
    --      after 'deadCode' was successful. This work in turn might remove a
    --      a construct 'inlineNonRep' would fire on - saving the compiler work.
    --
    inlineNR :: NormRewrite
    inlineNR :: NormRewrite
inlineNR =
          NormRewrite -> NormRewrite
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
bottomupR (String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"deadCode" HasCallStack => NormRewrite
NormRewrite
deadCode)
      NormRewrite -> NormRewrite -> NormRewrite
forall m. Rewrite m -> Rewrite m -> Rewrite m
>-! String -> NormRewrite -> NormRewrite
forall extra. String -> Rewrite extra -> Rewrite extra
apply String
"inlineNonRep" HasCallStack => NormRewrite
NormRewrite
inlineNonRep

    specTransformations :: [(String,NormRewrite)]
    specTransformations :: [(String, NormRewrite)]
specTransformations =
      [ (String
"typeSpec"    , HasCallStack => NormRewrite
NormRewrite
typeSpec)
      , (String
"nonRepSpec"  , HasCallStack => NormRewrite
NormRewrite
nonRepSpec)
      , (String
"zeroWidthSpec", HasCallStack => NormRewrite
NormRewrite
zeroWidthSpec)
        -- See Note [zeroWidthSpec enabling transformations]
      ]

{-
[Note] late elimCaseBigNum

elimCaseBigNum is placed fairly late in the pipeline, after any constant folding and after caseCon.
Because Naturals also hide inside SNat/Nat/KnownNats, and those can be bigger then 64bits.
But they're ultimately monomorphic and thus static and will be evaluated.
We just have to make sure that evaluation happens before we do elimCaseBigNum.

For an example of where this can happen see foldableSNat in tests/shouldwork/Issues/T3157_IntegerNaturalInternals.hs
-}

{-
Note [zeroWidthSpec enabling transformations]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
When zeroWidthSpec fires, it can lead to better results in normalization, but
this is somewhat incidental. The extra transformations which fire are typically
from

  * calls to transformations like caseCon which occur after constantPropagation
    (e.g. caseCon run after ANF conversion).

  * flattening / inlining which happens late in normalization (after regular
    normalization has occurred)

  * normalizing another function due to being marked NOINLINE

If we consider the following:

    data AB = A | B

    ab :: KnownNat n => Index n -> AB -> AB
    ab n A = if n >  0 then A else B
    ab n B = if n == 0 then B else A
    {-# NOINLINE ab #-}

    topEntity = ab @1
    {-# NOINLINE topEntity #-}

The zeroWidthSpec transformation fires on the topEntity, giving a
post-normalization topEntity of

    \(x :: Index 1) ->
      \(y :: AB) ->
        letrec result :: AB = ab' y in result

where

    ab' = ab (fromInteger# 0)

The extra transformations which fire happen later when ab' is normalized.
Removing the NOINLINE from ab gives the same result, but the extra
transformations fire in flattening instead.
-}

{- [Note] bottom-up traversal for liftNonRep
We used to say:

"The liftNonRep transformation must be applied in a topDown traversal because
of what Clash considers tail calls in its join-point analysis."

Consider:

> let fail = \x -> ...
> in  case ... of
>       A -> let fail1 = \y -> case ... of
>                                 X -> fail ...
>                                 Y -> ...
>            in case ... of
>                 P -> fail1 ...
>                 Q -> ...
>       B -> fail ...

under "normal" tail call rules, the local 'fail' functions is not a join-point
because it is used in a let-binding. However, we apply "special" tail call rules
in Clash. Because 'fail' is used in a TC position within 'fail1', and 'fail1' is
only used in a TC position, in Clash, we consider 'tail' also only to be used
in a TC position.

Now image we apply 'liftNonRep' in a bottom up traversal, we will end up with:

> fail1 = \fail y -> case ... of
>   X -> fail ...
>   Y -> ...

> let fail = \x -> ...
> in  case ... of
>       A -> case ... of
>                 P -> fail1 fail ...
>                 Q -> ...
>       B -> fail ...

Suddenly, 'fail' ends up in an argument position, because it occurred as a
_locally_ bound variable within 'fail1'. And because of that 'fail' stops being
a join-point.

However, when we apply 'liftNonRep' in a top down traversal we end up with:

> fail = \x -> ...
>
> fail1 = \y -> case ... of
>   X -> fail ...
>   Y -> ...
>
> let ...
> in  case ... of
>       A -> let
>            in case ... of
>                 P -> fail1 ...
>                 Q -> ...
>       B -> fail ...

and all is well with the world.

UPDATE:
We can now just perform liftNonRep in a bottom-up traversal again, because
liftNonRep no longer checks that if the binding that is lifted is a join-point.
However, for this to work, bindNonRep must always have been exhaustively applied
before liftNonRep. See also: [Note] bindNonRep before liftNonRep.
-}

{- [Note] bindNonRep before liftNonRep
The combination of liftNonRep and nonRepSpec can lead to non-termination in an
unchecked rewrite system (without termination measures in place) on the
following:

> main = f not
> f    = \a x -> (a x) && (f a x)

nonRepSpec will lead to:

> main = f'
> f    = \a x -> (a x) && (f a x)
> f'   = (\a x -> (a x) && (f a x)) not

then lamApp leads to:

> main = f'
> f    = \a x -> (a x) && (f a x)
> f'   = let a = not in (\x -> (a x) && (f a x))

then liftNonRep leads to:

> main = f'
> f    = \a x -> (a x) && (f a x)
> f'   = \x -> (g x) && (f g x)
> g    = not

and nonRepSepc leads to:

> main = f'
> f    = \a x -> (a x) && (f a x)
> f'   = \x -> (g x) && (f'' g x)
> g    = not
> f''  = (\a x -> (a x) && (f a x)) g

This cycle continues indefinitely, as liftNonRep creates a new global variable,
which is never alpha-equivalent to the previous global variable introduced by
liftNonRep.

That is why bindNonRep must always be applied before liftNonRep. When we end up
in the situation after lamApp:

> main = f'
> f    = \a x -> (a x) && (f a x)
> f'   = let a = not in (\x -> (a x) && (f a x))

bindNonRep will now lead to:

> main = f'
> f    = \a x -> (a x) && (f a x)
> f'   = \x -> (not x) && (f not x)

Because `f` has already been specialized on the alpha-equivalent-to-itself `not`
function, liftNonRep leads to:

> main = f'
> f    = \a x -> (a x) && (f a x)
> f'   = \x -> (not x) && (f' x)

And there is no non-terminating rewriting cycle.

That is why bindNonRep must always be exhaustively applied before we apply
liftNonRep.
-}

-- | Topdown traversal, stops upon first success
topdownSucR :: Rewrite extra -> Rewrite extra
topdownSucR :: forall m. Rewrite m -> Rewrite m
topdownSucR Rewrite extra
r = Rewrite extra
r Rewrite extra -> Rewrite extra -> Rewrite extra
forall m. Rewrite m -> Rewrite m -> Rewrite m
>-! (Rewrite extra -> Rewrite extra
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
allR (Rewrite extra -> Rewrite extra
forall m. Rewrite m -> Rewrite m
topdownSucR Rewrite extra
r))
{-# INLINE topdownSucR #-}

innerMost :: Rewrite extra -> Rewrite extra
innerMost :: forall m. Rewrite m -> Rewrite m
innerMost = let go :: Rewrite m -> Rewrite m
go Rewrite m
r = Rewrite m -> Rewrite m
forall (m :: Type -> Type). Monad m => Transform m -> Transform m
bottomupR (Rewrite m
r Rewrite m -> Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m -> Rewrite m
!-> Rewrite m -> Rewrite m
forall m. Rewrite m -> Rewrite m
innerMost Rewrite m
r) in Rewrite extra -> Rewrite extra
forall m. Rewrite m -> Rewrite m
go
{-# INLINE innerMost #-}

applyMany :: [(String,Rewrite extra)] -> Rewrite extra
applyMany :: forall extra. [(String, Rewrite extra)] -> Rewrite extra
applyMany = (Rewrite extra -> Rewrite extra -> Rewrite extra)
-> [Rewrite extra] -> Rewrite extra
forall a. (a -> a -> a) -> [a] -> a
forall (t :: Type -> Type) a.
Foldable t =>
(a -> a -> a) -> t a -> a
foldr1 Rewrite extra -> Rewrite extra -> Rewrite extra
forall (m :: Type -> Type).
Monad m =>
Transform m -> Transform m -> Transform m
(>->) ([Rewrite extra] -> Rewrite extra)
-> ([(String, Rewrite extra)] -> [Rewrite extra])
-> [(String, Rewrite extra)]
-> Rewrite extra
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((String, Rewrite extra) -> Rewrite extra)
-> [(String, Rewrite extra)] -> [Rewrite extra]
forall a b. (a -> b) -> [a] -> [b]
map ((String -> Rewrite extra -> Rewrite extra)
-> (String, Rewrite extra) -> Rewrite extra
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry String -> Rewrite extra -> Rewrite extra
forall extra. String -> Rewrite extra -> Rewrite extra
apply)
{-# INLINE applyMany #-}