{-# language ConstraintKinds #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Data.SRTree.Random 
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :  ConstraintKinds
--
-- Functions to generate random trees and nodes.
--
-----------------------------------------------------------------------------
module Data.SRTree.Random 
         ( HasVars
         , HasVals
         , HasFuns
         , HasEverything
         , FullParams(..)
         , RndTree
         , Rng(..)
         , randomVar
         , randomConst
         , randomPow
         , randomFunction
         , randomNode
         , randomNonTerminal
         , randomRange
         , randomTreeTemplate
         , randomTree
         , randomTreeBalanced
         , toss
         , tossBiased
         , randomVal
         , randomVec
         , randomFrom
         )
         where

import Control.Monad.Reader (ReaderT, asks, runReaderT)
import Control.Monad.State.Strict ( MonadState(state), MonadTrans(lift), StateT )
import Data.Maybe (fromJust)
import Data.SRTree.Internal
import System.Random (Random (random, randomR), StdGen, mkStdGen)
import Data.Massiv.Array as MA hiding (forM_, forM, P)
import Data.SRTree.Eval
import Control.Monad


-- * Class definition of properties that a certain parameter type has.
--
-- HasVars: does `p` provides a list of the variable indices?
-- HasVals: does `p` provides a range of values for the constants?
-- HasExps: does `p` provides a range for the integral exponentes?
-- HasFuns: does `p` provides a list of allowed functions?
class HasVars p where
  _vars :: p -> [Int]
class HasVals p where
  _range :: p -> (Double, Double)
class HasExps p where
  _exponents :: p -> (Int, Int)
class HasFuns p where
  _funs :: p -> [Function]

-- | Constraint synonym for all properties.
type HasEverything p = (HasVars p, HasVals p, HasExps p, HasFuns p)

-- | A structure with every property
data FullParams = P [Int] (Double, Double) (Int, Int) [Function]

instance HasVars FullParams where
  _vars :: FullParams -> [Int]
_vars (P [Int]
ixs (Double, Double)
_ (Int, Int)
_ [Function]
_) = [Int]
ixs
instance HasVals FullParams where
  _range :: FullParams -> (Double, Double)
_range (P [Int]
_ (Double, Double)
r (Int, Int)
_ [Function]
_) = (Double, Double)
r
instance HasExps FullParams where
  _exponents :: FullParams -> (Int, Int)
_exponents (P [Int]
_ (Double, Double)
_ (Int, Int)
e [Function]
_) = (Int, Int)
e
instance HasFuns FullParams where
  _funs :: FullParams -> [Function]
_funs (P [Int]
_ (Double, Double)
_ (Int, Int)
_ [Function]
fs) = [Function]
fs

type Rng m a = StateT StdGen m a

-- auxiliary function to sample between False and True
toss :: Monad m => Rng m Bool
toss :: forall (m :: * -> *). Monad m => Rng m Bool
toss = (StdGen -> (Bool, StdGen)) -> StateT StdGen m Bool
forall a. (StdGen -> (a, StdGen)) -> StateT StdGen m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state StdGen -> (Bool, StdGen)
forall g. RandomGen g => g -> (Bool, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
random
{-# INLINE toss #-}

tossBiased :: Monad m => Double -> Rng m Bool
tossBiased :: forall (m :: * -> *). Monad m => Double -> Rng m Bool
tossBiased Double
p = do Double
r <- (StdGen -> (Double, StdGen)) -> StateT StdGen m Double
forall a. (StdGen -> (a, StdGen)) -> StateT StdGen m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state StdGen -> (Double, StdGen)
forall g. RandomGen g => g -> (Double, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
random
                  Bool -> Rng m Bool
forall a. a -> StateT StdGen m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Double
r Double -> Double -> Bool
forall a. Ord a => a -> a -> Bool
< Double
p)

randomVal :: Monad m => Rng m Double
randomVal :: forall (m :: * -> *). Monad m => Rng m Double
randomVal = (StdGen -> (Double, StdGen)) -> StateT StdGen m Double
forall a. (StdGen -> (a, StdGen)) -> StateT StdGen m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state StdGen -> (Double, StdGen)
forall g. RandomGen g => g -> (Double, g)
forall a g. (Random a, RandomGen g) => g -> (a, g)
random

-- returns a random element of a list
randomFrom :: Monad m => [a] -> Rng m a
randomFrom :: forall (m :: * -> *) a. Monad m => [a] -> Rng m a
randomFrom [a]
funs = do Int
n <- (Int, Int) -> Rng m Int
forall val (m :: * -> *).
(Ord val, Random val, Monad m) =>
(val, val) -> Rng m val
randomRange (Int
0, [a] -> Int
forall a. [a] -> Int
forall (t :: * -> *) a. Foldable t => t a -> Int
length [a]
funs Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
                     a -> Rng m a
forall a. a -> StateT StdGen m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (a -> Rng m a) -> a -> Rng m a
forall a b. (a -> b) -> a -> b
$ [a]
funs [a] -> Int -> a
forall a. HasCallStack => [a] -> Int -> a
!! Int
n
{-# INLINE randomFrom #-}

-- returns a random element within a range
randomRange :: (Ord val, Random val, Monad m) => (val, val) -> Rng m val
randomRange :: forall val (m :: * -> *).
(Ord val, Random val, Monad m) =>
(val, val) -> Rng m val
randomRange (val, val)
rng = (StdGen -> (val, StdGen)) -> StateT StdGen m val
forall a. (StdGen -> (a, StdGen)) -> StateT StdGen m a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state ((val, val) -> StdGen -> (val, StdGen)
forall g. RandomGen g => (val, val) -> g -> (val, g)
forall a g. (Random a, RandomGen g) => (a, a) -> g -> (a, g)
randomR (val, val)
rng)
{-# INLINE randomRange #-}

-- Replace the child of a unary tree.
replaceChild :: Fix SRTree -> Fix SRTree -> Maybe (Fix SRTree)
replaceChild :: Fix SRTree -> Fix SRTree -> Maybe (Fix SRTree)
replaceChild (Fix (Uni Function
f Fix SRTree
_)) Fix SRTree
t = Fix SRTree -> Maybe (Fix SRTree)
forall a. a -> Maybe a
Just (Fix SRTree -> Maybe (Fix SRTree))
-> Fix SRTree -> Maybe (Fix SRTree)
forall a b. (a -> b) -> a -> b
$ SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
f Fix SRTree
t)
replaceChild Fix SRTree
_         Fix SRTree
_ = Maybe (Fix SRTree)
forall a. Maybe a
Nothing 
{-# INLINE replaceChild #-}

-- Replace the children of a binary tree.
replaceFixChildren :: Fix SRTree -> Fix SRTree -> Fix SRTree -> Maybe (Fix SRTree)
replaceFixChildren :: Fix SRTree -> Fix SRTree -> Fix SRTree -> Maybe (Fix SRTree)
replaceFixChildren (Fix (Bin Op
f Fix SRTree
_ Fix SRTree
_)) Fix SRTree
l Fix SRTree
r = Fix SRTree -> Maybe (Fix SRTree)
forall a. a -> Maybe a
Just (Fix SRTree -> Maybe (Fix SRTree))
-> Fix SRTree -> Maybe (Fix SRTree)
forall a b. (a -> b) -> a -> b
$ SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
f Fix SRTree
l Fix SRTree
r)
replaceFixChildren Fix SRTree
_             Fix SRTree
_ Fix SRTree
_ = Maybe (Fix SRTree)
forall a. Maybe a
Nothing
{-# INLINE replaceFixChildren #-}

-- | RndTree is a Monad Transformer to generate random trees of type `SRTree ix val` 
-- given the parameters `p ix val` using the random number generator `StdGen`.
type RndTree m p = ReaderT p (StateT StdGen m) (Fix SRTree)

-- | Returns a random variable, the parameter `p` must have the `HasVars` property
randomVar :: Monad m => HasVars p => RndTree m p
randomVar :: forall (m :: * -> *) p. (Monad m, HasVars p) => RndTree m p
randomVar = do [Int]
vars <- (p -> [Int]) -> ReaderT p (StateT StdGen m) [Int]
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks p -> [Int]
forall p. HasVars p => p -> [Int]
_vars
               StateT StdGen m (Fix SRTree) -> RndTree m p
forall (m :: * -> *) a. Monad m => m a -> ReaderT p m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT StdGen m (Fix SRTree) -> RndTree m p)
-> StateT StdGen m (Fix SRTree) -> RndTree m p
forall a b. (a -> b) -> a -> b
$ SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Int -> SRTree (Fix SRTree)) -> Int -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> SRTree (Fix SRTree)
forall val. Int -> SRTree val
Var (Int -> Fix SRTree)
-> StateT StdGen m Int -> StateT StdGen m (Fix SRTree)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Int] -> StateT StdGen m Int
forall (m :: * -> *) a. Monad m => [a] -> Rng m a
randomFrom [Int]
vars

-- | Returns a random constant, the parameter `p` must have the `HasConst` property
randomConst :: (HasVals p, Monad m) => RndTree m p
randomConst :: forall p (m :: * -> *). (HasVals p, Monad m) => RndTree m p
randomConst = do (Double, Double)
rng <- (p -> (Double, Double))
-> ReaderT p (StateT StdGen m) (Double, Double)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks p -> (Double, Double)
forall p. HasVals p => p -> (Double, Double)
_range
                 StateT StdGen m (Fix SRTree) -> RndTree m p
forall (m :: * -> *) a. Monad m => m a -> ReaderT p m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT StdGen m (Fix SRTree) -> RndTree m p)
-> StateT StdGen m (Fix SRTree) -> RndTree m p
forall a b. (a -> b) -> a -> b
$ SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Double -> SRTree (Fix SRTree)) -> Double -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const (Double -> Fix SRTree)
-> StateT StdGen m Double -> StateT StdGen m (Fix SRTree)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Double, Double) -> StateT StdGen m Double
forall val (m :: * -> *).
(Ord val, Random val, Monad m) =>
(val, val) -> Rng m val
randomRange (Double, Double)
rng

