-- | 'Vector' sort routines.
module Harpie.Sort
  ( sortV,
    sortByV,
    orderV,
    orderByV,
  )
where

import Data.Ord
import Data.Vector (Vector, convert, unsafeIndex)
import Data.Vector qualified as V
import Data.Vector.Algorithms.Intro (sortBy)
import Data.Vector.Unboxed (generate, modify)
import Prelude

-- $setup
-- >>> :m -Prelude
-- >>> import Harpie.Sort
-- >>> import Data.Vector qualified as V
-- >>> import Data.Ord (Down (..))
-- >>> import Prelude (Int)
-- >>> :set -XDataKinds
-- >>> :set -XTypeFamilies
-- >>> :set -XFlexibleContexts

-- | return the sorted array
--
-- >>> sortV (V.fromList [3,1,4,2,0,5::Int])
-- [0,1,2,3,4,5]
sortV :: (Ord a) => Vector a -> Vector a
sortV :: forall a. Ord a => Vector a -> Vector a
sortV Vector a
a = Vector a -> Vector Int -> Vector a
forall a. Vector a -> Vector Int -> Vector a
V.unsafeBackpermute Vector a
a (Vector a -> Vector Int
forall a. Ord a => Vector a -> Vector Int
orderV Vector a
a)

-- | return the array sorted by the comparison function
--
-- >>> sortByV Down (V.fromList [3,1,4,2,0,5::Int])
-- [5,4,3,2,1,0]
sortByV :: (Ord b) => (a -> b) -> Vector a -> Vector a
sortByV :: forall b a. Ord b => (a -> b) -> Vector a -> Vector a
sortByV a -> b
c Vector a
a = Vector a -> Vector Int -> Vector a
forall a. Vector a -> Vector Int -> Vector a
V.unsafeBackpermute Vector a
a ((a -> b) -> Vector a -> Vector Int
forall b a. Ord b => (a -> b) -> Vector a -> Vector Int
orderByV a -> b
c Vector a
a)

-- | returns the indices of the elements in ascending order.
--
-- >>> orderV (V.fromList [0..5::Int])
-- [0,1,2,3,4,5]
orderV :: (Ord a) => Vector a -> Vector Int
orderV :: forall a. Ord a => Vector a -> Vector Int
orderV Vector a
a = Vector Int
idx
  where
    idx :: Vector Int
idx = Vector Int -> Vector Int
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
convert (Vector Int -> Vector Int) -> Vector Int -> Vector Int
forall a b. (a -> b) -> a -> b
$ (forall s. MVector s Int -> ST s ()) -> Vector Int -> Vector Int
forall a.
Unbox a =>
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
modify (Comparison Int -> MVector (PrimState (ST s)) Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> m ()
sortBy Comparison Int
comp) Vector Int
init0
    comp :: Comparison Int
comp = (Int -> a) -> Comparison Int
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing ((Int -> a) -> Comparison Int) -> (Int -> a) -> Comparison Int
forall a b. (a -> b) -> a -> b
$ Vector a -> Int -> a
forall a. Vector a -> Int -> a
unsafeIndex Vector a
a -- comparing function
    init0 :: Vector Int
init0 = Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
generate (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
a) Int -> Int
forall a. a -> a
id -- [0..size - 1]

-- | returns the indices of the elements in order given a comparison function.
--
-- >>> orderByV Down (V.fromList [0..5::Int])
-- [5,4,3,2,1,0]
orderByV :: (Ord b) => (a -> b) -> Vector a -> Vector Int
orderByV :: forall b a. Ord b => (a -> b) -> Vector a -> Vector Int
orderByV a -> b
c Vector a
a = Vector Int
idx
  where
    idx :: Vector Int
idx = Vector Int -> Vector Int
forall (v :: * -> *) a (w :: * -> *).
(Vector v a, Vector w a) =>
v a -> w a
convert (Vector Int -> Vector Int) -> Vector Int -> Vector Int
forall a b. (a -> b) -> a -> b
$ (forall s. MVector s Int -> ST s ()) -> Vector Int -> Vector Int
forall a.
Unbox a =>
(forall s. MVector s a -> ST s ()) -> Vector a -> Vector a
modify (Comparison Int -> MVector (PrimState (ST s)) Int -> ST s ()
forall (m :: * -> *) (v :: * -> * -> *) e.
(PrimMonad m, MVector v e) =>
Comparison e -> v (PrimState m) e -> m ()
sortBy Comparison Int
comp) Vector Int
init0
    comp :: Comparison Int
comp = (Int -> b) -> Comparison Int
forall a b. Ord a => (b -> a) -> b -> b -> Ordering
comparing ((Int -> b) -> Comparison Int) -> (Int -> b) -> Comparison Int
forall a b. (a -> b) -> a -> b
$ a -> b
c (a -> b) -> (Int -> a) -> Int -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Vector a -> Int -> a
forall a. Vector a -> Int -> a
unsafeIndex Vector a
a -- comparing function
    init0 :: Vector Int
init0 = Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
generate (Vector a -> Int
forall a. Vector a -> Int
V.length Vector a
a) Int -> Int
forall a. a -> a
id -- [0..size - 1]