{-# LANGUAGE CPP #-}
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE FlexibleInstances #-}
module Data.Map.Ordered
        ( OMap
        
        , empty, singleton
        
        
        
        
        
        
        
        
        
        
        , (<|), (|<), (>|), (|>)
        , (<>|), (|<>), unionWithL, unionWithR
        , Bias(Bias, unbiased), L, R
        
        , delete, filter, (\\)
        , (|/\), (/\|), intersectionWith
        
        , null, size, member, notMember, lookup
        
        , Index, findIndex, elemAt
        
        , fromList, assocs, toAscList
        ) where
import Control.Applicative ((<|>))
import Control.Monad (guard)
import Data.Data
import Data.Foldable (Foldable, foldl', foldMap)
import Data.Function (on)
import Data.Map (Map)
import Data.Map.Util
import Data.Monoid (Monoid(..))
#if MIN_VERSION_base(4,9,0)
import Data.Semigroup (Semigroup(..))
#endif
#if !(MIN_VERSION_base(4,8,0))
import Control.Applicative ((<$>))
import Data.Traversable
#endif
import Prelude hiding (filter, lookup, null)
import qualified Data.Map as M
data OMap k v = OMap !(Map k (Tag, v)) !(Map Tag (k, v))
        deriving (Functor, Typeable)
instance Foldable (OMap k) where foldMap f (OMap _ kvs) = foldMap (f . snd) kvs
instance (       Eq   k, Eq   v) => Eq   (OMap k v) where (==)    = (==)    `on` assocs
instance (       Ord  k, Ord  v) => Ord  (OMap k v) where compare = compare `on` assocs
instance (       Show k, Show v) => Show (OMap k v) where showsPrec = showsPrecList assocs
instance (Ord k, Read k, Read v) => Read (OMap k v) where readsPrec = readsPrecList fromList
instance (Data k, Data a, Ord k) => Data (OMap k a) where
        gfoldl f z m   = z fromList `f` assocs m
        toConstr _     = fromListConstr
        gunfold k z c  = case constrIndex c of
                1 -> k (z fromList)
                _ -> error "gunfold"
        dataTypeOf _   = oMapDataType
        
        dataCast2 f    = gcast2 f
fromListConstr :: Constr
fromListConstr = mkConstr oMapDataType "fromList" [] Prefix
oMapDataType :: DataType
oMapDataType = mkDataType "Data.Map.Ordered.Map" [fromListConstr]
#if MIN_VERSION_base(4,9,0)
instance (Ord k, Semigroup v) => Semigroup (Bias L (OMap k v)) where
        Bias o <> Bias o' = Bias (unionWithL (const (<>)) o o')
instance (Ord k, Semigroup v) => Semigroup (Bias R (OMap k v)) where
        Bias o <> Bias o' = Bias (unionWithR (const (<>)) o o')
#endif
instance (Ord k, Monoid v) => Monoid (Bias L (OMap k v)) where
        mempty = Bias empty
        mappend (Bias o) (Bias o') = Bias (unionWithL (const mappend) o o')
instance (Ord k, Monoid v) => Monoid (Bias R (OMap k v)) where
        mempty = Bias empty
        mappend (Bias o) (Bias o') = Bias (unionWithR (const mappend) o o')
instance Ord k => Traversable (OMap k) where
        traverse f (OMap tvs kvs) = fromKV <$> traverse (\(k,v) -> (,) k <$> f v) kvs
infixr 5 <|, |< 
infixl 5 >|, |>
infixr 6 <>|, |<> 
(<|) , (|<) :: Ord k => (,)  k v -> OMap k v -> OMap k v
(>|) , (|>) :: Ord k => OMap k v -> (,)  k v -> OMap k v
(<>|) :: Ord k => OMap k v -> OMap k v -> OMap k v
(|<>) :: Ord k => OMap k v -> OMap k v -> OMap k v
(k, v) <| OMap tvs kvs = OMap (M.insert k (t, v) tvs) (M.insert t (k, v) kvs) where
        t = maybe (nextLowerTag kvs) fst (M.lookup k tvs)
(k, v) |< o = OMap (M.insert k (t, v) tvs) (M.insert t (k, v) kvs) where
        t = nextLowerTag kvs
        OMap tvs kvs = delete k o
o >| (k, v) = OMap (M.insert k (t, v) tvs) (M.insert t (k, v) kvs) where
        t = nextHigherTag kvs
        OMap tvs kvs = delete k o
OMap tvs kvs |> (k, v) = OMap (M.insert k (t, v) tvs) (M.insert t (k, v) kvs) where
        t = maybe (nextHigherTag kvs) fst (M.lookup k tvs)
