{-# LANGUAGE Rank2Types          #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Data.SBV.Core.Floating (
         IEEEFloating(..), IEEEFloatConvertable(..)
       , sFloatAsSWord32, sDoubleAsSWord64, sWord32AsSFloat, sWord64AsSDouble
       , blastSFloat, blastSDouble
       ) where
import qualified Data.Numbers.CrackNum as CN (wordToFloat, wordToDouble, floatToWord, doubleToWord)
import Data.Int            (Int8,  Int16,  Int32,  Int64)
import Data.Word           (Word8, Word16, Word32, Word64)
import Data.SBV.Core.Data
import Data.SBV.Core.Model
import Data.SBV.Core.AlgReals (isExactRational)
import Data.SBV.Utils.Boolean
import Data.SBV.Utils.Numeric
class (SymWord a, RealFloat a) => IEEEFloating a where
  
  fpAbs             ::                  SBV a -> SBV a
  
  fpNeg             ::                  SBV a -> SBV a
  
  fpAdd             :: SRoundingMode -> SBV a -> SBV a -> SBV a
  
  fpSub             :: SRoundingMode -> SBV a -> SBV a -> SBV a
  
  fpMul             :: SRoundingMode -> SBV a -> SBV a -> SBV a
  
  fpDiv             :: SRoundingMode -> SBV a -> SBV a -> SBV a
  
  
  
  fpFMA             :: SRoundingMode -> SBV a -> SBV a -> SBV a -> SBV a
  
  fpSqrt            :: SRoundingMode -> SBV a -> SBV a
  
  
  fpRem             ::                  SBV a -> SBV a -> SBV a
  
  fpRoundToIntegral :: SRoundingMode -> SBV a -> SBV a
  
  fpMin             ::                  SBV a -> SBV a -> SBV a
  
  fpMax             ::                  SBV a -> SBV a -> SBV a
  
  
  fpIsEqualObject   ::                  SBV a -> SBV a -> SBool
  
  fpIsNormal :: SBV a -> SBool
  
  fpIsSubnormal :: SBV a -> SBool
  
  fpIsZero :: SBV a -> SBool
  
  fpIsInfinite :: SBV a -> SBool
  
  fpIsNaN ::  SBV a -> SBool
  
  fpIsNegative :: SBV a -> SBool
  
  fpIsPositive :: SBV a -> SBool
  
  fpIsNegativeZero :: SBV a -> SBool
  
  fpIsPositiveZero :: SBV a -> SBool
  
  fpIsPoint :: SBV a -> SBool
  
  
  fpAbs              = lift1  FP_Abs             (Just abs)                Nothing
  fpNeg              = lift1  FP_Neg             (Just negate)             Nothing
  fpAdd              = lift2  FP_Add             (Just (+))                . Just
  fpSub              = lift2  FP_Sub             (Just (-))                . Just
  fpMul              = lift2  FP_Mul             (Just (*))                . Just
  fpDiv              = lift2  FP_Div             (Just (/))                . Just
  fpFMA              = lift3  FP_FMA             Nothing                   . Just
  fpSqrt             = lift1  FP_Sqrt            (Just sqrt)               . Just
  fpRem              = lift2  FP_Rem             (Just fpRemH)             Nothing
  fpRoundToIntegral  = lift1  FP_RoundToIntegral (Just fpRoundToIntegralH) . Just
  fpMin              = liftMM FP_Min             (Just fpMinH)             Nothing
  fpMax              = liftMM FP_Max             (Just fpMaxH)             Nothing
  fpIsEqualObject    = lift2B FP_ObjEqual        (Just fpIsEqualObjectH)   Nothing
  fpIsNormal         = lift1B FP_IsNormal        fpIsNormalizedH
  fpIsSubnormal      = lift1B FP_IsSubnormal     isDenormalized
  fpIsZero           = lift1B FP_IsZero          (== 0)
  fpIsInfinite       = lift1B FP_IsInfinite      isInfinite
  fpIsNaN            = lift1B FP_IsNaN           isNaN
  fpIsNegative       = lift1B FP_IsNegative      (\x -> x < 0 ||       isNegativeZero x)
  fpIsPositive       = lift1B FP_IsPositive      (\x -> x >= 0 && not (isNegativeZero x))
  fpIsNegativeZero x = fpIsZero x &&& fpIsNegative x
  fpIsPositiveZero x = fpIsZero x &&& fpIsPositive x
  fpIsPoint        x = bnot (fpIsNaN x ||| fpIsInfinite x)
