{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LinearTypes #-}
{-# LANGUAGE MagicHash #-}
{-# OPTIONS_GHC -Wno-name-shadowing #-}
{-# OPTIONS_GHC -Wno-unused-matches #-}

module Simple.TopSort where

import Data.Bifunctor.Linear (second)
import qualified Data.Functor.Linear as Data
import Data.HashMap.Mutable.Linear (HashMap)
import qualified Data.HashMap.Mutable.Linear as HMap
import Data.Maybe.Linear (catMaybes)
import Data.Unrestricted.Linear
import qualified Prelude.Linear as Linear

-- # The topological sort of a DAG
-------------------------------------------------------------------------------

type Node = Int

type InDegGraph = HashMap Node ([Node], Int)

topsort :: [(Node, [Node])] -> [Node]
topsort :: [(Node, [Node])] -> [Node]
topsort = [Node] -> [Node]
forall a. [a] -> [a]
reverse ([Node] -> [Node])
-> ([(Node, [Node])] -> [Node]) -> [(Node, [Node])] -> [Node]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [(Node, ([Node], Node))] -> [Node]
postOrder ([(Node, ([Node], Node))] -> [Node])
-> ([(Node, [Node])] -> [(Node, ([Node], Node))])
-> [(Node, [Node])]
-> [Node]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((Node, [Node]) -> (Node, ([Node], Node)))
-> [(Node, [Node])] -> [(Node, ([Node], Node))]
forall a b. (a -> b) -> [a] -> [b]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (\(Node
n, [Node]
nbrs) -> (Node
n, ([Node]
nbrs, Node
0)))
  where
    postOrder :: [(Node, ([Node], Int))] -> [Node]
    postOrder :: [(Node, ([Node], Node))] -> [Node]
postOrder [] = []
    postOrder ([(Node, ([Node], Node))]
xs) =
      let nodes :: [Node]
nodes = ((Node, ([Node], Node)) -> Node)
-> [(Node, ([Node], Node))] -> [Node]
forall a b. (a -> b) -> [a] -> [b]
map (Node, ([Node], Node)) -> Node
forall a b. (a, b) -> a
fst [(Node, ([Node], Node))]
xs
       in Ur [Node] %1 -> [Node]
forall a. Ur a %1 -> a
unur (Ur [Node] %1 -> [Node]) -> Ur [Node] %1 -> [Node]
forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
Linear.$
            Node -> (InDegGraph %1 -> Ur [Node]) %1 -> Ur [Node]
forall k v b.
(Keyed k, Movable b) =>
Node -> (HashMap k v %1 -> b) %1 -> b
HMap.empty ([(Node, ([Node], Node))] -> Node
forall a. [a] -> Node
forall (t :: * -> *) a. Foldable t => t a -> Node
length [(Node, ([Node], Node))]
xs Node -> Node -> Node
forall a. Num a => a -> a -> a
* Node
2) ((InDegGraph %1 -> Ur [Node]) %1 -> Ur [Node])
-> (InDegGraph %1 -> Ur [Node]) %1 -> Ur [Node]
forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
Linear.$
              \InDegGraph
hm -> [Node] -> InDegGraph %1 -> Ur [Node]
postOrderHM [Node]
nodes ([(Node, ([Node], Node))] -> InDegGraph %1 -> InDegGraph
forall k v. Keyed k => [(k, v)] -> HashMap k v %1 -> HashMap k v
HMap.insertAll [(Node, ([Node], Node))]
xs InDegGraph
hm)

postOrderHM :: [Node] -> InDegGraph %1 -> Ur [Node]
postOrderHM :: [Node] -> InDegGraph %1 -> Ur [Node]
postOrderHM [Node]
nodes InDegGraph
dag =
  case [Node] -> InDegGraph %1 -> (InDegGraph, Ur [Node])
findSources [Node]
nodes ([Node] -> InDegGraph %1 -> InDegGraph
computeInDeg [Node]
nodes InDegGraph
dag) of
    (InDegGraph
dag, Ur [Node]
sources) -> [Node] -> [Node] -> InDegGraph %1 -> Ur [Node]
pluckSources [Node]
sources [] InDegGraph
dag
  where
    -- O(V + N)
    computeInDeg :: [Node] -> InDegGraph %1 -> InDegGraph
    computeInDeg :: [Node] -> InDegGraph %1 -> InDegGraph
computeInDeg [Node]
nodes InDegGraph
dag = (InDegGraph %1 -> Ur Node %1 -> InDegGraph)
-> InDegGraph %1 -> [Ur Node] %1 -> InDegGraph
forall b a. (b %1 -> a %1 -> b) -> b %1 -> [a] %1 -> b
Linear.foldl InDegGraph %1 -> Ur Node %1 -> InDegGraph
incChildren InDegGraph
dag ((Node -> Ur Node) -> [Node] -> [Ur Node]
forall a b. (a -> b) -> [a] -> [b]
map Node -> Ur Node
forall a. a -> Ur a
Ur [Node]
nodes)

    -- Increment in-degree of all neighbors
    incChildren :: InDegGraph %1 -> Ur Node %1 -> InDegGraph
    incChildren :: InDegGraph %1 -> Ur Node %1 -> InDegGraph
incChildren InDegGraph
dag (Ur Node
node) =
      case Node -> InDegGraph %1 -> (Ur (Maybe ([Node], Node)), InDegGraph)
forall k v.
Keyed k =>
k -> HashMap k v %1 -> (Ur (Maybe v), HashMap k v)
HMap.lookup Node
node InDegGraph
dag of
        (Ur Maybe ([Node], Node)
Nothing, InDegGraph
dag) -> InDegGraph
dag
        (Ur (Just ([Node]
xs, Node
i)), InDegGraph
dag) -> Ur [Node] %1 -> InDegGraph %1 -> InDegGraph
incNodes ([Node] %1 -> Ur [Node]
forall a. Movable a => a %1 -> Ur a
move [Node]
xs) InDegGraph
dag
      where
        incNodes :: Ur [Node] %1 -> InDegGraph %1 -> InDegGraph
        incNodes :: Ur [Node] %1 -> InDegGraph %1 -> InDegGraph
incNodes (Ur [Node]
ns) InDegGraph
dag = (InDegGraph %1 -> Ur Node %1 -> InDegGraph)
-> InDegGraph %1 -> [Ur Node] %1 -> InDegGraph
forall b a. (b %1 -> a %1 -> b) -> b %1 -> [a] %1 -> b
Linear.foldl InDegGraph %1 -> Ur Node %1 -> InDegGraph
incNode InDegGraph
dag ((Node -> Ur Node) -> [Node] -> [Ur Node]
forall a b. (a -> b) -> [a] -> [b]
map Node -> Ur Node
forall a. a -> Ur a
Ur [Node]
ns)

        incNode :: InDegGraph %1 -> Ur Node %1 -> InDegGraph
        incNode :: InDegGraph %1 -> Ur Node %1 -> InDegGraph
incNode InDegGraph
dag (Ur Node
node) =
          case Node -> InDegGraph %1 -> (Ur (Maybe ([Node], Node)), InDegGraph)
forall k v.
Keyed k =>
k -> HashMap k v %1 -> (Ur (Maybe v), HashMap k v)
HMap.lookup Node
node InDegGraph
dag of
            (Ur Maybe ([Node], Node)
Nothing, InDegGraph
dag') -> InDegGraph
dag'
            (Ur (Just ([Node]
n, Node
d)), InDegGraph
dag') ->
              Node -> ([Node], Node) -> InDegGraph %1 -> InDegGraph
forall k v. Keyed k => k -> v -> HashMap k v %1 -> HashMap k v
HMap.insert Node
node ([Node]
n, Node
d Node -> Node -> Node
forall a. Num a => a -> a -> a
+ Node
1) InDegGraph
dag'

-- HMap.alter dag (\(Just (n,d)) -> Just (n,d+1)) node

-- pluckSources sources postOrdSoFar dag
pluckSources :: [Node] -> [Node] -> InDegGraph %1 -> Ur [Node]
pluckSources :: [Node] -> [Node] -> InDegGraph %1 -> Ur [Node]
pluckSources [] [Node]
postOrd InDegGraph
dag = InDegGraph %1 -> Ur [Node] %1 -> Ur [Node]
forall a b. Consumable a => a %1 -> b %1 -> b
lseq InDegGraph
dag ([Node] %1 -> Ur [Node]
forall a. Movable a => a %1 -> Ur a
move [Node]
postOrd)
pluckSources (Node
s : [Node]
ss) [Node]
postOrd InDegGraph
dag =
  case Node -> InDegGraph %1 -> (Ur (Maybe ([Node], Node)), InDegGraph)
forall k v.
Keyed k =>
k -> HashMap k v %1 -> (Ur (Maybe v), HashMap k v)
HMap.lookup Node
s InDegGraph
dag of
    (Ur Maybe ([Node], Node)
Nothing, InDegGraph
dag) -> [Node] -> [Node] -> InDegGraph %1 -> Ur [Node]
pluckSources [Node]
ss (Node
s Node -> [Node] -> [Node]
forall a. a -> [a] -> [a]
: [Node]
postOrd) InDegGraph
dag
    (Ur (Just ([Node]
xs, Node
i)), InDegGraph
dag) ->
      case [Node] -> InDegGraph %1 -> (InDegGraph, Ur [Node])
walk [Node]
xs InDegGraph
dag of
        (InDegGraph
dag', Ur [Node]
newSrcs) ->
          [Node] -> [Node] -> InDegGraph %1 -> Ur [Node]
pluckSources ([Node]
newSrcs [Node] -> [Node] -> [Node]
forall a. [a] -> [a] -> [a]
++ [Node]
ss) (Node
s Node -> [Node] -> [Node]
forall a. a -> [a] -> [a]
: [Node]
postOrd) InDegGraph
dag'
  where
    -- decrement degree of children, save newly made sources
    walk :: [Node] -> InDegGraph %1 -> (InDegGraph, Ur [Node])
    walk :: [Node] -> InDegGraph %1 -> (InDegGraph, Ur [Node])
walk [Node]
children InDegGraph
dag =
      (Ur [Maybe Node] %1 -> Ur [Node])
-> (InDegGraph, Ur [Maybe Node]) %1 -> (InDegGraph, Ur [Node])
forall b c a. (b %1 -> c) -> (a, b) %1 -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b %1 -> c) -> p a b %1 -> p a c
second (([Maybe Node] %1 -> [Node]) -> Ur [Maybe Node] %1 -> Ur [Node]
forall a b. (a %1 -> b) -> Ur a %1 -> Ur b
forall (f :: * -> *) a b. Functor f => (a %1 -> b) -> f a %1 -> f b
Data.fmap [Maybe Node] %1 -> [Node]
forall a. [Maybe a] %1 -> [a]
catMaybes) ((Node -> InDegGraph %1 -> (InDegGraph, Ur (Maybe Node)))
-> [Node] -> InDegGraph %1 -> (InDegGraph, Ur [Maybe Node])
forall a b c.
(a -> b %1 -> (b, Ur c)) -> [a] -> b %1 -> (b, Ur [c])
mapAccum Node -> InDegGraph %1 -> (InDegGraph, Ur (Maybe Node))
decDegree [Node]
children InDegGraph
dag)

    -- Decrement the degree of a node, save it if it is now a source
    decDegree :: Node -> InDegGraph %1 -> (InDegGraph, Ur (Maybe Node))
    decDegree :: Node -> InDegGraph %1 -> (InDegGraph, Ur (Maybe Node))
decDegree Node
node InDegGraph
dag =
      case Node -> InDegGraph %1 -> (Ur (Maybe ([Node], Node)), InDegGraph)
forall k v.
Keyed k =>
k -> HashMap k v %1 -> (Ur (Maybe v), HashMap k v)
HMap.lookup Node
node InDegGraph
dag of
        (Ur Maybe ([Node], Node)
Nothing, InDegGraph
dag') -> (InDegGraph
dag', Maybe Node -> Ur (Maybe Node)
forall a. a -> Ur a
Ur Maybe Node
forall a. Maybe a
Nothing)
        (Ur (Just ([Node]
n, Node
d)), InDegGraph
dag') ->
          Node -> InDegGraph %1 -> (InDegGraph, Ur (Maybe Node))
checkSource Node
node (Node -> ([Node], Node) -> InDegGraph %1 -> InDegGraph
forall k v. Keyed k => k -> v -> HashMap k v %1 -> HashMap k v
HMap.insert Node
node ([Node]
n, Node
d Node -> Node -> Node
forall a. Num a => a -> a -> a
- Node
1) InDegGraph
dag')

-- Given a list of nodes, determines which are sources
findSources :: [Node] -> InDegGraph %1 -> (InDegGraph, Ur [Node])
findSources :: [Node] -> InDegGraph %1 -> (InDegGraph, Ur [Node])
findSources [Node]
nodes InDegGraph
dag =
  (Ur [Maybe Node] %1 -> Ur [Node])
-> (InDegGraph, Ur [Maybe Node]) %1 -> (InDegGraph, Ur [Node])
forall b c a. (b %1 -> c) -> (a, b) %1 -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b %1 -> c) -> p a b %1 -> p a c
second (([Maybe Node] %1 -> [Node]) -> Ur [Maybe Node] %1 -> Ur [Node]
forall a b. (a %1 -> b) -> Ur a %1 -> Ur b
forall (f :: * -> *) a b. Functor f => (a %1 -> b) -> f a %1 -> f b
Data.fmap [Maybe Node] %1 -> [Node]
forall a. [Maybe a] %1 -> [a]
catMaybes) ((Node -> InDegGraph %1 -> (InDegGraph, Ur (Maybe Node)))
-> [Node] -> InDegGraph %1 -> (InDegGraph, Ur [Maybe Node])
forall a b c.
(a -> b %1 -> (b, Ur c)) -> [a] -> b %1 -> (b, Ur [c])
mapAccum Node -> InDegGraph %1 -> (InDegGraph, Ur (Maybe Node))
checkSource [Node]
nodes InDegGraph
dag)

-- | Check if a node is a source, and if so return it
checkSource :: Node -> InDegGraph %1 -> (InDegGraph, Ur (Maybe Node))
checkSource :: Node -> InDegGraph %1 -> (InDegGraph, Ur (Maybe Node))
checkSource Node
node InDegGraph
dag =
  case Node -> InDegGraph %1 -> (Ur (Maybe ([Node], Node)), InDegGraph)
forall k v.
Keyed k =>
k -> HashMap k v %1 -> (Ur (Maybe v), HashMap k v)
HMap.lookup Node
node InDegGraph
dag of
    (Ur Maybe ([Node], Node)
Nothing, InDegGraph
dag) -> (InDegGraph
dag, Maybe Node -> Ur (Maybe Node)
forall a. a -> Ur a
Ur Maybe Node
forall a. Maybe a
Nothing)
    (Ur (Just ([Node]
xs, Node
0)), InDegGraph
dag) -> (InDegGraph
dag, Maybe Node -> Ur (Maybe Node)
forall a. a -> Ur a
Ur (Node -> Maybe Node
forall a. a -> Maybe a
Just Node
node))
    (Ur (Just ([Node]
xs, Node
n)), InDegGraph
dag) -> (InDegGraph
dag, Maybe Node -> Ur (Maybe Node)
forall a. a -> Ur a
Ur Maybe Node
forall a. Maybe a
Nothing)

mapAccum ::
  (a -> b %1 -> (b, Ur c)) -> [a] -> b %1 -> (b, Ur [c])
mapAccum :: forall a b c.
(a -> b %1 -> (b, Ur c)) -> [a] -> b %1 -> (b, Ur [c])
mapAccum a -> b %1 -> (b, Ur c)
f [] b
b = (b
b, [c] -> Ur [c]
forall a. a -> Ur a
Ur [])
mapAccum a -> b %1 -> (b, Ur c)
f (a
x : [a]
xs) b
b =
  case (a -> b %1 -> (b, Ur c)) -> [a] -> b %1 -> (b, Ur [c])
forall a b c.
(a -> b %1 -> (b, Ur c)) -> [a] -> b %1 -> (b, Ur [c])
mapAccum a -> b %1 -> (b, Ur c)
f [a]
xs b
b of
    (b
b, Ur [c]
cs) -> (Ur c %1 -> Ur [c]) -> (b, Ur c) %1 -> (b, Ur [c])
forall b c a. (b %1 -> c) -> (a, b) %1 -> (a, c)
forall (p :: * -> * -> *) b c a.
Bifunctor p =>
(b %1 -> c) -> p a b %1 -> p a c
second ((c %1 -> [c]) -> Ur c %1 -> Ur [c]
forall a b. (a %1 -> b) -> Ur a %1 -> Ur b
forall (f :: * -> *) a b. Functor f => (a %1 -> b) -> f a %1 -> f b
Data.fmap (c -> [c] -> [c]
forall a. a -> [a] -> [a]
: [c]
cs)) (a -> b %1 -> (b, Ur c)
f a
x b
b)