{-# LANGUAGE OverloadedStrings #-}

module DataFrame.Operations.Join where

import qualified Data.Map.Strict as M
import qualified Data.Text as T
import qualified Data.Vector as VB
import qualified Data.Vector.Unboxed as VU
import DataFrame.Internal.Column as D
import DataFrame.Internal.DataFrame as D
import DataFrame.Operations.Aggregation as D
import DataFrame.Operations.Core as D

-- | Equivalent to SQL join types.
data JoinType
    = INNER
    | LEFT
    | RIGHT
    | FULL_OUTER

{- | Join two dataframes using SQL join semantics.

Only inner join is implemented for now.
-}
join ::
    JoinType ->
    [T.Text] ->
    DataFrame -> -- Right hand side
    DataFrame -> -- Left hand side
    DataFrame
join :: JoinType -> [Text] -> DataFrame -> DataFrame -> DataFrame
join JoinType
INNER [Text]
xs DataFrame
right = [Text] -> DataFrame -> DataFrame -> DataFrame
innerJoin [Text]
xs DataFrame
right
join JoinType
LEFT [Text]
xs DataFrame
right = [Char] -> DataFrame -> DataFrame
forall a. HasCallStack => [Char] -> a
error [Char]
"UNIMPLEMENTED"
join JoinType
RIGHT [Text]
xs DataFrame
right = [Char] -> DataFrame -> DataFrame
forall a. HasCallStack => [Char] -> a
error [Char]
"UNIMPLEMENTED"
join JoinType
FULL_OUTER [Text]
xs DataFrame
right = [Char] -> DataFrame -> DataFrame
forall a. HasCallStack => [Char] -> a
error [Char]
"UNIMPLEMENTED"

{- | Inner join of two dataframes. Note: for chaining, the left dataframe is actually
on the right side.
-}
innerJoin :: [T.Text] -> DataFrame -> DataFrame -> DataFrame
innerJoin :: [Text] -> DataFrame -> DataFrame -> DataFrame
innerJoin [Text]
cs DataFrame
right DataFrame
left =
    let
        leftIndicesToGroup :: [Int]
leftIndicesToGroup = Map Text Int -> [Int]
forall k a. Map k a -> [a]
M.elems (Map Text Int -> [Int]) -> Map Text Int -> [Int]
forall a b. (a -> b) -> a -> b
$ (Text -> Int -> Bool) -> Map Text Int -> Map Text Int
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (\Text
k Int
_ -> Text
k Text -> [Text] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Text]
cs) (DataFrame -> Map Text Int
D.columnIndices DataFrame
left)
        leftRowRepresentations :: Vector Int
leftRowRepresentations = Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
VU.generate ((Int, Int) -> Int
forall a b. (a, b) -> a
fst (DataFrame -> (Int, Int)
D.dimensions DataFrame
left)) ([Int] -> DataFrame -> Int -> Int
D.mkRowRep [Int]
leftIndicesToGroup DataFrame
left)
        -- key -> [index0, index1]
        leftKeyCountsAndIndices :: Map Int [Int]
leftKeyCountsAndIndices =
            ((Int, Int) -> Map Int [Int] -> Map Int [Int])
-> Map Int [Int] -> Vector (Int, Int) -> Map Int [Int]
forall a b. Unbox a => (a -> b -> b) -> b -> Vector a -> b
VU.foldr
                (\(Int
i, Int
v) Map Int [Int]
acc -> ([Int] -> [Int] -> [Int])
-> Int -> [Int] -> Map Int [Int] -> Map Int [Int]
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
(++) Int
v [Int
i] Map Int [Int]
acc)
                Map Int [Int]
forall k a. Map k a
M.empty
                (Vector Int -> Vector (Int, Int)
forall a. Unbox a => Vector a -> Vector (Int, a)
VU.indexed Vector Int
leftRowRepresentations)
        -- key -> [index0, index1]
        rightIndicesToGroup :: [Int]