-- | Returns a random integer power node, the parameter `p` must have the `HasExps` property
randomPow :: (HasExps p, Monad m) => RndTree m p
randomPow :: forall p (m :: * -> *). (HasExps p, Monad m) => RndTree m p
randomPow = do (Int, Int)
rng <- (p -> (Int, Int)) -> ReaderT p (StateT StdGen m) (Int, Int)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks p -> (Int, Int)
forall p. HasExps p => p -> (Int, Int)
_exponents
               StateT StdGen m (Fix SRTree) -> RndTree m p
forall (m :: * -> *) a. Monad m => m a -> ReaderT p m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT StdGen m (Fix SRTree) -> RndTree m p)
-> StateT StdGen m (Fix SRTree) -> RndTree m p
forall a b. (a -> b) -> a -> b
$ SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Int -> SRTree (Fix SRTree)) -> Int -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
Power Fix SRTree
0 (Fix SRTree -> SRTree (Fix SRTree))
-> (Int -> Fix SRTree) -> Int -> SRTree (Fix SRTree)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Int -> SRTree (Fix SRTree)) -> Int -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Double -> SRTree (Fix SRTree)
forall val. Double -> SRTree val
Const (Double -> SRTree (Fix SRTree))
-> (Int -> Double) -> Int -> SRTree (Fix SRTree)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Double
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> Fix SRTree)
-> StateT StdGen m Int -> StateT StdGen m (Fix SRTree)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> (Int, Int) -> StateT StdGen m Int
forall val (m :: * -> *).
(Ord val, Random val, Monad m) =>
(val, val) -> Rng m val
randomRange (Int, Int)
rng

