-- SPDX-FileCopyrightText: 2025 Alex Ionescu
-- SPDX-License-Identifier: MPL-2.0

{-# LANGUAGE AllowAmbiguousTypes   #-}
{-# LANGUAGE BlockArguments        #-}
{-# LANGUAGE DataKinds             #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE ImpredicativeTypes    #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE TypeFamilies          #-}

-- | This module defines multiply-located values and combinators for manipulating them.
module Choreography.Location.Multi where

import GHC.TypeLits(KnownSymbol)
import Data.Proxy(Proxy(..))
import Data.SOP.BasicFunctors(K(..), mapKK, unK)
import Data.SOP.Classes(HSequence(..), hmap)
import Data.SOP.Constraint(All)
import Data.SOP.NP(NP(..), hd, tl)

import Choreography.Choreo
import Choreography.Location

-- | The type of multiply-located values.
type a @@ ls = NP ((@) a) ls

-- | "scatter": Send a located value from a single location to many others.
(~>*) :: (KnownSymbol l, All KnownSymbol ls, Show a, Read a) => (Proxy l, a @ l) -> NP Proxy ls -> Choreo m (a @@ ls)
(Proxy l
l, a @ l
a) ~>* :: forall (l :: Symbol) (ls :: [Symbol]) a (m :: * -> *).
(KnownSymbol l, All KnownSymbol ls, Show a, Read a) =>
(Proxy l, a @ l) -> NP Proxy ls -> Choreo m (a @@ ls)
~>* NP Proxy ls
ls = Proxy KnownSymbol
-> (forall (a :: Symbol).
    KnownSymbol a =>
    Proxy a -> Freer (ChoreoSig m) (a @ a))
-> NP Proxy ls
-> Freer (ChoreoSig m) (NP ((@) a) ls)
forall k l (h :: (k -> *) -> l -> *) (c :: k -> Constraint)
       (xs :: l) (g :: * -> *) (proxy :: (k -> Constraint) -> *)
       (f :: k -> *) (f' :: k -> *).
(HSequence h, AllN h c xs, Applicative g) =>
proxy c
-> (forall (a :: k). c a => f a -> g (f' a))
-> h f xs
-> g (h f' xs)
forall (c :: Symbol -> Constraint) (xs :: [Symbol]) (g :: * -> *)
       (proxy :: (Symbol -> Constraint) -> *) (f :: Symbol -> *)
       (f' :: Symbol -> *).
(AllN NP c xs, Applicative g) =>
proxy c
-> (forall (a :: Symbol). c a => f a -> g (f' a))
-> NP f xs
-> g (NP f' xs)
hctraverse' (forall {k} (t :: k). Proxy t
forall (t :: Symbol -> Constraint). Proxy t
Proxy @KnownSymbol) ((Proxy l
l, a @ l
a) (Proxy l, a @ l) -> Proxy a -> Choreo m (a @ a)
forall a (l :: Symbol) (l' :: Symbol) (m :: * -> *).
(Show a, Read a, KnownSymbol l, KnownSymbol l') =>
(Proxy l, a @ l) -> Proxy l' -> Choreo m (a @ l')
~>) NP Proxy ls
ls

infix 4 ~>*

-- | "gather": Send a multiply-located value from many locations to one.
(*~>) :: (KnownSymbol l, All KnownSymbol ls, Applicative m, Show a, Read a) => a @@ ls -> Proxy l -> Choreo m (NP (K a) ls @ l)
a @@ ls
as *~> :: forall (l :: Symbol) (ls :: [Symbol]) (m :: * -> *) a.
(KnownSymbol l, All KnownSymbol ls, Applicative m, Show a,
 Read a) =>
(a @@ ls) -> Proxy l -> Choreo m (NP (K a) ls @ l)
*~> Proxy l
l = do
  as' <- Proxy KnownSymbol
-> (forall (a :: Symbol).
    KnownSymbol a =>
    (a @ a) -> Freer (ChoreoSig m) (K (a @ l) a))
-> (a @@ ls)
-> Freer (ChoreoSig m) (NP (K (a @ l)) ls)
forall k l (h :: (k -> *) -> l -> *) (c :: k -> Constraint)
       (xs :: l) (g :: * -> *) (proxy :: (k -> Constraint) -> *)
       (f :: k -> *) (f' :: k -> *).
(HSequence h, AllN h c xs, Applicative g) =>
proxy c
-> (forall (a :: k). c a => f a -> g (f' a))
-> h f xs
-> g (h f' xs)
forall (c :: Symbol -> Constraint) (xs :: [Symbol]) (g :: * -> *)
       (proxy :: (Symbol -> Constraint) -> *) (f :: Symbol -> *)
       (f' :: Symbol -> *).
(AllN NP c xs, Applicative g) =>
proxy c
-> (forall (a :: Symbol). c a => f a -> g (f' a))
-> NP f xs
-> g (NP f' xs)
hctraverse' (forall {k} (t :: k). Proxy t
forall (t :: Symbol -> Constraint). Proxy t
Proxy @KnownSymbol) (\a @ a
a -> (a @ l) -> K (a @ l) a
forall k a (b :: k). a -> K a b
K ((a @ l) -> K (a @ l) a)
-> Freer (ChoreoSig m) (a @ l) -> Freer (ChoreoSig m) (K (a @ l) a)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ((Proxy a
forall {k} (t :: k). Proxy t
Proxy, a @ a
a) (Proxy a, a @ a) -> Proxy l -> Freer (ChoreoSig m) (a @ l)
forall a (l :: Symbol) (l' :: Symbol) (m :: * -> *).
(Show a, Read a, KnownSymbol l, KnownSymbol l') =>
(Proxy l, a @ l) -> Proxy l' -> Choreo m (a @ l')
~> Proxy l
l)) a @@ ls
as
  locally l \Unwrap l
unwrap -> NP (K a) ls -> m (NP (K a) ls)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (NP (K a) ls -> m (NP (K a) ls)) -> NP (K a) ls -> m (NP (K a) ls)
forall a b. (a -> b) -> a -> b
$ (forall (a :: Symbol). K (a @ l) a -> K a a)
-> NP (K (a @ l)) ls -> NP (K a) ls
forall {k} {l} (h :: (k -> *) -> l -> *) (xs :: l) (f :: k -> *)
       (f' :: k -> *).
(SListIN (Prod h) xs, HAp h) =>
(forall (a :: k). f a -> f' a) -> h f xs -> h f' xs
hmap (((a @ l) -> a) -> K (a @ l) a -> K a a
forall {k1} {k2} a b (c :: k1) (d :: k2).
(a -> b) -> K a c -> K b d
mapKK (a @ l) -> a
Unwrap l
unwrap) NP (K (a @ l)) ls
as'

infix 4 *~>

-- | Send a list of values from one location to many others "pointwise" (i.e. each target gets one value).
(~>.) :: (KnownSymbol l, All KnownSymbol ls, Applicative m, Show a, Read a) => (Proxy l, NP (K a) ls @ l) -> NP Proxy ls -> Choreo m (a @@ ls)
(Proxy l, NP (K a) ls @ l)
_       ~>. :: forall (l :: Symbol) (ls :: [Symbol]) (m :: * -> *) a.
(KnownSymbol l, All KnownSymbol ls, Applicative m, Show a,
 Read a) =>
(Proxy l, NP (K a) ls @ l) -> NP Proxy ls -> Choreo m (a @@ ls)
~>. NP Proxy ls
Nil = (a @@ ls) -> Freer (ChoreoSig m) (a @@ ls)
forall a. a -> Freer (ChoreoSig m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a @@ ls
NP ((@) a) '[]
forall {k} (a :: k -> *). NP a '[]
Nil
(Proxy l
l, NP (K a) ls @ l
as) ~>. (Proxy x
l' :* NP Proxy xs
ls') = do
  a <- Proxy l -> (Unwrap l -> m a) -> Choreo m (a @ l)
forall (l :: Symbol) (m :: * -> *) a.
KnownSymbol l =>
Proxy l -> (Unwrap l -> m a) -> Choreo m (a @ l)
locally Proxy l
l \Unwrap l
unwrap -> a -> m a
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> m a) -> a -> m a
forall a b. (a -> b) -> a -> b
$ K a x -> a
forall {k} a (b :: k). K a b -> a
unK (K a x -> a) -> K a x -> a
forall a b. (a -> b) -> a -> b
$ NP (K a) (x : xs) -> K a x
forall {k} (f :: k -> *) (x :: k) (xs :: [k]). NP f (x : xs) -> f x
hd (NP (K a) (x : xs) -> K a x) -> NP (K a) (x : xs) -> K a x
forall a b. (a -> b) -> a -> b
$ (NP (K a) ls @ l) -> NP (K a) ls
Unwrap l
unwrap NP (K a) ls @ l
as
  as' <- locally l \Unwrap l
unwrap -> NP (K a) xs -> m (NP (K a) xs)
forall a. a -> m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (NP (K a) xs -> m (NP (K a) xs)) -> NP (K a) xs -> m (NP (K a) xs)
forall a b. (a -> b) -> a -> b
$ NP (K a) (x : xs) -> NP (K a) xs
forall {k} (f :: k -> *) (x :: k) (xs :: [k]).
NP f (x : xs) -> NP f xs
tl (NP (K a) (x : xs) -> NP (K a) xs)
-> NP (K a) (x : xs) -> NP (K a) xs
forall a b. (a -> b) -> a -> b
$ (NP (K a) ls @ l) -> NP (K a) ls
Unwrap l
unwrap NP (K a) ls @ l
as
  (:*) <$> ((l, a) ~> l') <*> ((l, as') ~>. ls')

infix 4 ~>.

-- | Multiply-located version of `Unwrap`.
type Unwraps ls = forall a. a @@ ls -> a

-- | The type of a (multi)local computation.
type LocalComp ls m a = forall l. KnownSymbol l => Proxy l -> Unwraps ls -> m a

-- | Multiply-located version of `locally`.
multilocally :: forall ls m a. All KnownSymbol ls => NP Proxy ls -> LocalComp ls m a -> Choreo m (a @@ ls)
multilocally :: forall (ls :: [Symbol]) (m :: * -> *) a.
All KnownSymbol ls =>
NP Proxy ls -> LocalComp ls m a -> Choreo m (a @@ ls)
multilocally NP Proxy ls
ls LocalComp ls m a
f = NP Proxy ls
-> (forall a. (a @@ ls) -> a @@ ls) -> Choreo m (a @@ ls)
forall (ls' :: [Symbol]).
All KnownSymbol ls' =>
NP Proxy ls'
-> (forall a. (a @@ ls) -> a @@ ls') -> Choreo m (a @@ ls')
go NP Proxy ls
ls (a @@ ls) -> a @@ ls
forall a. a -> a
forall a. (a @@ ls) -> a @@ ls
id
  where
    go :: forall ls'. All KnownSymbol ls' => NP Proxy ls' -> (forall a. a @@ ls -> a @@ ls') -> Choreo m (a @@ ls')
    go :: forall (ls' :: [Symbol]).
All KnownSymbol ls' =>
NP Proxy ls'
-> (forall a. (a @@ ls) -> a @@ ls') -> Choreo m (a @@ ls')
go NP Proxy ls'
Nil forall a. (a @@ ls) -> a @@ ls'
_pick = (a @@ ls') -> Freer (ChoreoSig m) (a @@ ls')
forall a. a -> Freer (ChoreoSig m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure a @@ ls'
NP ((@) a) '[]
forall {k} (a :: k -> *). NP a '[]
Nil
    go (Proxy x
l :* NP Proxy xs
ls) forall a. (a @@ ls) -> a @@ ls'
pick = do
      a <- Proxy x -> (Unwrap x -> m a) -> Choreo m (a @ x)
forall (l :: Symbol) (m :: * -> *) a.
KnownSymbol l =>
Proxy l -> (Unwrap l -> m a) -> Choreo m (a @ l)
locally Proxy x
l \Unwrap x
unwrap -> Proxy x -> Unwraps ls -> m a
LocalComp ls m a
f Proxy x
l (Unwraps ls -> m a) -> Unwraps ls -> m a
forall a b. (a -> b) -> a -> b
$ (a @ x) -> a
Unwrap x
unwrap ((a @ x) -> a) -> ((a @@ ls) -> a @ x) -> (a @@ ls) -> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NP ((@) a) (x : xs) -> a @ x
forall {k} (f :: k -> *) (x :: k) (xs :: [k]). NP f (x : xs) -> f x
hd (NP ((@) a) (x : xs) -> a @ x)
-> ((a @@ ls) -> NP ((@) a) (x : xs)) -> (a @@ ls) -> a @ x
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a @@ ls) -> a @@ ls'
(a @@ ls) -> NP ((@) a) (x : xs)
forall a. (a @@ ls) -> a @@ ls'
pick
      (a :*) <$> go ls (tl . pick)

-- | A variant of `~>*` that sends the result of a local computation.
(~~>*) :: (KnownSymbol l, All KnownSymbol ls, Show a, Read a) => (Proxy l, Unwrap l -> m a) -> NP Proxy ls -> Choreo m (a @@ ls)
(Proxy l
l, Unwrap l -> m a
f) ~~>* :: forall (l :: Symbol) (ls :: [Symbol]) a (m :: * -> *).
(KnownSymbol l, All KnownSymbol ls, Show a, Read a) =>
(Proxy l, Unwrap l -> m a) -> NP Proxy ls -> Choreo m (a @@ ls)
~~>* NP Proxy ls
ls = do
  a <- Proxy l -> (Unwrap l -> m a) -> Choreo m (a @ l)
forall (l :: Symbol) (m :: * -> *) a.
KnownSymbol l =>
Proxy l -> (Unwrap l -> m a) -> Choreo m (a @ l)
locally Proxy l
l Unwrap l -> m a
f
  (l, a) ~>* ls

infix 4 ~~>*

-- | A variant of `*~>` that sends the result of a local computation.
(*~~>) :: (KnownSymbol l, All KnownSymbol ls, Applicative m, Show a, Read a) => (NP Proxy ls, LocalComp ls m a) -> Proxy l -> Choreo m (NP (K a) ls @ l)
(NP Proxy ls
ls, LocalComp ls m a
f) *~~> :: forall (l :: Symbol) (ls :: [Symbol]) (m :: * -> *) a.
(KnownSymbol l, All KnownSymbol ls, Applicative m, Show a,
 Read a) =>
(NP Proxy ls, LocalComp ls m a)
-> Proxy l -> Choreo m (NP (K a) ls @ l)
*~~> Proxy l
l = do
  a <- NP Proxy ls -> LocalComp ls m a -> Choreo m (a @@ ls)
forall (ls :: [Symbol]) (m :: * -> *) a.
All KnownSymbol ls =>
NP Proxy ls -> LocalComp ls m a -> Choreo m (a @@ ls)
multilocally NP Proxy ls
ls Proxy l -> Unwraps ls -> m a
LocalComp ls m a
f
  a *~> l

infix 4 *~~>