-- SPDX-FileCopyrightText: 2020 Tocqueville Group
--
-- SPDX-License-Identifier: LicenseRef-MIT-TQ

-- | Common implementations of entrypoints.
module Lorentz.Entrypoints.Impl
  ( -- * Ways to implement 'ParameterHasEntrypoints'
    EpdPlain
  , EpdRecursive
  , EpdDelegate
  , EpdWithRoot

  -- * Implementation details
  , PlainEntrypointsC
  , EPTree (..)
  , BuildEPTree
  ) where

import Data.Singletons.Prelude (SBool(SFalse, STrue))
import Data.Singletons.Prelude.Eq ((%==))
import Data.Vinyl.Core (Rec(..), (<+>))
import Data.Vinyl.Recursive (rmap)
import Fcf (Eval, Exp)
import qualified Fcf
import qualified GHC.Generics as G
import Util.TypeLits

import Lorentz.Value
import Michelson.Typed
import Michelson.Typed.Haskell.Value (GValueType)
import Michelson.Untyped (FieldAnn, mkAnnotationUnsafe, noAnn)
import Util.Fcf (type (<|>), Over2, TyEqSing)
import Util.Type

import Lorentz.Annotation
import Lorentz.Entrypoints.Core
import Lorentz.Entrypoints.Helpers

-- | Implementation of 'ParameterHasEntrypoints' which fits for case when
-- your contract exposes multiple entrypoints via having sum type as its
-- parameter.
--
-- In particular, each constructor would produce a homonymous entrypoint with
-- argument type equal to type of constructor field (each constructor should
-- have only one field).
-- Constructor called 'Default' will designate the default entrypoint.
data EpdPlain
instance PlainEntrypointsC EpdPlain cp => EntrypointsDerivation EpdPlain cp where
  type EpdAllEntrypoints EpdPlain cp = PlainAllEntrypointsExt EpdPlain cp
  type EpdLookupEntrypoint EpdPlain cp = PlainLookupEntrypointExt EpdPlain cp
  epdNotes = (plainEpdNotesExt @EpdPlain @cp, noAnn)
  epdCall = plainEpdCallExt @EpdPlain @cp
  epdDescs = plainEpdDescsExt @EpdPlain @cp

-- | Extension of 'EpdPlain' on parameters being defined as several nested
-- datatypes.
--
-- In particular, this will traverse sum types recursively, stopping at
-- Michelson primitives (like 'Natural') and constructors with number of
-- fields different from one.
--
-- It does not assign names to intermediate nodes of 'Or' tree, only to the very
-- leaves.
--
-- If some entrypoint arguments have custom 'IsoValue' instance, this
-- derivation way will not work. As a workaround, you can wrap your
-- argument into some primitive (e.g. ':!').
data EpdRecursive
instance PlainEntrypointsC EpdRecursive cp => EntrypointsDerivation EpdRecursive cp where
  type EpdAllEntrypoints EpdRecursive cp = PlainAllEntrypointsExt EpdRecursive cp
  type EpdLookupEntrypoint EpdRecursive cp = PlainLookupEntrypointExt EpdRecursive cp
  epdNotes = (plainEpdNotesExt @EpdRecursive @cp, noAnn)
  epdCall = plainEpdCallExt @EpdRecursive @cp
  epdDescs = plainEpdDescsExt @EpdRecursive @cp

-- | Extension of 'EpdPlain' on parameters being defined as several nested
-- datatypes.
--
-- In particular, it will traverse the immediate sum type, and require another
-- 'ParameterHasEntrypoints' for the inner complex datatypes. Only those
-- inner types are considered which are the only fields in their respective
-- constructors.
-- Inner types should not themselves declare default entrypoint, we enforce
-- this for better modularity.
-- Each top-level constructor will be treated as entrypoint even if it contains
-- a complex datatype within, in such case that would be an entrypoint
-- corresponding to intermediate node in @or@ tree.
--
-- Comparing to 'EpdRecursive' this gives you more control over where and how
-- entrypoints will be derived.
data EpdDelegate
instance (PlainEntrypointsC EpdDelegate cp) => EntrypointsDerivation EpdDelegate cp where
  type EpdAllEntrypoints EpdDelegate cp = PlainAllEntrypointsExt EpdDelegate cp
  type EpdLookupEntrypoint EpdDelegate cp = PlainLookupEntrypointExt EpdDelegate cp
  epdNotes = (plainEpdNotesExt @EpdDelegate @cp, noAnn)
  epdCall = plainEpdCallExt @EpdDelegate @cp
  epdDescs = plainEpdDescsExt @EpdDelegate @cp