-- | Returns a random function, the parameter `p` must have the `HasFuns` property
randomFunction :: (HasFuns p, Monad m) => RndTree m p
randomFunction :: forall p (m :: * -> *). (HasFuns p, Monad m) => RndTree m p
randomFunction = do [Function]
funs <- (p -> [Function]) -> ReaderT p (StateT StdGen m) [Function]
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks p -> [Function]
forall p. HasFuns p => p -> [Function]
_funs
                    Function
f <- StateT StdGen m Function -> ReaderT p (StateT StdGen m) Function
forall (m :: * -> *) a. Monad m => m a -> ReaderT p m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT StdGen m Function -> ReaderT p (StateT StdGen m) Function)
-> StateT StdGen m Function -> ReaderT p (StateT StdGen m) Function
forall a b. (a -> b) -> a -> b
$ [Function] -> StateT StdGen m Function
forall (m :: * -> *) a. Monad m => [a] -> Rng m a
randomFrom [Function]
funs
                    StateT StdGen m (Fix SRTree) -> RndTree m p
forall (m :: * -> *) a. Monad m => m a -> ReaderT p m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT StdGen m (Fix SRTree) -> RndTree m p)
-> StateT StdGen m (Fix SRTree) -> RndTree m p
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> StateT StdGen m (Fix SRTree)
forall a. a -> StateT StdGen m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> StateT StdGen m (Fix SRTree))
-> Fix SRTree -> StateT StdGen m (Fix SRTree)
forall a b. (a -> b) -> a -> b
$ SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
f Fix SRTree
0)

