{-# language OverloadedStrings #-}
{-# language DataKinds #-}
{-# language BangPatterns #-}
module OptimizeTailRecursion where

import Control.Applicative ((<|>))
import Control.Lens.Cons (_last, _init)
import Control.Lens.Fold ((^..), (^?), (^?!), allOf, anyOf, folded, foldrOf)
import Control.Lens.Getter ((^.), to)
import Control.Lens.Plated (cosmos, transform, transformOn)
import Control.Lens.Prism (_Just)
import Control.Lens.Review ((#))
import Control.Lens.Setter ((%~), (.~))
import Control.Lens.Tuple (_2, _3)
import Data.Foldable (toList)
import Data.Function ((&))
import Data.Semigroup ((<>))

import Language.Python.Optics
import Language.Python.DSL
import Language.Python.Syntax.Expr (Expr (..), _Exprs, argExpr, paramName)
import Language.Python.Syntax.Statement (CompoundStatement (..), Statement (..), SmallStatement (..), SimpleStatement (..), _Statements)

optimizeTailRecursion :: Raw Statement -> Maybe (Raw Statement)
optimizeTailRecursion st = do
  function <- st ^? _Fundef
  let functionBody = function ^. body_
  bodyLast <- lastStatement functionBody

  let
    functionName = function ^. fdName.identValue
    bodyInit = functionBody ^?! _init
    paramNames = function ^.. fdParameters.folded.paramName.identValue

  if not $ hasTC functionName bodyLast
    then Nothing
    else
      Just $
      _Fundef #
        (function &
         body_ .~
           (zipWith
              (\a b -> line_ (var_ (a <> "__tr") .= var_ b))
              paramNames
              paramNames <>

            [ line_ ("__res__tr" .= none_)
            , line_ . while_ true_ .
              transformOn (traverse._Exprs) (renameIn paramNames "__tr") $
                bodyInit <>
                looped functionName paramNames bodyLast
            , line_ $ return_ "__res__tr"
            ]))

  where
    lastStatement :: [Raw Line] -> Maybe (Raw Statement)
    lastStatement = go Nothing
      where
        go !res [] = res
        go !res (a:as) = go (a ^? _Statements <|> res) as

    isTailCall :: String -> Raw Expr -> Bool
    isTailCall name e
      | anyOf (cosmos._Call.callFunction._Ident.identValue) (== name) e
      = (e ^? _Call.callFunction._Ident.identValue) == Just name
      | otherwise = False

    hasTC :: String -> Raw Statement -> Bool
    hasTC name st =
      case st of
        CompoundStatement (If _ _ _ _ sts [] sts') ->
          allOf _last (hasTC name) (sts ^.. _Statements) ||
          allOf _last (hasTC name) (sts' ^.. _Just._3._Statements)
        SmallStatement _ (MkSmallStatement s ss _ _ _) ->
          case last (s : fmap (^. _2) ss) of
            Return _ _ (Just e) -> isTailCall name e
            -- Return _ _ Nothing -> True
            Expr _ e -> isTailCall name e
            _ -> False
        _ -> False

    renameIn :: [String] -> String -> Raw Expr -> Raw Expr
    renameIn params suffix =
      transform
        (_Ident.identValue %~ (\a -> if a `elem` params then a <> suffix else a))

    looped :: String -> [String] -> Raw Statement -> [Raw Line]
    looped name params st
      | Just ifSt <- st ^? _If
      , hasTC name st =
          let
            ifBodyLines = toList $ ifSt ^. body_
          in
            case ifSt ^? to getElse._Just.body_ of
              Nothing ->
                [ line_ $
                  if_ (ifSt ^. ifCond)
                    ((ifBodyLines ^?! _init) <>
                     looped name params (ifBodyLines ^?! _last._Statements))
                ]
              Just sts'' ->
                [ line_ $
                  if_ (ifSt ^. ifCond)
                    ((ifSt ^?! body_.to toList._init) <>
                     looped name params (ifBodyLines ^?! _last._Statements)) &
                  else_
                    ((toList sts'' ^?! _init) <>
                     looped name params (toList sts'' ^?! _last._Statements))
                ]
      | otherwise =
          case st of
            CompoundStatement{} -> [line_ st]
            SmallStatement idnts (MkSmallStatement s ss sc cmt nl) ->
              let
                initExps = foldr (\_ _ -> init ss) [] ss
                lastExp = foldrOf (folded._2) (\_ _ -> last ss ^. _2) s ss
                newSts =
                  case initExps of
                    [] -> []
                    first : rest ->
                      [ line_ $
                        SmallStatement idnts
                        (MkSmallStatement (first ^. _2) rest sc cmt nl)
                      ]
              in
                case lastExp of
                  Return _ _ e ->
                    case e ^? _Just._Call of
                      Just call
                        | Just name' <- call ^? callFunction._Ident.identValue
                        , name' == name ->
                            newSts <>
                            fmap
                              (\a -> line_ (var_ (a <> "__tr__old") .= var_ (a <> "__tr")))
                              params <>
                            zipWith
                              (\a b -> line_ (var_ (a <> "__tr") .= b))
                              params
                              (transformOn
                                traverse
                                (renameIn params "__tr__old")
                                (call ^.. callArguments.folded.folded.argExpr))
                      _ ->
                        newSts <>
                        maybe [] (\e' -> [ line_ ("__res__tr" .= e') ]) e <>
                        [ line_ break_ ]
                  Expr _ e
                    | isTailCall name e -> newSts <> [line_ pass_]
                  _ -> [line_ st]