{-# LANGUAGE DerivingStrategies #-}
{-# LANGUAGE RecordWildCards #-}

-- SPDX-FileCopyrightText: Copyright (c) 2025 Objectionary.com
-- SPDX-License-Identifier: MIT

module Rewriter (rewrite, rewrite', RewriteContext (..)) where

import Ast
import Builder
import qualified Condition as C
import Data.Foldable (foldlM)
import qualified Data.Map.Strict as M
import Data.Maybe (catMaybes, fromMaybe, isJust)
import Debug.Trace (trace)
import Logger (logDebug)
import Matcher (MetaValue (MvAttribute, MvBindings, MvExpression, MvBytes), Subst (Subst), combine, combineMany, defaultScope, matchProgram, substEmpty, substSingle)
import Misc (ensuredFile)
import Parser (parseProgram, parseProgramThrows)
import Pretty (PrintMode (SWEET), prettyAttribute, prettyExpression, prettyExpression', prettyProgram, prettyProgram', prettySubsts, prettyBytes)
import Replacer (replaceProgram, replaceProgramThrows)
import Term
import Text.Printf
import Yaml (ExtraArgument (..))
import qualified Yaml as Y

data RewriteContext = RewriteContext
  { RewriteContext -> Program
_program :: Program,
    RewriteContext -> Integer
_maxDepth :: Integer,
    RewriteContext -> BuildTermFunc
_buildTerm :: BuildTermFunc
  }

-- Build pattern and result expression and replace patterns to results in given program
buildAndReplace :: Program -> Expression -> Expression -> [Subst] -> IO Program
buildAndReplace :: Program -> Expression -> Expression -> [Subst] -> IO Program
buildAndReplace Program
program Expression
ptn Expression
res [Subst]
substs = do
  [(Expression, Expression)]
ptns <- Expression -> [Subst] -> IO [(Expression, Expression)]
buildExpressions Expression
ptn [Subst]
substs
  [(Expression, Expression)]
repls <- Expression -> [Subst] -> IO [(Expression, Expression)]
buildExpressions Expression
res [Subst]
substs
  let repls' :: [Expression]
repls' = ((Expression, Expression) -> Expression)
-> [(Expression, Expression)] -> [Expression]
forall a b. (a -> b) -> [a] -> [b]
map (Expression, Expression) -> Expression
forall a b. (a, b) -> a
fst [(Expression, Expression)]
repls
      ptns' :: [Expression]
ptns' = ((Expression, Expression) -> Expression)
-> [(Expression, Expression)] -> [Expression]
forall a b. (a -> b) -> [a] -> [b]
map (Expression, Expression) -> Expression
forall a b. (a, b) -> a
fst [(Expression, Expression)]
ptns
  Program -> [Expression] -> [Expression] -> IO Program
replaceProgramThrows Program
program [Expression]
ptns' [Expression]
repls'

-- Extend list of given substitutions with extra substitutions from 'where' yaml rule section
extraSubstitutions :: Maybe [Y.Extra] -> [Subst] -> RewriteContext -> IO [Subst]
extraSubstitutions :: Maybe [Extra] -> [Subst] -> RewriteContext -> IO [Subst]
extraSubstitutions Maybe [Extra]
extras [Subst]
substs RewriteContext {Integer
Program
BuildTermFunc
_program :: RewriteContext -> Program
_maxDepth :: RewriteContext -> Integer
_buildTerm :: RewriteContext -> BuildTermFunc
_program :: Program
_maxDepth :: Integer
_buildTerm :: BuildTermFunc
..} = case Maybe [Extra]
extras of
  Maybe [Extra]
Nothing -> [Subst] -> IO [Subst]
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure [Subst]
substs
  Just [Extra]
extras' -> do
    [Maybe Subst]
res <-
      [IO (Maybe Subst)] -> IO [Maybe Subst]
forall (t :: * -> *) (m :: * -> *) a.
(Traversable t, Monad m) =>
t (m a) -> m (t a)
forall (m :: * -> *) a. Monad m => [m a] -> m [a]
sequence
        [ (Maybe Subst -> Extra -> IO (Maybe Subst))
-> Maybe Subst -> [Extra] -> IO (Maybe Subst)
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m b
foldlM
            ( \(Just Subst
subst') Extra
extra -> do
                let maybeName :: Maybe String
maybeName = case Extra -> ExtraArgument
Y.meta Extra
extra of
                      ArgExpression (ExMeta String
name) -> String -> Maybe String
forall a. a -> Maybe a
Just String
name
                      ArgAttribute (AtMeta String
name) -> String -> Maybe String
forall a. a -> Maybe a
Just String
name
                      ArgBinding (BiMeta String
name) -> String -> Maybe String
forall a. a -> Maybe a
Just String
name
                      ArgBytes (BtMeta String
name) -> String -> Maybe String
forall a. a -> Maybe a
Just String
name
                      ExtraArgument
_ -> Maybe String
forall a. Maybe a
Nothing
                    func :: String
func = Extra -> String
Y.function Extra
extra
                    args :: [ExtraArgument]
args = Extra -> [ExtraArgument]
Y.args Extra
extra
                Term
term <- BuildTermFunc
_buildTerm String
func [ExtraArgument]
args Subst
subst' Program
_program
                MetaValue
meta <- case Term
term of
                  TeExpression Expression
expr -> do
                    String -> IO ()
logDebug (String -> String -> String -> String
forall r. PrintfType r => String -> r
printf String
"Function %s() returned expression:\n%s" String
func (Expression -> String
prettyExpression' Expression
expr))
                    MetaValue -> IO MetaValue
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Expression -> Expression -> MetaValue
MvExpression Expression
expr Expression
defaultScope)
                  TeAttribute Attribute
attr -> do
                    String -> IO ()
logDebug (String -> String -> String -> String
forall r. PrintfType r => String -> r
printf String
"Function %s() returned attribute:\n%s" String
func (Attribute -> String
prettyAttribute Attribute
attr))
                    MetaValue -> IO MetaValue
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Attribute -> MetaValue
MvAttribute Attribute
attr)
                  TeBytes Bytes
bytes -> do
                    String -> IO ()
logDebug (String -> String -> String -> String
forall r. PrintfType r => String -> r
printf String
"Function %s() returned bytes: %s" String
func (Bytes -> String
prettyBytes Bytes
bytes))
                    MetaValue -> IO MetaValue
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Bytes -> MetaValue
MvBytes Bytes
bytes)
                case Maybe String
maybeName of
                  Just String
name -> Maybe Subst -> IO (Maybe Subst)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Subst -> Subst -> Maybe Subst
combine (String -> MetaValue -> Subst
substSingle String
name MetaValue
meta) Subst
subst')
                  Maybe String
_ -> Maybe Subst -> IO (Maybe Subst)
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Maybe Subst
forall a. Maybe a
Nothing
            )
            (Subst -> Maybe Subst
forall a. a -> Maybe a
Just Subst
subst)
            [Extra]
extras'
          | Subst
subst <- [Subst]
substs
        ]
    [Subst] -> IO [Subst]
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([Maybe Subst] -> [Subst]
forall a. [Maybe a] -> [a]
catMaybes [Maybe Subst]
res)

rewrite :: Program -> [Y.Rule] -> RewriteContext -> IO Program
rewrite :: Program -> [Rule] -> RewriteContext -> IO Program
rewrite Program
program [] RewriteContext
_ = Program -> IO Program
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Program
program
rewrite Program
program (Rule
rule : [Rule]
rest) RewriteContext
ctx = do
  let ptn :: Expression
ptn = Rule -> Expression
Y.pattern Rule
rule
      res :: Expression
res = Rule -> Expression
Y.result Rule
rule
      condition :: Maybe Condition
condition = Rule -> Maybe Condition
Y.when Rule
rule
  Maybe [Subst]
maybeMatched <- Expression -> Maybe Condition -> Program -> IO (Maybe [Subst])
C.matchProgramWithCondition Expression
ptn Maybe Condition
condition Program
program
  Program
prog <- case Maybe [Subst]
maybeMatched of
    Maybe [Subst]
Nothing -> Program -> IO Program
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Program
program
    Just [Subst]
matched -> do
      let ruleName :: String
ruleName = String -> Maybe String -> String
forall a. a -> Maybe a -> a
fromMaybe String
"unknown" (Rule -> Maybe String
Y.name Rule
rule)
      String -> IO ()
logDebug (String -> String -> String
forall r. PrintfType r => String -> r
printf String
"Rule '%s' has been matched, applying..." String
ruleName)
      [Subst]
substs <- Maybe [Extra] -> [Subst] -> RewriteContext -> IO [Subst]
extraSubstitutions (Rule -> Maybe [Extra]
Y.where_ Rule
rule) [Subst]
matched RewriteContext
ctx
      Program
prog' <- Program -> Expression -> Expression -> [Subst] -> IO Program
buildAndReplace Program
program Expression
ptn Expression
res [Subst]
substs
      if Program
program Program -> Program -> Bool
forall a. Eq a => a -> a -> Bool
== Program
prog'
        then String -> IO ()
logDebug (String -> String -> String
forall r. PrintfType r => String -> r
printf String
"Applied '%s', no changes made" String
ruleName)
        else
          String -> IO ()
logDebug
            ( String -> String -> Integer -> Integer -> String -> String
forall r. PrintfType r => String -> r
printf
                String
"Applied '%s' (%d nodes -> %d nodes):\n%s"
                String
ruleName
                (Program -> Integer
countNodes Program
program)
                (Program -> Integer
countNodes Program
prog')
                (Program -> PrintMode -> String
prettyProgram' Program
prog' PrintMode
SWEET)
            )
      Program -> IO Program
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Program
prog'
  Program -> [Rule] -> RewriteContext -> IO Program
rewrite Program
prog [Rule]
rest RewriteContext
ctx

-- @todo #169:30min Memorize previous rewritten programs. Right now in order not to
--  get an infinite recursion during rewriting we just count have many times we apply
--  rewriting rules. If we reach given amount - we just stop. It's not idiomatic and may
--  not work on big programs. We need to introduce some mechanism which would memorize
--  all rewritten program on each step and if on some step we get the program that have already
--  been memorized - we fail because we got into infinite recursion. Ofc we should keep counting
--  rewriting cycles if program just only grows on each rewriting.
rewrite' :: Program -> [Y.Rule] -> RewriteContext -> IO Program
rewrite' :: Program -> [Rule] -> RewriteContext -> IO Program
rewrite' Program
prog [Rule]
rules RewriteContext
ctx = Program -> Integer -> IO Program
_rewrite Program
prog Integer
0
  where
    _rewrite :: Program -> Integer -> IO Program
    _rewrite :: Program -> Integer -> IO Program
_rewrite Program
prog Integer
count = do
      let depth :: Integer
depth = RewriteContext -> Integer
_maxDepth RewriteContext
ctx
      String -> IO ()
logDebug (String -> Integer -> Integer -> String
forall r. PrintfType r => String -> r
printf String
"Starting rewriting cycle %d out of %d" Integer
count Integer
depth)
      if Integer
count Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
depth
        then do
          String -> IO ()
logDebug (String -> String
forall r. PrintfType r => String -> r
printf String
"Max amount of rewriting cycles has been reached, rewriting is stopped")
          Program -> IO Program
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Program
prog
        else do
          Program
rewritten <- Program -> [Rule] -> RewriteContext -> IO Program
rewrite Program
prog [Rule]
rules RewriteContext
ctx
          if Program
rewritten Program -> Program -> Bool
forall a. Eq a => a -> a -> Bool
== Program
prog
            then do
              String -> IO ()
logDebug String
"No rule matched, rewriting is stopped"
              Program -> IO Program
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Program
rewritten
            else Program -> Integer -> IO Program
_rewrite Program
rewritten (Integer
count Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1)