instance IEEEFloating Float
instance IEEEFloating Double
class IEEEFloatConvertable a where
  fromSFloat  :: SRoundingMode -> SFloat  -> SBV a
  toSFloat    :: SRoundingMode -> SBV a   -> SFloat
  fromSDouble :: SRoundingMode -> SDouble -> SBV a
  toSDouble   :: SRoundingMode -> SBV a   -> SDouble
genericFPConverter :: forall a r. (SymWord a, HasKind r, SymWord r, Num r) => Maybe (a -> Bool) -> Maybe (SBV a -> SBool) -> (a -> r) -> SRoundingMode -> SBV a -> SBV r
genericFPConverter mbConcreteOK mbSymbolicOK converter rm f
  | Just w <- unliteral f, Just RoundNearestTiesToEven <- unliteral rm, check w
  = literal $ converter w
  | Just symCheck <- mbSymbolicOK
  = ite (symCheck f) result (literal 0)
  | True
  = result
  where result  = SBV (SVal kTo (Right (cache y)))
        check w = maybe True ($ w) mbConcreteOK
        kFrom   = kindOf f
        kTo     = kindOf (undefined :: r)
        y st    = do msw <- sbvToSW st rm
                     xsw <- sbvToSW st f
                     newExpr st kTo (SBVApp (IEEEFP (FP_Cast kFrom kTo msw)) [xsw])
ptCheck :: IEEEFloating a => Maybe (SBV a -> SBool)
ptCheck = Just fpIsPoint
instance IEEEFloatConvertable Int8 where
  fromSFloat  = genericFPConverter Nothing ptCheck (fromIntegral . (fpRound0 :: Float -> Integer))
  toSFloat    = genericFPConverter Nothing Nothing (fromRational . fromIntegral)
  fromSDouble = genericFPConverter Nothing ptCheck (fromIntegral . (fpRound0 :: Double -> Integer))
  toSDouble   = genericFPConverter Nothing Nothing (fromRational . fromIntegral)
instance IEEEFloatConvertable Int16 where
  fromSFloat  = genericFPConverter Nothing ptCheck (fromIntegral . (fpRound0 :: Float -> Integer))
  toSFloat    = genericFPConverter Nothing Nothing (fromRational . fromIntegral)
  fromSDouble = genericFPConverter Nothing ptCheck (fromIntegral . (fpRound0 :: Double -> Integer))
  toSDouble   = genericFPConverter Nothing Nothing (fromRational . fromIntegral)
instance IEEEFloatConvertable Int32 where
  fromSFloat  = genericFPConverter Nothing ptCheck (fromIntegral . (fpRound0 :: Float -> Integer))
  toSFloat    = genericFPConverter Nothing Nothing (fromRational . fromIntegral)
  fromSDouble = genericFPConverter Nothing ptCheck (fromIntegral . (fpRound0 :: Double -> Integer))
  toSDouble   = genericFPConverter Nothing Nothing (fromRational . fromIntegral)
instance IEEEFloatConvertable Int64 where
  fromSFloat  = genericFPConverter Nothing ptCheck (fromIntegral . (fpRound0 :: Float -> Integer))
  toSFloat    = genericFPConverter Nothing Nothing (fromRational . fromIntegral)
  fromSDouble = genericFPConverter Nothing ptCheck (fromIntegral . (fpRound0 :: Double -> Integer))
  toSDouble   = genericFPConverter Nothing Nothing (fromRational . fromIntegral)
