{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE OverloadedStrings #-}
module Text.TeXMath.Readers.MathML (readMathML) where
import Text.XML.Light hiding (onlyText)
import Text.TeXMath.Types
import Text.TeXMath.Readers.MathML.MMLDict (getMathMLOperator)
import Text.TeXMath.Readers.MathML.EntityMap (getUnicode)
import Text.TeXMath.Shared (getTextType, readLength, getOperator, fixTree,
                            getSpaceWidth, isEmpty, empty)
import Text.TeXMath.Unicode.ToTeX (getSymbolType)
import Text.TeXMath.Unicode.ToUnicode (fromUnicode)
import Text.TeXMath.Compat (throwError, Except, runExcept, MonadError)
import Control.Applicative ((<$>), (<|>), (<*>))
import Control.Arrow ((&&&))
import Data.Maybe (fromMaybe, listToMaybe, isJust)
import Data.Monoid (mconcat, First(..), getFirst)
import Data.Semigroup ((<>))
import Data.List (transpose)
import qualified Data.Text as T
import Control.Monad (filterM, guard)
import Control.Monad.Reader (ReaderT, runReaderT, asks, local)
import Data.Either (rights)
readMathML :: T.Text -> Either T.Text [Exp]
readMathML inp = map fixTree <$>
  (runExcept (flip runReaderT defaultState (i >>= parseMathML)))
  where
    i = maybeToEither "Invalid XML" (parseXMLDoc inp)
data MMLState = MMLState { attrs :: [Attr]
                         , position :: Maybe FormType
                         , inAccent :: Bool
                         , curStyle :: TextType }
type MML = ReaderT MMLState (Except T.Text)
data SupOrSub = Sub | Sup deriving (Show, Eq)
data IR a = Stretchy TeXSymbolType (T.Text -> Exp) T.Text
          | Trailing (Exp -> Exp -> Exp) Exp
          | E a
instance Show a => Show (IR a) where
  show (Stretchy t _ s) = "Stretchy " ++ show t ++ " " ++ show s
  show (Trailing _ s) = "Trailing " ++ show s
  show (E s) = "E " ++ show s
parseMathML :: Element -> MML [Exp]
parseMathML e@(name -> "math") = do
  e' <- row e
  return $
    case e' of
      EGrouped es -> es
      _ -> [e']
parseMathML _ = throwError "Root must be math element"
expr :: Element -> MML [IR Exp]
expr e = local (addAttrs (elAttribs e)) (expr' e)
expr' :: Element -> MML [IR Exp]
expr' e =
  case name e of
    "mi" -> mkE <$> ident e
    "mn" -> mkE <$> number e
    "mo" -> (:[]) <$> op e
    "mtext" -> mkE <$> text e
    "ms" -> mkE <$> literal e
    "mspace" -> mkE <$> space e
    "mrow" -> mkE <$> row e
    "mstyle" -> mkE <$> style e
    "mfrac" -> mkE <$> frac e
    "msqrt" -> mkE <$> msqrt e
    "mroot" -> mkE <$> kroot e
    "merror" -> return (mkE empty)
    "mpadded" -> mkE <$> row e
    "mphantom" -> mkE <$> phantom e
    "mfenced" -> mkE <$> fenced e
    "menclose" -> mkE <$> enclosed e
    "msub" ->  sub e
    "msup" ->  sup e
    "msubsup" -> mkE <$> subsup e
    "munder" -> mkE <$> under e
    "mover" -> mkE <$> over e
    "munderover" -> mkE <$> underover e
    "mtable" -> mkE <$> table e
    "maction" -> mkE <$> action e
    "semantics" -> mkE <$> semantics e
    "maligngroup" -> return $ mkE empty
    "malignmark" -> return $ mkE empty
    _ -> throwError $ "Unexpected element " <> err e
  where
    mkE :: Exp -> [IR Exp]
    mkE = (:[]) . E
ident :: Element -> MML Exp
ident e =  do
  s <- getString e
  let base = case getOperator (EMathOperator s) of
                   Just _   -> EMathOperator s
                   Nothing  -> EIdentifier s
  mbVariant <- findAttrQ "mathvariant" e
  curstyle <- asks curStyle
  case mbVariant of
       Nothing  -> return base
       Just v
         | curstyle == getTextType v -> return base
         | otherwise  -> return $ EStyled (getTextType v) [base]
number :: Element -> MML Exp
number e = ENumber <$> getString e
op :: Element -> MML (IR Exp)
op e = do
  mInferredPosition <- (<|>) <$> (getFormType <$> findAttrQ "form" e)
                            <*> asks position
  inferredPosition <- case mInferredPosition of
    Just inferredPosition -> pure inferredPosition
    Nothing               -> throwError "Did not find an inferred position"
  opString <- getString e
  let dummy = Operator opString "" inferredPosition 0 0 0 []
  let opLookup = getMathMLOperator opString inferredPosition
  let opDict = fromMaybe dummy opLookup
  props <- filterM (checkAttr (properties opDict))
            ["fence", "accent", "stretchy"]
  let objectPosition = getPosition $ form opDict
  inScript <- asks inAccent
  let ts =  [("accent", ESymbol Accent), ("fence", ESymbol objectPosition)]
  let fallback = case T.unpack opString of
                   [t] -> ESymbol (getSymbolType t)
                   _   -> if isJust opLookup
                          then ESymbol Ord
                          else EMathOperator
  let constructor =
        fromMaybe fallback
          (getFirst . mconcat $ map (First . flip lookup ts) props)
  if ("stretchy" `elem` props) && not inScript
    then return $ Stretchy objectPosition constructor opString
    else do
      return $ (E . constructor) opString
  where
    checkAttr ps v = maybe (v `elem` ps) (=="true") <$> findAttrQ (T.unpack v) e
text :: Element -> MML Exp
text e = do
  textStyle <- maybe TextNormal getTextType
                <$> (findAttrQ "mathvariant" e)
  s <- getString e
  
  
  return $ case (textStyle, T.unpack s) of
       (TextNormal, [c]) ->
         case getSpaceWidth c of
              Just w  -> ESpace w
              Nothing -> EText textStyle s
       _ -> EText textStyle s
literal :: Element -> MML Exp
literal e = do
  lquote <- fromMaybe "\x201C" <$> findAttrQ "lquote" e
  rquote <- fromMaybe "\x201D" <$> findAttrQ "rquote" e
  textStyle <- maybe TextNormal getTextType
                <$> (findAttrQ "mathvariant" e)
  s <- getString e
  return $ EText textStyle $ lquote <> s <> rquote
space :: Element -> MML Exp
space e = do
  width <- fromMaybe "0.0em" <$> (findAttrQ "width" e)
  return $ ESpace (widthToNum width)
style :: Element -> MML Exp
style e = do
  tt <- maybe TextNormal getTextType <$> findAttrQ "mathvariant" e
  curstyle <- asks curStyle
  
  
  
  result <- local (filterMathVariant . enterStyled tt) (row e)
  return $ if curstyle == tt
              then result
              else EStyled tt [result]
row :: Element -> MML Exp
row e = mkExp <$> group e
mkExp :: [IR Exp] -> Exp
mkExp = toExp . toEDelim . matchNesting
toExp :: [InEDelimited] -> Exp
toExp [] = empty
toExp xs =
  if any isStretchy xs
    then case xs of
              [x] -> either (ESymbol Ord) id x
              _ -> EDelimited "" "" xs
    else
      case xs of
        [Right x] -> x
        _ -> EGrouped (rights xs)
toEDelim :: [IR InEDelimited] -> [InEDelimited]
toEDelim [] = []
toEDelim [Stretchy _ con s] = [Right $ con s]
toEDelim (xs) = map removeIR xs
removeIR :: IR a -> a
removeIR (E e) = e
removeIR _ = error "removeIR, should only be ever called on processed lists"
removeStretch :: [IR Exp] -> [IR InEDelimited]
removeStretch [Stretchy _ constructor s] = [E $ Right (constructor s)]
removeStretch xs = map f xs
  where
    f (Stretchy _ _ s) = E $ Left s
    f (E e) = E $ Right e
    f (Trailing a b) = Trailing a b
isStretchy :: InEDelimited -> Bool
isStretchy (Left _) = True
isStretchy (Right _) = False
trailingSup :: Maybe (T.Text, T.Text -> Exp)  -> Maybe (T.Text, T.Text -> Exp)  -> [IR InEDelimited] -> Exp
trailingSup open close es = go es
  where
    go [] = case (open, close) of
              (Nothing, Nothing) -> empty
              (Just (openFence, conOpen), Nothing) -> conOpen openFence
              (Nothing, Just (closeFence, conClose)) -> conClose closeFence
              (Just (openFence, conOpen), Just (closeFence, conClose))  ->
                EGrouped [conOpen openFence, conClose closeFence]
    go es'@(last -> Trailing constructor e) = (constructor (go (init es')) e)
    go es' = EDelimited (getFence open) (getFence close) (toEDelim es')
    getFence = fromMaybe "" . fmap fst