(<>|) = unionWithR (const const)
(|<>) = unionWithL (const const)
unionWithL :: Ord k => (k -> v -> v -> v) -> OMap k v -> OMap k v -> OMap k v
unionWithL = unionWithInternal (\t t' -> t )
unionWithR :: Ord k => (k -> v -> v -> v) -> OMap k v -> OMap k v -> OMap k v
unionWithR = unionWithInternal (\t t' -> t')
unionWithInternal :: Ord k => (Tag -> Tag -> Tag) -> (k -> v -> v -> v) -> OMap k v -> OMap k v -> OMap k v
unionWithInternal fT fKV (OMap tvs kvs) (OMap tvs' kvs') = fromTV tvs'' where
        bump  = case maxTag kvs  of
                Nothing -> 0
                Just k  -> -k-1
        bump' = case minTag kvs' of
                Nothing -> 0
                Just k  -> -k
        tvs'' = M.unionWithKey (\k (t,v) (t',v') -> (fT t t', fKV k v v'))
                (fmap (\(t,v) -> (bump +t,v)) tvs )
                (fmap (\(t,v) -> (bump'+t,v)) tvs')
(\\) :: Ord k => OMap k v -> OMap k v' -> OMap k v
o@(OMap tvs kvs) \\ o'@(OMap tvs' kvs') = if size o < size o'
        then filter (const . (`notMember` o')) o
        else foldr delete o (fmap fst (assocs o'))
empty :: OMap k v
empty = OMap M.empty M.empty
singleton :: (k, v) -> OMap k v
singleton kv@(k, v) = OMap (M.singleton k (0, v)) (M.singleton 0 kv)
fromList :: Ord k => [(k, v)] -> OMap k v
fromList = foldl' (|>) empty
null :: OMap k v -> Bool
null (OMap tvs _) = M.null tvs
size :: OMap k v -> Int
size (OMap tvs _) = M.size tvs
member, notMember :: Ord k => k -> OMap k v -> Bool
member    k (OMap tvs _) = M.member    k tvs
notMember k (OMap tvs _) = M.notMember k tvs
lookup :: Ord k => k -> OMap k v -> Maybe v
lookup k (OMap tvs _) = fmap snd (M.lookup k tvs)
filter :: Ord k => (k -> v -> Bool) -> OMap k v -> OMap k v
filter f (OMap tvs kvs) = OMap (M.filterWithKey (\k (t, v) -> f k v) tvs)
                               (M.filterWithKey (\t (k, v) -> f k v) kvs)
delete :: Ord k => k -> OMap k v -> OMap k v
delete k o@(OMap tvs kvs) = case M.lookup k tvs of
        Nothing     -> o
        Just (t, _) -> OMap (M.delete k tvs) (M.delete t kvs)
(/\|) :: Ord k => OMap k v -> OMap k v' -> OMap k v
o /\| o' = intersectionWith (\k v' v -> v) o' o
(|/\) :: Ord k => OMap k v -> OMap k v' -> OMap k v
o |/\ o' = intersectionWith (\k v v' -> v) o o'
intersectionWith ::
        Ord k =>
        (k -> v -> v' -> v'') ->
        OMap k v -> OMap k v' -> OMap k v''
intersectionWith f (OMap tvs kvs) (OMap tvs' kvs') = fromTV
        $ M.intersectionWithKey (\k (t,v) (t',v') -> (t, f k v v')) tvs tvs'
fromTV :: Ord k => Map k (Tag, v) -> OMap k v
fromTV tvs = OMap tvs kvs where
        kvs = M.fromList [(t,(k,v)) | (k,(t,v)) <- M.toList tvs]
fromKV :: Ord k => Map Tag (k, v) -> OMap k v
fromKV kvs = OMap tvs kvs where
        tvs = M.fromList [(k,(t,v)) | (t,(k,v)) <- M.toList kvs]
findIndex :: Ord k => k -> OMap k v -> Maybe Index
findIndex k o@(OMap tvs kvs) = do
        (t, _) <- M.lookup k tvs
        M.lookupIndex t kvs
elemAt :: OMap k v -> Index -> Maybe (k, v)
elemAt o@(OMap tvs kvs) i = do
        guard (0 <= i && i < M.size kvs)
        return . snd $ M.elemAt i kvs
assocs :: OMap k v -> [(k, v)]
assocs (OMap _ kvs) = map snd $ M.toAscList kvs
toAscList :: OMap k v -> [(k, v)]
toAscList (OMap tvs kvs) = map (\(k, (t, v)) -> (k, v)) $ M.toAscList tvs