-- | Extension of 'EpdPlain', 'EpdRecursive', and 'EpdDelegate' which allow specifying root annotation
-- for the parameters.
data EpdWithRoot (r :: Symbol) epd
instance (KnownSymbol r, PlainEntrypointsC deriv cp) => EntrypointsDerivation (EpdWithRoot r deriv) cp where
  type EpdAllEntrypoints (EpdWithRoot r deriv) cp =
    '(r, cp) ': PlainAllEntrypointsExt deriv cp
  type EpdLookupEntrypoint (EpdWithRoot r deriv) cp =
    Fcf.Case
      '[ Fcf.Is (TyEqSing r) ('Just cp)
       , Fcf.Else (PlainLookupEntrypointExt deriv cp)
       ]
  epdNotes = (plainEpdNotesExt @deriv @cp, mkAnnotationUnsafe (symbolValT' @r))
  epdCall label@(Label :: Label name) = case sing @r %== sing @name of
      STrue -> EpConstructed EplArgHere
      SFalse -> plainEpdCallExt @deriv @cp label
  epdDescs =
    (addDescStep @r $
      EpCallingDesc
        { epcdArg = Proxy
        , epcdEntrypoint = ctorNameToEp @r
        , epcdSteps = []
        } :& RNil
    ) <+> plainEpdDescsExt @deriv @cp

type PlainAllEntrypointsExt mode cp = AllEntrypoints mode (BuildEPTree mode cp) cp

type PlainLookupEntrypointExt mode cp = LookupEntrypoint mode (BuildEPTree mode cp) cp

plainEpdNotesExt
  :: forall mode cp.
     (PlainEntrypointsC mode cp, HasCallStack)
  => Notes (ToT cp)
plainEpdNotesExt = mkEntrypointsNotes @mode @(BuildEPTree mode cp) @cp

plainEpdCallExt
  :: forall mode cp name.
     (PlainEntrypointsC mode cp, ParameterScope (ToT cp))
  => Label name
  -> EpConstructionRes (ToT cp) (Eval (LookupEntrypoint mode (BuildEPTree mode cp) cp name))
plainEpdCallExt = mkEpLiftSequence @mode @(BuildEPTree mode cp) @cp

plainEpdDescsExt
  :: forall mode cp.
     (PlainEntrypointsC mode cp)
  => Rec EpCallingDesc (PlainAllEntrypointsExt mode cp)
plainEpdDescsExt = mkEpDescs @mode @(BuildEPTree mode cp) @cp

type PlainEntrypointsC mode cp =
  ( GenericIsoValue cp
  , EntrypointsNotes mode (BuildEPTree mode cp) cp
  , RequireSumType cp
  )

-- | Entrypoints tree - skeleton on 'TOr' tree later used to distinguish
-- between constructors-entrypoints and constructors which consolidate
-- a whole pack of entrypoints.
data EPTree
  = EPNode EPTree EPTree
    -- ^ We are in the intermediate node and need to go deeper.
  | EPLeaf
    -- ^ We reached entrypoint argument.
  | EPDelegate
    -- ^ We reached complex parameter part and will need to ask how to process it.

-- | Build 'EPTree' by parameter type.
type BuildEPTree mode a = GBuildEntrypointsTree mode (G.Rep a)

type family GBuildEntrypointsTree (mode :: Type) (x :: Type -> Type)
             :: EPTree where
  GBuildEntrypointsTree mode (G.D1 _ x) =
    GBuildEntrypointsTree mode x
  GBuildEntrypointsTree mode (x G.:+: y) =
    'EPNode (GBuildEntrypointsTree mode x) (GBuildEntrypointsTree mode y)

  GBuildEntrypointsTree EpdPlain (G.C1 _ _) =
    'EPLeaf
  GBuildEntrypointsTree EpdRecursive (G.C1 _ x) =
    GBuildEntrypointsTree EpdRecursive x
  GBuildEntrypointsTree EpdDelegate (G.C1 _ (G.S1 _ (G.Rec0 _))) =
    'EPDelegate
  GBuildEntrypointsTree EpdDelegate (G.C1 _ _) =
    'EPLeaf
  GBuildEntrypointsTree mode (G.S1 _ x) =
    GBuildEntrypointsTree mode x
  GBuildEntrypointsTree _ G.U1 =
    'EPLeaf
  GBuildEntrypointsTree _ (_ G.:*: _) =
    'EPLeaf
  GBuildEntrypointsTree mode (G.Rec0 a) =
    If (IsPrimitiveValue a)
       'EPLeaf
       (BuildEPTree mode a)

-- | Traverses sum type and constructs 'Notes' which report
-- constructor names via field annotations.
type EntrypointsNotes mode ep a = (Generic a, GEntrypointsNotes mode ep (G.Rep a))

-- | Makes up notes with proper field annotations for given parameter.
mkEntrypointsNotes
  :: forall mode ep a.
      (EntrypointsNotes mode ep a, GenericIsoValue a, HasCallStack)
  => Notes (ToT a)
mkEntrypointsNotes = fst $ gMkEntrypointsNotes @mode @ep @(G.Rep a)

-- | Makes up a way to lift entrypoint argument to full parameter.
mkEpLiftSequence
  :: forall mode ep a name.
      ( EntrypointsNotes mode ep a, ParameterScope (ToT a)
      , GenericIsoValue a
      )
  => Label name
  -> EpConstructionRes (ToT a) (Eval (LookupEntrypoint mode ep a name))
mkEpLiftSequence = gMkEpLiftSequence @mode @ep @(G.Rep a)

-- | Makes up descriptions of entrypoints calling.
mkEpDescs
  :: forall mode ep a.
      (EntrypointsNotes mode ep a)
  => Rec EpCallingDesc (AllEntrypoints mode ep a)
mkEpDescs = gMkDescs @mode @ep @(G.Rep a)

-- | Fetches information about all entrypoints - leaves of 'Or' tree.
type AllEntrypoints mode ep a = GAllEntrypoints mode ep (G.Rep a)

-- | Fetches information about all entrypoints - leaves of 'Or' tree.
type LookupEntrypoint mode ep a = GLookupEntrypoint mode ep (G.Rep a)

-- | Generic traversal for 'EntrypointsNotes'.
class GEntrypointsNotes (mode :: Type) (ep :: EPTree) (x :: Type -> Type) where
  type GAllEntrypoints mode ep x :: [(Symbol, Type)]
  type GLookupEntrypoint mode ep x :: Symbol -> Exp (Maybe Type)

  {- | Returns:
    1. Notes corresponding to this level;
    2. Field annotation for this level (and which should be used one level above).
  -}
  gMkEntrypointsNotes :: HasCallStack => (Notes (GValueType x), FieldAnn)

  gMkEpLiftSequence
    :: ParameterScope (GValueType x)
    => Label name
    -> EpConstructionRes (GValueType x) (Eval (GLookupEntrypoint mode ep x name))

  gMkDescs
    :: Rec EpCallingDesc (GAllEntrypoints mode ep x)

instance GEntrypointsNotes mode ep x => GEntrypointsNotes mode ep (G.D1 i x) where
  type GAllEntrypoints mode ep (G.D1 i x) = GAllEntrypoints mode ep x
  type GLookupEntrypoint mode ep (G.D1 i x) = GLookupEntrypoint mode ep x
  gMkEntrypointsNotes = gMkEntrypointsNotes @mode @ep @x
  gMkEpLiftSequence = gMkEpLiftSequence @mode @ep @x
  gMkDescs = gMkDescs @mode @ep @x

instance (GEntrypointsNotes mode epx x, GEntrypointsNotes mode epy y) =>
         GEntrypointsNotes mode ('EPNode epx epy) (x G.:+: y) where
  type GAllEntrypoints mode ('EPNode epx epy) (x G.:+: y) =
    GAllEntrypoints mode epx x ++ GAllEntrypoints mode epy y
  type GLookupEntrypoint mode ('EPNode epx epy) (x G.:+: y) =
    Over2 (<|>) (GLookupEntrypoint mode epx x) (GLookupEntrypoint mode epy y)
  gMkEntrypointsNotes =
    let (xnotes, xann) = gMkEntrypointsNotes @mode @epx @x
        (ynotes, yann) = gMkEntrypointsNotes @mode @epy @y
    in (NTOr noAnn xann yann xnotes ynotes, noAnn)
  gMkEpLiftSequence label =
    case sing @(GValueType (x G.:+: y)) of
      STOr sl _ -> case (checkOpPresence sl, checkNestedBigMapsPresence sl) of
        (OpAbsent, NestedBigMapsAbsent) ->
          case gMkEpLiftSequence @mode @epx @x label of
            EpConstructed liftSeq -> EpConstructed (EplWrapLeft liftSeq)
            EpConstructionFailed ->
              case gMkEpLiftSequence @mode @epy @y label of
                EpConstructed liftSeq -> EpConstructed (EplWrapRight liftSeq)
                EpConstructionFailed -> EpConstructionFailed
  gMkDescs =
    gMkDescs @mode @epx @x <+> gMkDescs @mode @epy @y

instance ( GHasAnnotation x, KnownSymbol ctor
         , ToT (GExtractField x) ~ GValueType x
         ) =>
         GEntrypointsNotes mode 'EPLeaf (G.C1 ('G.MetaCons ctor _1 _2) x) where
  type GAllEntrypoints mode 'EPLeaf (G.C1 ('G.MetaCons ctor _1 _2) x) =
    '[ '(ctor, GExtractField x) ]
  type GLookupEntrypoint mode 'EPLeaf (G.C1 ('G.MetaCons ctor _1 _2) x) =
    JustOnEq ctor (GExtractField x)
  gMkEntrypointsNotes =
    (gGetAnnotation @x defaultAnnOptions FollowEntrypoint NotGenerateFieldAnn ^. _1, ctorNameToAnn @ctor)
  gMkEpLiftSequence (Label :: Label name) =
    case sing @ctor %== sing @name of
      STrue -> EpConstructed EplArgHere
      SFalse -> EpConstructionFailed
  gMkDescs = addDescStep @ctor $
    EpCallingDesc
    { epcdArg = Proxy
    , epcdEntrypoint = ctorNameToEp @ctor
    , epcdSteps = []
    } :& RNil