matchNesting :: [IR Exp] -> [IR InEDelimited]
matchNesting ((break isFence) -> (inis, rest)) =
  let inis' = removeStretch inis in
  case rest of
    [] -> inis'
    ((Stretchy Open conOpen opens): rs) ->
      let jOpen = Just (opens, conOpen)
          (body, rems) = go rs 0 []
          body' = matchNesting body in
        case rems of
          [] -> inis' ++ [E $ Right $ trailingSup jOpen Nothing body']
          (Stretchy Close conClose closes : rs') ->
            let jClose = Just (closes, conClose) in
            inis' ++ (E $ Right $ trailingSup jOpen jClose body') : matchNesting rs'
          _ -> (error "matchNesting: Logical error 1")
    ((Stretchy Close conClose closes): rs) ->
      let jClose = Just (closes, conClose) in
      (E $ Right $ trailingSup Nothing jClose (matchNesting inis)) : matchNesting rs
    _ -> error "matchNesting: Logical error 2"
  where
    isOpen (Stretchy Open _ _) = True
    isOpen _ = False
    isClose (Stretchy Close _ _) = True
    isClose _ = False
    go :: [IR a] -> Int -> [IR a] -> ([IR a], [IR a])
    go (x:xs) 0 a | isClose x = (reverse a, x:xs)
    go (x:xs) n a | isOpen x  = go xs (n + 1) (x:a)
    go (x:xs) n a | isClose x = go xs (n - 1) (x:a)
    go (x:xs) n a = go xs n (x:a)
    go [] _ a = (reverse a, [])
