{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE RecordWildCards #-}

-- SPDX-FileCopyrightText: Copyright (c) 2025 Objectionary.com
-- SPDX-License-Identifier: MIT

module Rewriter (rewrite, rewrite', RewriteContext (..)) where

import Ast
import Builder
import Control.Exception (Exception, throwIO)
import Data.Foldable (foldlM)
import qualified Data.Map.Strict as M
import Data.Maybe (catMaybes, fromMaybe, isJust)
import Debug.Trace (trace)
import Logger (logDebug)
import Matcher (MetaValue (MvAttribute, MvBindings, MvBytes, MvExpression), Subst (Subst), combine, combineMany, defaultScope, matchProgram, substEmpty, substSingle)
import Misc (ensuredFile)
import Parser (parseProgram, parseProgramThrows)
import Pretty (PrintMode (SWEET), prettyAttribute, prettyBytes, prettyExpression, prettyExpression', prettyProgram, prettyProgram', prettySubsts)
import Replacer (replaceProgram, replaceProgramThrows)
import Rule (RuleContext (RuleContext), meetMaybeCondition)
import qualified Rule as R
import Term
import Text.Printf
import Yaml (ExtraArgument (..))
import qualified Yaml as Y

data RewriteContext = RewriteContext
  { RewriteContext -> Program
_program :: Program,
    RewriteContext -> Integer
_maxDepth :: Integer,
    RewriteContext -> BuildTermFunc
_buildTerm :: BuildTermFunc,
    RewriteContext -> Integer
_must :: Integer
  }

data MustException
  = StoppedBefore {MustException -> Integer
must :: Integer, MustException -> Integer
count :: Integer}
  | ContinueAfter {must :: Integer}
  deriving (Show MustException
Typeable MustException
(Typeable MustException, Show MustException) =>
(MustException -> SomeException)
-> (SomeException -> Maybe MustException)
-> (MustException -> String)
-> Exception MustException
SomeException -> Maybe MustException
MustException -> String
MustException -> SomeException
forall e.
(Typeable e, Show e) =>
(e -> SomeException)
-> (SomeException -> Maybe e) -> (e -> String) -> Exception e
$ctoException :: MustException -> SomeException
toException :: MustException -> SomeException
$cfromException :: SomeException -> Maybe MustException
fromException :: SomeException -> Maybe MustException
$cdisplayException :: MustException -> String
displayException :: MustException -> String
Exception)

instance Show MustException where
  show :: MustException -> String
show StoppedBefore {Integer
must :: MustException -> Integer
count :: MustException -> Integer
must :: Integer
count :: Integer
..} =
    String -> Integer -> Integer -> Integer -> String
forall r. PrintfType r => String -> r
printf
      String
"With option --must=%d it's expected exactly %d rewriting cycles happened, but rewriting stopped after %d"
      Integer
must
      Integer
must
      Integer
count
  show ContinueAfter {Integer
must :: MustException -> Integer
must :: Integer
..} =
    String -> Integer -> Integer -> String
forall r. PrintfType r => String -> r
printf
      String
"With option --must=%d it's expected exactly %d rewriting cycles happened, but rewriting is still going"
      Integer
must
      Integer
must

-- Build pattern and result expression and replace patterns to results in given program
buildAndReplace :: Program -> Expression -> Expression -> [Subst] -> IO Program
buildAndReplace :: Program -> Expression -> Expression -> [Subst] -> IO Program
buildAndReplace Program
program Expression
ptn Expression
res [Subst]
substs = do
  [(Expression, Expression)]
ptns <- Expression -> [Subst] -> IO [(Expression, Expression)]
buildExpressions Expression
ptn [Subst]
substs
  [(Expression, Expression)]
repls <- Expression -> [Subst] -> IO [(Expression, Expression)]
buildExpressions Expression
res [Subst]
substs
  let repls' :: [Expression]
repls' = ((Expression, Expression) -> Expression)
-> [(Expression, Expression)] -> [Expression]
forall a b. (a -> b) -> [a] -> [b]
map (Expression, Expression) -> Expression
forall a b. (a, b) -> a
fst [(Expression, Expression)]
repls
      ptns' :: [Expression]
ptns' = ((Expression, Expression) -> Expression)
-> [(Expression, Expression)] -> [Expression]
forall a b. (a -> b) -> [a] -> [b]
map (Expression, Expression) -> Expression
forall a b. (a, b) -> a
fst [(Expression, Expression)]
ptns
  Program -> [Expression] -> [Expression] -> IO Program
replaceProgramThrows Program
program [Expression]
ptns' [Expression]
repls'

rewrite :: Program -> [Y.Rule] -> RewriteContext -> IO Program
rewrite :: Program -> [Rule] -> RewriteContext -> IO Program
rewrite Program
program [] RewriteContext
_ = Program -> IO Program
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Program
program
rewrite Program
program (Rule
rule : [Rule]
rest) RewriteContext
ctx = do
  let ruleName :: String
ruleName = String -> Maybe String -> String
forall a. a -> Maybe a -> a
fromMaybe String
"unknown" (Rule -> Maybe String
Y.name Rule
rule)
      ptn :: Expression
ptn = Rule -> Expression
Y.pattern Rule
rule
      res :: Expression
res = Rule -> Expression
Y.result Rule
rule
  [Subst]
matched <- Program -> Rule -> RuleContext -> IO [Subst]
R.matchProgramWithRule Program
program Rule
rule (Program -> BuildTermFunc -> RuleContext
RuleContext (RewriteContext -> Program
_program RewriteContext
ctx) (RewriteContext -> BuildTermFunc
_buildTerm RewriteContext
ctx))
  Program
prog <-
    if [Subst] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Subst]
matched
      then Program -> IO Program
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Program
program
      else do
        String -> IO ()
logDebug (String -> String -> String
forall r. PrintfType r => String -> r
printf String
"Rule '%s' has been matched, applying..." String
ruleName)
        Program
prog' <- Program -> Expression -> Expression -> [Subst] -> IO Program
buildAndReplace Program
program Expression
ptn Expression
res [Subst]
matched
        if Program
program Program -> Program -> Bool
forall a. Eq a => a -> a -> Bool
== Program
prog'
          then String -> IO ()
logDebug (String -> String -> String
forall r. PrintfType r => String -> r
printf String
"Applied '%s', no changes made" String
ruleName)
          else
            String -> IO ()
logDebug
              ( String -> String -> Integer -> Integer -> String -> String
forall r. PrintfType r => String -> r
printf
                  String
"Applied '%s' (%d nodes -> %d nodes):\n%s"
                  String
ruleName
                  (Program -> Integer
countNodes Program
program)
                  (Program -> Integer
countNodes Program
prog')
                  (Program -> PrintMode -> String
prettyProgram' Program
prog' PrintMode
SWEET)
              )
        Program -> IO Program
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Program
prog'
  Program -> [Rule] -> RewriteContext -> IO Program
rewrite Program
prog [Rule]
rest RewriteContext
ctx

-- @todo #169:30min Memorize previous rewritten programs. Right now in order not to
--  get an infinite recursion during rewriting we just count have many times we apply
--  rewriting rules. If we reach given amount - we just stop. It's not idiomatic and may
--  not work on big programs. We need to introduce some mechanism which would memorize
--  all rewritten program on each step and if on some step we get the program that have already
--  been memorized - we fail because we got into infinite recursion. Ofc we should keep counting
--  rewriting cycles if program just only grows on each rewriting.
rewrite' :: Program -> [Y.Rule] -> RewriteContext -> IO Program
rewrite' :: Program -> [Rule] -> RewriteContext -> IO Program
rewrite' Program
prog [Rule]
rules RewriteContext
ctx = Program -> Integer -> IO Program
_rewrite Program
prog Integer
1
  where
    _rewrite :: Program -> Integer -> IO Program
    _rewrite :: Program -> Integer -> IO Program
_rewrite Program
prog Integer
count = do
      let depth :: Integer
depth = RewriteContext -> Integer
_maxDepth RewriteContext
ctx
          must :: Integer
must = RewriteContext -> Integer
_must RewriteContext
ctx
      if Integer
must Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer
0 Bool -> Bool -> Bool
&& Integer
count Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1 Integer -> Integer -> Bool
forall a. Ord a => a -> a -> Bool
> Integer
must
        then MustException -> IO Program
forall e a. Exception e => e -> IO a
throwIO (Integer -> MustException
ContinueAfter Integer
must)
        else
          if Integer
count Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
depth
            then do
              String -> IO ()
logDebug (String -> Integer -> String
forall r. PrintfType r => String -> r
printf String
"Max amount of rewriting cycles (%d) has been reached, rewriting is stopped" Integer
depth)
              Program -> IO Program
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Program
prog
            else do
              String -> IO ()
logDebug (String -> Integer -> Integer -> String
forall r. PrintfType r => String -> r
printf String
"Starting rewriting cycle %d out of %d" Integer
count Integer
depth)
              Program
rewritten <- Program -> [Rule] -> RewriteContext -> IO Program
rewrite Program
prog [Rule]
rules RewriteContext
ctx
              if Program
rewritten Program -> Program -> Bool
forall a. Eq a => a -> a -> Bool
== Program
prog
                then do
                  String -> IO ()
logDebug String
"No rule matched, rewriting is stopped"
                  if Integer
must Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer
0 Bool -> Bool -> Bool
&& Integer
count Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1 Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
/= Integer
must
                    then MustException -> IO Program
forall e a. Exception e => e -> IO a
throwIO (Integer -> Integer -> MustException
StoppedBefore Integer
must (Integer
count Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
- Integer
1))
                    else Program -> IO Program
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Program
rewritten
                else Program -> Integer -> IO Program
_rewrite Program
rewritten (Integer
count Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1)