{-# LANGUAGE FlexibleInstances    #-}
{-# LANGUAGE ScopedTypeVariables  #-}
{-# LANGUAGE TypeSynonymInstances #-}
module Data.SBV.Utils.PrettyNum (
        PrettyNum(..), readBin, shex, chex, shexI, sbin, sbinI
      , showCFloat, showCDouble, showHFloat, showHDouble
      , showSMTFloat, showSMTDouble, smtRoundingMode, cwToSMTLib, mkSkolemZero
      ) where
import Data.Char  (intToDigit, ord)
import Data.Int   (Int8, Int16, Int32, Int64)
import Data.List  (isPrefixOf)
import Data.Maybe (fromJust, fromMaybe, listToMaybe)
import Data.Ratio (numerator, denominator)
import Data.Word  (Word8, Word16, Word32, Word64)
import Numeric    (showIntAtBase, showHex, readInt)
import Data.Numbers.CrackNum (floatToFP, doubleToFP)
import Data.SBV.Core.Data
import Data.SBV.Core.Kind (smtType)
import Data.SBV.Core.AlgReals (algRealToSMTLib2)
import Data.SBV.Utils.Lib (stringToQFS)
class PrettyNum a where
  
  hexS :: a -> String
  
  binS :: a -> String
  
  hex :: a -> String
  
  bin :: a -> String
instance PrettyNum Bool where
  {hexS = show; binS = show; hex = show; bin = show}
instance PrettyNum String where
  {hexS = show; binS = show; hex = show; bin = show}
instance PrettyNum Word8 where
  {hexS = shex True True (False,8) ; binS = sbin True True (False,8) ; hex = shex False False (False,8) ; bin = sbin False False (False,8) ;}
instance PrettyNum Int8 where
  {hexS = shex True True (True,8)  ; binS = sbin True True (True,8)  ; hex = shex False False (True,8)  ; bin = sbin False False (True,8)  ;}
instance PrettyNum Word16 where
  {hexS = shex True True (False,16); binS = sbin True True (False,16); hex = shex False False (False,16); bin = sbin False False (False,16);}
instance PrettyNum Int16  where
  {hexS = shex True True (True,16);  binS = sbin True True (True,16) ; hex = shex False False (True,16);  bin = sbin False False (True,16) ;}
instance PrettyNum Word32 where
  {hexS = shex True True (False,32); binS = sbin True True (False,32); hex = shex False False (False,32); bin = sbin False False (False,32);}
instance PrettyNum Int32  where
  {hexS = shex True True (True,32);  binS = sbin True True (True,32) ; hex = shex False False (True,32);  bin = sbin False False (True,32) ;}
instance PrettyNum Word64 where
  {hexS = shex True True (False,64); binS = sbin True True (False,64); hex = shex False False (False,64); bin = sbin False False (False,64);}
instance PrettyNum Int64  where
  {hexS = shex True True (True,64);  binS = sbin True True (True,64) ; hex = shex False False (True,64);  bin = sbin False False (True,64) ;}
instance PrettyNum Integer where
  {hexS = shexI True True; binS = sbinI True True; hex = shexI False False; bin = sbinI False False;}
instance PrettyNum CW where
  hexS cw | isUninterpreted cw = show cw ++ " :: " ++ show (kindOf cw)
          | isBoolean cw       = hexS (cwToBool cw) ++ " :: Bool"
          | isFloat cw         = let CWFloat   f = cwVal cw in show f ++ " :: Float\n"  ++ show (floatToFP f)
          | isDouble cw        = let CWDouble  d = cwVal cw in show d ++ " :: Double\n" ++ show (doubleToFP d)
          | isReal cw          = let CWAlgReal w = cwVal cw in show w ++ " :: Real"
          | isString cw        = let CWString  s = cwVal cw in show s ++ " :: String"
          | not (isBounded cw) = let CWInteger w = cwVal cw in shexI True True w
          | True               = let CWInteger w = cwVal cw in shex  True True (hasSign cw, intSizeOf cw) w
  binS cw | isUninterpreted cw = show cw  ++ " :: " ++ show (kindOf cw)
          | isBoolean cw       = binS (cwToBool cw)  ++ " :: Bool"
          | isFloat cw         = let CWFloat   f = cwVal cw in show f ++ " :: Float\n"  ++ show (floatToFP f)
          | isDouble cw        = let CWDouble  d = cwVal cw in show d ++ " :: Double\n" ++ show (doubleToFP d)
          | isReal cw          = let CWAlgReal w = cwVal cw in show w ++ " :: Real"
          | isString cw        = let CWString  s = cwVal cw in show s ++ " :: String"
          | not (isBounded cw) = let CWInteger w = cwVal cw in sbinI True True w
          | True               = let CWInteger w = cwVal cw in sbin  True True (hasSign cw, intSizeOf cw) w
  hex cw | isUninterpreted cw = show cw
         | isBoolean cw       = hexS (cwToBool cw) ++ " :: Bool"
         | isFloat cw         = let CWFloat   f = cwVal cw in show f
         | isDouble cw        = let CWDouble  d = cwVal cw in show d
         | isReal cw          = let CWAlgReal w = cwVal cw in show w
         | isString cw        = let CWString  s = cwVal cw in show s
         | not (isBounded cw) = let CWInteger w = cwVal cw in shexI False False w
         | True               = let CWInteger w = cwVal cw in shex  False False (hasSign cw, intSizeOf cw) w
  bin cw | isUninterpreted cw = show cw
         | isBoolean cw       = binS (cwToBool cw) ++ " :: Bool"
         | isFloat cw         = let CWFloat  f  = cwVal cw in show f
         | isDouble cw        = let CWDouble d  = cwVal cw in show d
         | isReal cw          = let CWAlgReal w = cwVal cw in show w
         | isString cw        = let CWString  s = cwVal cw in show s
         | not (isBounded cw) = let CWInteger w = cwVal cw in sbinI False False w
         | True               = let CWInteger w = cwVal cw in sbin  False False (hasSign cw, intSizeOf cw) w
instance (SymWord a, PrettyNum a) => PrettyNum (SBV a) where
  hexS s = maybe (show s) (hexS :: a -> String) $ unliteral s
  binS s = maybe (show s) (binS :: a -> String) $ unliteral s
  hex  s = maybe (show s) (hex  :: a -> String) $ unliteral s
  bin  s = maybe (show s) (bin  :: a -> String) $ unliteral s
shex :: (Show a, Integral a) => Bool -> Bool -> (Bool, Int) -> a -> String
shex shType shPre (signed, size) a
 | a < 0
 = "-" ++ pre ++ pad l (s16 (abs (fromIntegral a :: Integer)))  ++ t
 | True
 = pre ++ pad l (s16 a) ++ t
 where t | shType = " :: " ++ (if signed then "Int" else "Word") ++ show size
         | True   = ""
       pre | shPre = "0x"
           | True  = ""
       l = (size + 3) `div` 4
chex :: (Show a, Integral a) => Bool -> Bool -> (Bool, Int) -> a -> String
chex shType shPre (signed, size) a
   | Just s <- (signed, size, fromIntegral a) `lookup` specials
   = s
   | True
   = shex shType shPre (signed, size) a ++ suffix
  where specials :: [((Bool, Int, Integer), String)]
        specials = [ ((True,  8, fromIntegral (minBound :: Int8)),  "INT8_MIN" )
                   , ((True, 16, fromIntegral (minBound :: Int16)), "INT16_MIN")
                   , ((True, 32, fromIntegral (minBound :: Int32)), "INT32_MIN")
                   , ((True, 64, fromIntegral (minBound :: Int64)), "INT64_MIN")
                   ]
        suffix = case (signed, size) of
                   (False, 16) -> "U"
                   (False, 32) -> "UL"
                   (True,  32) -> "L"
                   (False, 64) -> "ULL"
                   (True,  64) -> "LL"
                   _           -> ""
shexI :: Bool -> Bool -> Integer -> String
shexI shType shPre a
 | a < 0
 = "-" ++ pre ++ s16 (abs a)  ++ t
 | True
 = pre ++ s16 a ++ t
 where t | shType = " :: Integer"
         | True   = ""
       pre | shPre = "0x"
           | True  = ""
sbin :: (Show a, Integral a) => Bool -> Bool -> (Bool, Int) -> a -> String
sbin shType shPre (signed,size) a
 | a < 0
 = "-" ++ pre ++ pad size (s2 (abs (fromIntegral a :: Integer)))  ++ t
 | True
 = pre ++ pad size (s2 a) ++ t
 where t | shType = " :: " ++ (if signed then "Int" else "Word") ++ show size
         | True   = ""
       pre | shPre = "0b"
           | True  = ""
sbinI :: Bool -> Bool -> Integer -> String
sbinI shType shPre a
 | a < 0
 = "-" ++ pre ++ s2 (abs a) ++ t
 | True
 =  pre ++ s2 a ++ t
 where t | shType = " :: Integer"
         | True   = ""
       pre | shPre = "0b"
           | True  = ""
pad :: Int -> String -> String
pad l s = replicate (l - length s) '0' ++ s
s2 :: (Show a, Integral a) => a -> String
s2  v = showIntAtBase 2 dig v "" where dig = fromJust . flip lookup [(0, '0'), (1, '1')]
s16 :: (Show a, Integral a) => a -> String
s16 v = showHex v ""
readBin :: Num a => String -> a
readBin ('-':s) = -(readBin s)
readBin s = case readInt 2 isDigit cvt s' of
              [(a, "")] -> a
              _         -> error $ "SBV.readBin: Cannot read a binary number from: " ++ show s
  where cvt c = ord c - ord '0'
        isDigit = (`elem` "01")
        s' | "0b" `isPrefixOf` s = drop 2 s
           | True                = s
showCFloat :: Float -> String
showCFloat f
   | isNaN f             = "((float) NAN)"
   | isInfinite f, f < 0 = "((float) (-INFINITY))"
   | isInfinite f        = "((float) INFINITY)"
   | True                = show f ++ "F"
showCDouble :: Double -> String
showCDouble f
   | isNaN f             = "((double) NAN)"
   | isInfinite f, f < 0 = "((double) (-INFINITY))"
   | isInfinite f        = "((double) INFINITY)"
   | True                = show f
showHFloat :: Float -> String
showHFloat f
   | isNaN f             = "((0/0) :: Float)"
   | isInfinite f, f < 0 = "((-1/0) :: Float)"
   | isInfinite f        = "((1/0) :: Float)"
   | True                = show f
showHDouble :: Double -> String
showHDouble d
   | isNaN d             = "((0/0) :: Double)"
   | isInfinite d, d < 0 = "((-1/0) :: Double)"
   | isInfinite d        = "((1/0) :: Double)"
   | True                = show d
showSMTFloat :: RoundingMode -> Float -> String
showSMTFloat rm f
   | isNaN f             = as "NaN"
   | isInfinite f, f < 0 = as "-oo"
   | isInfinite f        = as "+oo"
   | isNegativeZero f    = as "-zero"
   | f == 0              = as "+zero"
   | True                = "((_ to_fp 8 24) " ++ smtRoundingMode rm ++ " " ++ toSMTLibRational (toRational f) ++ ")"
   where as s = "(_ " ++ s ++ " 8 24)"
showSMTDouble :: RoundingMode -> Double -> String
showSMTDouble rm d
   | isNaN d             = as "NaN"
   | isInfinite d, d < 0 = as "-oo"
   | isInfinite d        = as "+oo"
   | isNegativeZero d    = as "-zero"
   | d == 0              = as "+zero"
   | True                = "((_ to_fp 11 53) " ++ smtRoundingMode rm ++ " " ++ toSMTLibRational (toRational d) ++ ")"
   where as s = "(_ " ++ s ++ " 11 53)"
toSMTLibRational :: Rational -> String
toSMTLibRational r
   | n < 0
   = "(- (/ "  ++ show (abs n) ++ ".0 " ++ show d ++ ".0))"
   | True
   = "(/ " ++ show n ++ ".0 " ++ show d ++ ".0)"
  where n = numerator r
        d = denominator r
smtRoundingMode :: RoundingMode -> String
smtRoundingMode RoundNearestTiesToEven = "roundNearestTiesToEven"
smtRoundingMode RoundNearestTiesToAway = "roundNearestTiesToAway"
smtRoundingMode RoundTowardPositive    = "roundTowardPositive"
smtRoundingMode RoundTowardNegative    = "roundTowardNegative"
smtRoundingMode RoundTowardZero        = "roundTowardZero"
cwToSMTLib :: RoundingMode -> CW -> String
cwToSMTLib rm x
  | isBoolean       x, CWInteger  w      <- cwVal x = if w == 0 then "false" else "true"
  | isUninterpreted x, CWUserSort (_, s) <- cwVal x = roundModeConvert s
  | isReal          x, CWAlgReal  r      <- cwVal x = algRealToSMTLib2 r
  | isFloat         x, CWFloat    f      <- cwVal x = showSMTFloat  rm f
  | isDouble        x, CWDouble   d      <- cwVal x = showSMTDouble rm d
  | not (isBounded x), CWInteger  w      <- cwVal x = if w >= 0 then show w else "(- " ++ show (abs w) ++ ")"
  | not (hasSign x)  , CWInteger  w      <- cwVal x = smtLibHex (intSizeOf x) w
  
  
  
  | hasSign x        , CWInteger  w      <- cwVal x = if w == negate (2 ^ intSizeOf x)
                                                      then mkMinBound (intSizeOf x)
                                                      else negIf (w < 0) $ smtLibHex (intSizeOf x) (abs w)
  | isChar x         , CWChar c          <- cwVal x = smtLibHex 8 (fromIntegral (ord c))
  | isString x       , CWString s        <- cwVal x = '\"' : stringToQFS s ++ "\""
  | isList x         , CWList xs         <- cwVal x = smtLibSeq (kindOf x) xs
  | True = error $ "SBV.cvtCW: Impossible happened: Kind/Value disagreement on: " ++ show (kindOf x, x)
  where roundModeConvert s = fromMaybe s (listToMaybe [smtRoundingMode m | m <- [minBound .. maxBound] :: [RoundingMode], show m == s])
        
        
        
        smtLibHex :: Int -> Integer -> String
        smtLibHex 1  v = "#b" ++ show v
        smtLibHex sz v
          | sz `mod` 4 == 0 = "#x" ++ pad (sz `div` 4) (showHex v "")
          | True            = "#b" ++ pad sz (showBin v "")
           where showBin = showIntAtBase 2 intToDigit
        negIf :: Bool -> String -> String
        negIf True  a = "(bvneg " ++ a ++ ")"
        negIf False a = a
        smtLibSeq :: Kind -> [CWVal] -> String
        smtLibSeq k          [] = "(as seq.empty " ++ smtType k ++ ")"
        smtLibSeq (KList ek) xs = let mkSeq  [e]   = e
                                      mkSeq  es    = "(seq.++ " ++ unwords es ++ ")"
                                      mkUnit inner = "(seq.unit " ++ inner ++ ")"
                                  in mkSeq (mkUnit . cwToSMTLib rm . CW ek <$> xs)
        smtLibSeq k _ = error "SBV.cwToSMTLib: Impossible case (smtLibSeq), received kind: " ++ show k
        
        
        mkMinBound :: Int -> String
        mkMinBound i = "#b1" ++ replicate (i-1) '0'
mkSkolemZero :: RoundingMode -> Kind -> String
mkSkolemZero _ (KUserSort _ (Right (f:_))) = f
mkSkolemZero _ (KUserSort s _)             = error $ "SBV.mkSkolemZero: Unexpected uninterpreted sort: " ++ s
mkSkolemZero rm k                          = cwToSMTLib rm (mkConstCW k (0::Integer))