instance (ep ~ 'EPNode epx epy, GEntrypointsNotes mode ep x, KnownSymbol ctor) =>
         GEntrypointsNotes mode ('EPNode epx epy) (G.C1 ('G.MetaCons ctor _1 _2) x) where
  type GAllEntrypoints mode ('EPNode epx epy) (G.C1 ('G.MetaCons ctor _1 _2) x) =
    GAllEntrypoints mode ('EPNode epx epy) x
  type GLookupEntrypoint mode ('EPNode epx epy) (G.C1 ('G.MetaCons ctor _1 _2) x) =
    GLookupEntrypoint mode ('EPNode epx epy) x
  gMkEntrypointsNotes = gMkEntrypointsNotes @mode @ep @x
  gMkEpLiftSequence = gMkEpLiftSequence @mode @ep @x
  gMkDescs = addDescStep @ctor $ gMkDescs @mode @ep @x

instance ( ep ~ 'EPDelegate, GEntrypointsNotes mode ep x
         , KnownSymbol ctor, ToT (GExtractField x) ~ GValueType x
         ) =>
         GEntrypointsNotes mode 'EPDelegate (G.C1 ('G.MetaCons ctor _1 _2) x) where
  type GAllEntrypoints mode 'EPDelegate (G.C1 ('G.MetaCons ctor _1 _2) x) =
    '(ctor, GExtractField x) ': GAllEntrypoints mode 'EPDelegate x
  type GLookupEntrypoint mode 'EPDelegate (G.C1 ('G.MetaCons ctor _1 _2) x) =
    Over2 (<|>) (JustOnEq ctor (GExtractField x)) (GLookupEntrypoint mode 'EPDelegate x)
  gMkEntrypointsNotes =
    let (notes, _rootAnn) = gMkEntrypointsNotes @mode @ep @x
    in (notes, ctorNameToAnn @ctor)
  gMkEpLiftSequence label@(Label :: Label name) =
    case sing @ctor %== sing @name of
      STrue -> EpConstructed EplArgHere
      SFalse -> gMkEpLiftSequence @mode @ep @x label
  gMkDescs = addDescStep @ctor $
    EpCallingDesc
    { epcdArg = Proxy
    , epcdEntrypoint = ctorNameToEp @ctor
    , epcdSteps = []
    } :& gMkDescs @mode @ep @x

instance GEntrypointsNotes mode ep x => GEntrypointsNotes mode ep (G.S1 i x) where
  type GAllEntrypoints mode ep (G.S1 i x) = GAllEntrypoints mode ep x
  type GLookupEntrypoint mode ep (G.S1 i x) = GLookupEntrypoint mode ep x
  gMkEntrypointsNotes = gMkEntrypointsNotes @mode @ep @x
  gMkEpLiftSequence = gMkEpLiftSequence @mode @ep @x
  gMkDescs = gMkDescs @mode @ep @x

instance (EntrypointsNotes EpdRecursive ep a, GenericIsoValue a) =>
         GEntrypointsNotes EpdRecursive ep (G.Rec0 a) where
  type GAllEntrypoints EpdRecursive ep (G.Rec0 a) = AllEntrypoints EpdRecursive ep a
  type GLookupEntrypoint EpdRecursive ep (G.Rec0 a) = LookupEntrypoint EpdRecursive ep a
  gMkEntrypointsNotes = (mkEntrypointsNotes @EpdRecursive @ep @a, noAnn)
  gMkEpLiftSequence = mkEpLiftSequence @EpdRecursive @ep @a
  gMkDescs = mkEpDescs @EpdRecursive @ep @a

instance (ParameterDeclaresEntrypoints a) =>
         GEntrypointsNotes EpdDelegate 'EPDelegate (G.Rec0 a) where
  type GAllEntrypoints EpdDelegate 'EPDelegate (G.Rec0 a) = AllParameterEntrypoints a
  type GLookupEntrypoint EpdDelegate 'EPDelegate (G.Rec0 a) = LookupParameterEntrypoint a
  gMkEntrypointsNotes = (fst (pepNotes @a), noAnn)
  gMkEpLiftSequence = pepCall @a
  gMkDescs = pepDescs @a

instance GEntrypointsNotes mode 'EPLeaf G.U1 where
  type GAllEntrypoints mode 'EPLeaf G.U1 = '[]
  type GLookupEntrypoint mode 'EPLeaf G.U1 = Fcf.ConstFn 'Nothing
  gMkEntrypointsNotes = (starNotes, noAnn)
  gMkEpLiftSequence _ = EpConstructionFailed
  gMkDescs = RNil

instance Each '[KnownT] [GValueType x, GValueType y] =>
         GEntrypointsNotes mode 'EPLeaf (x G.:*: y) where
  type GAllEntrypoints mode 'EPLeaf (x G.:*: y) = '[]
  type GLookupEntrypoint mode 'EPLeaf (x G.:*: y) = Fcf.ConstFn 'Nothing
  gMkEntrypointsNotes = (starNotes, noAnn)
  gMkEpLiftSequence _ = EpConstructionFailed
  gMkDescs = RNil

-- Return 'Just' iff given entries of type @k1@ are equal.
type family JustOnEq (a :: k1) (b :: k2) :: k1 -> Exp (Maybe k2) where
  JustOnEq a b =
    Fcf.Case
      '[ Fcf.Is (TyEqSing a) ('Just b)
       , Fcf.Any 'Nothing
       ]

-- Get field type under 'G.C1'.
type family GExtractField (x :: Type -> Type) where
  GExtractField (G.S1 _ x) = GExtractField x
  GExtractField (G.Rec0 a) = a
  GExtractField G.U1 = ()

addDescStep
  :: forall ctor eps.
      KnownSymbol ctor
  => Rec EpCallingDesc eps -> Rec EpCallingDesc eps
addDescStep =
  let step = EpsWrapIn $ symbolValT' @ctor
  in rmap $ \EpCallingDesc{..} ->
       EpCallingDesc{ epcdSteps = step : epcdSteps, .. }