-----------------------------------------------------------------------------
-- |
-- Module    : Data.SBV.Utils.Numeric
-- Copyright : (c) Levent Erkok
-- License   : BSD3
-- Maintainer: erkokl@gmail.com
-- Stability : experimental
--
-- Various number related utilities
-----------------------------------------------------------------------------

{-# LANGUAGE FlexibleContexts #-}

{-# OPTIONS_GHC -Wall -Werror #-}

module Data.SBV.Utils.Numeric (
           fpMaxH, fpMinH, fp2fp, fpRemH, fpRoundToIntegralH, fpIsEqualObjectH, fpCompareObjectH, fpIsNormalizedH
         , floatToWord, wordToFloat, doubleToWord, wordToDouble
         , RoundingMode(..), smtRoundingMode
         ) where

import Data.Word
import Data.Array.ST     (newArray, readArray, MArray, STUArray)
import Data.Array.Unsafe (castSTUArray)
import GHC.ST            (runST, ST)

import Test.QuickCheck  (Arbitrary(..), elements)

-- | The SMT-Lib (in particular Z3) implementation for min/max for floats does not agree with
-- Haskell's; and also it does not agree with what the hardware does. Sigh.. See:
--      <https://gitlab.haskell.org/ghc/ghc/-/issues/10378>
--      <http://github.com/Z3Prover/z3/issues/68>
-- So, we codify here what the Z3 (SMTLib) is implementing for fpMax.
-- The discrepancy with Haskell is that the NaN propagation doesn't work in Haskell
-- The discrepancy with x86 is that given +0/-0, x86 returns the second argument; SMTLib is non-deterministic
fpMaxH :: RealFloat a => a -> a -> a
fpMaxH :: forall a. RealFloat a => a -> a -> a
fpMaxH a
x a
y
   | a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
x                                  = a
y
   | a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
y                                  = a
x
   | (a -> Bool
isN0 a
x Bool -> Bool -> Bool
&& a -> Bool
isP0 a
y) Bool -> Bool -> Bool
|| (a -> Bool
isN0 a
y Bool -> Bool -> Bool
&& a -> Bool
isP0 a
x) = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"fpMaxH: Called with alternating-sign 0's. Not supported"
   | a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
> a
y                                    = a
x
   | Bool
True                                     = a
y
   where isN0 :: a -> Bool
isN0   = a -> Bool
forall a. RealFloat a => a -> Bool
isNegativeZero
         isP0 :: a -> Bool
isP0 a
a = a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 Bool -> Bool -> Bool
&& Bool -> Bool
not (a -> Bool
isN0 a
a)

-- | SMTLib compliant definition for 'Data.SBV.fpMin'. See the comments for 'Data.SBV.fpMax'.
fpMinH :: RealFloat a => a -> a -> a
fpMinH :: forall a. RealFloat a => a -> a -> a
fpMinH a
x a
y
   | a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
x                                  = a
y
   | a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
y                                  = a
x
   | (a -> Bool
isN0 a
x Bool -> Bool -> Bool
&& a -> Bool
isP0 a
y) Bool -> Bool -> Bool
|| (a -> Bool
isN0 a
y Bool -> Bool -> Bool
&& a -> Bool
isP0 a
x) = [Char] -> a
forall a. HasCallStack => [Char] -> a
error [Char]
"fpMinH: Called with alternating-sign 0's. Not supported"
   | a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
y                                    = a
x
   | Bool
True                                     = a
y
   where isN0 :: a -> Bool
isN0   = a -> Bool
forall a. RealFloat a => a -> Bool
isNegativeZero
         isP0 :: a -> Bool
isP0 a
a = a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 Bool -> Bool -> Bool
&& Bool -> Bool
not (a -> Bool
isN0 a
a)

-- | Convert double to float and back. Essentially @fromRational . toRational@
-- except careful on NaN, Infinities, and -0.
fp2fp :: (RealFloat a, RealFloat b) => a -> b
fp2fp :: forall a b. (RealFloat a, RealFloat b) => a -> b
fp2fp a
x
 | a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
x               =   b
0 b -> b -> b
forall a. Fractional a => a -> a -> a
/ b
0
 | a -> Bool
forall a. RealFloat a => a -> Bool
isInfinite a
x Bool -> Bool -> Bool
&& a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 = -(b
1 b -> b -> b
forall a. Fractional a => a -> a -> a
/ b
0)
 | a -> Bool
forall a. RealFloat a => a -> Bool
isInfinite a
x          =   b
1 b -> b -> b
forall a. Fractional a => a -> a -> a
/ b
0
 | a -> Bool
forall a. RealFloat a => a -> Bool
isNegativeZero a
x      = b -> b
forall a. Num a => a -> a
negate b
0
 | Bool
True                  = Rational -> b
forall a. Fractional a => Rational -> a
fromRational (a -> Rational
forall a. Real a => a -> Rational
toRational a
x)

-- | Compute the "floating-point" remainder function, the float/double value that
-- remains from the division of @x@ and @y@. There are strict rules around 0's, Infinities,
-- and NaN's as coded below.
fpRemH :: RealFloat a => a -> a -> a
fpRemH :: forall a. RealFloat a => a -> a -> a
fpRemH a
x a
y
  | a -> Bool
forall a. RealFloat a => a -> Bool
isInfinite a
x Bool -> Bool -> Bool
|| a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
x = a
0 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
0
  | a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0       Bool -> Bool -> Bool
|| a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
y = a
0 a -> a -> a
forall a. Fractional a => a -> a -> a
/ a
0
  | a -> Bool
forall a. RealFloat a => a -> Bool
isInfinite a
y            = a
x
  | Bool
True                    = a -> a
forall {a}. (Fractional a, Eq a) => a -> a
pSign (a
x a -> a -> a
forall a. Num a => a -> a -> a
- Rational -> a
forall a. Fractional a => Rational -> a
fromRational (Integer -> Rational
forall a. Num a => Integer -> a
fromInteger Integer
d Rational -> Rational -> Rational
forall a. Num a => a -> a -> a
* Rational
ry))
  where rx, ry, rd :: Rational
        rx :: Rational
rx = a -> Rational
forall a. Real a => a -> Rational
toRational a
x
        ry :: Rational
ry = a -> Rational
forall a. Real a => a -> Rational
toRational a
y
        rd :: Rational
rd = Rational
rx Rational -> Rational -> Rational
forall a. Fractional a => a -> a -> a
/ Rational
ry
        d :: Integer
        d :: Integer
d = Rational -> Integer
forall b. Integral b => Rational -> b
forall a b. (RealFrac a, Integral b) => a -> b
round Rational
rd
        -- If the result is 0, make sure we preserve the sign of x
        pSign :: a -> a
pSign a
r
          | a
r a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = if a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 Bool -> Bool -> Bool
|| a -> Bool
forall a. RealFloat a => a -> Bool
isNegativeZero a
x then -a
0.0 else a
0.0
          | Bool
True   = a
r

-- | Convert a float to the nearest integral representable in that type
fpRoundToIntegralH :: RealFloat a => a -> a
fpRoundToIntegralH :: forall a. RealFloat a => a -> a
fpRoundToIntegralH a
x
  | a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
x      = a
x
  | a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0       = a
x
  | a -> Bool
forall a. RealFloat a => a -> Bool
isInfinite a
x = a
x
  | Integer
i Integer -> Integer -> Bool
forall a. Eq a => a -> a -> Bool
== Integer
0       = if a
x a -> a -> Bool
forall a. Ord a => a -> a -> Bool
< a
0 Bool -> Bool -> Bool
|| a -> Bool
forall a. RealFloat a => a -> Bool
isNegativeZero a
x then -a
0.0 else a
0.0
  | Bool
True         = Integer -> a
forall a. Num a => Integer -> a
fromInteger Integer
i
  where i :: Integer
        i :: Integer
i = a -> Integer
forall b. Integral b => a -> b
forall a b. (RealFrac a, Integral b) => a -> b
round a
x

-- | Check that two floats are the exact same values, i.e., +0/-0 does not
-- compare equal, and NaN's compare equal to themselves.
fpIsEqualObjectH :: RealFloat a => a -> a -> Bool
fpIsEqualObjectH :: forall a. RealFloat a => a -> a -> Bool
fpIsEqualObjectH a
a a
b
  | a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
a          = a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
b
  | a -> Bool
forall a. RealFloat a => a -> Bool
isNegativeZero a
a = a -> Bool
forall a. RealFloat a => a -> Bool
isNegativeZero a
b
  | a -> Bool
forall a. RealFloat a => a -> Bool
isNegativeZero a
b = a -> Bool
forall a. RealFloat a => a -> Bool
isNegativeZero a
a
  | Bool
True             = a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
b

-- | Ordering for floats, avoiding the +0/-0/NaN issues. Note that this is
-- essentially used for indexing into a map, so we need to be total. Thus,
-- the order we pick is:
--    NaN -oo -0 +0 +oo
-- The placement of NaN here is questionable, but immaterial.
fpCompareObjectH :: RealFloat a => a -> a -> Ordering
fpCompareObjectH :: forall a. RealFloat a => a -> a -> Ordering
fpCompareObjectH a
a a
b
  | a
a a -> a -> Bool
forall a. RealFloat a => a -> a -> Bool
`fpIsEqualObjectH` a
b   = Ordering
EQ
  | a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
a                  = Ordering
LT
  | a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
b                  = Ordering
GT
  | a -> Bool
forall a. RealFloat a => a -> Bool
isNegativeZero a
a, a
b a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = Ordering
LT
  | a -> Bool
forall a. RealFloat a => a -> Bool
isNegativeZero a
b, a
a a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0 = Ordering
GT
  | Bool
True                     = a
a a -> a -> Ordering
forall a. Ord a => a -> a -> Ordering
`compare` a
b

-- | Check if a number is "normal." Note that +0/-0 is not considered a normal-number
-- and also this is not simply the negation of isDenormalized!
fpIsNormalizedH :: RealFloat a => a -> Bool
fpIsNormalizedH :: forall a. RealFloat a => a -> Bool
fpIsNormalizedH a
x = Bool -> Bool
not (a -> Bool
forall a. RealFloat a => a -> Bool
isDenormalized a
x Bool -> Bool -> Bool
|| a -> Bool
forall a. RealFloat a => a -> Bool
isInfinite a
x Bool -> Bool -> Bool
|| a -> Bool
forall a. RealFloat a => a -> Bool
isNaN a
x Bool -> Bool -> Bool
|| a
x a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
0)

-------------------------------------------------------------------------
-- Reinterpreting float/double as word32/64 and back. Here, we use the
-- definitions from the reinterpret-cast package:
--
--     http://hackage.haskell.org/package/reinterpret-cast
--
-- The reason we steal these definitions is to make sure we keep minimal
-- dependencies and no FFI requirements anywhere.
-------------------------------------------------------------------------
-- | Reinterpret-casts a `Float` to a `Word32`.
floatToWord :: Float -> Word32
floatToWord :: Float -> Word32
floatToWord Float
x = (forall s. ST s Word32) -> Word32
forall a. (forall s. ST s a) -> a
runST (Float -> ST s Word32
forall s a b.
(MArray (STUArray s) a (ST s), MArray (STUArray s) b (ST s)) =>
a -> ST s b
cast Float
x)
{-# INLINEABLE floatToWord #-}

-- | Reinterpret-casts a `Word32` to a `Float`.
wordToFloat :: Word32 -> Float
wordToFloat :: Word32 -> Float
wordToFloat Word32
x = (forall s. ST s Float) -> Float
forall a. (forall s. ST s a) -> a
runST (Word32 -> ST s Float
forall s a b.
(MArray (STUArray s) a (ST s), MArray (STUArray s) b (ST s)) =>
a -> ST s b
cast Word32
x)
{-# INLINEABLE wordToFloat #-}

-- | Reinterpret-casts a `Double` to a `Word64`.
doubleToWord :: Double -> Word64
doubleToWord :: Double -> Word64
doubleToWord Double
x = (forall s. ST s Word64) -> Word64
forall a. (forall s. ST s a) -> a
runST (Double -> ST s Word64
forall s a b.
(MArray (STUArray s) a (ST s), MArray (STUArray s) b (ST s)) =>
a -> ST s b
cast Double
x)
{-# INLINEABLE doubleToWord #-}

-- | Reinterpret-casts a `Word64` to a `Double`.
wordToDouble :: Word64 -> Double
wordToDouble :: Word64 -> Double
wordToDouble Word64
x = (forall s. ST s Double) -> Double
forall a. (forall s. ST s a) -> a
runST (Word64 -> ST s Double
forall s a b.
(MArray (STUArray s) a (ST s), MArray (STUArray s) b (ST s)) =>
a -> ST s b
cast Word64
x)
{-# INLINEABLE wordToDouble #-}

{-# INLINE cast #-}
cast :: (MArray (STUArray s) a (ST s), MArray (STUArray s) b (ST s)) => a -> ST s b
cast :: forall s a b.
(MArray (STUArray s) a (ST s), MArray (STUArray s) b (ST s)) =>
a -> ST s b
cast a
x = (Int, Int) -> a -> ST s (STUArray s Int a)
forall i. Ix i => (i, i) -> a -> ST s (STUArray s i a)
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
(i, i) -> e -> m (a i e)
newArray (Int
0 :: Int, Int
0) a
x ST s (STUArray s Int a)
-> (STUArray s Int a -> ST s (STUArray s Int b))
-> ST s (STUArray s Int b)
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= STUArray s Int a -> ST s (STUArray s Int b)
forall s ix a b. STUArray s ix a -> ST s (STUArray s ix b)
castSTUArray ST s (STUArray s Int b) -> (STUArray s Int b -> ST s b) -> ST s b
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= (STUArray s Int b -> Int -> ST s b)
-> Int -> STUArray s Int b -> ST s b
forall a b c. (a -> b -> c) -> b -> a -> c
flip STUArray s Int b -> Int -> ST s b
forall (a :: * -> * -> *) e (m :: * -> *) i.
(MArray a e m, Ix i) =>
a i e -> i -> m e
readArray Int
0

-- | Rounding mode to be used for the IEEE floating-point operations.
-- Note that Haskell's default is 'RoundNearestTiesToEven'. If you use
-- a different rounding mode, then the counter-examples you get may not
-- match what you observe in Haskell.
data RoundingMode = RoundNearestTiesToEven  -- ^ Round to nearest representable floating point value.
                                            -- If precisely at half-way, pick the even number.
                                            -- (In this context, /even/ means the lowest-order bit is zero.)
                  | RoundNearestTiesToAway  -- ^ Round to nearest representable floating point value.
                                            -- If precisely at half-way, pick the number further away from 0.
                                            -- (That is, for positive values, pick the greater; for negative values, pick the smaller.)
                  | RoundTowardPositive     -- ^ Round towards positive infinity. (Also known as rounding-up or ceiling.)
                  | RoundTowardNegative     -- ^ Round towards negative infinity. (Also known as rounding-down or floor.)
                  | RoundTowardZero         -- ^ Round towards zero. (Also known as truncation.)
                  deriving (Int -> RoundingMode -> ShowS
[RoundingMode] -> ShowS
RoundingMode -> [Char]
(Int -> RoundingMode -> ShowS)
-> (RoundingMode -> [Char])
-> ([RoundingMode] -> ShowS)
-> Show RoundingMode
forall a.
(Int -> a -> ShowS) -> (a -> [Char]) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> RoundingMode -> ShowS
showsPrec :: Int -> RoundingMode -> ShowS
$cshow :: RoundingMode -> [Char]
show :: RoundingMode -> [Char]
$cshowList :: [RoundingMode] -> ShowS
showList :: [RoundingMode] -> ShowS
Show, Int -> RoundingMode
RoundingMode -> Int
RoundingMode -> [RoundingMode]
RoundingMode -> RoundingMode
RoundingMode -> RoundingMode -> [RoundingMode]
RoundingMode -> RoundingMode -> RoundingMode -> [RoundingMode]
(RoundingMode -> RoundingMode)
-> (RoundingMode -> RoundingMode)
-> (Int -> RoundingMode)
-> (RoundingMode -> Int)
-> (RoundingMode -> [RoundingMode])
-> (RoundingMode -> RoundingMode -> [RoundingMode])
-> (RoundingMode -> RoundingMode -> [RoundingMode])
-> (RoundingMode -> RoundingMode -> RoundingMode -> [RoundingMode])
-> Enum RoundingMode
forall a.
(a -> a)
-> (a -> a)
-> (Int -> a)
-> (a -> Int)
-> (a -> [a])
-> (a -> a -> [a])
-> (a -> a -> [a])
-> (a -> a -> a -> [a])
-> Enum a
$csucc :: RoundingMode -> RoundingMode
succ :: RoundingMode -> RoundingMode
$cpred :: RoundingMode -> RoundingMode
pred :: RoundingMode -> RoundingMode
$ctoEnum :: Int -> RoundingMode
toEnum :: Int -> RoundingMode
$cfromEnum :: RoundingMode -> Int
fromEnum :: RoundingMode -> Int
$cenumFrom :: RoundingMode -> [RoundingMode]
enumFrom :: RoundingMode -> [RoundingMode]
$cenumFromThen :: RoundingMode -> RoundingMode -> [RoundingMode]
enumFromThen :: RoundingMode -> RoundingMode -> [RoundingMode]
$cenumFromTo :: RoundingMode -> RoundingMode -> [RoundingMode]
enumFromTo :: RoundingMode -> RoundingMode -> [RoundingMode]
$cenumFromThenTo :: RoundingMode -> RoundingMode -> RoundingMode -> [RoundingMode]
enumFromThenTo :: RoundingMode -> RoundingMode -> RoundingMode -> [RoundingMode]
Enum, RoundingMode
RoundingMode -> RoundingMode -> Bounded RoundingMode
forall a. a -> a -> Bounded a
$cminBound :: RoundingMode
minBound :: RoundingMode
$cmaxBound :: RoundingMode
maxBound :: RoundingMode
Bounded)

-- | Arbitrary instance for 'RoundingMode'
instance Arbitrary RoundingMode where
  arbitrary :: Gen RoundingMode
arbitrary = [RoundingMode] -> Gen RoundingMode
forall a. HasCallStack => [a] -> Gen a
elements [RoundingMode
forall a. Bounded a => a
minBound .. RoundingMode
forall a. Bounded a => a
maxBound]

-- | Convert a rounding mode to the format SMT-Lib2 understands.
smtRoundingMode :: RoundingMode -> String
smtRoundingMode :: RoundingMode -> [Char]
smtRoundingMode RoundingMode
RoundNearestTiesToEven = [Char]
"roundNearestTiesToEven"
smtRoundingMode RoundingMode
RoundNearestTiesToAway = [Char]
"roundNearestTiesToAway"
smtRoundingMode RoundingMode
RoundTowardPositive    = [Char]
"roundTowardPositive"
smtRoundingMode RoundingMode
RoundTowardNegative    = [Char]
"roundTowardNegative"
smtRoundingMode RoundingMode
RoundTowardZero        = [Char]
"roundTowardZero"