{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}

module DataFrame.Typed.Join (
    -- * Typed joins
    innerJoin,
    leftJoin,
    rightJoin,
    fullOuterJoin,
) where

import GHC.TypeLits (Symbol)

import qualified DataFrame.Operations.Join as DJ

import DataFrame.Typed.Freeze (unsafeFreeze)
import DataFrame.Typed.Schema
import DataFrame.Typed.Types (TypedDataFrame (..))

-- | Typed inner join on one or more key columns.
innerJoin ::
    forall (keys :: [Symbol]) left right.
    (AllKnownSymbol keys) =>
    TypedDataFrame left ->
    TypedDataFrame right ->
    TypedDataFrame (InnerJoinSchema keys left right)
innerJoin :: forall (keys :: [Symbol]) (left :: [*]) (right :: [*]).
AllKnownSymbol keys =>
TypedDataFrame left
-> TypedDataFrame right
-> TypedDataFrame (InnerJoinSchema keys left right)
innerJoin (TDF DataFrame
l) (TDF DataFrame
r) =
    DataFrame
-> TypedDataFrame
     (Append
        (SubsetSchema keys left)
        (Append
           (UniqueLeft left (Append keys (ColumnNames right)))
           (Append
              (UniqueLeft right (Append keys (ColumnNames left)))
              (CollidingColumns left right keys))))
forall (cols :: [*]). DataFrame -> TypedDataFrame cols
unsafeFreeze ([Text] -> DataFrame -> DataFrame -> DataFrame
DJ.innerJoin [Text]
keyNames DataFrame
r DataFrame
l)
  where
    keyNames :: [Text]
keyNames = forall (names :: [Symbol]). AllKnownSymbol names => [Text]
symbolVals @keys

-- | Typed left join.
leftJoin ::
    forall (keys :: [Symbol]) left right.
    (AllKnownSymbol keys) =>
    TypedDataFrame left ->
    TypedDataFrame right ->
    TypedDataFrame (LeftJoinSchema keys left right)
leftJoin :: forall (keys :: [Symbol]) (left :: [*]) (right :: [*]).
AllKnownSymbol keys =>
TypedDataFrame left
-> TypedDataFrame right
-> TypedDataFrame (LeftJoinSchema keys left right)
leftJoin (TDF DataFrame
l) (TDF DataFrame
r) =
    DataFrame
-> TypedDataFrame
     (Append
        (SubsetSchema keys left)
        (Append
           (UniqueLeft left (Append keys (ColumnNames right)))
           (Append
              (WrapMaybe (UniqueLeft right (Append keys (ColumnNames left))))
              (CollidingColumns left right keys))))
forall (cols :: [*]). DataFrame -> TypedDataFrame cols
unsafeFreeze ([Text] -> DataFrame -> DataFrame -> DataFrame
DJ.leftJoin [Text]
keyNames DataFrame
l DataFrame
r)
  where
    keyNames :: [Text]
keyNames = forall (names :: [Symbol]). AllKnownSymbol names => [Text]
symbolVals @keys

-- | Typed right join.
rightJoin ::
    forall (keys :: [Symbol]) left right.
    (AllKnownSymbol keys) =>
    TypedDataFrame left ->
    TypedDataFrame right ->
    TypedDataFrame (RightJoinSchema keys left right)
rightJoin :: forall (keys :: [Symbol]) (left :: [*]) (right :: [*]).
AllKnownSymbol keys =>
TypedDataFrame left
-> TypedDataFrame right
-> TypedDataFrame (RightJoinSchema keys left right)
rightJoin (TDF DataFrame
l) (TDF DataFrame
r) =
    DataFrame
-> TypedDataFrame
     (Append
        (SubsetSchema keys right)
        (Append
           (WrapMaybe (UniqueLeft left (Append keys (ColumnNames right))))
           (Append
              (UniqueLeft right (Append keys (ColumnNames left)))
              (CollidingColumns left right keys))))
forall (cols :: [*]). DataFrame -> TypedDataFrame cols
unsafeFreeze ([Text] -> DataFrame -> DataFrame -> DataFrame
DJ.rightJoin [Text]
keyNames DataFrame
l DataFrame
r)
  where
    keyNames :: [Text]
keyNames = forall (names :: [Symbol]). AllKnownSymbol names => [Text]
symbolVals @keys

-- | Typed full outer join.
fullOuterJoin ::
    forall (keys :: [Symbol]) left right.
    (AllKnownSymbol keys) =>
    TypedDataFrame left ->
    TypedDataFrame right ->
    TypedDataFrame (FullOuterJoinSchema keys left right)
fullOuterJoin :: forall (keys :: [Symbol]) (left :: [*]) (right :: [*]).
AllKnownSymbol keys =>
TypedDataFrame left
-> TypedDataFrame right
-> TypedDataFrame (FullOuterJoinSchema keys left right)
fullOuterJoin (TDF DataFrame
l) (TDF DataFrame
r) =
    DataFrame
-> TypedDataFrame
     (Append
        (WrapMaybe (SubsetSchema keys left))
        (Append
           (WrapMaybe (UniqueLeft left (Append keys (ColumnNames right))))
           (Append
              (WrapMaybe (UniqueLeft right (Append keys (ColumnNames left))))
              (CollidingColumns left right keys))))
forall (cols :: [*]). DataFrame -> TypedDataFrame cols
unsafeFreeze ([Text] -> DataFrame -> DataFrame -> DataFrame
DJ.fullOuterJoin [Text]
keyNames DataFrame
r DataFrame
l)
  where
    keyNames :: [Text]
keyNames = forall (names :: [Symbol]). AllKnownSymbol names => [Text]
symbolVals @keys