{-# language ImportQualifiedPost #-}
{-# language ViewPatterns #-}
{-# language OverloadedStrings #-}
-----------------------------------------------------------------------------
-- |
-- Module      :  Data.SRTree.Datasets
-- Copyright   :  (c) Fabricio Olivetti 2021 - 2024
-- License     :  BSD3
-- Maintainer  :  fabricio.olivetti@gmail.com
-- Stability   :  experimental
-- Portability :  FlexibleInstances, DeriveFunctor, ScopedTypeVariables, ConstraintKinds
--
-- Utility library to handle regression datasets
-- this module exports only the `loadDataset` function.
--
-----------------------------------------------------------------------------
module Data.SRTree.Datasets ( loadDataset, loadTrainingOnly, getX, splitData, DataSet(..) )
    where

import Codec.Compression.GZip (decompress)
import Data.ByteString.Char8 qualified as B
import Data.ByteString.Lazy qualified as BS
import Data.List (delete, find, intercalate)
import Data.Massiv.Array
  ( Array,
    Comp (Seq, Par),
    Ix2 ((:.)),
    S (..),
    Sz (Sz1),
    (<!),
  )
import Data.Massiv.Array qualified as M
import Data.Maybe (fromJust)
import Data.SRTree.Eval (PVector, SRMatrix, compMode)
import Data.Vector qualified as V
import System.FilePath (takeExtension)
import Text.Read (readMaybe)
import Data.Massiv.Array as MA hiding (forM_, forM, map, take, tail, zip, replicate, all, read)
import Control.Monad.State.Strict
import System.Random
import List.Shuffle ( shuffle )


-- a dataset is a triple (X, y, y_error)
type DataSet = (SRMatrix, PVector, Maybe PVector)

-- | Loads a list of list of bytestrings to a matrix of double
loadMtx :: [[B.ByteString]] -> Array S Ix2 Double
loadMtx :: [[ByteString]] -> Array S Ix2 Double
loadMtx = Comp -> [ListItem Ix2 Double] -> Array S Ix2 Double
forall r ix e.
(HasCallStack, Ragged L ix e, Manifest r e) =>
Comp -> [ListItem ix e] -> Array r ix e
M.fromLists' Comp
compMode ([[Double]] -> Array S Ix2 Double)
-> ([[ByteString]] -> [[Double]])
-> [[ByteString]]
-> Array S Ix2 Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ([ByteString] -> [Double]) -> [[ByteString]] -> [[Double]]
forall a b. (a -> b) -> [a] -> [b]
map ((ByteString -> Double) -> [ByteString] -> [Double]
forall a b. (a -> b) -> [a] -> [b]
map ([Char] -> Double
forall a. Read a => [Char] -> a
read ([Char] -> Double)
-> (ByteString -> [Char]) -> ByteString -> Double
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Char]
B.unpack))
{-# INLINE loadMtx #-}

-- | Returns true if the extension is .gz
isGZip :: FilePath -> Bool
isGZip :: [Char] -> Bool
isGZip = ([Char] -> [Char] -> Bool
forall a. Eq a => a -> a -> Bool
== [Char]
".gz") ([Char] -> Bool) -> ([Char] -> [Char]) -> [Char] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Char] -> [Char]
takeExtension
{-# INLINE isGZip #-}

-- | Detects the separator automatically by 
--   checking whether the use of each separator generates
--   the same amount of SRMatrix in every row and at least two SRMatrix.
--
--  >>> detectSep ["x1,x2,x3,x4"] 
-- ','
detectSep :: [B.ByteString] -> Char
detectSep :: [ByteString] -> Char
detectSep [ByteString]
xss = [Char] -> Char
go [Char]
seps
  where
    seps :: [Char]
seps = [Char
' ',Char
'\t',Char
'|',Char
':',Char
';',Char
',']
    xss' :: [ByteString]
xss' = (ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> ByteString
B.strip [ByteString]
xss

    -- consistency check whether all rows have the same
    -- number of columns when spliting by this sep 
    allSameLen :: [a] -> Bool
allSameLen []     = Bool
True
    allSameLen (a
y:[a]
ys) = a
y a -> a -> Bool
forall a. Eq a => a -> a -> Bool
/= a
1 Bool -> Bool -> Bool
&& (a -> Bool) -> [a] -> Bool
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Bool
all (a -> a -> Bool
forall a. Eq a => a -> a -> Bool
==a
y) [a]
ys

    go :: [Char] -> Char
go []     = [Char] -> Char
forall a. HasCallStack => [Char] -> a
error ([Char] -> Char) -> [Char] -> Char
forall a b. (a -> b) -> a -> b
$ [Char]
"CSV parsing error: unsupported separator. Supporter separators are "
                      [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char] -> [[Char]] -> [Char]
forall a. [a] -> [[a]] -> [a]
intercalate [Char]
"," ((Char -> [Char]) -> [Char] -> [[Char]]
forall a b. (a -> b) -> [a] -> [b]
map Char -> [Char]
forall a. Show a => a -> [Char]
show [Char]
seps)
    go (Char
c:[Char]
cs) = if [Ix1] -> Bool
forall {a}. (Eq a, Num a) => [a] -> Bool
allSameLen ([Ix1] -> Bool) -> [Ix1] -> Bool
forall a b. (a -> b) -> a -> b
$ (ByteString -> Ix1) -> [ByteString] -> [Ix1]
forall a b. (a -> b) -> [a] -> [b]
map ([ByteString] -> Ix1
forall a. [a] -> Ix1
forall (t :: * -> *) a. Foldable t => t a -> Ix1
length ([ByteString] -> Ix1)
-> (ByteString -> [ByteString]) -> ByteString -> Ix1
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Char -> ByteString -> [ByteString]
B.split Char
c) [ByteString]
xss'
                   then Char
c
                   else [Char] -> Char
go [Char]
cs
{-# INLINE detectSep #-}

-- | reads a file and returns a list of list of `ByteString`
-- corresponding to each element of the matrix.
-- The first row can be a header. 
readFileToLines :: FilePath -> IO [[B.ByteString]]
readFileToLines :: [Char] -> IO [[ByteString]]
readFileToLines [Char]
filename = do
  [ByteString]
content <- [ByteString] -> [ByteString]
removeBEmpty ([ByteString] -> [ByteString])
-> (ByteString -> [ByteString]) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [ByteString]
toLines (ByteString -> [ByteString])
-> (ByteString -> ByteString) -> ByteString -> [ByteString]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
toChar8 (ByteString -> ByteString)
-> (ByteString -> ByteString) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> ByteString
unzip (ByteString -> [ByteString]) -> IO ByteString -> IO [ByteString]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char] -> IO ByteString
BS.readFile [Char]
filename
  let sep :: Char
sep = [ByteString] -> Char
getSep [ByteString]
content
  [[ByteString]] -> IO [[ByteString]]
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ([[ByteString]] -> IO [[ByteString]])
-> ([ByteString] -> [[ByteString]])
-> [ByteString]
-> IO [[ByteString]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [[ByteString]] -> [[ByteString]]
forall {a}. [[a]] -> [[a]]
removeEmpty ([[ByteString]] -> [[ByteString]])
-> ([ByteString] -> [[ByteString]])
-> [ByteString]
-> [[ByteString]]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString -> [ByteString]) -> [ByteString] -> [[ByteString]]
forall a b. (a -> b) -> [a] -> [b]
map (Char -> ByteString -> [ByteString]
B.split Char
sep) ([ByteString] -> IO [[ByteString]])
-> [ByteString] -> IO [[ByteString]]
forall a b. (a -> b) -> a -> b
$ [ByteString]
content
  where
      getSep :: [ByteString] -> Char
getSep       = [ByteString] -> Char
detectSep ([ByteString] -> Char)
-> ([ByteString] -> [ByteString]) -> [ByteString] -> Char
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ix1 -> [ByteString] -> [ByteString]
forall a. Ix1 -> [a] -> [a]
take Ix1
100 -- use only first 100 rows to detect separator
      removeBEmpty :: [ByteString] -> [ByteString]
removeBEmpty = (ByteString -> Bool) -> [ByteString] -> [ByteString]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> (ByteString -> Bool) -> ByteString -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> Bool
B.null)
      removeEmpty :: [[a]] -> [[a]]
removeEmpty  = ([a] -> Bool) -> [[a]] -> [[a]]
forall a. (a -> Bool) -> [a] -> [a]
filter (Bool -> Bool
not (Bool -> Bool) -> ([a] -> Bool) -> [a] -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [a] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null)
      toLines :: ByteString -> [ByteString]
toLines      = Char -> ByteString -> [ByteString]
B.split Char
'\n'
      unzip :: ByteString -> ByteString
unzip        = if [Char] -> Bool
isGZip [Char]
filename then ByteString -> ByteString
decompress else ByteString -> ByteString
forall a. a -> a
id
      toChar8 :: ByteString -> ByteString
toChar8      = [Char] -> ByteString
B.pack ([Char] -> ByteString)
-> (ByteString -> [Char]) -> ByteString -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word8 -> Char) -> [Word8] -> [Char]
forall a b. (a -> b) -> [a] -> [b]
map (Ix1 -> Char
forall a. Enum a => Ix1 -> a
toEnum (Ix1 -> Char) -> (Word8 -> Ix1) -> Word8 -> Char
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Ix1
forall a. Enum a => a -> Ix1
fromEnum) ([Word8] -> [Char])
-> (ByteString -> [Word8]) -> ByteString -> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Word8]
BS.unpack
{-# INLINE readFileToLines #-}

-- | Splits the parameters from the filename
-- the expected format of the filename is *filename.ext:p1:p2:p3:p4*
-- where p1 and p2 is the starting and end rows for the training data,
-- by default p1 = 0 and p2 = number of rows - 1
-- p3 is the target PVector, it can be a string corresponding to the header
-- or an index.
-- p4 is a comma separated list of SRMatrix (either index or name) to be used as 
-- input variables. These will be renamed internally as x0, x1, ... in the order
-- of this list.
splitFileNameParams :: FilePath -> (FilePath, [B.ByteString])
splitFileNameParams :: [Char] -> ([Char], [ByteString])
splitFileNameParams ([Char] -> ByteString
B.pack -> ByteString
filename) = (ByteString -> [Char]
B.unpack ByteString
fname, Ix1 -> [ByteString] -> [ByteString]
forall a. Ix1 -> [a] -> [a]
take Ix1
6 [ByteString]
params)
  where
    (ByteString
fname : [ByteString]
params') = Char -> ByteString -> [ByteString]
B.split Char
':' ByteString
filename
    -- fill up the empty parameters with an empty string
    params :: [ByteString]
params            = [ByteString]
params' [ByteString] -> [ByteString] -> [ByteString]
forall a. Semigroup a => a -> a -> a
<> Ix1 -> ByteString -> [ByteString]
forall a. Ix1 -> a -> [a]
replicate (Ix1
6 Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
- Ix1 -> Ix1 -> Ix1
forall a. Ord a => a -> a -> a
min Ix1
6 ([ByteString] -> Ix1
forall a. [a] -> Ix1
forall (t :: * -> *) a. Foldable t => t a -> Ix1
length [ByteString]
params')) ByteString
B.empty
{-# inline splitFileNameParams #-}

-- | Tries to parse a string into an int
parseVal :: String -> Either String Int
parseVal :: [Char] -> Either [Char] Ix1
parseVal [Char]
xs = case [Char] -> Maybe Ix1
forall a. Read a => [Char] -> Maybe a
readMaybe [Char]
xs of
                Maybe Ix1
Nothing -> [Char] -> Either [Char] Ix1
forall a b. a -> Either a b
Left [Char]
xs
                Just Ix1
x  -> Ix1 -> Either [Char] Ix1
forall a b. b -> Either a b
Right Ix1
x
{-# inline parseVal #-}

-- | Given a map between PVector name and indeces,
-- the target PVector and the variables SRMatrix,
-- returns the indices of the variables SRMatrix and the target
getColumns :: [(B.ByteString, Int)] -> B.ByteString -> B.ByteString -> B.ByteString -> ([Int], Int, Int)
getColumns :: [(ByteString, Ix1)]
-> ByteString -> ByteString -> ByteString -> ([Ix1], Ix1, Ix1)
getColumns [(ByteString, Ix1)]
headerMap ByteString
target ByteString
columns ByteString
target_error = ([Ix1]
ixs, Ix1
iy, Ix1
iy_error)
  where
      n_cols :: Ix1
n_cols  = [(ByteString, Ix1)] -> Ix1
forall a. [a] -> Ix1
forall (t :: * -> *) a. Foldable t => t a -> Ix1
length [(ByteString, Ix1)]
headerMap
      getIx :: [Char] -> Ix1
getIx [Char]
c = case [Char] -> Either [Char] Ix1
parseVal [Char]
c of
                  -- if the PVector is a name, retrive the index
                  Left [Char]
name -> case ((ByteString, Ix1) -> Bool)
-> [(ByteString, Ix1)] -> Maybe (ByteString, Ix1)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((ByteString -> ByteString -> Bool
forall a. Eq a => a -> a -> Bool
== [Char] -> ByteString
B.pack [Char]
name) (ByteString -> Bool)
-> ((ByteString, Ix1) -> ByteString) -> (ByteString, Ix1) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString, Ix1) -> ByteString
forall a b. (a, b) -> a
fst) [(ByteString, Ix1)]
headerMap of
                                 Maybe (ByteString, Ix1)
Nothing -> [Char] -> Ix1
forall a. HasCallStack => [Char] -> a
error ([Char] -> Ix1) -> [Char] -> Ix1
forall a b. (a -> b) -> a -> b
$ [Char]
"PVector name " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
name [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" does not exist."
                                 Just (ByteString, Ix1)
v  -> (ByteString, Ix1) -> Ix1
forall a b. (a, b) -> b
snd (ByteString, Ix1)
v
                  -- if it is an int, check if it is within range
                  Right Ix1
v   -> if Ix1
v Ix1 -> Ix1 -> Bool
forall a. Ord a => a -> a -> Bool
>= Ix1
0 Bool -> Bool -> Bool
&& Ix1
v Ix1 -> Ix1 -> Bool
forall a. Ord a => a -> a -> Bool
< Ix1
n_cols
                                 then Ix1
v
                                 else [Char] -> Ix1
forall a. HasCallStack => [Char] -> a
error ([Char] -> Ix1) -> [Char] -> Ix1
forall a b. (a -> b) -> a -> b
$ [Char]
"PVector index " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Ix1 -> [Char]
forall a. Show a => a -> [Char]
show Ix1
v [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
" out of range."
      -- if the input variables SRMatrix are ommitted, use
      -- every PVector except for iy
      ixs :: [Ix1]
ixs = if ByteString -> Bool
B.null ByteString
columns
               then Ix1 -> [Ix1] -> [Ix1]
forall a. Eq a => a -> [a] -> [a]
delete Ix1
iy [Ix1
0 .. Ix1
n_cols Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
- Ix1
1]
               else (ByteString -> Ix1) -> [ByteString] -> [Ix1]
forall a b. (a -> b) -> [a] -> [b]
map ([Char] -> Ix1
getIx ([Char] -> Ix1) -> (ByteString -> [Char]) -> ByteString -> Ix1
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ByteString -> [Char]
B.unpack) ([ByteString] -> [Ix1]) -> [ByteString] -> [Ix1]
forall a b. (a -> b) -> a -> b
$ Char -> ByteString -> [ByteString]
B.split Char
',' ByteString
columns
      -- if the target PVector is ommitted, use the last one
      iy :: Ix1
iy = if ByteString -> Bool
B.null ByteString
target
              then Ix1
n_cols Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
- Ix1
1
              else [Char] -> Ix1
getIx ([Char] -> Ix1) -> [Char] -> Ix1
forall a b. (a -> b) -> a -> b
$ ByteString -> [Char]
B.unpack ByteString
target
      -- if the target PVector is ommitted, use the last one
      iy_error :: Ix1
iy_error = if ByteString -> Bool
B.null ByteString
target_error
                  then (-Ix1
1)
                  else [Char] -> Ix1
getIx ([Char] -> Ix1) -> [Char] -> Ix1
forall a b. (a -> b) -> a -> b
$ ByteString -> [Char]
B.unpack ByteString
target_error
{-# inline getColumns #-}

-- | Given the start and end rows, it returns the 
-- hmatrix extractors for the training and validation data
getRows :: B.ByteString -> B.ByteString -> Int -> (Int, Int)
getRows :: ByteString -> ByteString -> Ix1 -> (Ix1, Ix1)
getRows (ByteString -> [Char]
B.unpack -> [Char]
start) (ByteString -> [Char]
B.unpack -> [Char]
end) Ix1
nRows
  | Ix1
st_ix Ix1 -> Ix1 -> Bool
forall a. Ord a => a -> a -> Bool
>= Ix1
end_ix                 = [Char] -> (Ix1, Ix1)
forall a. HasCallStack => [Char] -> a
error ([Char] -> (Ix1, Ix1)) -> [Char] -> (Ix1, Ix1)
forall a b. (a -> b) -> a -> b
$ [Char]
"Invalid range: " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char] -> [Char]
forall a. Show a => a -> [Char]
show [Char]
start [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
":" [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char] -> [Char]
forall a. Show a => a -> [Char]
show [Char]
end [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"."
  | Ix1
st_ix Ix1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
== Ix1
0 Bool -> Bool -> Bool
&& Ix1
end_ix Ix1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
== Ix1
nRowsIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1 = (Ix1
0, Ix1
nRows Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
- Ix1
1)
  | Bool
otherwise                       = (Ix1
st_ix, Ix1
end_ix)
  where
      st_ix :: Ix1
st_ix = if [Char] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Char]
start
                then Ix1
0
                else case [Char] -> Maybe Ix1
forall a. Read a => [Char] -> Maybe a
readMaybe [Char]
start of
                       Maybe Ix1
Nothing -> [Char] -> Ix1
forall a. HasCallStack => [Char] -> a
error ([Char] -> Ix1) -> [Char] -> Ix1
forall a b. (a -> b) -> a -> b
$ [Char]
"Invalid starting row " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
start [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"."
                       Just Ix1
x  -> if Ix1
x Ix1 -> Ix1 -> Bool
forall a. Ord a => a -> a -> Bool
< Ix1
0 Bool -> Bool -> Bool
|| Ix1
x Ix1 -> Ix1 -> Bool
forall a. Ord a => a -> a -> Bool
>= Ix1
nRows
                                    then [Char] -> Ix1
forall a. HasCallStack => [Char] -> a
error ([Char] -> Ix1) -> [Char] -> Ix1
forall a b. (a -> b) -> a -> b
$ [Char]
"Invalid starting row " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Ix1 -> [Char]
forall a. Show a => a -> [Char]
show Ix1
x [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"."
                                    else Ix1
x
      end_ix :: Ix1
end_ix = if [Char] -> Bool
forall a. [a] -> Bool
forall (t :: * -> *) a. Foldable t => t a -> Bool
null [Char]
end
                then Ix1
nRows Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
- Ix1
1
                else case [Char] -> Maybe Ix1
forall a. Read a => [Char] -> Maybe a
readMaybe [Char]
end of
                       Maybe Ix1
Nothing -> [Char] -> Ix1
forall a. HasCallStack => [Char] -> a
error ([Char] -> Ix1) -> [Char] -> Ix1
forall a b. (a -> b) -> a -> b
$ [Char]
"Invalid end row " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
end [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"."
                       Just Ix1
x  -> if Ix1
x Ix1 -> Ix1 -> Bool
forall a. Ord a => a -> a -> Bool
< Ix1
0 Bool -> Bool -> Bool
|| Ix1
x Ix1 -> Ix1 -> Bool
forall a. Ord a => a -> a -> Bool
>= Ix1
nRows
                                    then [Char] -> Ix1
forall a. HasCallStack => [Char] -> a
error ([Char] -> Ix1) -> [Char] -> Ix1
forall a b. (a -> b) -> a -> b
$ [Char]
"Invalid end row " [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> Ix1 -> [Char]
forall a. Show a => a -> [Char]
show Ix1
x [Char] -> [Char] -> [Char]
forall a. Semigroup a => a -> a -> a
<> [Char]
"."
                                    else Ix1
x
{-# inline getRows #-}

-- | `loadDataset` loads a dataset with a filename in the format:
--   filename.ext:start_row:end_row:target:features:y_err
--   it returns the X_train, y_train, X_test, y_test, varnames, target name 
--   where varnames are a comma separated list of the name of the vars 
--   and target name is the name of the target
--
-- where
--
-- **start_row:end_row** is the range of the training rows (default 0:nrows-1).
--   every other row not included in this range will be used as validation
-- **target** is either the name of the PVector (if the datafile has headers) or the index
-- of the target variable
-- **features** is a comma separated list of SRMatrix names or indices to be used as
-- input variables of the regression model.
loadDataset :: FilePath -> Bool -> IO ((SRMatrix, PVector, SRMatrix, PVector), (Maybe PVector, Maybe PVector), String, String)
loadDataset :: [Char]
-> Bool
-> IO
     ((Array S Ix2 Double, PVector, Array S Ix2 Double, PVector),
      (Maybe PVector, Maybe PVector), [Char], [Char])
loadDataset [Char]
filename Bool
hasHeader = do  
  [[ByteString]]
csv <- [Char] -> IO [[ByteString]]
readFileToLines [Char]
fname
  ((Array S Ix2 Double, PVector, Array S Ix2 Double, PVector),
 (Maybe PVector, Maybe PVector), [Char], [Char])
-> IO
     ((Array S Ix2 Double, PVector, Array S Ix2 Double, PVector),
      (Maybe PVector, Maybe PVector), [Char], [Char])
forall a. a -> IO a
forall (f :: * -> *) a. Applicative f => a -> f a
pure (((Array S Ix2 Double, PVector, Array S Ix2 Double, PVector),
  (Maybe PVector, Maybe PVector), [Char], [Char])
 -> IO
      ((Array S Ix2 Double, PVector, Array S Ix2 Double, PVector),
       (Maybe PVector, Maybe PVector), [Char], [Char]))
-> ((Array S Ix2 Double, PVector, Array S Ix2 Double, PVector),
    (Maybe PVector, Maybe PVector), [Char], [Char])
-> IO
     ((Array S Ix2 Double, PVector, Array S Ix2 Double, PVector),
      (Maybe PVector, Maybe PVector), [Char], [Char])
forall a b. (a -> b) -> a -> b
$ [[ByteString]]
-> [ByteString]
-> Bool
-> ((Array S Ix2 Double, PVector, Array S Ix2 Double, PVector),
    (Maybe PVector, Maybe PVector), [Char], [Char])
processData [[ByteString]]
csv [ByteString]
params Bool
hasHeader
  where
    ([Char]
fname, [ByteString]
params) = [Char] -> ([Char], [ByteString])
splitFileNameParams [Char]
filename

-- support function that does everything for loadDataset
processData :: [[B.ByteString]] -> [B.ByteString] -> Bool -> ((SRMatrix, PVector, SRMatrix, PVector), (Maybe PVector, Maybe PVector), String, String)
processData :: [[ByteString]]
-> [ByteString]
-> Bool
-> ((Array S Ix2 Double, PVector, Array S Ix2 Double, PVector),
    (Maybe PVector, Maybe PVector), [Char], [Char])
processData [[ByteString]]
csv [ByteString]
params Bool
hasHeader = ((Array S Ix2 Double
x_train, PVector
y_train, Array S Ix2 Double
x_val, PVector
y_val) , (Maybe PVector
y_err_train, Maybe PVector
y_err_val), [Char]
varnames, [Char]
targetname)
  where
    ncols :: Ix1
ncols             = [ByteString] -> Ix1
forall a. [a] -> Ix1
forall (t :: * -> *) a. Foldable t => t a -> Ix1
length ([ByteString] -> Ix1) -> [ByteString] -> Ix1
forall a b. (a -> b) -> a -> b
$ [[ByteString]] -> [ByteString]
forall a. HasCallStack => [a] -> a
head [[ByteString]]
csv
    nrows :: Ix1
nrows             = [[ByteString]] -> Ix1
forall a. [a] -> Ix1
forall (t :: * -> *) a. Foldable t => t a -> Ix1
length [[ByteString]]
csv Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
- Bool -> Ix1
forall a. Enum a => a -> Ix1
fromEnum Bool
hasHeader
    ([(ByteString, Ix1)]
header, [[ByteString]]
content) = if Bool
hasHeader
                           then ([ByteString] -> [Ix1] -> [(ByteString, Ix1)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((ByteString -> ByteString) -> [ByteString] -> [ByteString]
forall a b. (a -> b) -> [a] -> [b]
map ByteString -> ByteString
B.strip ([ByteString] -> [ByteString]) -> [ByteString] -> [ByteString]
forall a b. (a -> b) -> a -> b
$ [[ByteString]] -> [ByteString]
forall a. HasCallStack => [a] -> a
head [[ByteString]]
csv) [Ix1
0..], [[ByteString]] -> [[ByteString]]
forall a. HasCallStack => [a] -> [a]
tail [[ByteString]]
csv)
                           else ((Ix1 -> (ByteString, Ix1)) -> [Ix1] -> [(ByteString, Ix1)]
forall a b. (a -> b) -> [a] -> [b]
map (\Ix1
i -> ([Char] -> ByteString
B.pack (Char
'x' Char -> [Char] -> [Char]
forall a. a -> [a] -> [a]
: Ix1 -> [Char]
forall a. Show a => a -> [Char]
show Ix1
i), Ix1
i)) [Ix1
0 .. Ix1
ncolsIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1], [[ByteString]]
csv)
    varnames :: [Char]
varnames          = [Char] -> [[Char]] -> [Char]
forall a. [a] -> [[a]] -> [a]
intercalate [Char]
"," [ByteString -> [Char]
B.unpack ByteString
v | Ix1
c <- [Ix1]
ixs
                                        , let v :: ByteString
v = (ByteString, Ix1) -> ByteString
forall a b. (a, b) -> a
fst ((ByteString, Ix1) -> ByteString)
-> (Maybe (ByteString, Ix1) -> (ByteString, Ix1))
-> Maybe (ByteString, Ix1)
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (ByteString, Ix1) -> (ByteString, Ix1)
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (ByteString, Ix1) -> ByteString)
-> Maybe (ByteString, Ix1) -> ByteString
forall a b. (a -> b) -> a -> b
$ ((ByteString, Ix1) -> Bool)
-> [(ByteString, Ix1)] -> Maybe (ByteString, Ix1)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((Ix1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
==Ix1
c)(Ix1 -> Bool)
-> ((ByteString, Ix1) -> Ix1) -> (ByteString, Ix1) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(ByteString, Ix1) -> Ix1
forall a b. (a, b) -> b
snd) [(ByteString, Ix1)]
header
                                        ]
    targetname :: [Char]
targetname        = if Bool
hasHeader then (ByteString -> [Char]
B.unpack (ByteString -> [Char])
-> ([(ByteString, Ix1)] -> ByteString)
-> [(ByteString, Ix1)]
-> [Char]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (ByteString, Ix1) -> ByteString
forall a b. (a, b) -> a
fst ((ByteString, Ix1) -> ByteString)
-> ([(ByteString, Ix1)] -> (ByteString, Ix1))
-> [(ByteString, Ix1)]
-> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Maybe (ByteString, Ix1) -> (ByteString, Ix1)
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe (ByteString, Ix1) -> (ByteString, Ix1))
-> ([(ByteString, Ix1)] -> Maybe (ByteString, Ix1))
-> [(ByteString, Ix1)]
-> (ByteString, Ix1)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ((ByteString, Ix1) -> Bool)
-> [(ByteString, Ix1)] -> Maybe (ByteString, Ix1)
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find ((Ix1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
==Ix1
iy)(Ix1 -> Bool)
-> ((ByteString, Ix1) -> Ix1) -> (ByteString, Ix1) -> Bool
forall b c a. (b -> c) -> (a -> b) -> a -> c
.(ByteString, Ix1) -> Ix1
forall a b. (a, b) -> b
snd) ([(ByteString, Ix1)] -> [Char]) -> [(ByteString, Ix1)] -> [Char]
forall a b. (a -> b) -> a -> b
$ [(ByteString, Ix1)]
header) else [Char]
"y"
    -- get rows and SRMatrix indices
    (Ix1
st, Ix1
end)                  = ByteString -> ByteString -> Ix1 -> (Ix1, Ix1)
getRows ([ByteString]
params [ByteString] -> Ix1 -> ByteString
forall a. HasCallStack => [a] -> Ix1 -> a
!! Ix1
0) ([ByteString]
params [ByteString] -> Ix1 -> ByteString
forall a. HasCallStack => [a] -> Ix1 -> a
!! Ix1
1) Ix1
nrows
    ([Ix1]
ixs, Ix1
iy, Ix1
iy_err) = [(ByteString, Ix1)]
-> ByteString -> ByteString -> ByteString -> ([Ix1], Ix1, Ix1)
getColumns [(ByteString, Ix1)]
header ([ByteString]
params [ByteString] -> Ix1 -> ByteString
forall a. HasCallStack => [a] -> Ix1 -> a
!! Ix1
2) ([ByteString]
params [ByteString] -> Ix1 -> ByteString
forall a. HasCallStack => [a] -> Ix1 -> a
!! Ix1
3) ([ByteString]
params [ByteString] -> Ix1 -> ByteString
forall a. HasCallStack => [a] -> Ix1 -> a
!! Ix1
4)

    -- load data and split sets
    datum :: Array S Ix2 Double
datum   = [[ByteString]] -> Array S Ix2 Double
loadMtx [[ByteString]]
content
    p :: Ix1
p       = [Ix1] -> Ix1
forall a. [a] -> Ix1
forall (t :: * -> *) a. Foldable t => t a -> Ix1
length [Ix1]
ixs

    x :: Array S Ix2 Double
x       = S -> Array DL Ix2 Double -> Array S Ix2 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array DL Ix2 Double -> Array S Ix2 Double)
-> Array DL Ix2 Double -> Array S Ix2 Double
forall a b. (a -> b) -> a -> b
$ Either SomeException (Array DL Ix2 Double) -> Array DL Ix2 Double
forall a. HasCallStack => Either SomeException a -> a
M.throwEither (Either SomeException (Array DL Ix2 Double) -> Array DL Ix2 Double)
-> Either SomeException (Array DL Ix2 Double)
-> Array DL Ix2 Double
forall a b. (a -> b) -> a -> b
$ [Array D (Lower Ix2) Double]
-> Either SomeException (Array DL Ix2 Double)
forall r ix e (f :: * -> *) (m :: * -> *).
(Foldable f, MonadThrow m, Index (Lower ix), Source r e,
 Index ix) =>
f (Array r (Lower ix) e) -> m (Array DL ix e)
M.stackInnerSlicesM ([Array D (Lower Ix2) Double]
 -> Either SomeException (Array DL Ix2 Double))
-> [Array D (Lower Ix2) Double]
-> Either SomeException (Array DL Ix2 Double)
forall a b. (a -> b) -> a -> b
$ (Ix1 -> Array D (Lower Ix2) Double)
-> [Ix1] -> [Array D (Lower Ix2) Double]
forall a b. (a -> b) -> [a] -> [b]
map (Array S Ix2 Double
datum Array S Ix2 Double -> Ix1 -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Ix1 -> Array D (Lower ix) e
<!) [Ix1]
ixs
    y :: Array D (Lower Ix2) Double
y       = Array S Ix2 Double
datum Array S Ix2 Double -> Ix1 -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Ix1 -> Array D (Lower ix) e
<! Ix1
iy
    y_err :: Array D (Lower Ix2) Double
y_err   = Array S Ix2 Double
datum Array S Ix2 Double -> Ix1 -> Array D (Lower Ix2) Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
Array r ix e -> Ix1 -> Array D (Lower ix) e
<! Ix1
iy_err

    x_train :: Array S Ix2 Double
x_train = S -> Array D Ix2 Double -> Array S Ix2 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix2 Double -> Array S Ix2 Double)
-> Array D Ix2 Double -> Array S Ix2 Double
forall a b. (a -> b) -> a -> b
$ Ix2 -> Ix2 -> Array S Ix2 Double -> Array D Ix2 Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
ix -> ix -> Array r ix e -> Array D ix e
M.extractFromTo' (Ix1
st Ix1 -> Ix1 -> Ix2
:. Ix1
0) (Ix1
endIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+Ix1
1 Ix1 -> Ix1 -> Ix2
:. Ix1
p) Array S Ix2 Double
x
    y_train :: PVector
y_train = S -> Array D Ix1 Double -> PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> PVector) -> Array D Ix1 Double -> PVector
forall a b. (a -> b) -> a -> b
$ Ix1 -> Ix1 -> Array D Ix1 Double -> Array D Ix1 Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
ix -> ix -> Array r ix e -> Array D ix e
M.extractFromTo' Ix1
st (Ix1
endIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+Ix1
1) Array D Ix1 Double
y 
    x_val :: Array S Ix2 Double
x_val   = S -> Array DL Ix2 Double -> Array S Ix2 Double
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array DL Ix2 Double -> Array S Ix2 Double)
-> Array DL Ix2 Double -> Array S Ix2 Double
forall a b. (a -> b) -> a -> b
$ Either SomeException (Array DL Ix2 Double) -> Array DL Ix2 Double
forall a. HasCallStack => Either SomeException a -> a
M.throwEither (Either SomeException (Array DL Ix2 Double) -> Array DL Ix2 Double)
-> Either SomeException (Array DL Ix2 Double)
-> Array DL Ix2 Double
forall a b. (a -> b) -> a -> b
$ Ix1
-> Sz Ix1
-> Array S Ix2 Double
-> Either SomeException (Array DL Ix2 Double)
forall r ix e (m :: * -> *).
(MonadThrow m, Index ix, Index (Lower ix), Source r e) =>
Ix1 -> Sz Ix1 -> Array r ix e -> m (Array DL ix e)
M.deleteRowsM Ix1
st (Ix1 -> Sz Ix1
Sz1 (Ix1 -> Sz Ix1) -> Ix1 -> Sz Ix1
forall a b. (a -> b) -> a -> b
$ Ix1
end Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
- Ix1
st Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1) Array S Ix2 Double
x
    y_val :: PVector
y_val   = S -> Array DL Ix1 Double -> PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array DL Ix1 Double -> PVector) -> Array DL Ix1 Double -> PVector
forall a b. (a -> b) -> a -> b
$ Either SomeException (Array DL Ix1 Double) -> Array DL Ix1 Double
forall a. HasCallStack => Either SomeException a -> a
M.throwEither (Either SomeException (Array DL Ix1 Double) -> Array DL Ix1 Double)
-> Either SomeException (Array DL Ix1 Double)
-> Array DL Ix1 Double
forall a b. (a -> b) -> a -> b
$ Ix1
-> Sz Ix1
-> Array D Ix1 Double
-> Either SomeException (Array DL Ix1 Double)
forall r ix e (m :: * -> *).
(MonadThrow m, Index ix, Source r e) =>
Ix1 -> Sz Ix1 -> Array r ix e -> m (Array DL ix e)
M.deleteColumnsM Ix1
st (Ix1 -> Sz Ix1
Sz1 (Ix1 -> Sz Ix1) -> Ix1 -> Sz Ix1
forall a b. (a -> b) -> a -> b
$ Ix1
end Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
- Ix1
st Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1) Array D Ix1 Double
y

    y_err_train :: Maybe PVector
y_err_train = if Ix1
iy_err Ix1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
== -Ix1
1 then Maybe PVector
forall a. Maybe a
Nothing else PVector -> Maybe PVector
forall a. a -> Maybe a
Just (PVector -> Maybe PVector) -> PVector -> Maybe PVector
forall a b. (a -> b) -> a -> b
$ S -> Array D Ix1 Double -> PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array D Ix1 Double -> PVector) -> Array D Ix1 Double -> PVector
forall a b. (a -> b) -> a -> b
$ Ix1 -> Ix1 -> Array D Ix1 Double -> Array D Ix1 Double
forall r ix e.
(HasCallStack, Index ix, Source r e) =>
ix -> ix -> Array r ix e -> Array D ix e
M.extractFromTo' Ix1
st (Ix1
endIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+Ix1
1) Array D Ix1 Double
y_err
    y_err_val :: Maybe PVector
y_err_val   = if Ix1
iy_err Ix1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
== -Ix1
1 then Maybe PVector
forall a. Maybe a
Nothing else PVector -> Maybe PVector
forall a. a -> Maybe a
Just (PVector -> Maybe PVector) -> PVector -> Maybe PVector
forall a b. (a -> b) -> a -> b
$ S -> Array DL Ix1 Double -> PVector
forall r e r' ix.
(Manifest r e, Load r' ix e) =>
r -> Array r' ix e -> Array r ix e
M.computeAs S
S (Array DL Ix1 Double -> PVector) -> Array DL Ix1 Double -> PVector
forall a b. (a -> b) -> a -> b
$ Either SomeException (Array DL Ix1 Double) -> Array DL Ix1 Double
forall a. HasCallStack => Either SomeException a -> a
M.throwEither (Either SomeException (Array DL Ix1 Double) -> Array DL Ix1 Double)
-> Either SomeException (Array DL Ix1 Double)
-> Array DL Ix1 Double
forall a b. (a -> b) -> a -> b
$ Ix1
-> Sz Ix1
-> Array D Ix1 Double
-> Either SomeException (Array DL Ix1 Double)
forall r ix e (m :: * -> *).
(MonadThrow m, Index ix, Source r e) =>
Ix1 -> Sz Ix1 -> Array r ix e -> m (Array DL ix e)
M.deleteColumnsM Ix1
st (Ix1 -> Sz Ix1
Sz1 (Ix1 -> Sz Ix1) -> Ix1 -> Sz Ix1
forall a b. (a -> b) -> a -> b
$ Ix1
end Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
- Ix1
st Ix1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
+ Ix1
1) Array D Ix1 Double
y_err
{-# inline processData #-}

chunksOf :: Int -> [e] -> [[e]]
chunksOf :: forall e. Ix1 -> [e] -> [[e]]
chunksOf Ix1
i [e]
ls = ([e] -> [e]) -> [[e]] -> [[e]]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Ix1 -> [e] -> [e]
forall a. Ix1 -> [a] -> [a]
Prelude.take Ix1
i) ((([e] -> [[e]] -> [[e]]) -> [[e]] -> [[e]]) -> [[e]]
forall a. ((a -> [a] -> [a]) -> [a] -> [a]) -> [a]
build ([e] -> ([e] -> [[e]] -> [[e]]) -> [[e]] -> [[e]]
forall e a. [e] -> ([e] -> a -> a) -> a -> a
splitter [e]
ls))
 where
  splitter :: [e] -> ([e] -> a -> a) -> a -> a
  splitter :: forall e a. [e] -> ([e] -> a -> a) -> a -> a
splitter [] [e] -> a -> a
_ a
n = a
n
  splitter [e]
l [e] -> a -> a
c a
n = [e]
l [e] -> a -> a
`c` [e] -> ([e] -> a -> a) -> a -> a
forall e a. [e] -> ([e] -> a -> a) -> a -> a
splitter (Ix1 -> [e] -> [e]
forall a. Ix1 -> [a] -> [a]
Prelude.drop Ix1
i [e]
l) [e] -> a -> a
c a
n
  build :: ((a -> [a] -> [a]) -> [a] -> [a]) -> [a]
  build :: forall a. ((a -> [a] -> [a]) -> [a] -> [a]) -> [a]
build (a -> [a] -> [a]) -> [a] -> [a]
g = (a -> [a] -> [a]) -> [a] -> [a]
g (:) []

splitData :: DataSet ->Int -> State StdGen (DataSet, DataSet)
splitData :: DataSet -> Ix1 -> State StdGen (DataSet, DataSet)
splitData (Array S Ix2 Double
x, PVector
y, Maybe PVector
mYErr) Ix1
k = do
  if Ix1
k Ix1 -> Ix1 -> Bool
forall a. Eq a => a -> a -> Bool
== Ix1
1
    then (DataSet, DataSet) -> State StdGen (DataSet, DataSet)
forall a. a -> StateT StdGen Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Array S Ix2 Double
x, PVector
y, Maybe PVector
mYErr), (Array S Ix2 Double
x, PVector
y, Maybe PVector
mYErr))
    else do
      [Ix1]
ixs' <- ((StdGen -> ([Ix1], StdGen)) -> StateT StdGen Identity [Ix1]
forall a. (StdGen -> (a, StdGen)) -> StateT StdGen Identity a
forall s (m :: * -> *) a. MonadState s m => (s -> (a, s)) -> m a
state ((StdGen -> ([Ix1], StdGen)) -> StateT StdGen Identity [Ix1])
-> ([Ix1] -> StdGen -> ([Ix1], StdGen))
-> [Ix1]
-> StateT StdGen Identity [Ix1]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Ix1] -> StdGen -> ([Ix1], StdGen)
forall g a. RandomGen g => [a] -> g -> ([a], g)
shuffle) [Ix1
0 .. Ix1
szIx1 -> Ix1 -> Ix1
forall a. Num a => a -> a -> a
-Ix1
1]
      let ixs :: [[Ix1]]
ixs = Ix1 -> [Ix1] -> [[Ix1]]
forall e. Ix1 -> [e] -> [[e]]
chunksOf Ix1
k [Ix1]
ixs'

      let (Array S Ix2 Double
x_tr, Array S Ix2 Double
x_te) = [[Ix1]]
-> Array S Ix2 Double -> (Array S Ix2 Double, Array S Ix2 Double)
getX [[Ix1]]
ixs Array S Ix2 Double
x
          (PVector
y_tr, PVector
y_te) = [[Ix1]] -> PVector -> (PVector, PVector)
getY [[Ix1]]
ixs PVector
y
          mY :: Maybe (PVector, PVector)
mY = (PVector -> (PVector, PVector))
-> Maybe PVector -> Maybe (PVector, PVector)
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([[Ix1]] -> PVector -> (PVector, PVector)
getY [[Ix1]]
ixs) Maybe PVector
mYErr
          (Maybe PVector
y_err_tr, Maybe PVector
y_err_te) = (((PVector, PVector) -> PVector)
-> Maybe (PVector, PVector) -> Maybe PVector
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PVector, PVector) -> PVector
forall a b. (a, b) -> a
fst Maybe (PVector, PVector)
mY, ((PVector, PVector) -> PVector)
-> Maybe (PVector, PVector) -> Maybe PVector
forall a b. (a -> b) -> Maybe a -> Maybe b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (PVector, PVector) -> PVector
forall a b. (a, b) -> b
snd Maybe (PVector, PVector)
mY)
      (DataSet, DataSet) -> State StdGen (DataSet, DataSet)
forall a. a -> StateT StdGen Identity a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ((Array S Ix2 Double
x_tr, PVector
y_tr, Maybe PVector
y_err_tr), (Array S Ix2 Double
x_te, PVector
y_te, Maybe PVector
y_err_te))
  where
    (MA.Sz Ix1
sz) = PVector -> Sz Ix1
forall r ix e. Size r => Array r ix e -> Sz ix
forall ix e. Array S ix e -> Sz ix
MA.size PVector
y
    comp_x :: Comp
comp_x     = Array S Ix2 Double -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
MA.getComp Array S Ix2 Double
x
    comp_y :: Comp
comp_y     = PVector -> Comp
forall r ix e. Strategy r => Array r ix e -> Comp
forall ix e. Array S ix e -> Comp
MA.getComp PVector
y

    getX :: [[Int]] -> SRMatrix -> (SRMatrix, SRMatrix)
    getX :: [[Ix1]]
-> Array S Ix2 Double -> (Array S Ix2 Double, Array S Ix2 Double)
getX [[Ix1]]
ixs Array S Ix2 Double
xs' = let xs :: [ListItem Ix2 Double]
xs = Array S Ix2 Double -> [ListItem Ix2 Double]
forall ix e r.
(Ragged L ix e, Shape r ix, Source r e) =>
Array r ix e -> [ListItem ix e]
MA.toLists Array S Ix2 Double
xs' :: [MA.ListItem MA.Ix2 Double]
                    in ( Comp -> [ListItem Ix2 Double] -> Array S Ix2 Double
forall r ix e.
(HasCallStack, Ragged L ix e, Manifest r e) =>
Comp -> [ListItem ix e] -> Array r ix e
MA.fromLists' Comp
comp_x [[[Double]]
xs [[Double]] -> Ix1 -> [Double]
forall a. HasCallStack => [a] -> Ix1 -> a
!! Ix1
ix | [Ix1]
ixs_i <- [[Ix1]]
ixs, Ix1
ix <- [Ix1] -> [Ix1]
forall a. HasCallStack => [a] -> [a]
Prelude.tail [Ix1]
ixs_i]
                       , Comp -> [ListItem Ix2 Double] -> Array S Ix2 Double
forall r ix e.
(HasCallStack, Ragged L ix e, Manifest r e) =>
Comp -> [ListItem ix e] -> Array r ix e
MA.fromLists' Comp
comp_x [[[Double]]
xs [[Double]] -> Ix1 -> [Double]
forall a. HasCallStack => [a] -> Ix1 -> a
!! Ix1
ix | [Ix1]
ixs_i <- [[Ix1]]
ixs, let ix :: Ix1
ix = [Ix1] -> Ix1
forall a. HasCallStack => [a] -> a
Prelude.head [Ix1]
ixs_i]
                       )
    getY :: [[Int]] -> PVector -> (PVector, PVector)
    getY :: [[Ix1]] -> PVector -> (PVector, PVector)
getY [[Ix1]]
ixs PVector
ys  = ( Comp -> [Double] -> PVector
forall r e. Manifest r e => Comp -> [e] -> Vector r e
MA.fromList Comp
comp_y [PVector
ys PVector -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
MA.! Ix1
ix | [Ix1]
ixs_i <- [[Ix1]]
ixs, Ix1
ix <- [Ix1] -> [Ix1]
forall a. HasCallStack => [a] -> [a]
Prelude.tail [Ix1]
ixs_i]
                   , Comp -> [Double] -> PVector
forall r e. Manifest r e => Comp -> [e] -> Vector r e
MA.fromList Comp
comp_y [PVector
ys PVector -> Ix1 -> Double
forall r ix e.
(HasCallStack, Manifest r e, Index ix) =>
Array r ix e -> ix -> e
MA.! Ix1
ix | [Ix1]
ixs_i <- [[Ix1]]
ixs, let ix :: Ix1
ix = [Ix1] -> Ix1
forall a. HasCallStack => [a] -> a
Prelude.head [Ix1]
ixs_i]
                   )

getTrain :: ((a, b1, c1, d1), (c2, b2), c3, d2) -> (a, b1, c2)
getTrain :: forall a b1 c1 d1 c2 b2 c3 d2.
((a, b1, c1, d1), (c2, b2), c3, d2) -> (a, b1, c2)
getTrain ((a
a, b1
b, c1
_, d1
_), (c2
c, b2
_), c3
_, d2
_) = (a
a,b1
b,c2
c)

getX :: DataSet -> SRMatrix
getX :: DataSet -> Array S Ix2 Double
getX (Array S Ix2 Double
a, PVector
_, Maybe PVector
_) = Array S Ix2 Double
a

getTarget :: DataSet -> PVector
getTarget :: DataSet -> PVector
getTarget (Array S Ix2 Double
_, PVector
b, Maybe PVector
_) = PVector
b

getError :: DataSet -> Maybe PVector
getError :: DataSet -> Maybe PVector
getError (Array S Ix2 Double
_, PVector
_, Maybe PVector
c) = Maybe PVector
c

loadTrainingOnly :: [Char] -> Bool -> IO DataSet
loadTrainingOnly [Char]
fname Bool
b = ((Array S Ix2 Double, PVector, Array S Ix2 Double, PVector),
 (Maybe PVector, Maybe PVector), [Char], [Char])
-> DataSet
forall a b1 c1 d1 c2 b2 c3 d2.
((a, b1, c1, d1), (c2, b2), c3, d2) -> (a, b1, c2)
getTrain (((Array S Ix2 Double, PVector, Array S Ix2 Double, PVector),
  (Maybe PVector, Maybe PVector), [Char], [Char])
 -> DataSet)
-> IO
     ((Array S Ix2 Double, PVector, Array S Ix2 Double, PVector),
      (Maybe PVector, Maybe PVector), [Char], [Char])
-> IO DataSet
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Char]
-> Bool
-> IO
     ((Array S Ix2 Double, PVector, Array S Ix2 Double, PVector),
      (Maybe PVector, Maybe PVector), [Char], [Char])
loadDataset [Char]
fname Bool
b