{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE ImpredicativeTypes #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE LinearTypes #-}
{-# LANGUAGE PartialTypeSignatures #-}
{-# LANGUAGE QualifiedDo #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ViewPatterns #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# OPTIONS_GHC -Wno-name-shadowing #-}
{-# OPTIONS_GHC -Wno-partial-type-signatures #-}

module Data.Unique.Linear (
  UniqueSource,
  new,
  sample,
  split,
  splitN,
  splitV,
  split3,
  split4,
  split5,
) where

import Control.Monad.Borrow.Pure.Lifetime.Token (Linearly)
import Data.Proxy (Proxy (Proxy))
import Data.V.Linear.Internal hiding (consume)
import Data.Vector qualified as V
import GHC.TypeNats (KnownNat, natVal)
import Prelude.Linear
import Unsafe.Linear qualified as Unsafe
import Prelude qualified as NL

data UniqueSource where
  -- | Seed, multiplier, constant.
  UniqueSource :: !Int %1 -> !Int %1 -> !Int %1 -> UniqueSource
  deriving (Int -> UniqueSource -> ShowS
[UniqueSource] -> ShowS
UniqueSource -> String
(Int -> UniqueSource -> ShowS)
-> (UniqueSource -> String)
-> ([UniqueSource] -> ShowS)
-> Show UniqueSource
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> UniqueSource -> ShowS
showsPrec :: Int -> UniqueSource -> ShowS
$cshow :: UniqueSource -> String
show :: UniqueSource -> String
$cshowList :: [UniqueSource] -> ShowS
showList :: [UniqueSource] -> ShowS
Show, UniqueSource -> UniqueSource -> Bool
(UniqueSource -> UniqueSource -> Bool)
-> (UniqueSource -> UniqueSource -> Bool) -> Eq UniqueSource
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: UniqueSource -> UniqueSource -> Bool
== :: UniqueSource -> UniqueSource -> Bool
$c/= :: UniqueSource -> UniqueSource -> Bool
/= :: UniqueSource -> UniqueSource -> Bool
NL.Eq, Eq UniqueSource
Eq UniqueSource =>
(UniqueSource -> UniqueSource -> Ordering)
-> (UniqueSource -> UniqueSource -> Bool)
-> (UniqueSource -> UniqueSource -> Bool)
-> (UniqueSource -> UniqueSource -> Bool)
-> (UniqueSource -> UniqueSource -> Bool)
-> (UniqueSource -> UniqueSource -> UniqueSource)
-> (UniqueSource -> UniqueSource -> UniqueSource)
-> Ord UniqueSource
UniqueSource -> UniqueSource -> Bool
UniqueSource -> UniqueSource -> Ordering
UniqueSource -> UniqueSource -> UniqueSource
forall a.
Eq a =>
(a -> a -> Ordering)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> Bool)
-> (a -> a -> a)
-> (a -> a -> a)
-> Ord a
$ccompare :: UniqueSource -> UniqueSource -> Ordering
compare :: UniqueSource -> UniqueSource -> Ordering
$c< :: UniqueSource -> UniqueSource -> Bool
< :: UniqueSource -> UniqueSource -> Bool
$c<= :: UniqueSource -> UniqueSource -> Bool
<= :: UniqueSource -> UniqueSource -> Bool
$c> :: UniqueSource -> UniqueSource -> Bool
> :: UniqueSource -> UniqueSource -> Bool
$c>= :: UniqueSource -> UniqueSource -> Bool
>= :: UniqueSource -> UniqueSource -> Bool
$cmax :: UniqueSource -> UniqueSource -> UniqueSource
max :: UniqueSource -> UniqueSource -> UniqueSource
$cmin :: UniqueSource -> UniqueSource -> UniqueSource
min :: UniqueSource -> UniqueSource -> UniqueSource
NL.Ord)

instance Consumable UniqueSource where
  consume :: UniqueSource %1 -> ()
consume (UniqueSource Int
seed Int
multiplier Int
constant) =
    Int
seed Int %1 -> () %1 -> ()
forall a b. Consumable a => a %1 -> b %1 -> b
`lseq`
      Int
multiplier Int %1 -> () %1 -> ()
forall a b. Consumable a => a %1 -> b %1 -> b
`lseq`
        Int %1 -> ()
forall a. Consumable a => a %1 -> ()
consume Int
constant
  {-# INLINE consume #-}

new :: Linearly %1 -> UniqueSource
new :: Linearly %1 -> UniqueSource
new Linearly
lin = Linearly
lin Linearly %1 -> UniqueSource %1 -> UniqueSource
forall a b. Consumable a => a %1 -> b %1 -> b
`lseq` Int -> Int -> Int -> UniqueSource
UniqueSource Int
0 Int
1 Int
0

sample :: UniqueSource %1 -> (Int, UniqueSource)
sample :: UniqueSource %1 -> (Int, UniqueSource)
sample =
  (UniqueSource -> (Int, UniqueSource))
%1 -> UniqueSource %1 -> (Int, UniqueSource)
forall a b (p :: Multiplicity) (x :: Multiplicity).
(a %p -> b) %1 -> a %x -> b
Unsafe.toLinear \(UniqueSource Int
x Int
a Int
b) ->
    (Int
x Int %1 -> Int %1 -> Int
forall a. Multiplicative a => a %1 -> a %1 -> a
* Int
a Int %1 -> Int %1 -> Int
forall a. Additive a => a %1 -> a %1 -> a
+ Int
b, Int -> Int -> Int -> UniqueSource
UniqueSource (Int
x Int %1 -> Int %1 -> Int
forall a. Additive a => a %1 -> a %1 -> a
+ Int
1) Int
a Int
b)

{- | Split a 'UniqueSource' into two, each with non-overlapping ranges.

See also 'splitN' and 'splitV'.
-}
split :: UniqueSource %1 -> (UniqueSource, UniqueSource)
split :: UniqueSource %1 -> (UniqueSource, UniqueSource)
split =
  (UniqueSource -> (UniqueSource, UniqueSource))
%1 -> UniqueSource %1 -> (UniqueSource, UniqueSource)
forall a b (p :: Multiplicity) (x :: Multiplicity).
(a %p -> b) %1 -> a %x -> b
Unsafe.toLinear \(UniqueSource Int
x Int
a Int
b) ->
    (Int
x Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Int
2) (Int, Int)
-> ((Int, Int) -> (UniqueSource, UniqueSource))
-> (UniqueSource, UniqueSource)
forall a b (p :: Multiplicity) (q :: Multiplicity).
a %p -> (a %p -> b) %q -> b
& \(Int
q, Int
r) ->
      if Int
r Int %1 -> Int %1 -> Bool
forall a. Eq a => a %1 -> a %1 -> Bool
== Int
0
        then
          ( Int -> Int -> Int -> UniqueSource
UniqueSource Int
q (Int
a Int %1 -> Int %1 -> Int
forall a. Multiplicative a => a %1 -> a %1 -> a
* Int
2) Int
b
          , Int -> Int -> Int -> UniqueSource
UniqueSource Int
q (Int
a Int %1 -> Int %1 -> Int
forall a. Multiplicative a => a %1 -> a %1 -> a
* Int
2) (Int
a Int %1 -> Int %1 -> Int
forall a. Additive a => a %1 -> a %1 -> a
+ Int
b)
          )
        else
          ( Int -> Int -> Int -> UniqueSource
UniqueSource Int
q (Int
a Int %1 -> Int %1 -> Int
forall a. Multiplicative a => a %1 -> a %1 -> a
* Int
2) (Int
a Int %1 -> Int %1 -> Int
forall a. Additive a => a %1 -> a %1 -> a
+ Int
b)
          , Int -> Int -> Int -> UniqueSource
UniqueSource (Int
q Int %1 -> Int %1 -> Int
forall a. Additive a => a %1 -> a %1 -> a
+ Int
1) (Int
a Int %1 -> Int %1 -> Int
forall a. Multiplicative a => a %1 -> a %1 -> a
* Int
2) Int
b
          )

splitN :: Int -> UniqueSource %1 -> V.Vector UniqueSource
{-# INLINE splitN #-}
splitN :: Int -> UniqueSource %1 -> Vector UniqueSource
splitN Int
n = (UniqueSource -> Vector UniqueSource)
%1 -> UniqueSource %1 -> Vector UniqueSource
forall a b (p :: Multiplicity) (x :: Multiplicity).
(a %p -> b) %1 -> a %x -> b
Unsafe.toLinear \(UniqueSource Int
x Int
a Int
b) ->
  let (Int
q, Int
r) = Int
x Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Int
n
   in Int -> (Int -> UniqueSource) -> Vector UniqueSource
forall a. Int -> (Int -> a) -> Vector a
V.generate Int
n \((Int %1 -> Int %1 -> Int
forall a. Additive a => a %1 -> a %1 -> a
+ Int
r) -> Int
i) ->
        let (Int
offx, Int
offb) = Int
i Int -> Int -> (Int, Int)
forall a. Integral a => a -> a -> (a, a)
`quotRem` Int
n
            !x :: Int
x = Int
q Int %1 -> Int %1 -> Int
forall a. Additive a => a %1 -> a %1 -> a
+ Int
offx
         in Int -> Int -> Int -> UniqueSource
UniqueSource Int
x (Int
a Int %1 -> Int %1 -> Int
forall a. Multiplicative a => a %1 -> a %1 -> a
* Int
n) (Int
b Int %1 -> Int %1 -> Int
forall a. Additive a => a %1 -> a %1 -> a
+ Int
a Int %1 -> Int %1 -> Int
forall a. Multiplicative a => a %1 -> a %1 -> a
* Int
offb)

splitV :: forall n. (KnownNat n) => UniqueSource %1 -> V n UniqueSource
{-# INLINE splitV #-}
splitV :: forall (n :: Nat).
KnownNat n =>
UniqueSource %1 -> V n UniqueSource
splitV = Vector UniqueSource -> V n UniqueSource
forall (n :: Nat) a. Vector a -> V n a
V (Vector UniqueSource %1 -> V n UniqueSource)
-> (UniqueSource %1 -> Vector UniqueSource)
-> UniqueSource
%1 -> V n UniqueSource
forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. Int -> UniqueSource %1 -> Vector UniqueSource
splitN (Nat -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> Int) -> Nat -> Int
forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
$ Proxy n -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (Proxy n -> Nat) -> Proxy n -> Nat
forall a b (p :: Multiplicity) (q :: Multiplicity).
(a %p -> b) %q -> a %p -> b
$ forall (t :: Nat). Proxy t
forall {k} (t :: k). Proxy t
Proxy @n)

split3 :: UniqueSource -> (UniqueSource, UniqueSource, UniqueSource)
split3 :: UniqueSource -> (UniqueSource, UniqueSource, UniqueSource)
split3 = FunN
  (NatToPeano 3)
  UniqueSource
  (UniqueSource, UniqueSource, UniqueSource)
%1 -> V 3 UniqueSource
%1 -> (UniqueSource, UniqueSource, UniqueSource)
forall (n :: Nat) a b f.
(n ~ PeanoToNat (NatToPeano n), Elim (NatToPeano n) a b,
 IsFunN a b f, f ~ FunN (NatToPeano n) a b, n ~ Arity b f) =>
f %1 -> V n a %1 -> b
elim (,,) (V 3 UniqueSource %1 -> (UniqueSource, UniqueSource, UniqueSource))
-> (UniqueSource %1 -> V 3 UniqueSource)
-> UniqueSource
-> (UniqueSource, UniqueSource, UniqueSource)
forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. UniqueSource %1 -> V 3 UniqueSource
forall (n :: Nat).
KnownNat n =>
UniqueSource %1 -> V n UniqueSource
splitV

split4 :: UniqueSource -> (UniqueSource, UniqueSource, UniqueSource, UniqueSource)
split4 :: UniqueSource
-> (UniqueSource, UniqueSource, UniqueSource, UniqueSource)
split4 = FunN
  (NatToPeano 4)
  UniqueSource
  (UniqueSource, UniqueSource, UniqueSource, UniqueSource)
%1 -> V 4 UniqueSource
%1 -> (UniqueSource, UniqueSource, UniqueSource, UniqueSource)
forall (n :: Nat) a b f.
(n ~ PeanoToNat (NatToPeano n), Elim (NatToPeano n) a b,
 IsFunN a b f, f ~ FunN (NatToPeano n) a b, n ~ Arity b f) =>
f %1 -> V n a %1 -> b
elim (,,,) (V 4 UniqueSource
 %1 -> (UniqueSource, UniqueSource, UniqueSource, UniqueSource))
-> (UniqueSource %1 -> V 4 UniqueSource)
-> UniqueSource
-> (UniqueSource, UniqueSource, UniqueSource, UniqueSource)
forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. UniqueSource %1 -> V 4 UniqueSource
forall (n :: Nat).
KnownNat n =>
UniqueSource %1 -> V n UniqueSource
splitV

split5 :: UniqueSource -> (UniqueSource, UniqueSource, UniqueSource, UniqueSource, UniqueSource)
split5 :: UniqueSource
-> (UniqueSource, UniqueSource, UniqueSource, UniqueSource,
    UniqueSource)
split5 = FunN
  (NatToPeano 5)
  UniqueSource
  (UniqueSource, UniqueSource, UniqueSource, UniqueSource,
   UniqueSource)
%1 -> V 5 UniqueSource
%1 -> (UniqueSource, UniqueSource, UniqueSource, UniqueSource,
       UniqueSource)
forall (n :: Nat) a b f.
(n ~ PeanoToNat (NatToPeano n), Elim (NatToPeano n) a b,
 IsFunN a b f, f ~ FunN (NatToPeano n) a b, n ~ Arity b f) =>
f %1 -> V n a %1 -> b
elim (,,,,) (V 5 UniqueSource
 %1 -> (UniqueSource, UniqueSource, UniqueSource, UniqueSource,
        UniqueSource))
-> (UniqueSource %1 -> V 5 UniqueSource)
-> UniqueSource
-> (UniqueSource, UniqueSource, UniqueSource, UniqueSource,
    UniqueSource)
forall b c a (q :: Multiplicity) (m :: Multiplicity)
       (n :: Multiplicity).
(b %1 -> c) %q -> (a %1 -> b) %m -> a %n -> c
. UniqueSource %1 -> V 5 UniqueSource
forall (n :: Nat).
KnownNat n =>
UniqueSource %1 -> V n UniqueSource
splitV