{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_GHC -Wno-orphans #-}
{-# OPTIONS_GHC -Wno-x-partial #-}

{- |
Module: Crypto.Curve.Secp256k1.MuSig2.Internal
Copyright: (c) 2025 Jose Storopoli
License: MIT
Maintainer: Jose Storopoli <jose@storopoli.com>

Internal MuSig2 functions - not part of the public API.
-}
module Crypto.Curve.Secp256k1.MuSig2.Internal (
  -- Pubkey functions
  aggPublicKeys,
  -- Key aggregation
  computeKeyAggCoef,
  getSecondKey,
  -- utils/misc
  isEvenPub,
  xBytes,
  bytesToInteger,
  integerToBytes32,
  xorByteStrings,
  encodeLen,
  hashTag,
  hashTagModQ,
  hashProjectivesTag,
) where

import Crypto.Curve.Secp256k1 (Projective, Pub, add, modQ, mul, serialize_point, _CURVE_ZERO)
import Crypto.Hash.SHA256 (hash)
import Data.Bits (shiftR, xor, (.&.))
import Data.ByteString (ByteString)
import qualified Data.ByteString as BS
import Data.ByteString.Builder (toLazyByteString, word64BE)
import qualified Data.ByteString.Lazy as BSL
import Data.Foldable (Foldable (fold), find, toList)
import qualified Data.Foldable as F
import Data.Maybe (fromMaybe)
import Data.Sequence (Seq)
import qualified Data.Sequence as Seq
import Data.Traversable ()

{- | Aggregates a 'Traversable' of 'Pub'keys using the
[Key Aggregation algorithm in BIP327](https://github.com/bitcoin/bips/blob/master/bip-0327.mediawiki).

The algorith can be briefly described as

\[
f(pk_1, \dots, pk_u) = a_i \cdot pk_i
\]

where \(pk_i\) is the \(i\)th participant's public key and \(a_i\) is the
respective public key aggregation coefficient.

== WARNING

'aggPublicKeys' do not sort the keys and aggregates public keys according to the
ordering of the 'Traversable' provided.
-}
aggPublicKeys :: (Traversable t) => t Pub -> Maybe Pub
aggPublicKeys :: forall (t :: * -> *). Traversable t => t Pub -> Maybe Pub
aggPublicKeys t Pub
pks
  | Seq Pub -> Bool
forall a. Seq a -> Bool
Seq.null Seq Pub
pksSeq = Maybe Pub
forall a. Maybe a
Nothing
  | Bool
otherwise = Pub -> Maybe Pub
forall a. a -> Maybe a
Just (Pub -> Maybe Pub) -> Pub -> Maybe Pub
forall a b. (a -> b) -> a -> b
$ Pub -> Seq Pub -> Pub
fold1WithDefault Pub
_CURVE_ZERO ((Integer -> Pub -> Pub) -> Seq Integer -> Seq Pub -> Seq Pub
forall a b c. (a -> b -> c) -> Seq a -> Seq b -> Seq c
Seq.zipWith Integer -> Pub -> Pub
aggPk Seq Integer
coefs Seq Pub
pksSeq)
 where
  pksSeq :: Seq Pub
pksSeq = [Pub] -> Seq Pub
forall a. [a] -> Seq a
Seq.fromList (t Pub -> [Pub]
forall a. t a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList t Pub
pks)
  coefs :: Seq Integer
coefs = (Pub -> Integer) -> Seq Pub -> Seq Integer
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Pub -> Seq Pub -> Integer
`computeKeyAggCoef` Seq Pub
pksSeq) Seq Pub
pksSeq
  aggPk :: Integer -> Pub -> Pub
aggPk Integer
i Pub
p = Pub -> Integer -> Pub
mul Pub
p Integer
i -- mul takes first point then scalar
  -- Safe fold1 that handles empty sequences
  fold1WithDefault :: Pub -> Seq Pub -> Pub
fold1WithDefault Pub
def Seq Pub
xs = case Seq Pub -> ViewL Pub
forall a. Seq a -> ViewL a
Seq.viewl Seq Pub
xs of
    ViewL Pub
Seq.EmptyL -> Pub
def
    Pub
x Seq.:< Seq Pub
xs' -> (Pub -> Pub -> Pub) -> Pub -> Seq Pub -> Pub
forall b a. (b -> a -> b) -> b -> Seq a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
F.foldl' Pub -> Pub -> Pub
add Pub
x Seq Pub
xs'

{- | Computes the key aggregation coefficient from:

1. Desired key to compute the key aggregation coefficient
2. 'Seq' of 'Pub'keys
-}
computeKeyAggCoef :: Pub -> Seq Pub -> Integer
computeKeyAggCoef :: Pub -> Seq Pub -> Integer
computeKeyAggCoef Pub
pk Seq Pub
pks =
  let pk2 :: Pub
pk2 = Seq Pub -> Pub
getSecondKey Seq Pub
pks
      hashKeys :: ByteString
hashKeys = ByteString -> Seq Pub -> ByteString
hashProjectivesTag ByteString
"KeyAgg list" Seq Pub
pks
      taggedHash :: ByteString
taggedHash = ByteString -> ByteString -> ByteString
hashTag ByteString
"KeyAgg coefficient" (ByteString
hashKeys ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> Pub -> ByteString
serialize_point Pub
pk)
   in if Pub
pk Pub -> Pub -> Bool
forall a. Eq a => a -> a -> Bool
== Pub
pk2 then Integer
1 else Integer -> Integer
modQ (Integer -> Integer) -> Integer -> Integer
forall a b. (a -> b) -> a -> b
$ ByteString -> Integer
bytesToInteger ByteString
taggedHash

{- | Returns the first second key that is different from the first key in
a 'Seq' of 'Pub'keys.

Returns the point at infinity, i.e. zero'th point of monoidal identity.
-}
getSecondKey :: Seq Pub -> Pub
getSecondKey :: Seq Pub -> Pub
getSecondKey Seq Pub
pks =
  case Seq Pub -> ViewL Pub
forall a. Seq a -> ViewL a
Seq.viewl Seq Pub
pks of
    ViewL Pub
Seq.EmptyL -> Pub
_CURVE_ZERO
    Pub
pk1 Seq.:< Seq Pub
_ ->
      let pk2 :: Maybe Pub
pk2 = (Pub -> Bool) -> Seq Pub -> Maybe Pub
forall (t :: * -> *) a. Foldable t => (a -> Bool) -> t a -> Maybe a
find (Pub -> Pub -> Bool
forall a. Eq a => a -> a -> Bool
/= Pub
pk1) Seq Pub
pks
       in Pub -> Maybe Pub -> Pub
forall a. a -> Maybe a -> a
fromMaybe Pub
_CURVE_ZERO Maybe Pub
pk2

-- | "Taghashes" a 'Seq' of 'Projective's by concatenating all their 'ByteString' representations together.
hashProjectivesTag :: ByteString -> Seq Projective -> ByteString
hashProjectivesTag :: ByteString -> Seq Pub -> ByteString
hashProjectivesTag ByteString
tag Seq Pub
ps = ByteString -> ByteString -> ByteString
hashTag ByteString
tag (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Seq ByteString -> ByteString
forall m. Monoid m => Seq m -> m
forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold Seq ByteString
byteStrings
 where
  byteStrings :: Seq ByteString
byteStrings = (Pub -> ByteString) -> Seq Pub -> Seq ByteString
forall a b. (a -> b) -> Seq a -> Seq b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap Pub -> ByteString
serialize_point Seq Pub
ps

{- | Tagged hashes used in [BIP327](https://github.com/bitcoin/bips/blob/master/bip-0327.mediawiki).

Takes a tag and a string.
-}
hashTag :: ByteString -> ByteString -> ByteString
hashTag :: ByteString -> ByteString -> ByteString
hashTag ByteString
t ByteString
s = ByteString -> ByteString
hash (ByteString
taggedHash ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
taggedHash ByteString -> ByteString -> ByteString
forall a. Semigroup a => a -> a -> a
<> ByteString
s)
 where
  taggedHash :: ByteString
taggedHash = ByteString -> ByteString
hash ByteString
t

{- | Tagged hashes used in [BIP327](https://github.com/bitcoin/bips/blob/master/bip-0327.mediawiki)
modulo the curve order.

Takes a tag and a string.
-}
hashTagModQ :: ByteString -> ByteString -> ByteString
hashTagModQ :: ByteString -> ByteString -> ByteString
hashTagModQ ByteString
t ByteString
s = ByteString -> ByteString
hashModQ ByteString
taggedHash
 where
  taggedHash :: ByteString
taggedHash = ByteString -> ByteString -> ByteString
hashTag ByteString
t ByteString
s
  hashModQ :: ByteString -> ByteString
hashModQ ByteString
h = Integer -> ByteString
integerToBytes32 (Integer -> ByteString) -> Integer -> ByteString
forall a b. (a -> b) -> a -> b
$ Integer -> Integer
modQ (Integer -> Integer) -> Integer -> Integer
forall a b. (a -> b) -> a -> b
$ ByteString -> Integer
bytesToInteger ByteString
h

-- | Converts a SHA-256 'ByteString' to an 'Integer'.
bytesToInteger :: ByteString -> Integer
bytesToInteger :: ByteString -> Integer
bytesToInteger = (Integer -> Word8 -> Integer) -> Integer -> ByteString -> Integer
forall a. (a -> Word8 -> a) -> a -> ByteString -> a
BS.foldl' (\Integer
acc Word8
b -> Integer
acc Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
* Integer
256 Integer -> Integer -> Integer
forall a. Num a => a -> a -> a
+ Word8 -> Integer
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word8
b) Integer
0

-- | Converts an 'Integer' to a 32-byte big-endian 'ByteString'.
integerToBytes32 :: Integer -> ByteString
integerToBytes32 :: Integer -> ByteString
integerToBytes32 Integer
i = [Word8] -> ByteString
BS.pack ([Word8] -> ByteString) -> [Word8] -> ByteString
forall a b. (a -> b) -> a -> b
$ [Word8] -> [Word8]
forall a. [a] -> [a]
reverse [Integer -> Word8
forall a. Num a => Integer -> a
fromInteger (Integer
i Integer -> Int -> Integer
forall a. Bits a => a -> Int -> a
`shiftR` (Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
j)) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
0xff | Int
j <- [Int
0 .. Int
31]]

-- | @XOR@s two 'ByteString's of same length.
xorByteStrings :: ByteString -> ByteString -> ByteString
xorByteStrings :: ByteString -> ByteString -> ByteString
xorByteStrings = (Word8 -> Word8 -> Word8) -> ByteString -> ByteString -> ByteString
BS.packZipWith Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
xor

-- | Returns the 8-byte big-endian encoding of the length of a 'ByteString'.
encodeLen :: ByteString -> ByteString
encodeLen :: ByteString -> ByteString
encodeLen ByteString
bs = LazyByteString -> ByteString
BSL.toStrict (LazyByteString -> ByteString)
-> (Int -> LazyByteString) -> Int -> ByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> LazyByteString
toLazyByteString (Builder -> LazyByteString)
-> (Int -> Builder) -> Int -> LazyByteString
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word64 -> Builder
word64BE (Word64 -> Builder) -> (Int -> Word64) -> Int -> Builder
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Word64
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int -> ByteString) -> Int -> ByteString
forall a b. (a -> b) -> a -> b
$ ByteString -> Int
BS.length ByteString
bs

-- | Checks if a 'Pub'key is even.
isEvenPub :: Pub -> Bool
isEvenPub :: Pub -> Bool
isEvenPub Pub
pub = case ByteString -> [Word8]
BS.unpack (Pub -> ByteString
serialize_point Pub
pub) of
  (Word8
0x02 : [Word8]
_) -> Bool
True -- even y-coordinate
  (Word8
0x03 : [Word8]
_) -> Bool
False -- odd y-coordinate
  [Word8]
_ -> [Char] -> Bool
forall a. HasCallStack => [Char] -> a
error [Char]
"musig2 (isEvenPub): invalid compressed point format"

-- | Gets the X-coordinate from a 'Pub'lic key as 'ByteString'
xBytes :: Pub -> ByteString
xBytes :: Pub -> ByteString
xBytes Pub
pk = Int -> ByteString -> ByteString
BS.drop Int
1 (ByteString -> ByteString) -> ByteString -> ByteString
forall a b. (a -> b) -> a -> b
$ Pub -> ByteString
serialize_point Pub
pk