{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DuplicateRecordFields #-}
{-# LANGUAGE RecordWildCards #-}

-- SPDX-FileCopyrightText: Copyright (c) 2025 Objectionary.com
-- SPDX-License-Identifier: MIT

-- The goal of the module is to traverse though the Program with replacing
-- pattern sub expression with target expressions
module Replacer
  ( replaceProgram,
    replaceProgramThrows,
    replaceProgramFast,
    replaceProgramFastThrows,
    ReplaceProgramThrowsFunc,
    ReplaceProgramFunc,
    ReplaceProgramContext (..),
  )
where

import Ast
import Control.Exception (Exception, throwIO)
import Matcher (Tail (TaApplication, TaDispatch))
import Pretty (prettyExpression, prettyProgram)
import Text.Printf (printf)
import Data.List (isPrefixOf)

data ReplaceProgramContext = ReplaceProgramContext
  { ReplaceProgramContext -> Program
_program :: Program,
    ReplaceProgramContext -> Integer
_maxDepth :: Integer
  }

data ReplaceExpressionContext = ReplaceExpressionContext
  { ReplaceExpressionContext -> Expression
_expression :: Expression,
    ReplaceExpressionContext -> Integer
_maxDepth :: Integer
  }

updateExpressionContext :: ReplaceExpressionContext -> Expression -> ReplaceExpressionContext
updateExpressionContext :: ReplaceExpressionContext -> Expression -> ReplaceExpressionContext
updateExpressionContext ReplaceExpressionContext {Integer
Expression
_expression :: ReplaceExpressionContext -> Expression
_maxDepth :: ReplaceExpressionContext -> Integer
_expression :: Expression
_maxDepth :: Integer
..} Expression
expr = Expression -> Integer -> ReplaceExpressionContext
ReplaceExpressionContext Expression
expr Integer
_maxDepth

type ReplaceProgramThrowsFunc = [Expression] -> [Expression] -> ReplaceProgramContext -> IO Program

type ReplaceProgramFunc = [Expression] -> [Expression] -> ReplaceProgramContext -> Maybe Program

type ReplaceExpressionFunc = [Expression] -> [Expression] -> ReplaceExpressionContext -> (Expression, [Expression], [Expression])

newtype ReplaceException = CouldNotReplace {ReplaceException -> Program
prog :: Program}
  deriving (Show ReplaceException
Typeable ReplaceException
(Typeable ReplaceException, Show ReplaceException) =>
(ReplaceException -> SomeException)
-> (SomeException -> Maybe ReplaceException)
-> (ReplaceException -> String)
-> Exception ReplaceException
SomeException -> Maybe ReplaceException
ReplaceException -> String
ReplaceException -> SomeException
forall e.
(Typeable e, Show e) =>
(e -> SomeException)
-> (SomeException -> Maybe e) -> (e -> String) -> Exception e
$ctoException :: ReplaceException -> SomeException
toException :: ReplaceException -> SomeException
$cfromException :: SomeException -> Maybe ReplaceException
fromException :: SomeException -> Maybe ReplaceException
$cdisplayException :: ReplaceException -> String
displayException :: ReplaceException -> String
Exception)

instance Show ReplaceException where
  show :: ReplaceException -> String
show CouldNotReplace {Program
prog :: ReplaceException -> Program
prog :: Program
..} =
    String -> ShowS
forall r. PrintfType r => String -> r
printf
      String
"Couldn't replace expression in program, lists of patterns and targets has different lengths\nProgram: %s"
      (Program -> String
prettyProgram Program
prog)

replaceBindings :: [Binding] -> [Expression] -> [Expression] -> ReplaceExpressionContext -> ReplaceExpressionFunc -> ([Binding], [Expression], [Expression])
replaceBindings :: [Binding]
-> [Expression]
-> [Expression]
-> ReplaceExpressionContext
-> ReplaceExpressionFunc
-> ([Binding], [Expression], [Expression])
replaceBindings [Binding]
bds [] [] ReplaceExpressionContext
_ ReplaceExpressionFunc
_ = ([Binding]
bds, [], [])
replaceBindings [] [Expression]
ptns [Expression]
repls ReplaceExpressionContext
_ ReplaceExpressionFunc
_ = ([], [Expression]
ptns, [Expression]
repls)
replaceBindings (BiTau Attribute
attr Expression
expr : [Binding]
bds) [Expression]
ptns [Expression]
repls ReplaceExpressionContext
ctx ReplaceExpressionFunc
func =
  let (Expression
expr', [Expression]
ptns', [Expression]
repls') = ReplaceExpressionFunc
func [Expression]
ptns [Expression]
repls (ReplaceExpressionContext -> Expression -> ReplaceExpressionContext
updateExpressionContext ReplaceExpressionContext
ctx Expression
expr)
      ([Binding]
bds', [Expression]
ptns'', [Expression]
repls'') = [Binding]
-> [Expression]
-> [Expression]
-> ReplaceExpressionContext
-> ReplaceExpressionFunc
-> ([Binding], [Expression], [Expression])
replaceBindings [Binding]
bds [Expression]
ptns' [Expression]
repls' ReplaceExpressionContext
ctx ReplaceExpressionFunc
func
   in (Attribute -> Expression -> Binding
BiTau Attribute
attr Expression
expr' Binding -> [Binding] -> [Binding]
forall a. a -> [a] -> [a]
: [Binding]
bds', [Expression]
ptns'', [Expression]
repls'')
replaceBindings (Binding
bd : [Binding]
bds) [Expression]
ptns [Expression]
repls ReplaceExpressionContext
ctx ReplaceExpressionFunc
func =
  let ([Binding]
bds', [Expression]
ptns', [Expression]
repls') = [Binding]
-> [Expression]
-> [Expression]
-> ReplaceExpressionContext
-> ReplaceExpressionFunc
-> ([Binding], [Expression], [Expression])
replaceBindings [Binding]
bds [Expression]
ptns [Expression]
repls ReplaceExpressionContext
ctx ReplaceExpressionFunc
func
   in (Binding
bd Binding -> [Binding] -> [Binding]
forall a. a -> [a] -> [a]
: [Binding]
bds', [Expression]
ptns', [Expression]
repls')

replaceExpression :: ReplaceExpressionFunc
replaceExpression :: ReplaceExpressionFunc
replaceExpression [] [] ReplaceExpressionContext {Integer
Expression
_expression :: ReplaceExpressionContext -> Expression
_maxDepth :: ReplaceExpressionContext -> Integer
_expression :: Expression
_maxDepth :: Integer
..} = (Expression
_expression, [], [])
replaceExpression ptns :: [Expression]
ptns@(Expression
ptn : [Expression]
ptnsRest) repls :: [Expression]
repls@(Expression
repl : [Expression]
replsRest) ctx :: ReplaceExpressionContext
ctx@ReplaceExpressionContext {Integer
Expression
_expression :: ReplaceExpressionContext -> Expression
_maxDepth :: ReplaceExpressionContext -> Integer
_expression :: Expression
_maxDepth :: Integer
..} =
  if Expression
_expression Expression -> Expression -> Bool
forall a. Eq a => a -> a -> Bool
== Expression
ptn
    then ReplaceExpressionFunc
replaceExpression [Expression]
ptnsRest [Expression]
replsRest (ReplaceExpressionContext -> Expression -> ReplaceExpressionContext
updateExpressionContext ReplaceExpressionContext
ctx Expression
repl)
    else case Expression
_expression of
      ExDispatch Expression
inner Attribute
attr ->
        let (Expression
expr', [Expression]
ptns', [Expression]
repls') = ReplaceExpressionFunc
replaceExpression [Expression]
ptns [Expression]
repls (ReplaceExpressionContext -> Expression -> ReplaceExpressionContext
updateExpressionContext ReplaceExpressionContext
ctx Expression
inner)
         in (Expression -> Attribute -> Expression
ExDispatch Expression
expr' Attribute
attr, [Expression]
ptns', [Expression]
repls')
      ExApplication Expression
inner Binding
tau ->
        let (Expression
expr', [Expression]
ptns', [Expression]
repls') = ReplaceExpressionFunc
replaceExpression [Expression]
ptns [Expression]
repls (ReplaceExpressionContext -> Expression -> ReplaceExpressionContext
updateExpressionContext ReplaceExpressionContext
ctx Expression
inner)
            ([Binding
tau'], [Expression]
ptns'', [Expression]
repls'') = [Binding]
-> [Expression]
-> [Expression]
-> ReplaceExpressionContext
-> ReplaceExpressionFunc
-> ([Binding], [Expression], [Expression])
replaceBindings [Binding
tau] [Expression]
ptns' [Expression]
repls' ReplaceExpressionContext
ctx ReplaceExpressionFunc
replaceExpression
         in (Expression -> Binding -> Expression
ExApplication Expression
expr' Binding
tau', [Expression]
ptns'', [Expression]
repls'')
      ExFormation [Binding]
bds ->
        let ([Binding]
bds', [Expression]
ptns', [Expression]
repls') = [Binding]
-> [Expression]
-> [Expression]
-> ReplaceExpressionContext
-> ReplaceExpressionFunc
-> ([Binding], [Expression], [Expression])
replaceBindings [Binding]
bds [Expression]
ptns [Expression]
repls ReplaceExpressionContext
ctx ReplaceExpressionFunc
replaceExpression
         in ([Binding] -> Expression
ExFormation [Binding]
bds', [Expression]
ptns', [Expression]
repls')
      Expression
_ -> (Expression
_expression, [Expression]
ptns, [Expression]
repls)

replaceBindingsFast :: [Binding] -> [Expression] -> [Expression] -> [Binding]
replaceBindingsFast :: [Binding] -> [Expression] -> [Expression] -> [Binding]
replaceBindingsFast [Binding]
bds [] [] = [Binding]
bds
replaceBindingsFast [Binding]
bds ((ExFormation [Binding]
pbds) : [Expression]
rptns) ((ExFormation [Binding]
rbds) : [Expression]
rrepls) =
  let replaced :: [Binding]
replaced = [Binding] -> [Binding] -> [Binding] -> [Binding]
replaceBindingsFast' [Binding]
bds [Binding]
pbds [Binding]
rbds
   in [Binding] -> [Expression] -> [Expression] -> [Binding]
replaceBindingsFast [Binding]
replaced [Expression]
rptns [Expression]
rrepls
  where
    replaceBindingsFast' :: [Binding] -> [Binding] -> [Binding] -> [Binding]
    replaceBindingsFast' :: [Binding] -> [Binding] -> [Binding] -> [Binding]
replaceBindingsFast' [Binding]
bds [Binding]
pattern [Binding]
replacement
      | [Binding] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Binding]
pattern = [Binding]
replacement
      | Bool
otherwise = [Binding] -> [Binding] -> [Binding] -> [Binding]
findAndReplace [Binding]
bds [Binding]
pattern [Binding]
replacement
    findAndReplace :: [Binding] -> [Binding] -> [Binding] -> [Binding]
    findAndReplace :: [Binding] -> [Binding] -> [Binding] -> [Binding]
findAndReplace [] [Binding]
_ [Binding]
_ = []
    findAndReplace xs :: [Binding]
xs@(Binding
x : [Binding]
xs') [Binding]
pattern [Binding]
replacement
      | [Binding]
pattern [Binding] -> [Binding] -> Bool
forall a. Eq a => [a] -> [a] -> Bool
`isPrefixOf` [Binding]
xs = [Binding]
replacement [Binding] -> [Binding] -> [Binding]
forall a. [a] -> [a] -> [a]
++ [Binding] -> [Binding] -> [Binding] -> [Binding]
findAndReplace (Int -> [Binding] -> [Binding]
forall a. Int -> [a] -> [a]
drop ([Binding] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Binding]
pattern) [Binding]
xs) [Binding]
pattern [Binding]
replacement
      | Bool
otherwise = Binding
x Binding -> [Binding] -> [Binding]
forall a. a -> [a] -> [a]
: [Binding] -> [Binding] -> [Binding] -> [Binding]
findAndReplace [Binding]
xs' [Binding]
pattern [Binding]
replacement

replaceExpressionFast :: ReplaceExpressionFunc
replaceExpressionFast :: ReplaceExpressionFunc
replaceExpressionFast = Integer -> ReplaceExpressionFunc
replaceExpressionFast' Integer
0
  where
    replaceExpressionFast' :: Integer -> ReplaceExpressionFunc
    replaceExpressionFast' :: Integer -> ReplaceExpressionFunc
replaceExpressionFast' Integer
_ [] [] ReplaceExpressionContext {Integer
Expression
_expression :: ReplaceExpressionContext -> Expression
_maxDepth :: ReplaceExpressionContext -> Integer
_expression :: Expression
_maxDepth :: Integer
..} = (Expression
_expression, [], [])
    replaceExpressionFast' Integer
depth ptns :: [Expression]
ptns@((ExFormation [Binding]
pbds) : [Expression]
rptns) repls :: [Expression]
repls@((ExFormation [Binding]
rbds) : [Expression]
rrepls) ctx :: ReplaceExpressionContext
ctx@ReplaceExpressionContext {Integer
Expression
_expression :: ReplaceExpressionContext -> Expression
_maxDepth :: ReplaceExpressionContext -> Integer
_expression :: Expression
_maxDepth :: Integer
..} =
      if Integer
depth Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
_maxDepth
        then (Expression
_expression, [], [])
        else case Expression
_expression of
          ExFormation [Binding]
bds ->
            let replaced :: [Binding]
replaced = [Binding] -> [Expression] -> [Expression] -> [Binding]
replaceBindingsFast [Binding]
bds [Expression]
ptns [Expression]
repls
                ([Binding]
bds', [Expression]
ptns', [Expression]
repls') = [Binding]
-> [Expression]
-> [Expression]
-> ReplaceExpressionContext
-> ReplaceExpressionFunc
-> ([Binding], [Expression], [Expression])
replaceBindings [Binding]
replaced [Expression]
ptns [Expression]
repls ReplaceExpressionContext
ctx (Integer -> ReplaceExpressionFunc
replaceExpressionFast' (Integer
depth Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Integer
1))
             in ([Binding] -> Expression
ExFormation [Binding]
bds', [Expression]
ptns', [Expression]
repls')
          ExDispatch Expression
inner Attribute
attr ->
            let (Expression
expr', [Expression]
ptns', [Expression]
repls') = ReplaceExpressionFunc
replaceExpressionFast [Expression]
ptns [Expression]
repls (ReplaceExpressionContext -> Expression -> ReplaceExpressionContext
updateExpressionContext ReplaceExpressionContext
ctx Expression
inner)
             in (Expression -> Attribute -> Expression
ExDispatch Expression
expr' Attribute
attr, [Expression]
ptns', [Expression]
repls')
          ExApplication Expression
inner (BiTau Attribute
attr Expression
texpr) ->
            let (Expression
expr', [Expression]
ptns', [Expression]
repls') = ReplaceExpressionFunc
replaceExpressionFast [Expression]
ptns [Expression]
repls (ReplaceExpressionContext -> Expression -> ReplaceExpressionContext
updateExpressionContext ReplaceExpressionContext
ctx Expression
inner)
                (Expression
expr'', [Expression]
ptns'', [Expression]
repls'') = ReplaceExpressionFunc
replaceExpressionFast [Expression]
ptns' [Expression]
repls' (ReplaceExpressionContext -> Expression -> ReplaceExpressionContext
updateExpressionContext ReplaceExpressionContext
ctx Expression
texpr)
             in (Expression -> Binding -> Expression
ExApplication Expression
expr' (Attribute -> Expression -> Binding
BiTau Attribute
attr Expression
expr''), [Expression]
ptns'', [Expression]
repls'')
          Expression
_ -> (Expression
_expression, [Expression]
ptns, [Expression]
repls)
    replaceExpressionFast' Integer
_ [Expression]
ptns [Expression]
repls ctx :: ReplaceExpressionContext
ctx@ReplaceExpressionContext{Integer
Expression
_expression :: ReplaceExpressionContext -> Expression
_maxDepth :: ReplaceExpressionContext -> Integer
_expression :: Expression
_maxDepth :: Integer
..} = (Expression
_expression, [Expression]
ptns, [Expression]
repls)

replaceProgram' :: ReplaceExpressionFunc -> ReplaceProgramFunc
replaceProgram' :: ReplaceExpressionFunc -> ReplaceProgramFunc
replaceProgram' ReplaceExpressionFunc
func [Expression]
ptns [Expression]
repls ReplaceProgramContext {_program :: ReplaceProgramContext -> Program
_program = Program Expression
expr, Integer
_maxDepth :: ReplaceProgramContext -> Integer
_maxDepth :: Integer
..}
  | [Expression] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Expression]
ptns Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== [Expression] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [Expression]
repls =
      let (Expression
expr', [Expression]
_, [Expression]
_) = ReplaceExpressionFunc
func [Expression]
ptns [Expression]
repls (Expression -> Integer -> ReplaceExpressionContext
ReplaceExpressionContext Expression
expr Integer
_maxDepth)
       in Program -> Maybe Program
forall a. a -> Maybe a
Just (Expression -> Program
Program Expression
expr')
  | Bool
otherwise = Maybe Program
forall a. Maybe a
Nothing

replaceProgram :: ReplaceProgramFunc
replaceProgram :: ReplaceProgramFunc
replaceProgram = ReplaceExpressionFunc -> ReplaceProgramFunc
replaceProgram' ReplaceExpressionFunc
replaceExpression

replaceProgramThrows' :: ReplaceExpressionFunc -> ReplaceProgramThrowsFunc
replaceProgramThrows' :: ReplaceExpressionFunc -> ReplaceProgramThrowsFunc
replaceProgramThrows' ReplaceExpressionFunc
func [Expression]
ptns [Expression]
repls ReplaceProgramContext
ctx = case ReplaceExpressionFunc -> ReplaceProgramFunc
replaceProgram' ReplaceExpressionFunc
func [Expression]
ptns [Expression]
repls ReplaceProgramContext
ctx of
  Just Program
prog' -> Program -> IO Program
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure Program
prog'
  Maybe Program
_ -> ReplaceException -> IO Program
forall e a. Exception e => e -> IO a
throwIO (Program -> ReplaceException
CouldNotReplace (ReplaceProgramContext -> Program
_program ReplaceProgramContext
ctx))

replaceProgramThrows :: ReplaceProgramThrowsFunc
replaceProgramThrows :: ReplaceProgramThrowsFunc
replaceProgramThrows = ReplaceExpressionFunc -> ReplaceProgramThrowsFunc
replaceProgramThrows' ReplaceExpressionFunc
replaceExpression

replaceProgramFast :: ReplaceProgramFunc
replaceProgramFast :: ReplaceProgramFunc
replaceProgramFast = ReplaceExpressionFunc -> ReplaceProgramFunc
replaceProgram' ReplaceExpressionFunc
replaceExpressionFast

replaceProgramFastThrows :: ReplaceProgramThrowsFunc
replaceProgramFastThrows :: ReplaceProgramThrowsFunc
replaceProgramFastThrows = ReplaceExpressionFunc -> ReplaceProgramThrowsFunc
replaceProgramThrows' ReplaceExpressionFunc
replaceExpressionFast