-- |
-- Module:      Data.Poly.Internal.Dense.Field
-- Copyright:   (c) 2019 Andrew Lelechenko
-- Licence:     BSD3
-- Maintainer:  Andrew Lelechenko <andrew.lelechenko@gmail.com>
--
-- 'Euclidean' instance with a 'Field' constraint on the coefficient type.
--

{-# LANGUAGE ConstraintKinds            #-}
{-# LANGUAGE FlexibleInstances          #-}
{-# LANGUAGE ScopedTypeVariables        #-}
{-# LANGUAGE TypeFamilies               #-}

{-# OPTIONS_GHC -fno-warn-orphans #-}

module Data.Poly.Internal.Dense.Field
  ( quotRemFractional
  ) where

import Prelude hiding (quotRem, quot, rem, gcd)
import Control.Exception
import Control.Monad
import Control.Monad.ST
import Data.Euclidean (Euclidean(..), Field)
import Data.Semiring (times, minus, zero, one)
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as MG

import Data.Poly.Internal.Dense
import Data.Poly.Internal.Dense.GcdDomain ()

-- | Note that 'degree' 0 = 0.
--
-- @since 0.3.0.0
instance (Eq a, Field a, G.Vector v a) => Euclidean (Poly v a) where
  degree :: Poly v a -> Natural
degree (Poly v a
xs)
    | v a -> Bool
forall (v :: * -> *) a. Vector v a => v a -> Bool
G.null v a
xs = Natural
0
    | Bool
otherwise = Int -> Natural
forall a b. (Integral a, Num b) => a -> b
fromIntegral (v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
xs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

  quotRem :: Poly v a -> Poly v a -> (Poly v a, Poly v a)
quotRem (Poly v a
xs) (Poly v a
ys) = (v a -> Poly v a
forall a (v :: * -> *).
(Eq a, Semiring a, Vector v a) =>
v a -> Poly v a
toPoly' v a
qs, v a -> Poly v a
forall a (v :: * -> *).
(Eq a, Semiring a, Vector v a) =>
v a -> Poly v a
toPoly' v a
rs)
    where
      (v a
qs, v a
rs) = a
-> (a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> v a
-> v a
-> (v a, v a)
forall a (v :: * -> *).
(Eq a, Vector v a) =>
a
-> (a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> v a
-> v a
-> (v a, v a)
quotientAndRemainder a
forall a. Semiring a => a
zero (a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. Semiring a => a
one) a -> a -> a
forall a. Ring a => a -> a -> a
minus a -> a -> a
forall a. Semiring a => a -> a -> a
times (a
forall a. Semiring a => a
one a -> a -> a
forall a. Euclidean a => a -> a -> a
`quot`) v a
xs v a
ys
  {-# INLINE quotRem #-}

  rem :: Poly v a -> Poly v a -> Poly v a
rem (Poly v a
xs) (Poly v a
ys) = v a -> Poly v a
forall a (v :: * -> *).
(Eq a, Semiring a, Vector v a) =>
v a -> Poly v a
toPoly' (v a -> Poly v a) -> v a -> Poly v a
forall a b. (a -> b) -> a -> b
$ v a -> v a -> v a
forall a (v :: * -> *).
(Eq a, Field a, Vector v a) =>
v a -> v a -> v a
remainder v a
xs v a
ys
  {-# INLINE rem #-}

-- | Polynomial division with remainder.
--
-- >>> quotRemFractional (X^3 + 2) (X^2 - 1 :: UPoly Double)
-- (1.0 * X + 0.0,1.0 * X + 2.0)
--
-- @since 0.5.0.0
quotRemFractional :: (Eq a, Fractional a, G.Vector v a) => Poly v a -> Poly v a -> (Poly v a, Poly v a)
quotRemFractional :: forall a (v :: * -> *).
(Eq a, Fractional a, Vector v a) =>
Poly v a -> Poly v a -> (Poly v a, Poly v a)
quotRemFractional (Poly v a
xs) (Poly v a
ys) = (v a -> Poly v a
forall a (v :: * -> *).
(Eq a, Num a, Vector v a) =>
v a -> Poly v a
toPoly v a
qs, v a -> Poly v a
forall a (v :: * -> *).
(Eq a, Num a, Vector v a) =>
v a -> Poly v a
toPoly v a
rs)
  where
    (v a
qs, v a
rs) = a
-> (a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> v a
-> v a
-> (v a, v a)
forall a (v :: * -> *).
(Eq a, Vector v a) =>
a
-> (a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> v a
-> v a
-> (v a, v a)
quotientAndRemainder a
0 (a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
1) (-) a -> a -> a
forall a. Num a => a -> a -> a
(*) a -> a
forall a. Fractional a => a -> a
recip v a
xs v a
ys
{-# INLINE quotRemFractional #-}

quotientAndRemainder
  :: (Eq a, G.Vector v a)
  => a             -- ^ zero
  -> (a -> Bool)   -- ^ is one?
  -> (a -> a -> a) -- ^ subtract
  -> (a -> a -> a) -- ^ multiply
  -> (a -> a)      -- ^ invert
  -> v a           -- ^ dividend
  -> v a           -- ^ divisor
  -> (v a, v a)
quotientAndRemainder :: forall a (v :: * -> *).
(Eq a, Vector v a) =>
a
-> (a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> (a -> a)
-> v a
-> v a
-> (v a, v a)
quotientAndRemainder a
zer a -> Bool
isOne a -> a -> a
sub a -> a -> a
mul a -> a
inv v a
xs v a
ys
  | Int
lenXs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
lenYs = (v a
forall (v :: * -> *) a. Vector v a => v a
G.empty, v a
xs)
  | Int
lenYs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = ArithException -> (v a, v a)
forall a e. Exception e => e -> a
throw ArithException
DivideByZero
  | Int
lenYs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = let invY :: a
invY = a -> a
inv (v a -> a
forall (v :: * -> *) a. Vector v a => v a -> a
G.unsafeHead v a
ys) in
                 ((a -> a) -> v a -> v a
forall (v :: * -> *) a b.
(Vector v a, Vector v b) =>
(a -> b) -> v a -> v b
G.map (a -> a -> a
`mul` a
invY) v a
xs, v a
forall (v :: * -> *) a. Vector v a => v a
G.empty)
  | Bool
otherwise = (forall s. ST s (v a, v a)) -> (v a, v a)
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (v a, v a)) -> (v a, v a))
-> (forall s. ST s (v a, v a)) -> (v a, v a)
forall a b. (a -> b) -> a -> b
$ do
    Mutable v s a
qs <- Int -> ST s (Mutable v (PrimState (ST s)) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
MG.unsafeNew Int
lenQs
    Mutable v s a
rs <- Int -> ST s (Mutable v (PrimState (ST s)) a)
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
Int -> m (v (PrimState m) a)
MG.unsafeNew Int
lenXs
    Mutable v (PrimState (ST s)) a -> v a -> ST s ()
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> v a -> m ()
G.unsafeCopy Mutable v s a
Mutable v (PrimState (ST s)) a
rs v a
xs
    let yLast :: a
yLast = v a -> a
forall (v :: * -> *) a. Vector v a => v a -> a
G.unsafeLast v a
ys
        invYLast :: a
invYLast = a -> a
inv a
yLast
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
lenQs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1, Int
lenQs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2 .. Int
0] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
      a
r <- Mutable v (PrimState (ST s)) a -> Int -> ST s a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
MG.unsafeRead Mutable v s a
Mutable v (PrimState (ST s)) a
rs (Int
lenYs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)
      let q :: a
q = if a -> Bool
isOne a
yLast then a
r else a
r a -> a -> a
`mul` a
invYLast
      Mutable v (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
MG.unsafeWrite Mutable v s a
Mutable v (PrimState (ST s)) a
qs Int
i a
q
      Mutable v (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
MG.unsafeWrite Mutable v s a
Mutable v (PrimState (ST s)) a
rs (Int
lenYs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) a
zer
      [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
lenYs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
k -> do
        let y :: a
y = v a -> Int -> a
forall (v :: * -> *) a. Vector v a => v a -> Int -> a
G.unsafeIndex v a
ys Int
k
        Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
zer) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$
          Mutable v (PrimState (ST s)) a -> (a -> a) -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
MG.unsafeModify Mutable v s a
Mutable v (PrimState (ST s)) a
rs (\a
c -> a
c a -> a -> a
`sub` (a
q a -> a -> a
`mul` a
y)) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k)
    let rs' :: Mutable v s a
rs' = Int -> Int -> Mutable v s a -> Mutable v s a
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
MG.unsafeSlice Int
0 Int
lenYs Mutable v s a
rs
    (,) (v a -> v a -> (v a, v a))
-> ST s (v a) -> ST s (v a -> (v a, v a))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Mutable v (PrimState (ST s)) a -> ST s (v a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
G.unsafeFreeze Mutable v s a
Mutable v (PrimState (ST s)) a
qs ST s (v a -> (v a, v a)) -> ST s (v a) -> ST s (v a, v a)
forall a b. ST s (a -> b) -> ST s a -> ST s b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Mutable v (PrimState (ST s)) a -> ST s (v a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
G.unsafeFreeze Mutable v s a
Mutable v (PrimState (ST s)) a
rs'
  where
    lenXs :: Int
lenXs = v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
xs
    lenYs :: Int
lenYs = v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
ys
    lenQs :: Int
lenQs = Int
lenXs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
lenYs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
{-# INLINABLE quotientAndRemainder #-}

remainder
  :: (Eq a, Field a, G.Vector v a)
  => v a
  -> v a
  -> v a
remainder :: forall a (v :: * -> *).
(Eq a, Field a, Vector v a) =>
v a -> v a -> v a
remainder v a
xs v a
ys
  | v a -> Bool
forall (v :: * -> *) a. Vector v a => v a -> Bool
G.null v a
ys = ArithException -> v a
forall a e. Exception e => e -> a
throw ArithException
DivideByZero
  | Bool
otherwise = (forall s. ST s (v a)) -> v a
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (v a)) -> v a) -> (forall s. ST s (v a)) -> v a
forall a b. (a -> b) -> a -> b
$ do
    Mutable v s a
rs <- v a -> ST s (Mutable v (PrimState (ST s)) a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> m (Mutable v (PrimState m) a)
G.thaw v a
xs
    Mutable v s a
ys' <- v a -> ST s (Mutable v (PrimState (ST s)) a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
v a -> m (Mutable v (PrimState m) a)
G.unsafeThaw v a
ys
    Mutable v s a -> Mutable v s a -> ST s ()
forall a (v :: * -> *) s.
(Eq a, Field a, Vector v a) =>
Mutable v s a -> Mutable v s a -> ST s ()
remainderM Mutable v s a
rs Mutable v s a
ys'
    Mutable v (PrimState (ST s)) a -> ST s (v a)
forall (m :: * -> *) (v :: * -> *) a.
(PrimMonad m, Vector v a) =>
Mutable v (PrimState m) a -> m (v a)
G.unsafeFreeze (Mutable v (PrimState (ST s)) a -> ST s (v a))
-> Mutable v (PrimState (ST s)) a -> ST s (v a)
forall a b. (a -> b) -> a -> b
$ Int -> Int -> Mutable v s a -> Mutable v s a
forall (v :: * -> * -> *) a s.
MVector v a =>
Int -> Int -> v s a -> v s a
MG.unsafeSlice Int
0 (v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
xs Int -> Int -> Int
forall a. Ord a => a -> a -> a
`min` v a -> Int
forall (v :: * -> *) a. Vector v a => v a -> Int
G.length v a
ys) Mutable v s a
rs
{-# INLINABLE remainder #-}

remainderM
  :: (Eq a, Field a, G.Vector v a)
  => G.Mutable v s a
  -> G.Mutable v s a
  -> ST s ()
remainderM :: forall a (v :: * -> *) s.
(Eq a, Field a, Vector v a) =>
Mutable v s a -> Mutable v s a -> ST s ()
remainderM Mutable v s a
xs Mutable v s a
ys
  | Int
lenXs Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
lenYs = () -> ST s ()
forall a. a -> ST s a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
  | Int
lenYs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = ArithException -> ST s ()
forall a e. Exception e => e -> a
throw ArithException
DivideByZero
  | Int
lenYs Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
1 = Mutable v (PrimState (ST s)) a -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> a -> m ()
MG.set Mutable v s a
Mutable v (PrimState (ST s)) a
xs a
forall a. Semiring a => a
zero
  | Bool
otherwise = do
    a
yLast <- Mutable v (PrimState (ST s)) a -> Int -> ST s a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
MG.unsafeRead Mutable v s a
Mutable v (PrimState (ST s)) a
ys (Int
lenYs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    let invYLast :: a
invYLast = a
forall a. Semiring a => a
one a -> a -> a
forall a. Euclidean a => a -> a -> a
`quot` a
yLast
    [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
lenQs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1, Int
lenQs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2 .. Int
0] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
i -> do
      a
r <- Mutable v (PrimState (ST s)) a -> Int -> ST s a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
MG.unsafeRead Mutable v s a
Mutable v (PrimState (ST s)) a
xs (Int
lenYs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i)
      Mutable v (PrimState (ST s)) a -> Int -> a -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> a -> m ()
MG.unsafeWrite Mutable v s a
Mutable v (PrimState (ST s)) a
xs (Int
lenYs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i) a
forall a. Semiring a => a
zero
      let q :: a
q = if a
yLast a -> a -> Bool
forall a. Eq a => a -> a -> Bool
== a
forall a. Semiring a => a
one then a
r else a
r a -> a -> a
forall a. Semiring a => a -> a -> a
`times` a
invYLast
      [Int] -> (Int -> ST s ()) -> ST s ()
forall (t :: * -> *) (m :: * -> *) a b.
(Foldable t, Monad m) =>
t a -> (a -> m b) -> m ()
forM_ [Int
0 .. Int
lenYs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2] ((Int -> ST s ()) -> ST s ()) -> (Int -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Int
k -> do
        a
y <- Mutable v (PrimState (ST s)) a -> Int -> ST s a
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> Int -> m a
MG.unsafeRead Mutable v s a
Mutable v (PrimState (ST s)) a
ys Int
k
        Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
forall a. Semiring a => a
zero) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$
          Mutable v (PrimState (ST s)) a -> (a -> a) -> Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) a.
(PrimMonad m, MVector v a) =>
v (PrimState m) a -> (a -> a) -> Int -> m ()
MG.unsafeModify Mutable v s a
Mutable v (PrimState (ST s)) a
xs (\a
c -> a
c a -> a -> a
forall a. Ring a => a -> a -> a
`minus` a
q a -> a -> a
forall a. Semiring a => a -> a -> a
`times` a
y) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
k)
  where
    lenXs :: Int
lenXs = Mutable v s a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
MG.length Mutable v s a
xs
    lenYs :: Int
lenYs = Mutable v s a -> Int
forall (v :: * -> * -> *) a s. MVector v a => v s a -> Int
MG.length Mutable v s a
ys
    lenQs :: Int
lenQs = Int
lenXs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
lenYs Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
{-# INLINABLE remainderM #-}