{-# LANGUAGE AllowAmbiguousTypes #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}

module DataFrame.Typed.Aggregate (
    -- * Typed groupBy
    groupBy,

    -- * Typed aggregation builder (Option B)
    agg,
    aggNil,

    -- * Running aggregations
    aggregate,

    -- * Escape hatch
    aggregateUntyped,
) where

import Data.Proxy (Proxy (..))
import qualified Data.Text as T
import GHC.TypeLits (KnownSymbol, Symbol, symbolVal)

import DataFrame.Internal.Column (Columnable)
import qualified DataFrame.Internal.DataFrame as D
import DataFrame.Internal.Expression (NamedExpr)
import qualified DataFrame.Operations.Aggregation as DA

import DataFrame.Typed.Freeze (unsafeFreeze)
import DataFrame.Typed.Schema
import DataFrame.Typed.Types

{- | Group a typed DataFrame by one or more key columns.

@
grouped = groupBy \@'[\"department\"] employees
@
-}
groupBy ::
    forall (keys :: [Symbol]) cols.
    (AllKnownSymbol keys, AssertAllPresent keys cols) =>
    TypedDataFrame cols -> TypedGrouped keys cols
groupBy :: forall (keys :: [Symbol]) (cols :: [*]).
(AllKnownSymbol keys, AssertAllPresent keys cols) =>
TypedDataFrame cols -> TypedGrouped keys cols
groupBy (TDF DataFrame
df) = GroupedDataFrame -> TypedGrouped keys cols
forall (keys :: [Symbol]) (cols :: [*]).
GroupedDataFrame -> TypedGrouped keys cols
TGD ([Text] -> DataFrame -> GroupedDataFrame
DA.groupBy (forall (names :: [Symbol]). AllKnownSymbol names => [Text]
symbolVals @keys) DataFrame
df)

-- | The empty aggregation — no output columns beyond the group keys.
aggNil :: TAgg keys cols '[]
aggNil :: forall (keys :: [Symbol]) (cols :: [*]). TAgg keys cols '[]
aggNil = TAgg keys cols '[]
forall (keys :: [Symbol]) (cols :: [*]). TAgg keys cols '[]
TAggNil

{- | Add one aggregation to the builder.

Each call prepends a @Column name a@ to the result schema and records
the runtime 'NamedExpr'. The expression is validated against the
source schema @cols@ at compile time.

@
agg \@\"total_sales\" (tsum (col \@\"salary\"))
  $ agg \@\"avg_price\" (tmean (col \@\"price\"))
  $ aggNil
@
-}
agg ::
    forall name a keys cols aggs.
    ( KnownSymbol name
    , Columnable a
    ) =>
    TExpr cols a -> TAgg keys cols aggs -> TAgg keys cols (Column name a ': aggs)
agg :: forall (name :: Symbol) a (keys :: [Symbol]) (cols :: [*])
       (aggs :: [*]).
(KnownSymbol name, Columnable a) =>
TExpr cols a
-> TAgg keys cols aggs -> TAgg keys cols (Column name a : aggs)
agg = Text
-> TExpr cols a
-> TAgg keys cols aggs
-> TAgg keys cols (Column name a : aggs)
forall a (cols :: [*]) (keys :: [Symbol]) (aggs1 :: [*])
       (name :: Symbol).
Columnable a =>
Text
-> TExpr cols a
-> TAgg keys cols aggs1
-> TAgg keys cols (Column name a : aggs1)
TAggCons Text
colName
  where
    colName :: Text
colName = String -> Text
T.pack (Proxy name -> String
forall (n :: Symbol) (proxy :: Symbol -> *).
KnownSymbol n =>
proxy n -> String
symbolVal (forall {k} (t :: k). Proxy t
forall (t :: Symbol). Proxy t
Proxy @name))

{- | Run a typed aggregation.

Result schema = grouping key columns ++ aggregated columns (in declaration order).

@
result = aggregate
    (agg \@\"total\" (tsum (col @"salary")) $ agg \@\"count\" (tcount (col @"salary") $ aggNil)
    (groupBy \@'[\"dept\"] employees)
-- result :: TDF '[Column \"dept\" Text, Column \"total\" Double, Column \"count\" Int]
@
-}
aggregate ::
    forall keys cols aggs.
    TAgg keys cols aggs ->
    TypedGrouped keys cols ->
    TypedDataFrame (Append (GroupKeyColumns keys cols) (Reverse aggs))
aggregate :: forall (keys :: [Symbol]) (cols :: [*]) (aggs :: [*]).
TAgg keys cols aggs
-> TypedGrouped keys cols
-> TypedDataFrame
     (Append (GroupKeyColumns keys cols) (Reverse aggs))
aggregate TAgg keys cols aggs
tagg (TGD GroupedDataFrame
gdf) =
    DataFrame
-> TypedDataFrame
     (Append (GroupKeyColumns keys cols) (ReverseAcc aggs '[]))
forall (cols :: [*]). DataFrame -> TypedDataFrame cols
unsafeFreeze ([NamedExpr] -> GroupedDataFrame -> DataFrame
DA.aggregate (TAgg keys cols aggs -> [NamedExpr]
forall (keys :: [Symbol]) (cols :: [*]) (aggs :: [*]).
TAgg keys cols aggs -> [NamedExpr]
taggToNamedExprs TAgg keys cols aggs
tagg) GroupedDataFrame
gdf)

-- | Escape hatch: run an untyped aggregation and return a raw 'DataFrame'.
aggregateUntyped :: [NamedExpr] -> TypedGrouped keys cols -> D.DataFrame
aggregateUntyped :: forall (keys :: [Symbol]) (cols :: [*]).
[NamedExpr] -> TypedGrouped keys cols -> DataFrame
aggregateUntyped [NamedExpr]
exprs (TGD GroupedDataFrame
gdf) = [NamedExpr] -> GroupedDataFrame -> DataFrame
DA.aggregate [NamedExpr]
exprs GroupedDataFrame
gdf