-- | Returns a random node, the parameter `p` must have every property.
randomNode :: (HasEverything p, Monad m) => RndTree m p
randomNode :: forall p (m :: * -> *). (HasEverything p, Monad m) => RndTree m p
randomNode = do
  Int
choice <- StateT StdGen m Int -> ReaderT p (StateT StdGen m) Int
forall (m :: * -> *) a. Monad m => m a -> ReaderT p m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT StdGen m Int -> ReaderT p (StateT StdGen m) Int)
-> StateT StdGen m Int -> ReaderT p (StateT StdGen m) Int
forall a b. (a -> b) -> a -> b
$ (Int, Int) -> StateT StdGen m Int
forall val (m :: * -> *).
(Ord val, Random val, Monad m) =>
(val, val) -> Rng m val
randomRange (Int
0, Int
8 :: Int)
  case Int
choice of
    Int
0 -> RndTree m p
forall (m :: * -> *) p. (Monad m, HasVars p) => RndTree m p
randomVar
    Int
1 -> RndTree m p
forall p (m :: * -> *). (HasVals p, Monad m) => RndTree m p
randomConst
    Int
2 -> RndTree m p
forall p (m :: * -> *). (HasFuns p, Monad m) => RndTree m p
randomFunction
    Int
3 -> RndTree m p
forall p (m :: * -> *). (HasExps p, Monad m) => RndTree m p
randomPow
    Int
4 -> Fix SRTree -> RndTree m p
forall a. a -> ReaderT p (StateT StdGen m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> RndTree m p)
-> (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree)
-> RndTree m p
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> RndTree m p)
-> SRTree (Fix SRTree) -> RndTree m p
forall a b. (a -> b) -> a -> b
$ Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
Add Fix SRTree
0 Fix SRTree
0
    Int
5 -> Fix SRTree -> RndTree m p
forall a. a -> ReaderT p (StateT StdGen m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> RndTree m p)
-> (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree)
-> RndTree m p
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> RndTree m p)
-> SRTree (Fix SRTree) -> RndTree m p
forall a b. (a -> b) -> a -> b
$ Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
Sub Fix SRTree
0 Fix SRTree
0
    Int
6 -> Fix SRTree -> RndTree m p
forall a. a -> ReaderT p (StateT StdGen m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> RndTree m p)
-> (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree)
-> RndTree m p
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> RndTree m p)
-> SRTree (Fix SRTree) -> RndTree m p
forall a b. (a -> b) -> a -> b
$ Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Fix SRTree
0 Fix SRTree
0
    Int
7 -> Fix SRTree -> RndTree m p
forall a. a -> ReaderT p (StateT StdGen m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> RndTree m p)
-> (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree)
-> RndTree m p
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> RndTree m p)
-> SRTree (Fix SRTree) -> RndTree m p
forall a b. (a -> b) -> a -> b
$ Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
Div Fix SRTree
0 Fix SRTree
0
    Int
