{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RecordWildCards #-}

-- SPDX-FileCopyrightText: Copyright (c) 2025 Objectionary.com
-- SPDX-License-Identifier: MIT

module Rewriter (rewrite, rewrite', RewriteContext (..)) where

import Ast
import Builder
import Control.Exception (Exception, throwIO)
import Data.Foldable (foldlM)
import qualified Data.Map.Strict as M
import Data.Maybe (catMaybes, fromMaybe, isJust)
import Logger (logDebug)
import Matcher (MetaValue (MvAttribute, MvBindings, MvBytes, MvExpression), Subst (Subst), combine, combineMany, defaultScope, matchProgram, substEmpty, substSingle)
import Misc (ensuredFile)
import Must (Must (..), exceedsUpperBound, inRange)
import Parser (parseProgram, parseProgramThrows)
import Pretty (PrintMode (SWEET), prettyAttribute, prettyBytes, prettyExpression, prettyExpression', prettyProgram, prettyProgram', prettySubsts)
import Replacer (ReplaceProgramContext (ReplaceProgramContext), ReplaceProgramThrowsFunc, replaceProgramFastThrows, replaceProgramThrows)
import Rule (RuleContext (RuleContext), matchProgramWithRule)
import qualified Rule as R
import Term
import Text.Printf
import Yaml (ExtraArgument (..))
import qualified Yaml as Y

data RewriteContext = RewriteContext
  { RewriteContext -> Program
_program :: Program,
    RewriteContext -> Integer
_maxDepth :: Integer,
    RewriteContext -> Integer
_maxCycles :: Integer,
    RewriteContext -> Bool
_depthSensitive :: Bool,
    RewriteContext -> BuildTermFunc
_buildTerm :: BuildTermFunc,
    RewriteContext -> Must
_must :: Must
  }

data RewriteException
  = MustBeGoing {RewriteException -> Must
must :: Must, RewriteException -> Integer
count :: Integer}
  | MustStopBefore {must :: Must, count :: Integer}
  | StoppedOnLimit {RewriteException -> String
flag :: String, RewriteException -> Integer
limit :: Integer}
  deriving (Show RewriteException
Typeable RewriteException
(Typeable RewriteException, Show RewriteException) =>
(RewriteException -> SomeException)
-> (SomeException -> Maybe RewriteException)
-> (RewriteException -> String)
-> Exception RewriteException
SomeException -> Maybe RewriteException
RewriteException -> String
RewriteException -> SomeException
forall e.
(Typeable e, Show e) =>
(e -> SomeException)
-> (SomeException -> Maybe e) -> (e -> String) -> Exception e
$ctoException :: RewriteException -> SomeException
toException :: RewriteException -> SomeException
$cfromException :: SomeException -> Maybe RewriteException
fromException :: SomeException -> Maybe RewriteException
$cdisplayException :: RewriteException -> String
displayException :: RewriteException -> String
Exception)

instance Show RewriteException where
  show :: RewriteException -> String
show MustBeGoing {Integer
Must
must :: RewriteException -> Must
count :: RewriteException -> Integer
must :: Must
count :: Integer
..} =
    String -> String -> String -> Integer -> String
forall r. PrintfType r => String -> r
printf
      String
"With option --must=%s it's expected rewriting cycles to be in range [%s], but rewriting stopped after %d cycles"
      (Must -> String
forall a. Show a => a -> String
show Must
must)
      (Must -> String
forall a. Show a => a -> String
show Must
must)
      Integer
count
  show MustStopBefore {Integer
Must
must :: RewriteException -> Must
count :: RewriteException -> Integer
must :: Must
count :: Integer
..} =
    String -> String -> String -> Integer -> String
forall r. PrintfType r => String -> r
printf
      String
"With option --must=%s it's expected rewriting cycles to be in range [%s], but rewriting has already reached %d cycles and is still going"
      (Must -> String
forall a. Show a => a -> String
show Must
must)
      (Must -> String
forall a. Show a => a -> String
show Must
must)
      Integer
count
  show StoppedOnLimit {Integer
String
flag :: RewriteException -> String
limit :: RewriteException -> Integer
flag :: String
limit :: Integer
..} =
    String -> String -> Integer -> String
forall r. PrintfType r => String -> r
printf
      String
"With option --depth-sensitive it's expected rewriting iterations amount does not reach the limit: --%s=%d"
      String
flag
      Integer
limit

-- Build pattern and result expression and replace patterns to results in given program
buildAndReplace' :: Expression -> Expression -> [Subst] -> ReplaceProgramThrowsFunc -> ReplaceProgramContext -> IO Program
buildAndReplace' :: Expression
-> Expression
-> [Subst]
-> ReplaceProgramThrowsFunc
-> ReplaceProgramContext
-> IO Program
buildAndReplace' Expression
ptn Expression
res [Subst]
substs ReplaceProgramThrowsFunc
func ReplaceProgramContext
ctx = do
  [(Expression, Expression)]
ptns <- Expression -> [Subst] -> IO [(Expression, Expression)]
buildExpressions Expression
ptn [Subst]
substs
  [(Expression, Expression)]
repls <- Expression -> [Subst] -> IO [(Expression, Expression)]
buildExpressions Expression
res [Subst]
substs
  let repls' :: [Expression]
repls' = ((Expression, Expression) -> Expression)
-> [(Expression, Expression)] -> [Expression]
forall a b. (a -> b) -> [a] -> [b]
map (Expression, Expression) -> Expression
forall a b. (a, b) -> a
fst [(Expression, Expression)]
repls
      ptns' :: [Expression]
ptns' = ((Expression, Expression) -> Expression)
-> [(Expression, Expression)] -> [Expression]
forall a b. (a -> b) -> [a] -> [b]
map (Expression, Expression) -> Expression
forall a b. (a, b) -> a
fst [(Expression, Expression)]
ptns
  ReplaceProgramThrowsFunc
func [Expression]
ptns' [Expression]
repls' ReplaceProgramContext
ctx

-- If pattern and replacement are appropriate for fast replacing - does it.
-- Pattern and replacement expressions can be used in fast replacing only if
-- 1. they are both formations
-- 2. they start and end with the same meta bindings, e.g. [!B1, ..., !B2]
-- 3. the does not have meta bindings between first and last meta bindings
-- In such case we can just replace bindings one by one without building whole expression.
-- You can find more details in this ticket: https://github.com/objectionary/phino/issues/321
-- If we don't meet the conditions above - just do a regular replacing
tryBuildAndReplaceFast :: Expression -> Expression -> [Subst] -> ReplaceProgramContext -> IO Program
tryBuildAndReplaceFast :: Expression
-> Expression -> [Subst] -> ReplaceProgramContext -> IO Program
tryBuildAndReplaceFast (ExFormation [Binding]
pbds) (ExFormation [Binding]
rbds) [Subst]
substs ReplaceProgramContext
ctx =
  let pbds' :: [Binding]
pbds' = [Binding] -> [Binding]
forall a. HasCallStack => [a] -> [a]
init ([Binding] -> [Binding]
forall a. HasCallStack => [a] -> [a]
tail [Binding]
pbds)
      rbds' :: [Binding]
rbds' = [Binding] -> [Binding]
forall a. HasCallStack => [a] -> [a]
init ([Binding] -> [Binding]
forall a. HasCallStack => [a] -> [a]
tail [Binding]
rbds)
   in if [Binding] -> Bool
startsAndEndsWithMeta [Binding]
pbds
        Bool -> Bool -> Bool
&& [Binding] -> Bool
startsAndEndsWithMeta [Binding]
rbds
        Bool -> Bool -> Bool
&& [Binding] -> Binding
forall a. HasCallStack => [a] -> a
head [Binding]
pbds Binding -> Binding -> Bool
forall a. Eq a => a -> a -> Bool
== [Binding] -> Binding
forall a. HasCallStack => [a] -> a
head [Binding]
rbds
        Bool -> Bool -> Bool
&& [Binding] -> Binding
forall a. HasCallStack => [a] -> a
last [Binding]
pbds Binding -> Binding -> Bool
forall a. Eq a => a -> a -> Bool
== [Binding] -> Binding
forall a. HasCallStack => [a] -> a
last [Binding]
rbds
        Bool -> Bool -> Bool
&& Bool -> Bool
not ([Binding] -> Bool
hasMetaBindings [Binding]
pbds')
        Bool -> Bool -> Bool
&& Bool -> Bool
not ([Binding] -> Bool
hasMetaBindings [Binding]
rbds')
        then do
          String -> IO ()
logDebug String
"Applying fast replacing since 'pattern' and 'result' are suitable for this..."
          Expression
-> Expression
-> [Subst]
-> ReplaceProgramThrowsFunc
-> ReplaceProgramContext
-> IO Program
buildAndReplace' ([Binding] -> Expression
ExFormation [Binding]
pbds') ([Binding] -> Expression
ExFormation [Binding]
rbds') [Subst]
substs ReplaceProgramThrowsFunc
replaceProgramFastThrows ReplaceProgramContext
ctx
        else do
          String -> IO ()
logDebug String
"Applying regular replacing..."
          Expression
-> Expression
-> [Subst]
-> ReplaceProgramThrowsFunc
-> ReplaceProgramContext
-> IO Program
buildAndReplace' ([Binding] -> Expression
ExFormation [Binding]
pbds) ([Binding] -> Expression
ExFormation [Binding]
rbds) [Subst]
substs ReplaceProgramThrowsFunc
replaceProgramThrows ReplaceProgramContext
ctx
  where
    startsAndEndsWithMeta :: [Binding] -> Bool
    startsAndEndsWithMeta :: [Binding] -> Bool
startsAndEndsWithMeta [Binding]
bds =
      [Binding] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Binding]
bds Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
1
        Bool -> Bool -> Bool
&& Binding -> Bool
isMetaBinding ([Binding] -> Binding
forall a. HasCallStack => [a] -> a
head [Binding]
bds)
        Bool -> Bool -> Bool
&& Binding -> Bool
isMetaBinding ([Binding] -> Binding
forall a. HasCallStack => [a] -> a
last [Binding]
bds)
    hasMetaBindings :: [Binding] -> Bool
    isMetaBinding :: Binding -> Bool
    isMetaBinding :: Binding -> Bool
isMetaBinding = \case
      BiMeta String
_ -> Bool
True
      Binding
_ -> Bool
False
    hasMetaBindings :: [Binding] -> Bool
hasMetaBindings = (Bool -> Binding -> Bool) -> Bool -> [Binding] -> Bool
forall b a. (b -> a -> b) -> b -> [a] -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl (\Bool
acc Binding
bd -> Bool
acc Bool -> Bool -> Bool
|| Binding -> Bool
isMetaBinding Binding
bd) Bool
False
tryBuildAndReplaceFast Expression
ptn Expression
res [Subst]
substs ReplaceProgramContext
ctx = Expression
-> Expression
-> [Subst]
-> ReplaceProgramThrowsFunc
-> ReplaceProgramContext
-> IO Program
buildAndReplace' Expression
ptn Expression
res [Subst]
substs ReplaceProgramThrowsFunc
replaceProgramThrows ReplaceProgramContext
ctx

rewrite :: Program -> [Y.Rule] -> RewriteContext -> IO Program
rewrite :: Program -> [Rule] -> RewriteContext -> IO Program
rewrite Program
program [] RewriteContext
_ = Program -> IO Program
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Program
program
rewrite Program
program (Rule
rule : [Rule]
rest) RewriteContext
ctx = do
  Program
prog <- Program -> Integer -> IO Program
_rewrite Program
program Integer
1
  Program -> [Rule] -> RewriteContext -> IO Program
rewrite Program
prog [Rule]
rest RewriteContext
ctx
  where
    _rewrite :: Program -> Integer -> IO Program
    _rewrite :: Program -> Integer -> IO Program
_rewrite Program
prog Integer
count =
      let ruleName :: String
ruleName = String -> Maybe String -> String
forall a. a -> Maybe a -> a
fromMaybe String
"unknown" (Rule -> Maybe String
Y.name Rule
rule)
          ptn :: Expression
ptn = Rule -> Expression
Y.pattern Rule
rule
          res :: Expression
res = Rule -> Expression
Y.result Rule
rule
          depth :: Integer
depth = RewriteContext -> Integer
_maxDepth RewriteContext
ctx
       in if Integer
count Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
depth
            then do
              String -> IO ()
logDebug (String -> Integer -> String -> String
forall r. PrintfType r => String -> r
printf String
"Max amount of rewriting cycles (%d) for rule '%s' has been reached, rewriting is stopped" Integer
depth String
ruleName)
              if RewriteContext -> Bool
_depthSensitive RewriteContext
ctx
                then RewriteException -> IO Program
forall e a. Exception e => e -> IO a
throwIO (String -> Integer -> RewriteException
StoppedOnLimit String
"max-depth" Integer
depth)
                else Program -> IO Program
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Program
prog
            else do
              String -> IO ()
logDebug (String -> String -> Integer -> Integer -> String
forall r. PrintfType r => String -> r
printf String
"Starting rewriting cycle for rule '%s': %d out of %d" String
ruleName Integer
count Integer
depth)
              [Subst]
matched <- Program -> Rule -> RuleContext -> IO [Subst]
R.matchProgramWithRule Program
prog Rule
rule (Program -> BuildTermFunc -> RuleContext
RuleContext (RewriteContext -> Program
_program RewriteContext
ctx) (RewriteContext -> BuildTermFunc
_buildTerm RewriteContext
ctx))
              if [Subst] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Subst]
matched
                then do
                  String -> IO ()
logDebug (String -> String -> String
forall r. PrintfType r => String -> r
printf String
"Rule '%s' does not match, rewriting is stopped" String
ruleName)
                  Program -> IO Program
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Program
prog
                else do
                  String -> IO ()
logDebug (String -> String -> String
forall r. PrintfType r => String -> r
printf String
"Rule '%s' has been matched, applying..." String
ruleName)
                  Program
prog' <- Expression
-> Expression -> [Subst] -> ReplaceProgramContext -> IO Program
tryBuildAndReplaceFast Expression
ptn Expression
res [Subst]
matched (Program -> Integer -> ReplaceProgramContext
ReplaceProgramContext Program
prog Integer
depth)
                  if Program
prog Program -> Program -> Bool
forall a. Eq a => a -> a -> Bool
== Program
prog'
                    then do
                      String -> IO ()
logDebug (String -> String -> String
forall r. PrintfType r => String -> r
printf String
"Applied '%s', no changes made" String
ruleName)
                      Program -> IO Program
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Program
prog
                    else do
                      String -> IO ()
logDebug
                        ( String -> String -> Integer -> Integer -> String
forall r. PrintfType r => String -> r
printf
                            String
"Applied '%s' (%d nodes -> %d nodes)"
                            String
ruleName
                            (Program -> Integer
countNodes Program
prog)
                            (Program -> Integer
countNodes Program
prog')
                        )
                      Program -> Integer -> IO Program
_rewrite Program
prog' (Integer
count Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1)

-- @todo #169:30min Memorize previous rewritten programs. Right now in order not to
--  get an infinite recursion during rewriting we just count have many times we apply
--  rewriting rules. If we reach given amount - we just stop. It's not idiomatic and may
--  not work on big programs. We need to introduce some mechanism which would memorize
--  all rewritten program on each step and if on some step we get the program that have already
--  been memorized - we fail because we got into infinite recursion. Ofc we should keep counting
--  rewriting cycles if program just only grows on each rewriting.
rewrite' :: Program -> [Y.Rule] -> RewriteContext -> IO Program
rewrite' :: Program -> [Rule] -> RewriteContext -> IO Program
rewrite' Program
prog [Rule]
rules RewriteContext
ctx = Program -> Integer -> IO Program
_rewrite Program
prog Integer
1
  where
    _rewrite :: Program -> Integer -> IO Program
    _rewrite :: Program -> Integer -> IO Program
_rewrite Program
prog Integer
count = do
      let cycles :: Integer
cycles = RewriteContext -> Integer
_maxCycles RewriteContext
ctx
          must :: Must
must = RewriteContext -> Must
_must RewriteContext
ctx
          current :: Integer
current = Integer
count Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1
      if Bool -> Bool
not (Must -> Integer -> Bool
inRange Must
must Integer
current) Bool -> Bool -> Bool
&& Integer
current Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
0 Bool -> Bool -> Bool
&& Must -> Integer -> Bool
exceedsUpperBound Must
must Integer
current
        then RewriteException -> IO Program
forall e a. Exception e => e -> IO a
throwIO (Must -> Integer -> RewriteException
MustStopBefore Must
must Integer
current)
        else
          if Integer
current Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
cycles
            then do
              String -> IO ()
logDebug (String -> Integer -> String
forall r. PrintfType r => String -> r
printf String
"Max amount of rewriting cycles for all rules (%d) has been reached, rewriting is stopped" Integer
cycles)
              if RewriteContext -> Bool
_depthSensitive RewriteContext
ctx
                then RewriteException -> IO Program
forall e a. Exception e => e -> IO a
throwIO (String -> Integer -> RewriteException
StoppedOnLimit String
"max-cycles" Integer
cycles)
                else Program -> IO Program
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Program
prog
            else do
              String -> IO ()
logDebug (String -> Integer -> Integer -> String
forall r. PrintfType r => String -> r
printf String
"Starting rewriting cycle for all rules: %d out of %d" Integer
count Integer
cycles)
              Program
rewritten <- Program -> [Rule] -> RewriteContext -> IO Program
rewrite Program
prog [Rule]
rules RewriteContext
ctx
              if Program
rewritten Program -> Program -> Bool
forall a. Eq a => a -> a -> Bool
== Program
prog
                then do
                  String -> IO ()
logDebug String
"No rule matched, rewriting is stopped"
                  if Bool -> Bool
not (Must -> Integer -> Bool
inRange Must
must Integer
current)
                    then RewriteException -> IO Program
forall e a. Exception e => e -> IO a
throwIO (Must -> Integer -> RewriteException
MustBeGoing Must
must Integer
current)
                    else Program -> IO Program
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Program
rewritten
                else Program -> Integer -> IO Program
_rewrite Program
rewritten (Integer
count Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1)