instance IEEEFloatConvertable Word8 where
  fromSFloat  = genericFPConverter Nothing ptCheck (fromIntegral . (fpRound0 :: Float -> Integer))
  toSFloat    = genericFPConverter Nothing Nothing (fromRational . fromIntegral)
  fromSDouble = genericFPConverter Nothing ptCheck (fromIntegral . (fpRound0 :: Double -> Integer))
  toSDouble   = genericFPConverter Nothing Nothing (fromRational . fromIntegral)
instance IEEEFloatConvertable Word16 where
  fromSFloat  = genericFPConverter Nothing ptCheck (fromIntegral . (fpRound0 :: Float -> Integer))
  toSFloat    = genericFPConverter Nothing Nothing (fromRational . fromIntegral)
  fromSDouble = genericFPConverter Nothing ptCheck (fromIntegral . (fpRound0 :: Double -> Integer))
  toSDouble   = genericFPConverter Nothing Nothing (fromRational . fromIntegral)
instance IEEEFloatConvertable Word32 where
  fromSFloat  = genericFPConverter Nothing ptCheck (fromIntegral . (fpRound0 :: Float -> Integer))
  toSFloat    = genericFPConverter Nothing Nothing (fromRational . fromIntegral)
  fromSDouble = genericFPConverter Nothing ptCheck (fromIntegral . (fpRound0 :: Double -> Integer))
  toSDouble   = genericFPConverter Nothing Nothing (fromRational . fromIntegral)
instance IEEEFloatConvertable Word64 where
  fromSFloat  = genericFPConverter Nothing ptCheck (fromIntegral . (fpRound0 :: Float -> Integer))
  toSFloat    = genericFPConverter Nothing Nothing (fromRational . fromIntegral)
  fromSDouble = genericFPConverter Nothing ptCheck (fromIntegral . (fpRound0 :: Double -> Integer))
  toSDouble   = genericFPConverter Nothing Nothing (fromRational . fromIntegral)
instance IEEEFloatConvertable Float where
  fromSFloat _ f = f
  toSFloat   _ f = f
  fromSDouble    = genericFPConverter Nothing Nothing fp2fp
  toSDouble      = genericFPConverter Nothing Nothing fp2fp
instance IEEEFloatConvertable Double where
  fromSFloat      = genericFPConverter Nothing Nothing fp2fp
  toSFloat        = genericFPConverter Nothing Nothing fp2fp
  fromSDouble _ d = d
  toSDouble   _ d = d
instance IEEEFloatConvertable Integer where
  fromSFloat  = genericFPConverter Nothing ptCheck (fromIntegral . (fpRound0 :: Float -> Integer))
  toSFloat    = genericFPConverter Nothing Nothing (fromRational . fromIntegral)
  fromSDouble = genericFPConverter Nothing ptCheck (fromIntegral . (fpRound0 :: Double -> Integer))
  toSDouble   = genericFPConverter Nothing Nothing (fromRational . fromIntegral)
instance IEEEFloatConvertable AlgReal where
  fromSFloat  = genericFPConverter Nothing                ptCheck (fromRational . fpRatio0)
  toSFloat    = genericFPConverter (Just isExactRational) Nothing (fromRational . toRational)
  fromSDouble = genericFPConverter Nothing                ptCheck (fromRational . fpRatio0)
  toSDouble   = genericFPConverter (Just isExactRational) Nothing (fromRational . toRational)
concEval1 :: SymWord a => Maybe (a -> a) -> Maybe SRoundingMode -> SBV a -> Maybe (SBV a)
concEval1 mbOp mbRm a = do op <- mbOp
                           v  <- unliteral a
                           case unliteral =<< mbRm of
                             Nothing                     -> (Just . literal) (op v)
                             Just RoundNearestTiesToEven -> (Just . literal) (op v)
                             _                           -> Nothing
concEval2 :: SymWord a => Maybe (a -> a -> a) -> Maybe SRoundingMode -> SBV a -> SBV a -> Maybe (SBV a)
concEval2 mbOp mbRm a b  = do op <- mbOp
                              v1 <- unliteral a
                              v2 <- unliteral b
                              case unliteral =<< mbRm of
                                Nothing                     -> (Just . literal) (v1 `op` v2)
                                Just RoundNearestTiesToEven -> (Just . literal) (v1 `op` v2)
                                _                           -> Nothing