8 -> Fix SRTree -> RndTree m p
forall a. a -> ReaderT p (StateT StdGen m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> RndTree m p)
-> (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree)
-> RndTree m p
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> RndTree m p)
-> SRTree (Fix SRTree) -> RndTree m p
forall a b. (a -> b) -> a -> b
$ Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
Power Fix SRTree
0 Fix SRTree
0

-- | Returns a random non-terminal node, the parameter `p` must have every property.
randomNonTerminal :: (HasEverything p, Monad m) => RndTree m p
randomNonTerminal :: forall p (m :: * -> *). (HasEverything p, Monad m) => RndTree m p
randomNonTerminal = do
  Int
choice <- StateT StdGen m Int -> ReaderT p (StateT StdGen m) Int
forall (m :: * -> *) a. Monad m => m a -> ReaderT p m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (StateT StdGen m Int -> ReaderT p (StateT StdGen m) Int)
-> StateT StdGen m Int -> ReaderT p (StateT StdGen m) Int
forall a b. (a -> b) -> a -> b
$ (Int, Int) -> StateT StdGen m Int
forall val (m :: * -> *).
(Ord val, Random val, Monad m) =>
(val, val) -> Rng m val
randomRange (Int
0, Int
6 :: Int)
  case Int
choice of
    Int
0 -> RndTree m p
forall p (m :: * -> *). (HasFuns p, Monad m) => RndTree m p
randomFunction
    Int
1 -> RndTree m p
forall p (m :: * -> *). (HasExps p, Monad m) => RndTree m p
randomPow
    Int
2 -> Fix SRTree -> RndTree m p
forall a. a -> ReaderT p (StateT StdGen m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> RndTree m p)
-> (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree)
-> RndTree m p
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> RndTree m p)
-> SRTree (Fix SRTree) -> RndTree m p
forall a b. (a -> b) -> a -> b
$ Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
Add Fix SRTree
0 Fix SRTree
0
    Int
3 -> Fix SRTree -> RndTree m p
forall a. a -> ReaderT p (StateT StdGen m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> RndTree m p)
-> (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree)
-> RndTree m p
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> RndTree m p)
-> SRTree (Fix SRTree) -> RndTree m p
forall a b. (a -> b) -> a -> b
$ Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
Sub Fix SRTree
0 Fix SRTree
0
    Int
4 -> Fix SRTree -> RndTree m p
forall a. a -> ReaderT p (StateT StdGen m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> RndTree m p)
-> (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree)
-> RndTree m p
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> RndTree m p)
-> SRTree (Fix SRTree) -> RndTree m p
forall a b. (a -> b) -> a -> b
$ Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
Mul Fix SRTree
0 Fix SRTree
0
    Int
5 -> Fix SRTree -> RndTree m p
forall a. a -> ReaderT p (StateT StdGen m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> RndTree m p)
-> (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree)
-> RndTree m p
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> RndTree m p)
-> SRTree (Fix SRTree) -> RndTree m p
forall a b. (a -> b) -> a -> b
$ Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
Div Fix SRTree
0 Fix SRTree
0
    Int
6 -> Fix SRTree -> RndTree m p
forall a. a -> ReaderT p (StateT StdGen m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> RndTree m p)
-> (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree)
-> RndTree m p
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> RndTree m p)
-> SRTree (Fix SRTree) -> RndTree m p
forall a b. (a -> b) -> a -> b
$ Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
Power Fix SRTree
0 Fix SRTree
0
    
-- | Returns a random tree with a limited budget, the parameter `p` must have every property.
--
-- >>> let treeGen = runReaderT (randomTree 12) (P [0,1] (-10, 10) (2, 3) [Log, Exp])
-- >>> tree <- evalStateT treeGen (mkStdGen 52)
-- >>> showExpr tree
-- "(-2.7631152121655838 / Exp((x0 / ((x0 * -7.681722660704317) - Log(3.378309080134594)))))"
randomTreeTemplate :: (HasEverything p, Monad m) => Int -> RndTree m p
randomTreeTemplate :: forall p (m :: * -> *).
(HasEverything p, Monad m) =>
Int -> RndTree m p
randomTreeTemplate Int
0      = do
  Bool