rightIndicesToGroup = Map Text Int -> [Int]
forall k a. Map k a -> [a]
M.elems (Map Text Int -> [Int]) -> Map Text Int -> [Int]
forall a b. (a -> b) -> a -> b
$ (Text -> Int -> Bool) -> Map Text Int -> Map Text Int
forall k a. (k -> a -> Bool) -> Map k a -> Map k a
M.filterWithKey (\Text
k Int
_ -> Text
k Text -> [Text] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Text]
cs) (DataFrame -> Map Text Int
D.columnIndices DataFrame
right)
        rightRowRepresentations :: Vector Int
rightRowRepresentations = Int -> (Int -> Int) -> Vector Int
forall a. Unbox a => Int -> (Int -> a) -> Vector a
VU.generate ((Int, Int) -> Int
forall a b. (a, b) -> a
fst (DataFrame -> (Int, Int)
D.dimensions DataFrame
right)) ([Int] -> DataFrame -> Int -> Int
D.mkRowRep [Int]
rightIndicesToGroup DataFrame
right)
        rightKeyCountsAndIndices :: Map Int [Int]
rightKeyCountsAndIndices =
            ((Int, Int) -> Map Int [Int] -> Map Int [Int])
-> Map Int [Int] -> Vector (Int, Int) -> Map Int [Int]
forall a b. Unbox a => (a -> b -> b) -> b -> Vector a -> b
VU.foldr
                (\(Int
i, Int
v) Map Int [Int]
acc -> ([Int] -> [Int] -> [Int])
-> Int -> [Int] -> Map Int [Int] -> Map Int [Int]
forall k a. Ord k => (a -> a -> a) -> k -> a -> Map k a -> Map k a
M.insertWith [Int] -> [Int] -> [Int]
forall a. [a] -> [a] -> [a]
(++) Int
v [Int
i] Map Int [Int]
acc)
                Map Int [Int]
forall k a. Map k a
M.empty
                (Vector Int -> Vector (Int, Int)
forall a. Unbox a => Vector a -> Vector (Int, a)
VU.indexed Vector Int
rightRowRepresentations)
        -- key -> [(left_indexes0, right_indexes1)]
        mergedKeyCountsAndIndices :: Map Int (Vector Int, Vector Int)
mergedKeyCountsAndIndices =
            (Int
 -> [Int]
 -> Map Int (Vector Int, Vector Int)
 -> Map Int (Vector Int, Vector Int))
-> Map Int (Vector Int, Vector Int)
-> Map Int [Int]
-> Map Int (Vector Int, Vector Int)
forall k a b. (k -> a -> b -> b) -> b -> Map k a -> b
M.foldrWithKey
                ( \Int
k [Int]
v Map Int (Vector Int, Vector Int)
m ->
                    if Int
k Int -> Map Int [Int] -> Bool
forall k a. Ord k => k -> Map k a -> Bool
`M.member` Map Int [Int]
rightKeyCountsAndIndices
                        then Int
-> (Vector Int, Vector Int)
-> Map Int (Vector Int, Vector Int)
-> Map Int (Vector Int, Vector Int)
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Int
k ([Int] -> Vector Int
forall a. Unbox a => [a] -> Vector a
VU.fromList [Int]
v, [Int] -> Vector Int
forall a. Unbox a => [a] -> Vector a
VU.fromList (Map Int [Int]
rightKeyCountsAndIndices Map Int [Int] -> Int -> [Int]
forall k a. Ord k => Map k a -> k -> a
M.! Int
k)) Map Int (Vector Int, Vector Int)
m
                        else Map Int (Vector Int, Vector Int)
m
                )
                Map Int (Vector Int, Vector Int)
forall k a. Map k a
M.empty
                Map Int [Int]
leftKeyCountsAndIndices
        -- [(ints, ints)]
        leftAndRightIndicies :: [(Vector Int, Vector Int)]
leftAndRightIndicies = Map Int (Vector Int, Vector Int) -> [(Vector Int, Vector Int)]
forall k a. Map k a -> [a]
M.elems Map Int (Vector Int, Vector Int)
mergedKeyCountsAndIndices
        -- [(ints, ints)] (expanded to n * m)
        expandedIndices :: [(Vector Int, Vector Int)]
expandedIndices =
            ((Vector Int, Vector Int) -> (Vector Int, Vector Int))