concEval2B :: SymWord a => Maybe (a -> a -> Bool) -> Maybe SRoundingMode -> SBV a -> SBV a -> Maybe SBool
concEval2B mbOp mbRm a b  = do op <- mbOp
                               v1 <- unliteral a
                               v2 <- unliteral b
                               case unliteral =<< mbRm of
                                 Nothing                     -> (Just . literal) (v1 `op` v2)
                                 Just RoundNearestTiesToEven -> (Just . literal) (v1 `op` v2)
                                 _                           -> Nothing
concEval3 :: SymWord a => Maybe (a -> a -> a -> a) -> Maybe SRoundingMode -> SBV a -> SBV a -> SBV a -> Maybe (SBV a)
concEval3 mbOp mbRm a b c = do op <- mbOp
                               v1 <- unliteral a
                               v2 <- unliteral b
                               v3 <- unliteral c
                               case unliteral =<< mbRm of
                                 Nothing                     -> (Just . literal) (op v1 v2 v3)
                                 Just RoundNearestTiesToEven -> (Just . literal) (op v1 v2 v3)
                                 _                           -> Nothing
addRM :: State -> Maybe SRoundingMode -> [SW] -> IO [SW]
addRM _  Nothing   as = return as
addRM st (Just rm) as = do swm <- sbvToSW st rm
                           return (swm : as)
lift1 :: SymWord a => FPOp -> Maybe (a -> a) -> Maybe SRoundingMode -> SBV a -> SBV a
lift1 w mbOp mbRm a
  | Just cv <- concEval1 mbOp mbRm a
  = cv
  | True
  = SBV $ SVal k $ Right $ cache r
  where k    = kindOf a
        r st = do swa  <- sbvToSW st a
                  args <- addRM st mbRm [swa]
                  newExpr st k (SBVApp (IEEEFP w) args)
lift1B :: SymWord a => FPOp -> (a -> Bool) -> SBV a -> SBool
lift1B w f a
   | Just v <- unliteral a = literal $ f v
   | True                  = SBV $ SVal KBool $ Right $ cache r
   where r st = do swa <- sbvToSW st a
                   newExpr st KBool (SBVApp (IEEEFP w) [swa])
lift2 :: SymWord a => FPOp -> Maybe (a -> a -> a) -> Maybe SRoundingMode -> SBV a -> SBV a -> SBV a
lift2 w mbOp mbRm a b
  | Just cv <- concEval2 mbOp mbRm a b
  = cv
  | True
  = SBV $ SVal k $ Right $ cache r
  where k    = kindOf a
        r st = do swa  <- sbvToSW st a
                  swb  <- sbvToSW st b
                  args <- addRM st mbRm [swa, swb]
                  newExpr st k (SBVApp (IEEEFP w) args)
liftMM :: (SymWord a, RealFloat a) => FPOp -> Maybe (a -> a -> a) -> Maybe SRoundingMode -> SBV a -> SBV a -> SBV a
liftMM w mbOp mbRm a b
  | Just v1 <- unliteral a
  , Just v2 <- unliteral b
  , not ((isN0 v1 && isP0 v2) || (isP0 v1 && isN0 v2))          
  , Just cv <- concEval2 mbOp mbRm a b
  = cv
  | True
  = SBV $ SVal k $ Right $ cache r
  where isN0   = isNegativeZero
        isP0 x = x == 0 && not (isN0 x)
        k    = kindOf a
        r st = do swa  <- sbvToSW st a
                  swb  <- sbvToSW st b
                  args <- addRM st mbRm [swa, swb]
                  newExpr st k (SBVApp (IEEEFP w) args)
lift2B :: SymWord a => FPOp -> Maybe (a -> a -> Bool) -> Maybe SRoundingMode -> SBV a -> SBV a -> SBool
lift2B w mbOp mbRm a b
  | Just cv <- concEval2B mbOp mbRm a b
  = cv
  | True
  = SBV $ SVal KBool $ Right $ cache r
  where r st = do swa  <- sbvToSW st a
                  swb  <- sbvToSW st b
                  args <- addRM st mbRm [swa, swb]
                  newExpr st KBool (SBVApp (IEEEFP w) args)
