-- | Test values
--
-- Intended for unqualified import.
module Test.Tensor.TestValue (
    TestValue -- opaque
  ) where

import Data.List (sort)
import System.Random (Random)
import Test.QuickCheck
import Text.Printf (printf)

{-------------------------------------------------------------------------------
  Definition
-------------------------------------------------------------------------------}

-- | Test values
--
-- Test values are suitable for use in QuickCheck tests involving floating
-- point numbers, if you want to ignore rounding errors.
newtype TestValue = TestValue Float
  deriving newtype (Integer -> TestValue
TestValue -> TestValue
TestValue -> TestValue -> TestValue
(TestValue -> TestValue -> TestValue)
-> (TestValue -> TestValue -> TestValue)
-> (TestValue -> TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (Integer -> TestValue)
-> Num TestValue
forall a.
(a -> a -> a)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> (a -> a)
-> (a -> a)
-> (Integer -> a)
-> Num a
$c+ :: TestValue -> TestValue -> TestValue
+ :: TestValue -> TestValue -> TestValue
$c- :: TestValue -> TestValue -> TestValue
- :: TestValue -> TestValue -> TestValue
$c* :: TestValue -> TestValue -> TestValue
* :: TestValue -> TestValue -> TestValue
$cnegate :: TestValue -> TestValue
negate :: TestValue -> TestValue
$cabs :: TestValue -> TestValue
abs :: TestValue -> TestValue
$csignum :: TestValue -> TestValue
signum :: TestValue -> TestValue
$cfromInteger :: Integer -> TestValue
fromInteger :: Integer -> TestValue
Num, Num TestValue
Num TestValue =>
(TestValue -> TestValue -> TestValue)
-> (TestValue -> TestValue)
-> (Rational -> TestValue)
-> Fractional TestValue
Rational -> TestValue
TestValue -> TestValue
TestValue -> TestValue -> TestValue
forall a.
Num a =>
(a -> a -> a) -> (a -> a) -> (Rational -> a) -> Fractional a
$c/ :: TestValue -> TestValue -> TestValue
/ :: TestValue -> TestValue -> TestValue
$crecip :: TestValue -> TestValue
recip :: TestValue -> TestValue
$cfromRational :: Rational -> TestValue
fromRational :: Rational -> TestValue
Fractional, Num TestValue
Ord TestValue
(Num TestValue, Ord TestValue) =>
(TestValue -> Rational) -> Real TestValue
TestValue -> Rational
forall a. (Num a, Ord a) => (a -> Rational) -> Real a
$ctoRational :: TestValue -> Rational
toRational :: TestValue -> Rational
Real, (forall g.
 RandomGen g =>
 (TestValue, TestValue) -> g -> (TestValue, g))
-> (forall g. RandomGen g => g -> (TestValue, g))
-> (forall g.
    RandomGen g =>
    (TestValue, TestValue) -> g -> [TestValue])
-> (forall g. RandomGen g => g -> [TestValue])
-> Random TestValue
forall g. RandomGen g => g -> [TestValue]
forall g. RandomGen g => g -> (TestValue, g)
forall g. RandomGen g => (TestValue, TestValue) -> g -> [TestValue]
forall g.
RandomGen g =>
(TestValue, TestValue) -> g -> (TestValue, g)
forall a.
(forall g. RandomGen g => (a, a) -> g -> (a, g))
-> (forall g. RandomGen g => g -> (a, g))
-> (forall g. RandomGen g => (a, a) -> g -> [a])
-> (forall g. RandomGen g => g -> [a])
-> Random a
$crandomR :: forall g.
RandomGen g =>
(TestValue, TestValue) -> g -> (TestValue, g)
randomR :: forall g.
RandomGen g =>
(TestValue, TestValue) -> g -> (TestValue, g)
$crandom :: forall g. RandomGen g => g -> (TestValue, g)
random :: forall g. RandomGen g => g -> (TestValue, g)
$crandomRs :: forall g. RandomGen g => (TestValue, TestValue) -> g -> [TestValue]
randomRs :: forall g. RandomGen g => (TestValue, TestValue) -> g -> [TestValue]
$crandoms :: forall g. RandomGen g => g -> [TestValue]
randoms :: forall g. RandomGen g => g -> [TestValue]
Random)

-- | Test values are equipped with a crude equality
--
-- >               (==)
-- > --------------------
-- > 1.0    1.1    False
-- > 1.00   1.01   True
-- > 10     11     False
-- > 10.0   10.1   True
-- > 100    110    False
-- > 100    101    True
instance Eq TestValue where
  TestValue Float
x == :: TestValue -> TestValue -> Bool
== TestValue Float
y = Float -> Float -> Bool
nearlyEqual Float
x Float
y

-- | Show instance
--
-- We have more precision available for smaller values, so we show more
-- decimals. However, larger values the show instance does not reflect the
-- precision: @1000@ and @1001@ are shown as @1000@ and @1001@, even though
-- they are considered to be equal.
--
-- > show @TestValue 0     == "0"     -- True zero
-- > show @TestValue 1     == "1"     -- True one
-- > show @TestValue 0.001 == "0.00"
-- > show @TestValue 0.009 == "0.01"
-- > show @TestValue 1.001 == "1.0"
-- > show @TestValue 11    == "11"
instance Show TestValue where
  show :: TestValue -> String
show (TestValue Float
x)
    | Float
x Float -> Float -> Bool
forall a. Eq a => a -> a -> Bool
== Float
0    = String
"0"
    | Float
x Float -> Float -> Bool
forall a. Eq a => a -> a -> Bool
== Float
1    = String
"1"
    | Float
x Float -> Float -> Bool
forall a. Ord a => a -> a -> Bool
<  Float
1    = String -> Float -> String
forall r. PrintfType r => String -> r
printf String
"%0.2f" Float
x
    | Float
x Float -> Float -> Bool
forall a. Ord a => a -> a -> Bool
<  Float
10   = String -> Float -> String
forall r. PrintfType r => String -> r
printf String
"%0.1f" Float
x
    | Bool
otherwise = String -> Float -> String
forall r. PrintfType r => String -> r
printf String
"%0.0f" Float
x

-- | Arbitrary instance
--
-- The definition of 'arbitrary' simply piggy-backs on the definition for
-- 'Float', but in shrinking we avoid generating nearly equal values, and prefer
-- values closer to integral values. Compare:
--
-- >    shrink @TestValue 100.1
-- > == [0,50,75,88,94,97]
--
-- versus
--
-- >    shrink @Float 100.1
-- > == [100.0,0.0,50.0,75.0,88.0,94.0,97.0,99.0,0.0,50.1,75.1,87.6,93.9,97.0,98.6,99.4,99.8,100.0]
instance Arbitrary TestValue where
  arbitrary :: Gen TestValue
arbitrary = Float -> TestValue
TestValue (Float -> TestValue) -> Gen Float -> Gen TestValue
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Gen Float
forall a. Arbitrary a => Gen a
arbitrary

  shrink :: TestValue -> [TestValue]
shrink (TestValue Float
x)
    | Float
x Float -> Float -> Bool
forall a. Eq a => a -> a -> Bool
== Float
0          = []
    | Float -> Float -> Bool
nearlyEqual Float
x Float
0 = [TestValue
0]
    | Bool
otherwise       = case [Float] -> [Float]
forall a. Ord a => [a] -> [a]
sort (Float -> [Float]
forall a. Arbitrary a => a -> [a]
shrink Float
x) of
                          []   -> []
                          Float
y:[Float]
ys -> Float -> [Float] -> [TestValue]
aux Float
y [Float]
ys
    where
      aux :: Float -> [Float] -> [TestValue]
      aux :: Float -> [Float] -> [TestValue]
aux Float
y []
        | Float -> Float -> Bool
nearlyEqual Float
y Float
x = []
        | Bool
otherwise       = [Float -> TestValue
TestValue Float
y]
      aux Float
y (Float
z:[Float]
zs)
        | Float -> Float -> Bool
nearlyEqual Float
y Float
z = if Float -> Float
decimalPart Float
y Float -> Float -> Bool
forall a. Ord a => a -> a -> Bool
< Float -> Float
decimalPart Float
z
                              then Float -> [Float] -> [TestValue]
aux Float
y [Float]
zs
                              else Float -> [Float] -> [TestValue]
aux Float
z [Float]
zs
        | Bool
otherwise       = Float -> TestValue
TestValue Float
y TestValue -> [TestValue] -> [TestValue]
forall a. a -> [a] -> [a]
: Float -> [Float] -> [TestValue]
aux Float
z [Float]
zs

instance Ord TestValue where
  compare :: TestValue -> TestValue -> Ordering
compare (TestValue Float
x) (TestValue Float
y)
    | Float -> Float -> Bool
nearlyEqual Float
x Float
y = Ordering
EQ
    | Float
x Float -> Float -> Bool
forall a. Ord a => a -> a -> Bool
< Float
y           = Ordering
LT
    | Bool
otherwise       = Ordering
GT

{-------------------------------------------------------------------------------
  Internal auxiliary
-------------------------------------------------------------------------------}

-- | Compare for near equality
--
-- Adapted from <https://stackoverflow.com/a/32334103/742991>
nearlyEqual :: Float -> Float -> Bool
nearlyEqual :: Float -> Float -> Bool
nearlyEqual Float
a Float
b
  | Float
a Float -> Float -> Bool
forall a. Eq a => a -> a -> Bool
== Float
b    = Bool
True
  | Bool
otherwise = Float
diff Float -> Float -> Bool
forall a. Ord a => a -> a -> Bool
< Float -> Float -> Float
forall a. Ord a => a -> a -> a
max Float
abs_th (Float
epsilon Float -> Float -> Float
forall a. Num a => a -> a -> a
* Float
norm)
  where
    diff, norm :: Float
    diff :: Float
diff = Float -> Float
forall a. Num a => a -> a
abs (Float
a Float -> Float -> Float
forall a. Num a => a -> a -> a
- Float
b)
    norm :: Float
norm = Float -> Float
forall a. Num a => a -> a
abs Float
a Float -> Float -> Float
forall a. Num a => a -> a -> a
+ Float -> Float
forall a. Num a => a -> a
abs Float
b

    -- Define precision
    abs_th, epsilon :: Float
    epsilon :: Float
epsilon = Float
0.01
    abs_th :: Float
abs_th  = Float
0.01

decimalPart :: Float -> Float
decimalPart :: Float -> Float
decimalPart Float
x = Float
x Float -> Float -> Float
forall a. Num a => a -> a -> a
- Int -> Float
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Float -> Int
forall b. Integral b => Float -> b
forall a b. (RealFrac a, Integral b) => a -> b
floor Float
x :: Int)