coin <- StateT StdGen m Bool -> ReaderT p (StateT StdGen m) Bool
forall (m :: * -> *) a. Monad m => m a -> ReaderT p m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift StateT StdGen m Bool
forall (m :: * -> *). Monad m => Rng m Bool
toss
  if Bool
coin
    then RndTree m p
forall (m :: * -> *) p. (Monad m, HasVars p) => RndTree m p
randomVar
    else RndTree m p
forall p (m :: * -> *). (HasVals p, Monad m) => RndTree m p
randomConst
randomTreeTemplate Int
budget = do
  Fix SRTree
node  <- RndTree m p
forall p (m :: * -> *). (HasEverything p, Monad m) => RndTree m p
randomNode
  Maybe (Fix SRTree) -> Fix SRTree
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (Fix SRTree) -> Fix SRTree)
-> ReaderT p (StateT StdGen m) (Maybe (Fix SRTree)) -> RndTree m p
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case Fix SRTree -> Int
arity Fix SRTree
node of
    Int
0 -> Maybe (Fix SRTree)
-> ReaderT p (StateT StdGen m) (Maybe (Fix SRTree))
forall a. a -> ReaderT p (StateT StdGen m) a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Maybe (Fix SRTree)
 -> ReaderT p (StateT StdGen m) (Maybe (Fix SRTree)))
-> Maybe (Fix SRTree)
-> ReaderT p (StateT StdGen m) (Maybe (Fix SRTree))
forall a b. (a -> b) -> a -> b
$ Fix SRTree -> Maybe (Fix SRTree)
forall a. a -> Maybe a
Just Fix SRTree
node
    Int
1 -> Fix SRTree -> Fix SRTree -> Maybe (Fix SRTree)
replaceChild Fix SRTree
node (Fix SRTree -> Maybe (Fix SRTree))
-> RndTree m p -> ReaderT p (StateT StdGen m) (Maybe (Fix SRTree))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> RndTree m p
forall p (m :: * -> *).
(HasEverything p, Monad m) =>
Int -> RndTree m p
randomTreeTemplate (Int
budget Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    Int
2 -> Fix SRTree -> Fix SRTree -> Fix SRTree -> Maybe (Fix SRTree)
replaceFixChildren Fix SRTree
node (Fix SRTree -> Fix SRTree -> Maybe (Fix SRTree))
-> RndTree m p
-> ReaderT p (StateT StdGen m) (Fix SRTree -> Maybe (Fix SRTree))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> RndTree m p
forall p (m :: * -> *).
(HasEverything p, Monad m) =>
Int -> RndTree m p
randomTreeTemplate (Int
budget Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) ReaderT p (StateT StdGen m) (Fix SRTree -> Maybe (Fix SRTree))
-> RndTree m p -> ReaderT p (StateT StdGen m) (Maybe (Fix SRTree))
forall a b.
ReaderT p (StateT StdGen m) (a -> b)
-> ReaderT p (StateT StdGen m) a -> ReaderT p (StateT StdGen m) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> RndTree m p
forall p (m :: * -> *).
(HasEverything p, Monad m) =>
Int -> RndTree m p
randomTreeTemplate (Int
budget Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2)
    
-- | Returns a random tree with a approximately a number `n` of nodes, the parameter `p` must have every property.
--
-- >>> let treeGen = runReaderT (randomTreeBalanced 10) (P [0,1] (-10, 10) (2, 3) [Log, Exp])
-- >>> tree <- evalStateT treeGen (mkStdGen 42)
-- >>> showExpr tree
-- "Exp(Log((((7.784360517385774 * x0) - (3.6412224491658223 ^ x1)) ^ ((x0 ^ -4.09764995657091) + Log(-7.710216839988497)))))"
randomTreeBalanced :: (HasEverything p, Monad m) => Int -> RndTree m p
randomTreeBalanced :: forall p (m :: * -> *).
(HasEverything p, Monad m) =>
Int -> RndTree m p
randomTreeBalanced Int
n | Int
n Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 = do
  Bool
coin <- StateT StdGen m Bool -> ReaderT p (StateT StdGen m) Bool
forall (m :: * -> *) a. Monad m => m a -> ReaderT p m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift StateT StdGen m Bool
forall (m :: * -> *). Monad m => Rng m Bool
toss
  if Bool
coin
    then RndTree m p
forall (m :: * -> *) p. (Monad m, HasVars p) => RndTree m p
randomVar
    else RndTree m p
forall p (m :: * -> *). (HasVals p, Monad m) => RndTree m p
randomConst
randomTreeBalanced Int
n = do 
  Fix SRTree
node  <- RndTree m p
forall p (m :: * -> *). (HasEverything p, Monad m) => RndTree m p
randomNonTerminal
  Maybe (Fix SRTree) -> Fix SRTree
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (Fix SRTree) -> Fix SRTree)
-> ReaderT p (StateT StdGen m) (Maybe (Fix SRTree)) -> RndTree m p
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> case Fix SRTree -> Int
arity Fix SRTree
node of
    Int