lift3 :: SymWord a => FPOp -> Maybe (a -> a -> a -> a) -> Maybe SRoundingMode -> SBV a -> SBV a -> SBV a -> SBV a
lift3 w mbOp mbRm a b c
  | Just cv <- concEval3 mbOp mbRm a b c
  = cv
  | True
  = SBV $ SVal k $ Right $ cache r
  where k    = kindOf a
        r st = do swa  <- sbvToSW st a
                  swb  <- sbvToSW st b
                  swc  <- sbvToSW st c
                  args <- addRM st mbRm [swa, swb, swc]
                  newExpr st k (SBVApp (IEEEFP w) args)
sFloatAsSWord32 :: SFloat -> SWord32
sFloatAsSWord32 fVal
  | Just f <- unliteral fVal, not (isNaN f)
  = literal (CN.floatToWord f)
  | True
  = SBV (SVal w32 (Right (cache y)))
  where w32  = KBounded False 32
        y st = do cg <- isCodeGenMode st
                  if cg
                     then do f <- sbvToSW st fVal
                             newExpr st w32 (SBVApp (IEEEFP (FP_Reinterpret KFloat w32)) [f])
                     else do n   <- internalVariable st w32
                             ysw <- newExpr st KFloat (SBVApp (IEEEFP (FP_Reinterpret w32 KFloat)) [n])
                             internalConstraint st False [] $ unSBV $ fVal `fpIsEqualObject` SBV (SVal KFloat (Right (cache (\_ -> return ysw))))
                             return n
sDoubleAsSWord64 :: SDouble -> SWord64
sDoubleAsSWord64 fVal
  | Just f <- unliteral fVal, not (isNaN f)
  = literal (CN.doubleToWord f)
  | True
  = SBV (SVal w64 (Right (cache y)))
  where w64  = KBounded False 64
        y st = do cg <- isCodeGenMode st
                  if cg
                     then do f <- sbvToSW st fVal
                             newExpr st w64 (SBVApp (IEEEFP (FP_Reinterpret KDouble w64)) [f])
                     else do n   <- internalVariable st w64
                             ysw <- newExpr st KDouble (SBVApp (IEEEFP (FP_Reinterpret w64 KDouble)) [n])
                             internalConstraint st False [] $ unSBV $ fVal `fpIsEqualObject` SBV (SVal KDouble (Right (cache (\_ -> return ysw))))
                             return n
blastSFloat :: SFloat -> (SBool, [SBool], [SBool])
blastSFloat = extract . sFloatAsSWord32
 where extract x = (sTestBit x 31, sExtractBits x [30, 29 .. 23], sExtractBits x [22, 21 .. 0])
blastSDouble :: SDouble -> (SBool, [SBool], [SBool])
blastSDouble = extract . sDoubleAsSWord64
 where extract x = (sTestBit x 63, sExtractBits x [62, 61 .. 52], sExtractBits x [51, 50 .. 0])
sWord32AsSFloat :: SWord32 -> SFloat
sWord32AsSFloat fVal
  | Just f <- unliteral fVal = literal $ CN.wordToFloat f
  | True                     = SBV (SVal KFloat (Right (cache y)))
  where y st = do xsw <- sbvToSW st fVal
                  newExpr st KFloat (SBVApp (IEEEFP (FP_Reinterpret (kindOf fVal) KFloat)) [xsw])
sWord64AsSDouble :: SWord64 -> SDouble
sWord64AsSDouble dVal
  | Just d <- unliteral dVal = literal $ CN.wordToDouble d
  | True                     = SBV (SVal KDouble (Right (cache y)))
  where y st = do xsw <- sbvToSW st dVal
                  newExpr st KDouble (SBVApp (IEEEFP (FP_Reinterpret (kindOf dVal) KDouble)) [xsw])
{-# ANN module ("HLint: ignore Reduce duplication" :: String) #-}