-> [(Vector Int, Vector Int)] -> [(Vector Int, Vector Int)]
forall a b. (a -> b) -> [a] -> [b]
map
                ( \(Vector Int
l, Vector Int
r) -> ([Vector Int] -> Vector Int
forall a. Monoid a => [a] -> a
mconcat (Int -> Vector Int -> [Vector Int]
forall a. Int -> a -> [a]
replicate (Vector Int -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector Int
r) Vector Int
l), [Vector Int] -> Vector Int
forall a. Monoid a => [a] -> a
mconcat (Int -> Vector Int -> [Vector Int]
forall a. Int -> a -> [a]
replicate (Vector Int -> Int
forall a. Unbox a => Vector a -> Int
VU.length Vector Int
l) Vector Int
r))
                )
                [(Vector Int, Vector Int)]
leftAndRightIndicies
        expandedLeftIndicies :: Vector Int
expandedLeftIndicies = [Vector Int] -> Vector Int
forall a. Monoid a => [a] -> a
mconcat (((Vector Int, Vector Int) -> Vector Int)
-> [(Vector Int, Vector Int)] -> [Vector Int]
forall a b. (a -> b) -> [a] -> [b]
map (Vector Int, Vector Int) -> Vector Int
forall a b. (a, b) -> a
fst [(Vector Int, Vector Int)]
expandedIndices)
        expandedRightIndicies :: Vector Int
expandedRightIndicies = [Vector Int] -> Vector Int
forall a. Monoid a => [a] -> a
mconcat (((Vector Int, Vector Int) -> Vector Int)
-> [(Vector Int, Vector Int)] -> [Vector Int]
forall a b. (a -> b) -> [a] -> [b]
map (Vector Int, Vector Int) -> Vector Int
forall a b. (a, b) -> b
snd [(Vector Int, Vector Int)]
expandedIndices)
        -- df
        expandedLeft :: DataFrame
expandedLeft =
            DataFrame
left
                { columns = VB.map (D.atIndicesStable expandedLeftIndicies) (D.columns left)
                , dataframeDimensions =
                    (VU.length expandedLeftIndicies, snd (D.dataframeDimensions left))
                }
        -- df
        expandedRight :: DataFrame
expandedRight =
            DataFrame
right
                { columns = VB.map (D.atIndicesStable expandedRightIndicies) (D.columns right)
                , dataframeDimensions =
                    (VU.length expandedRightIndicies, snd (D.dataframeDimensions right))
                }
        -- [string]
        leftColumns :: [Text]
leftColumns = DataFrame -> [Text]
D.columnNames DataFrame
left
        rightColumns :: [Text]
rightColumns = DataFrame -> [Text]
D.columnNames DataFrame
right
        initDf :: DataFrame
initDf = DataFrame
expandedLeft
        insertIfPresent :: Text -> Maybe Column -> DataFrame -> DataFrame
insertIfPresent Text
_ Maybe Column
Nothing DataFrame
df = DataFrame
df
        insertIfPresent Text
name (Just Column
c) DataFrame
df = Text -> Column -> DataFrame -> DataFrame
D.insertColumn Text
name Column
c DataFrame
df
     in
        (Text -> DataFrame -> DataFrame)
-> [Text] -> DataFrame -> DataFrame
forall a.
(a -> DataFrame -> DataFrame) -> [a] -> DataFrame -> DataFrame
D.fold
            ( \Text
name DataFrame
df ->
                if Text
name Text -> [Text] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Text]
cs
                    then DataFrame
df
                    else
                        ( if Text
name Text -> [Text] -> Bool
forall a. Eq a => a -> [a] -> Bool
forall (t :: * -> *) a. (Foldable t, Eq a) => a -> t a -> Bool
`elem` [Text]
leftColumns
                            then Text -> Maybe Column -> DataFrame -> DataFrame
insertIfPresent (Text
"Right_" Text -> Text -> Text
forall a. Semigroup a => a -> a -> a
<> Text
name) (Text -> DataFrame -> Maybe Column
D.getColumn Text
name DataFrame
expandedRight) DataFrame
df
                            else Text -> Maybe Column -> DataFrame -> DataFrame
insertIfPresent Text
name (Text -> DataFrame -> Maybe Column
D.getColumn Text
name DataFrame
expandedRight) DataFrame
df
                        )
            )
            [Text]
rightColumns
            DataFrame
initDf