1 -> Fix SRTree -> Fix SRTree -> Maybe (Fix SRTree)
replaceChild Fix SRTree
node (Fix SRTree -> Maybe (Fix SRTree))
-> RndTree m p -> ReaderT p (StateT StdGen m) (Maybe (Fix SRTree))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> RndTree m p
forall p (m :: * -> *).
(HasEverything p, Monad m) =>
Int -> RndTree m p
randomTreeBalanced (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)
    Int
2 -> Fix SRTree -> Fix SRTree -> Fix SRTree -> Maybe (Fix SRTree)
replaceFixChildren Fix SRTree
node (Fix SRTree -> Fix SRTree -> Maybe (Fix SRTree))
-> RndTree m p
-> ReaderT p (StateT StdGen m) (Fix SRTree -> Maybe (Fix SRTree))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> RndTree m p
forall p (m :: * -> *).
(HasEverything p, Monad m) =>
Int -> RndTree m p
randomTreeBalanced (Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) ReaderT p (StateT StdGen m) (Fix SRTree -> Maybe (Fix SRTree))
-> RndTree m p -> ReaderT p (StateT StdGen m) (Maybe (Fix SRTree))
forall a b.
ReaderT p (StateT StdGen m) (a -> b)
-> ReaderT p (StateT StdGen m) a -> ReaderT p (StateT StdGen m) b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Int -> RndTree m p
forall p (m :: * -> *).
(HasEverything p, Monad m) =>
Int -> RndTree m p
randomTreeBalanced (Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2)    


randomVec :: Monad m => Int -> Rng m PVector
randomVec :: forall (m :: * -> *). Monad m => Int -> Rng m PVector
randomVec Int
n = Comp -> [Double] -> PVector
forall r e. Manifest r e => Comp -> [e] -> Vector r e
MA.fromList Comp
compMode ([Double] -> PVector)
-> StateT StdGen m [Double] -> StateT StdGen m PVector
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> StateT StdGen m Double -> StateT StdGen m [Double]
forall (m :: * -> *) a. Applicative m => Int -> m a -> m [a]
replicateM Int
n ((Double, Double) -> StateT StdGen m Double
forall val (m :: * -> *).
(Ord val, Random val, Monad m) =>
(val, val) -> Rng m val
randomRange (-Double
1, Double
1))

randomTree :: Monad m => Int -> Int -> Int -> Rng m (Fix SRTree) -> Rng m (SRTree ()) -> Bool -> Rng  m (Fix SRTree)
randomTree :: forall (m :: * -> *).
Monad m =>
Int
-> Int
-> Int
-> Rng m (Fix SRTree)
-> Rng m (SRTree ())
-> Bool
-> Rng m (Fix SRTree)
randomTree Int
minDepth Int
maxDepth Int
maxSize Rng m (Fix SRTree)
genTerm Rng m (SRTree ())
genNonTerm Bool
grow
  | Bool
noSpaceLeft = Rng m (Fix SRTree)
genTerm
  | Bool
needNonTerm = Rng m (Fix SRTree)
genRecursion
  | Bool
