{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE FlexibleInstances #-}
module Data.Set.Ordered
        ( OSet
        
        , empty, singleton
        
        
        
        
        
        
        
        
        , (<|), (|<), (>|), (|>)
        , (<>|), (|<>)
        , Bias(Bias, unbiased), L, R
        
        , null, size, member, notMember
        
        , delete, filter, (\\), (|/\), (/\|)
        
        , Index, findIndex, elemAt
        
        , fromList, toAscList
        ) where
import Control.Monad (guard)
import Data.Data
import Data.Foldable (Foldable, foldl', foldMap, foldr, toList)
import Data.Function (on)
import Data.Map (Map)
import Data.Map.Util
import Data.Monoid (Monoid(..))
#if MIN_VERSION_base(4,9,0)
import Data.Semigroup (Semigroup(..))
#endif
import Data.Set (Set) 
import Prelude hiding (filter, foldr, lookup, null)
import qualified Data.Map as M
data OSet a = OSet !(Map a Tag) !(Map Tag a)
        deriving Typeable
instance Foldable OSet where foldMap f (OSet _ vs) = foldMap f vs
instance         Eq   a  => Eq   (OSet a) where (==)    = (==)    `on` toList
instance         Ord  a  => Ord  (OSet a) where compare = compare `on` toList
instance         Show a  => Show (OSet a) where showsPrec = showsPrecList toList
instance (Ord a, Read a) => Read (OSet a) where readsPrec = readsPrecList fromList
instance (Data a, Ord a) => Data (OSet a) where
        gfoldl f z set = z fromList `f` toList set
        toConstr _     = fromListConstr
        gunfold k z c  = case constrIndex c of
                1 -> k (z fromList)
                _ -> error "gunfold"
        dataTypeOf _   = oSetDataType
        
        dataCast1 f    = gcast1 f
fromListConstr :: Constr
fromListConstr = mkConstr oSetDataType "fromList" [] Prefix
oSetDataType :: DataType
oSetDataType = mkDataType "Data.Set.Ordered.Set" [fromListConstr]
#if MIN_VERSION_base(4,9,0)
instance Ord a => Semigroup (Bias L (OSet a)) where Bias o <> Bias o' = Bias (o |<> o')
instance Ord a => Semigroup (Bias R (OSet a)) where Bias o <> Bias o' = Bias (o <>| o')
#endif
instance Ord a => Monoid (Bias L (OSet a)) where
        mempty = Bias empty
        mappend (Bias o) (Bias o') = Bias (o |<> o')
instance Ord a => Monoid (Bias R (OSet a)) where
        mempty = Bias empty
        mappend (Bias o) (Bias o') = Bias (o <>| o')
infixr 5 <|, |<   
infixl 5 >|, |>
infixr 6 <>|, |<> 
(<|) , (|<)  :: Ord a =>      a -> OSet a -> OSet a
(>|) , (|>)  :: Ord a => OSet a ->      a -> OSet a
(<>|) :: Ord a => OSet a -> OSet a -> OSet a
(|<>) :: Ord a => OSet a -> OSet a -> OSet a
v <| o@(OSet ts vs)
        | v `member` o = o
        | otherwise    = OSet (M.insert v t ts) (M.insert t v vs) where
                t = nextLowerTag vs
v |< o = OSet (M.insert v t ts) (M.insert t v vs) where
        t = nextLowerTag vs
        OSet ts vs = delete v o
o@(OSet ts vs) |> v
        | v `member` o = o
        | otherwise    = OSet (M.insert v t ts) (M.insert t v vs) where
                t = nextHigherTag vs
o >| v = OSet (M.insert v t ts) (M.insert t v vs) where
        t = nextHigherTag vs
        OSet ts vs = delete v o
o <>| o' = unsafeMappend (o \\ o') o'
o |<> o' = unsafeMappend o (o' \\ o)
unsafeMappend (OSet ts vs) (OSet ts' vs')
        = OSet (M.union tsBumped tsBumped')
               (M.union vsBumped vsBumped')
        where
        bump  = case maxTag vs  of
                Nothing -> 0
                Just k  -> -k-1
        bump' = case minTag vs' of
                Nothing -> 0
                Just k  -> -k
        tsBumped  = fmap (bump +) ts
        tsBumped' = fmap (bump'+) ts'
        vsBumped  = (bump +) `M.mapKeysMonotonic` vs
        vsBumped' = (bump'+) `M.mapKeysMonotonic` vs'
(\\) :: Ord a => OSet a -> OSet a -> OSet a
o@(OSet ts vs) \\ o'@(OSet ts' vs') = if size o < size o'
        then filter (`notMember` o') o
        else foldr delete o vs'
(|/\) :: Ord a => OSet a -> OSet a -> OSet a
OSet ts vs |/\ OSet ts' vs' = OSet ts'' vs'' where
        ts'' = M.intersection ts ts'
        vs'' = M.fromList [(t, v) | (v, t) <- M.toList ts]
(/\|) :: Ord a => OSet a -> OSet a -> OSet a
(/\|) = flip (/\|)
empty :: OSet a
empty = OSet M.empty M.empty
member, notMember :: Ord a => a -> OSet a -> Bool
member    v (OSet ts _) = M.member    v ts
notMember v (OSet ts _) = M.notMember v ts
size :: OSet a -> Int
size (OSet ts _) = M.size ts
filter :: Ord a => (a -> Bool) -> OSet a -> OSet a
filter f (OSet ts vs) = OSet (M.filterWithKey (\v t -> f v) ts)
                             (M.filterWithKey (\t v -> f v) vs)
delete :: Ord a => a -> OSet a -> OSet a
delete v o@(OSet ts vs) = case M.lookup v ts of
        Nothing -> o
        Just t  -> OSet (M.delete v ts) (M.delete t vs)
singleton :: a -> OSet a
singleton v = OSet (M.singleton v 0) (M.singleton 0 v)
fromList :: Ord a => [a] -> OSet a
fromList = foldl' (|>) empty
null :: OSet a -> Bool
null (OSet ts _) = M.null ts
findIndex :: Ord a => a -> OSet a -> Maybe Index
findIndex v o@(OSet ts vs) = do
        t <- M.lookup v ts
        M.lookupIndex t vs
elemAt :: OSet a -> Index -> Maybe a
elemAt o@(OSet ts vs) i = do
        guard (0 <= i && i < M.size vs)
        return . snd $ M.elemAt i vs
toAscList :: OSet a -> [a]
toAscList o@(OSet ts _) = fmap fst (M.toAscList ts)