isFence :: IR a -> Bool
isFence (Stretchy Open _ _) = True
isFence (Stretchy Close _ _) = True
isFence _ = False
group :: Element -> MML [IR Exp]
group e = do
  front <- concat <$> mapM expr frontSpaces
  middle <- local resetPosition (row' body)
  end <- concat <$> local resetPosition (mapM expr endSpaces)
  return $ (front ++ middle ++ end)
  where
    cs = elChildren e
    (frontSpaces, noFront)  = span spacelike cs
    (endSpaces, body) = let (as, bs) = span spacelike (reverse noFront) in
                          (reverse as, reverse bs)
row' :: [Element] -> MML [IR Exp]
row' [] = return []
row' [x] = do
              pos <- maybe FInfix (const FPostfix) <$> asks position
              local (setPosition pos) (expr x)
row' (x:xs) =
  do
    pos <- maybe FPrefix (const FInfix) <$> asks position
    e  <- local (setPosition pos) (expr x)
    es <- local (setPosition pos) (row' xs)
    return (e ++ es)
safeExpr :: Element -> MML Exp
safeExpr e = mkExp <$> expr e
frac :: Element -> MML Exp
frac e = do
  (num, denom) <- mapPairM safeExpr =<< (checkArgs2 e)
  rawThick <- findAttrQ "linethickness" e
  return $
    if thicknessZero rawThick
       then EFraction NoLineFrac num denom
       else EFraction NormalFrac num denom
msqrt :: Element -> MML Exp
msqrt e = ESqrt <$> (row e)
kroot :: Element -> MML Exp
kroot e = do
  (base, index) <- mapPairM safeExpr =<< (checkArgs2 e)
  return $ ERoot index base
phantom :: Element -> MML Exp
phantom e = EPhantom <$> row e
fenced :: Element -> MML Exp
fenced e = do
  open  <- fromMaybe "(" <$> (findAttrQ "open" e)
  close <- fromMaybe ")" <$> (findAttrQ "close" e)
  sep  <- fromMaybe "," <$> (findAttrQ "separators" e)
  let expanded =
        case sep of
          "" -> elChildren e
          _  ->
            let seps = map (\x -> unode "mo" [x]) $ T.unpack sep
                sepsList = seps ++ repeat (last seps) in
                fInterleave (elChildren e) (sepsList)
  safeExpr $ unode "mrow"
              ([tunode "mo" open | not $ T.null open] ++
               [unode "mrow" expanded] ++
               [tunode "mo" close | not $ T.null close])
enclosed :: Element -> MML Exp
enclosed e = do
  mbNotation <- findAttrQ "notation" e
  case mbNotation of
       Just "box" -> EBoxed <$> row e
       _ -> row e
action :: Element -> MML Exp
action e = do
  selection <-  maybe 1 (read . T.unpack) <$> (findAttrQ "selection" e)  
  safeExpr =<< maybeToEither ("Selection out of range")
            (listToMaybe $ drop (selection - 1) (elChildren e))
sub :: Element -> MML [IR Exp]
sub e =  do
  (base, subs) <- checkArgs2 e
  reorderScripts base subs ESub
reorderScripts :: Element -> Element -> (Exp -> Exp -> Exp) -> MML [IR Exp]
reorderScripts e subs c = do
  baseExpr <- expr e
  subExpr <- postfixExpr subs
  return $
    case baseExpr of
      [s@(Stretchy Open _ _)] -> [s, E $ c empty subExpr]  
      [s@(Stretchy Close _ _)] -> [Trailing c subExpr, s] 
      [s@(Stretchy _ _ _)] -> [s, E $ ESub empty subExpr] 
      _ -> [E $ c (mkExp baseExpr) subExpr] 
sup :: Element -> MML [IR Exp]
sup e = do
  (base, sups) <- checkArgs2 e
  reorderScripts base sups ESuper
subsup :: Element -> MML Exp
subsup e = do
  (base, subs, sups) <- checkArgs3 e
  ESubsup <$> safeExpr base <*> (postfixExpr subs)
                         <*> (postfixExpr sups)
under :: Element -> MML Exp
under e = do
  (base, below) <- checkArgs2 e
  EUnder False <$> safeExpr base <*> postfixExpr below
over :: Element -> MML Exp
over e = do
  (base, above) <- checkArgs2 e
  EOver False <$> safeExpr base <*> postfixExpr above
underover :: Element -> MML Exp
underover e = do
  (base, below, above) <- checkArgs3 e
  EUnderover False <$> safeExpr base  <*> (postfixExpr below)
                                      <*> (postfixExpr above)
semantics :: Element -> MML Exp
semantics e = do
  guard (not $ null cs)
  first <- safeExpr (head cs)
  if isEmpty first
    then fromMaybe empty . getFirst . mconcat <$> mapM annotation (tail cs)
    else return first
  where
    cs = elChildren e
annotation :: Element -> MML (First Exp)
annotation e = do
  encoding <- findAttrQ "encoding" e
  case encoding of
    Just "application/mathml-presentation+xml" ->
      First . Just <$> row e
    Just "MathML-Presentation" ->
      First . Just <$> row e
    _ -> return (First Nothing)
table :: Element -> MML Exp
table e = do
  defAlign <- maybe AlignCenter toAlignment <$> (findAttrQ "columnalign" e)
  rs <- mapM (tableRow defAlign) (elChildren e)
  let (onlyAligns, exprs) = (map .map) fst &&& (map . map) snd $ rs
  let rs' = map (pad (maximum (map length rs))) exprs
  let aligns = map findAlign (transpose onlyAligns)
  return $ EArray aligns rs'
  where
    findAlign xs = if null xs then AlignCenter
                    else foldl1 combine xs
    combine x y = if x == y then x else AlignCenter
tableRow :: Alignment -> Element -> MML [(Alignment, [Exp])]
tableRow a e = do
  align <- maybe a toAlignment <$> (findAttrQ "columnalign" e)
  case name e of
    "mtr" -> mapM (tableCell align) (elChildren e)
    "mlabeledtr" -> mapM (tableCell align) (tail $ elChildren e)
    _ -> throwError $ "Invalid Element: Only expecting mtr elements " <> err e
tableCell :: Alignment -> Element -> MML (Alignment, [Exp])
tableCell a e = do
  align <- maybe a toAlignment <$> (findAttrQ "columnalign" e)
  case name e of
    "mtd" -> (,) align . (:[]) <$> row e
    _ -> throwError $ "Invalid Element: Only expecting mtd elements " <> err e
maybeToEither :: (MonadError e m) => e -> Maybe a -> m a
maybeToEither = flip maybe return . throwError
fInterleave :: [a] -> [a] -> [a]
fInterleave [] _ = []
fInterleave _ [] = []
fInterleave (x:xs) ys = x : fInterleave ys xs
defaultState :: MMLState
defaultState = MMLState [] Nothing False TextNormal
addAttrs :: [Attr] -> MMLState -> MMLState
addAttrs as s = s {attrs = (map renameAttr as) ++ attrs s }
renameAttr :: Attr -> Attr
renameAttr v@(qName . attrKey -> "accentunder") =
  Attr (unqual "accent") (attrVal v)
renameAttr a = a
filterMathVariant :: MMLState -> MMLState
filterMathVariant s@(attrs -> as) =
  s{attrs = filter ((/= unqual "mathvariant") . attrKey) as}
setPosition :: FormType -> MMLState -> MMLState
setPosition p s = s {position = Just p}
resetPosition :: MMLState -> MMLState
resetPosition s = s {position = Nothing}
enterAccent :: MMLState -> MMLState
enterAccent s = s{ inAccent = True }
enterStyled :: TextType -> MMLState -> MMLState
enterStyled tt s = s{ curStyle = tt }
getString :: Element -> MML T.Text
getString e = do
  tt <- asks curStyle
  return $ fromUnicode tt $ stripSpaces $ T.pack $ concatMap cdData
         $ onlyText $ elContent $ e
onlyText :: [Content] -> [CData]
onlyText [] = []
onlyText ((Text c):xs) = c : onlyText xs
onlyText (CRef s : xs)  = (CData CDataText (fromMaybe s $ getUnicode' s) Nothing) : onlyText xs
  where getUnicode' = fmap T.unpack . getUnicode . T.pack
onlyText (_:xs) = onlyText xs
checkArgs2 :: Element -> MML (Element, Element)
checkArgs2 e = case elChildren e of
  [a, b] -> return (a, b)
  _      -> throwError ("Incorrect number of arguments for " <> err e)
checkArgs3 :: Element -> MML (Element, Element, Element)
checkArgs3 e = case elChildren e of
  [a, b, c] -> return (a, b, c)
  _         -> throwError ("Incorrect number of arguments for " <> err e)
mapPairM :: Monad m => (a -> m b) -> (a, a) -> m (b, b)
mapPairM f (a, b) = (,) <$> (f a) <*> (f b)
err :: Element -> T.Text
err e = name e <> maybe "" (\x -> " line " <> T.pack (show x)) (elLine e)
findAttrQ :: String -> Element -> MML (Maybe T.Text)
findAttrQ s e = do
  inherit <- asks (lookupAttrQ s . attrs)
  return $ fmap T.pack $
    findAttr (QName s Nothing Nothing) e
      <|> inherit
lookupAttrQ :: String -> [Attr] -> Maybe String
lookupAttrQ s = lookupAttr (QName s Nothing Nothing)
name :: Element -> T.Text
name (elName -> (QName n _ _)) = T.pack n
tunode :: String -> T.Text -> Element
tunode s = unode s . T.unpack
stripSpaces :: T.Text -> T.Text
stripSpaces = T.dropAround isSpace
toAlignment :: T.Text -> Alignment
toAlignment "left" = AlignLeft
toAlignment "center" = AlignCenter
toAlignment "right" = AlignRight
toAlignment _ = AlignCenter
getPosition :: FormType -> TeXSymbolType
getPosition (FPrefix) = Open
getPosition (FPostfix) = Close
getPosition (FInfix) = Op
getFormType :: Maybe T.Text -> Maybe FormType
getFormType (Just "infix") = (Just FInfix)
getFormType (Just "prefix") = (Just FPrefix)
getFormType (Just "postfix") = (Just FPostfix)
getFormType _ = Nothing
pad :: Int -> [[a]] -> [[a]]
pad n xs = xs ++ (replicate (n - len) [])
  where
    len = length xs
isSpace :: Char -> Bool
isSpace ' '  = True
isSpace '\t' = True
isSpace '\n' = True
isSpace _    = False
spacelikeElems, cSpacelikeElems :: [T.Text]
spacelikeElems = ["mtext", "mspace", "maligngroup", "malignmark"]
cSpacelikeElems = ["mrow", "mstyle", "mphantom", "mpadded"]
spacelike :: Element -> Bool
spacelike e@(name -> uid) =
  uid `elem` spacelikeElems || uid `elem` cSpacelikeElems &&
    and (map spacelike (elChildren e))
thicknessZero :: Maybe T.Text -> Bool
thicknessZero (Just s) = thicknessToNum s == 0.0
thicknessZero Nothing  = False
widthToNum :: T.Text -> Rational
widthToNum s =
  case s of
       "veryverythinmathspace"  -> 1/18
       "verythinmathspace"      -> 2/18
       "thinmathspace"          -> 3/18
       "mediummathspace"        -> 4/18
       "thickmathspace"         -> 5/18
       "verythickmathspace"     -> 6/18
       "veryverythickmathspace" -> 7/18
       "negativeveryverythinmathspace"  -> -1/18
       "negativeverythinmathspace"      -> -2/18
       "negativethinmathspace"          -> -3/18
       "negativemediummathspace"        -> -4/18
       "negativethickmathspace"         -> -5/18
       "negativeverythickmathspace"     -> -6/18
       "negativeveryverythickmathspace" -> -7/18
       _ -> fromMaybe 0 (readLength s)
thicknessToNum :: T.Text -> Rational
thicknessToNum s =
  case s of
       "thin" -> (3/18)
       "medium" -> (1/2)
       "thick" -> 1
       v -> fromMaybe 0.5 (readLength v)
postfixExpr :: Element -> MML Exp
postfixExpr e = local (setPosition FPostfix . enterAccent) (safeExpr e)