otherwise   = do Bool
r <- Rng m Bool
forall (m :: * -> *). Monad m => Rng m Bool
toss
                     if Bool
r
                       then Rng m (Fix SRTree)
genTerm
                       else Rng m (Fix SRTree)
genRecursion
  where
    noSpaceLeft :: Bool
noSpaceLeft = Int
maxDepth Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
1 Bool -> Bool -> Bool
|| Int
maxSize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
2
    needNonTerm :: Bool
needNonTerm = (Int
minDepth Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
0 Bool -> Bool -> Bool
|| (Int
maxDepth Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int
2 Bool -> Bool -> Bool
&& Bool -> Bool
not Bool
grow)) -- && maxSize > 2

    genRecursion :: Rng m (Fix SRTree)
genRecursion = do
        SRTree ()
node <- Rng m (SRTree ())
genNonTerm
        case SRTree ()
node of
          Uni Function
f ()
_    -> SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix (SRTree (Fix SRTree) -> Fix SRTree)
-> (Fix SRTree -> SRTree (Fix SRTree)) -> Fix SRTree -> Fix SRTree
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Function -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Function -> val -> SRTree val
Uni Function
f (Fix SRTree -> Fix SRTree)
-> Rng m (Fix SRTree) -> Rng m (Fix SRTree)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int
-> Int
-> Int
-> Rng m (Fix SRTree)
-> Rng m (SRTree ())
-> Bool
-> Rng m (Fix SRTree)
forall (m :: * -> *).
Monad m =>
Int
-> Int
-> Int
-> Rng m (Fix SRTree)
-> Rng m (SRTree ())
-> Bool
-> Rng m (Fix SRTree)
randomTree (Int
minDepth Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int
maxDepth Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int
maxSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) Rng m (Fix SRTree)
genTerm Rng m (SRTree ())
genNonTerm Bool
grow
          Bin Op
op ()
_ ()
_ -> do Fix SRTree
l <- Int
-> Int
-> Int
-> Rng m (Fix SRTree)
-> Rng m (SRTree ())
-> Bool
-> Rng m (Fix SRTree)
forall (m :: * -> *).
Monad m =>
Int
-> Int
-> Int
-> Rng m (Fix SRTree)
-> Rng m (SRTree ())
-> Bool
-> Rng m (Fix SRTree)
randomTree (Int
minDepth Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int
maxDepth Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (if Bool
grow then Int
maxSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
2 else Int
maxSize Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) Rng m (Fix SRTree)
genTerm Rng m (SRTree ())
genNonTerm Bool
grow
                           Fix SRTree
r <- Int
-> Int
-> Int
-> Rng m (Fix SRTree)
-> Rng m (SRTree ())
-> Bool
-> Rng m (Fix SRTree)
forall (m :: * -> *).
Monad m =>
Int
-> Int
-> Int
-> Rng m (Fix SRTree)
-> Rng m (SRTree ())
-> Bool
-> Rng m (Fix SRTree)
randomTree (Int
minDepth Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int
maxDepth Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1) (Int
maxSize Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Fix SRTree -> Int
forall a. Num a => Fix SRTree -> a
countNodes Fix SRTree
l) Rng m (Fix SRTree)
genTerm Rng m (SRTree ())
genNonTerm Bool
grow
                           Fix SRTree -> Rng m (Fix SRTree)
forall a. a -> StateT StdGen m a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (Fix SRTree -> Rng m (Fix SRTree))
-> (SRTree (Fix SRTree) -> Fix SRTree)
-> SRTree (Fix SRTree)
-> Rng m (Fix SRTree)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SRTree (Fix SRTree) -> Fix SRTree
forall (f :: * -> *). f (Fix f) -> Fix f
Fix  (SRTree (Fix SRTree) -> Rng m (Fix SRTree))
-> SRTree (Fix SRTree) -> Rng m (Fix SRTree)
forall a b. (a -> b) -> a -> b
$ Op -> Fix SRTree -> Fix SRTree -> SRTree (Fix SRTree)
forall val. Op -> val -> val -> SRTree val
Bin Op
op Fix SRTree
l Fix SRTree
r
{-# INLINE randomTree #-}