{-# LANGUAGE OverloadedLists #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.KnownNat.Solver #-} {-# OPTIONS_GHC -fplugin GHC.TypeLits.Normalise #-} -- | Assorted mostly high rank tensor tests. module TestHighRankSimplified (testTrees) where import Prelude import Data.Int (Int64) import GHC.Exts (IsList (..)) import GHC.TypeLits (KnownNat, type (+), type (-), type (<=)) import Test.Tasty import Test.Tasty.HUnit hiding (assert) import Data.Array.Nested qualified as Nested import Data.Array.Nested.Mixed.Shape import Data.Array.Nested.Ranked.Shape import Data.Array.Nested.Shaped.Shape import HordeAd import HordeAd.Core.Ops (tD, tfromPrimal) import CrossTesting import EqEpsilon testTrees :: [TestTree] testTrees = [ testCase "3foo" testFoo , testCase "3bar" testBar -- , testCase "3barS" testBarS , testCase "3fooD T Double [1.1, 2.2, 3.3]" testFooD , testCase "3fooBuild0" testFooBuild0 , testCase "3fooBuildOut" testFooBuildOut , testCase "3fooBuild91" testFooBuild91 , testCase "3fooBuild92" testFooBuild92 , testCase "3fooBuild21" testFooBuild21 , testCase "3fooBuild25" testFooBuild25 , testCase "3fooBuild21S" testFooBuild21S , testCase "3fooBuild25S" testFooBuild25S , testCase "3fooBuildNest21S" testFooBuildNest21S , testCase "3fooBuildNest25S" testFooBuildNest25S , testCase "3fooBuild3" testFooBuild3 , testCase "3fooBuildDt" testFooBuildDt , testCase "3fooBuildDt2" testFooBuildDt2 , testCase "3fooBuild5" testFooBuild5 , testCase "3fooBuild1" testFooBuild1 , testCase "3fooMap" testFooMap , testCase "3fooMap1" testFooMap1 , testCase "3fooNoGo" testFooNoGo , testCase "3fooNoGo10" testFooNoGo10 , testCase "3nestedBuildMap1" testNestedBuildMap1 , testCase "3nestedBuildMap10" testNestedBuildMap10 , testCase "3nestedBuildMap11" testNestedBuildMap11 -- , testCase "3nestedBuildMap7" testNestedBuildMap7 , testCase "3nestedSumBuild1" testNestedSumBuild1 -- , testCase "3nestedSumBuild5" testNestedSumBuild5 , testCase "3nestedSumBuildB" testNestedSumBuildB , testCase "3nestedBuildIndex" testNestedBuildIndex , testCase "3barReluADValDt" testBarReluADValDt , testCase "3barReluADValDt2" testBarReluADValDt2 , testCase "3barReluADVal" testBarReluADVal , testCase "3barReluADVal3" testBarReluADVal3 , testCase "3braidedBuilds" testBraidedBuilds , testCase "3braidedBuilds1" testBraidedBuilds1 , testCase "3recycled" testRecycled -- takes too long (can't be helped) , testCase "3recycled1" testRecycled1 , testCase "3concatBuild0" testConcatBuild0 , testCase "3concatBuild1" testConcatBuild1 , testCase "3concatBuild0m" testConcatBuild0m , testCase "3concatBuild1m" testConcatBuild1m , testCase "3concatBuild2" testConcatBuild2 , testCase "3concatBuild22" testConcatBuild22 , testCase "3concatBuild3" testConcatBuild3 , testCase "3logistic0" testLogistic0 , testCase "3logistic5" testLogistic5 , testCase "3logistic52" testLogistic52 , testCase "3logistic0Old" testLogistic0Old , testCase "3logistic5Old" testLogistic5Old , testCase "3logistic52Old" testLogistic52Old , testCase "3logisticA0" testLogisticA0 , testCase "3logisticB0" testLogisticB0 , testCase "3logisticC0" testLogisticC0 ] foo :: RealFloatH a => (a,a,a) -> a foo (x,y,z) = let w = x * sin y in atan2H z w + z * w _fooF :: RealFloatH a => (a,a,a) -> a _fooF (x,y,z) = let w = x * sin y in atan2H z w + z * w testFoo :: Assertion testFoo = assertEqualUpToEpsilon 1e-3 (ringestData [2,2,1, 2,2] [-4.6947093,1.5697206,-1.6332961,0.34882763,1.5697206,-1.0,-0.9784988,-0.9158946,6.6326222,3.6699238,7.85237,-2.9069107,17.976654,0.3914159,32.98194,19.807974], ringestData [2,2,1, 2,2] [6.943779,-1.436789,33.67549,0.22397964,-1.436789,-1.0,-0.975235,-0.90365005,147.06645,-73.022705,-9.238474,-10.042692,-980.2843,-7.900571,-14.451739,436.9084], ringestData [2,2,1, 2,2] [-4.8945336,2.067469,-1.7196897,1.3341143,2.067469,1.0,0.99846554,0.99536234,6.6943173,3.7482092,7.977362,-3.1475093,18.000969,0.48736274,33.01224,19.845064]) (grad (kfromR . rsum0 @5 @(TKScalar Float) . foo) (t16, t16, t16)) bar :: forall a. RealFloatH a => (a, a) -> a bar (x, y) = let w = foo (x, y, x) * sin y in atan2H x w + y * w _barF :: forall a. RealFloatH a => (a, a) -> a _barF (x, y) = let w = _fooF (x, y, x) * sin y in atan2H x w + y * w testBar :: Assertion testBar = assertEqualUpToEpsilon 1e-5 (Concrete $ Nested.rfromListLinear [3,1,2,2,1,2,2] [304.13867,914.9335,823.0187,1464.4688,5264.3306,1790.0055,1535.4309,3541.6572,304.13867,914.9335,823.0187,1464.4688,6632.4355,6047.113,1535.4309,1346.6815,45.92141,6.4903135,5.5406737,1.4242969,6.4903135,1.1458766,4.6446533,2.3550234,88.783676,27.467598,125.27507,18.177452,647.1915,0.3878851,2177.6152,786.1792,6.4903135,6.4903135,6.4903135,6.4903135,2.3550234,2.3550234,2.3550234,2.3550234,21.783596,2.3550234,2.3550234,2.3550234,21.783596,21.783596,21.783596,21.783596],Concrete $ Nested.rfromListLinear [3,1,2,2,1,2,2] [-5728.7617,24965.113,32825.07,-63505.953,-42592.203,145994.88,-500082.5,-202480.06,-5728.7617,24965.113,32825.07,-63505.953,49494.473,-2446.7632,-500082.5,-125885.58,-43.092484,-1.9601002,-98.97709,2.1931143,-1.9601002,1.8243169,-4.0434446,-1.5266153,2020.9731,-538.0603,-84.28137,62.963814,-34986.996,-9.917454,135.30023,17741.998,-1.9601002,-1.9601002,-1.9601002,-1.9601002,-1.5266153,-1.5266153,-1.5266153,-1.5266153,-4029.1775,-1.5266153,-1.5266153,-1.5266153,-4029.1775,-4029.1775,-4029.1775,-4029.1775]) (cgrad (kfromR . rsum0 . bar @(ADVal Concrete (TKR 7 Float))) (t48, t48)) {- TODO: divergent result; bring back when GHC 9.10 dropped: testBarS :: Assertion testBarS = assertEqualUpToEpsilon 1e-5 (sconcrete $ Nested.sfromListPrimLinear @_ @'[3, 1, 2, 2, 1, 2, 2] knownShS [304.13867,914.9335,823.0187,1464.4688,5264.3306,1790.0055,1535.4309,3541.6572,304.13867,914.9335,823.0187,1464.4688,6632.4355,6047.113,1535.4309,1346.6815,45.92141,6.4903135,5.5406737,1.4242969,6.4903135,1.1458766,4.6446533,2.3550234,88.783676,27.467598,125.27507,18.177452,647.1915,0.3878851,2177.6152,786.1792,6.4903135,6.4903135,6.4903135,6.4903135,2.3550234,2.3550234,2.3550234,2.3550234,21.783596,2.3550234,2.3550234,2.3550234,21.783596,21.783596,21.783596,21.783596], sconcrete $ Nested.sfromListPrimLinear @_ @'[3, 1, 2, 2, 1, 2, 2] knownShS [-5728.761,24965.113,32825.074,-63505.957,-42592.203,145994.89,-500082.5,-202480.05,-5728.761,24965.113,32825.074,-63505.957,49494.473,-2446.7632,-500082.5,-125885.58,-43.092484,-1.9601007,-98.97708,2.1931143,-1.9601007,1.8243167,-4.0434446,-1.5266151,2020.9731,-538.06036,-84.28139,62.963818,-34986.992,-9.917454,135.3003,17741.996,-1.9601007,-1.9601007,-1.9601007,-1.9601007,-1.5266151,-1.5266151,-1.5266151,-1.5266151,-4029.1775,-1.5266151,-1.5266151,-1.5266151,-4029.1775,-4029.1775,-4029.1775,-4029.1775]) (cgrad (kfromS . ssum0 . barF @(ADVal Concrete (TKS '[3, 1, 2, 2, 1, 2, 2] Float))) (sfromR t48, sfromR t48)) -} -- A dual-number and list-based version of a function that goes -- from `R^3` to `R`. fooD :: forall r n. (RealFloatH (ADVal Concrete (TKR n r))) => ListR 3 (ADVal Concrete (TKR n r)) -> ADVal Concrete (TKR n r) fooD (x ::: y ::: z ::: ZR) = let w = x * sin y in atan2H z w + z * w testFooD :: Assertion testFooD = assertEqualUpToEpsilon 1e-10 (fromList [ringestData [1,2,2,1,2,2,2,2,2,1] [18.73108960474591,20.665204824764675,25.821775835995922,18.666613887422585,34.775664100213014,62.54884873632415,37.93303229694526,11.635186977032971,18.73108960474591,20.665204824764675,25.821775835995922,20.600738734367262,34.775664100213014,62.54884873632415,16.663997008808924,3.1300339898598155,1.060799258653783,3.78942741815228,0.1889454555944933,-1.060799258653783,62.54884873632415,37.93303229694526,62.54884873632415,35.99996432769119,62.54884873632415,37.93303229694526,11.635186977032971,18.73108960474591,20.665204824764675,25.821775835995922,20.665204824764675,20.665204824764675,25.821775835995922,34.134947381491145,34.775664100213014,45527.22315787758,-4.488300547708207,2.1475176207684497,8.404498097344806,5.747373381623309,5.096832468946128,-2.4630526910399646,18.666613887422585,1.7769486222994448,-215.8115662030395,16.73214939773215,1.060799258653783,1.060799258653783,1.060799258653783,1.060799258653783,2.1475176207684497,2.1475176207684497,2.1475176207684497,2.1475176207684497,16.08742477551077,16.08742477551077,16.08742477551077,16.08742477551077,2.1475176207684497,2.1475176207684497,2.1475176207684497,2.1475176207684497,16.08742477551077,16.08742477551077,16.08742477551077,16.08742477551077,25.821775835995922,5.096832468946128,7.045006174919766,-1.7808956511653404,16.663997008744435,18.533999054066836,-25.177267779903083,16.60317012020362,25.821775835995922,5.096832468946128,7.045006174919766,-1.7808956511653404,16.663997008744435,18.533999054066836,-12.280721583745471,16.60317012020362,5.161956818274285,18.73108960474591,20.665204824764675,25.821775835995922,20.665204824764675,25.821775835995922,188.11000552192755,34.775664100213014,62.54884873632415,35.99996432769119,62.54884873632415,55.32933980086011,62.54884873632415,55.32933980086011,11.635186977032971,18.73108960474591,20.665204824764675,25.821775835995922,20.665204824764675,25.821775835995922,20.665204824764675,25.821775835995922,14.152094926881784,34.775664100213014,62.54884873632415,53.39649491503442,62.54884873632415,14.72904006548922,62.54884873632415,37.93303229694526,11.635186977032971,18.73108960474591,20.665204824764675,25.821775835995922,20.665204824764675,25.821775835995922,20.665204824764675,25.821775835995922,57.33025874582143,34.775664100213014,62.54884873632415,36.64432517917614,62.54884873632415,34.06684929392724,62.54884873632415,35.99996432769119], ringestData [1,2,2,1,2,2,2,2,2,1] [647.1354943759653,787.5605199613974,1229.333367336918,642.6917612678424,2229.2701397674327,7210.705208776531,2652.3459120285806,250.02943073785886,647.1354943759653,787.5605199613974,1229.333367336918,782.6578815409038,2229.2701397674327,7210.705208776531,512.2982591657892,18.580536443699742,2.518850510725482,26.993800503829114,0.2243239488720164,2.518850510725482,7210.705208776531,2652.3459120285806,7210.705208776531,2388.9603285490866,7210.705208776531,2652.3459120285806,250.02943073785886,647.1354943759653,787.5605199613974,1229.333367336918,787.5605199613974,787.5605199613974,1229.333367336918,2147.9011858437157,2229.2701397674327,-0.5405182383359878,-0.5328698165396271,-0.5099245509210925,130.7140495214786,61.4116989316311,48.40938174779479,11.696956758139343,642.6917612678424,6.317020301049852,85833.87394976329,516.4928003659018,2.518850510725482,2.518850510725482,2.518850510725482,2.518850510725482,-0.5099245509210925,-0.5099245509210925,-0.5099245509210925,-0.5099245509210925,477.4973215160379,477.4973215160379,477.4973215160379,477.4973215160379,-0.5099245509210925,-0.5099245509210925,-0.5099245509210925,-0.5099245509210925,477.4973215160379,477.4973215160379,477.4973215160379,477.4973215160379,1229.333367336918,48.40938174779479,92.00538642301063,6.3430614471479245,512.2982591618282,633.5999783697488,1168.7578661039847,508.56903530563443,1229.333367336918,48.40938174779479,92.00538642301063,6.3430614471479245,512.2982591618282,633.5999783697488,278.48156010484087,508.56903530563443,49.64077766932281,647.1354943759653,787.5605199613974,1229.333367336918,787.5605199613974,1229.333367336918,65212.963738386214,2229.2701397674327,7210.705208776531,2388.9603285490866,7210.705208776531,5642.338335044463,7210.705208776531,5642.338335044463,250.02943073785886,647.1354943759653,787.5605199613974,1229.333367336918,787.5605199613974,1229.333367336918,787.5605199613974,1229.333367336918,369.6431004072799,2229.2701397674327,7210.705208776531,5255.048317224881,7210.705208776531,400.3514287686239,7210.705208776531,2652.3459120285806,250.02943073785886,647.1354943759653,787.5605199613974,1229.333367336918,787.5605199613974,1229.333367336918,787.5605199613974,1229.333367336918,6057.774447242021,2229.2701397674327,7210.705208776531,2475.225838667682,7210.705208776531,2139.3419044407133,7210.705208776531,2388.9603285490866], ringestData [1,2,2,1,2,2,2,2,2,1] [18.76237979248771,20.69357069589509,25.8444826804669,18.698011972363496,34.7925278085306,62.558226125235436,37.948492946856575,11.685493300971446,18.76237979248771,20.69357069589509,25.8444826804669,20.629193248844963,34.7925278085306,62.558226125235436,16.699160877305292,3.3121428825170947,1.516071490296981,3.9411848287000124,1.0994899188808887,-1.516071490296981,62.558226125235436,37.948492946856575,62.558226125235436,36.01625479268449,62.558226125235436,37.948492946856575,11.685493300971446,18.76237979248771,20.69357069589509,25.8444826804669,20.69357069589509,20.69357069589509,25.8444826804669,34.1521274657041,34.7925278085306,-45527.22317076194,4.617144085155745,-2.4052046956635262,8.474005308282699,5.84854498865513,5.210650526856928,-2.6906888068615635,18.698011972363496,2.0810391881996813,-215.8142842462135,16.767170338627782,1.516071490296981,1.516071490296981,1.516071490296981,1.516071490296981,-2.4052046956635262,-2.4052046956635262,-2.4052046956635262,-2.4052046956635262,16.123846116986126,16.123846116986126,16.123846116986126,16.123846116986126,-2.4052046956635262,-2.4052046956635262,-2.4052046956635262,-2.4052046956635262,16.123846116986126,16.123846116986126,16.123846116986126,16.123846116986126,25.8444826804669,5.210650526856928,7.127782944309438,-2.0844104722608057,16.69916087724094,18.565621417897145,-25.200555362084323,16.638462541261234,25.8444826804669,5.210650526856928,7.127782944309438,-2.0844104722608057,16.69916087724094,18.565621417897145,-12.328394068734287,16.638462541261234,5.2743697149763085,18.76237979248771,20.69357069589509,25.8444826804669,20.69357069589509,25.8444826804669,188.113123824884,34.7925278085306,62.558226125235436,36.01625479268449,62.558226125235436,55.33994055377702,62.558226125235436,55.33994055377702,11.685493300971446,18.76237979248771,20.69357069589509,25.8444826804669,20.69357069589509,25.8444826804669,20.69357069589509,25.8444826804669,14.193483311576621,34.7925278085306,62.558226125235436,53.40747931617656,62.558226125235436,14.768811697198851,62.558226125235436,37.948492946856575,11.685493300971446,18.76237979248771,20.69357069589509,25.8444826804669,20.69357069589509,25.8444826804669,20.69357069589509,25.8444826804669,57.34048958248757,34.7925278085306,62.558226125235436,36.660329315674915,62.558226125235436,34.08406370302229,62.558226125235436,36.01625479268449]]) (cgrad (kfromR . rsum0 . fooD) (fromList [ t128 , rreplicate0N [1, 2, 2, 1, 2, 2, 2, 2, 2, 1] (rscalar (0.7 :: Double)) , t128 ])) fooBuild0 :: forall target r n. (ADReady target, GoodScalar r, KnownNat n) => target (TKR (1 + n) r) -> target (TKR (1 + n) r) fooBuild0 v = let r = rsum v in rbuild1 2 $ const r testFooBuild0 :: Assertion testFooBuild0 = assertEqualUpToEpsilon' 1e-10 (ringestData [2,2,1,2,2] [2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0]) (rev' @Double @5 fooBuild0 t16) fooBuildOut :: forall target r n. (ADReady target, GoodScalar r, KnownNat n) => target (TKR (1 + n) r) -> target (TKR (1 + n) r) fooBuildOut v = rbuild1 2 $ \ix -> ifH (ix ==. 0) (rindex v [ix + 1]) -- index out of bounds; guarded (rsum v) testFooBuildOut :: Assertion testFooBuildOut = assertEqualUpToEpsilon' 1e-10 (ringestData [2,2,1,2,2] [1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0,2.0]) (rev' @Double @5 fooBuildOut t16) fooBuild2 :: forall target r n. (ADReady target, GoodScalar r, KnownNat n, Floating (target (TKR n r)), RealFloat r) => target (TKR (1 + n) r) -> target (TKR (1 + n) r) fooBuild2 v = rbuild1 2 $ \ix' -> let ix :: PrimalOf target (TKS '[] Int64) ix = sfromR $ rfromK ix' in ifH (ix - (sprimalPart . sfloor . sfromR) (rsum0 @5 @(TKScalar r) $ rreplicate0N [5,12,11,9,4] (rsum0 v)) - sscalar 10001 >=. sscalar 0 &&* ix - (sprimalPart . sfloor . sfromR) (rsum0 @5 @(TKScalar r) @target $ rreplicate0N [5,12,11,9,4] (rsum0 v)) - sscalar 10001 <=. sscalar 1) (rindex v [kfromR $ rfromS $ ix - (sprimalPart . sfloor . sfromR) (rsum0 @5 @(TKScalar r) @target $ rreplicate0N [5,12,11,9,4] (rsum0 v)) - sscalar 10001]) -- index out of bounds; also fine (sqrt $ abs $ rindex v [kfromS $ let rr = (ix - (sfromR . rprimalPart . rfloor) (rsum0 v) - sscalar 10001) `remH` sscalar 2 in ifH (signum rr ==. negate (signum $ sscalar 2)) (rr + sscalar 2) rr]) fooBuild2L :: forall k target r n. (ADReady target, GoodScalar r, KnownNat n, Floating (target (TKR n r)), RealFloat r) => ListR k (target (TKR (1 + n) r)) -> target (TKR (1 + n) r) fooBuild2L = foldr1 (+) . fmap fooBuild2 testFooBuild91 :: Assertion testFooBuild91 = assertEqualUpToEpsilon 1e-8 (fromList $ map (ringestData [2]) [[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299],[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299],[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299],[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299],[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299]]) (cgrad (kfromR . rsum0 @1 . fooBuild2L @50 @(ADVal Concrete) @Double @0) (fromList $ map (ringestData [2]) [[0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9], [0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9], [0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9], [0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9], [0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9]])) testFooBuild92 :: Assertion testFooBuild92 = assertEqualUpToEpsilon 1e-8 (fromList $ map (ringestData [2]) [[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299],[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299],[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299],[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299],[1.5811388300841895,1.118033988749895],[1.118033988749895,0.9128709291752769],[0.9128709291752769,0.7905694150420948],[0.7905694150420948,0.7071067811865475],[0.7071067811865475,0.6454972243679028],[0.6454972243679028,0.5976143046671968],[0.5976143046671968,0.5590169943749475],[0.5590169943749475,0.5270462766947299],[0.5270462766947299,1.5811388300841895],[1.5811388300841895,0.5270462766947299]]) (grad (kfromR . rsum0 @1 . fooBuild2L @50 @(AstTensor AstMethodLet FullSpan) @Double @0) (fromList $ map (ringestData [2]) [[0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9], [0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9], [0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9], [0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9], [0.1, 0.2], [0.2, 0.3], [0.3, 0.4], [0.4, 0.5], [0.5, 0.6], [0.6, 0.7], [0.7, 0.8], [0.8, 0.9], [0.9, 0.1], [0.1, 0.9]])) testFooBuild21 :: Assertion testFooBuild21 = assertEqualUpToEpsilon' 1e-10 (ringestData [2] [0.2886751345948129,0.35355339059327373]) (rev' @Double @1 fooBuild2 (ringestData [2] [3.0,2.0])) testFooBuild25 :: Assertion testFooBuild25 = assertEqualUpToEpsilon' 1e-10 (ringestData [2,2,1,2,2] [0.22360679774997896,0.35355339059327373,0.20412414523193154,0.5,-0.35355339059327373,500.0,1.5811388300841895,-1.118033988749895,0.1381447409988844,0.16666666666666666,0.17677669529663687,-0.25,8.574929257125441e-2,0.288948802391873,-8.703882797784893e-2,9.805806756909202e-2]) (rev' @Double @5 fooBuild2 t16) fooBuild2S :: forall k sh target r. (ADReady target, GoodScalar r, KnownNat k, Floating (target (TKS sh r)), RealFloat r, KnownShS sh) => target (TKS (k : sh) r) -> target (TKR (1 + Rank sh) r) fooBuild2S v = rfromS $ sbuild1 @2 $ \ix' -> let ix :: PrimalOf target (TKS '[] Int64) ix = sfromR $ rfromK ix' in ifH (ix - (sprimalPart . sfloor) (ssum0 @[5,12,11,9,4] @(TKScalar r) $ sreplicate0N @[5,12,11,9,4] (ssum0 v)) - srepl 10001 >=. srepl 0 &&* ix - (sprimalPart . sfloor) (ssum0 @[5,12,11,9,4] @(TKScalar r) $ sreplicate0N @[5,12,11,9,4] (ssum0 v)) - srepl 10001 <=. srepl 1) (sindex v ((kfromS $ ix - (sprimalPart . sfloor) (ssum0 @[5,12,11,9,4] @(TKScalar r) $ sreplicate0N @[5,12,11,9,4] (ssum0 v)) - srepl 10001) :.$ ZIS )) -- index out of bounds; also fine (sqrt $ abs $ sindex v ((kfromR $ rfromS $ let rr = (ix - (sprimalPart . sfloor) (ssum0 v) - srepl 10001) `remH` srepl 2 in ifH (signum rr ==. negate (signum $ srepl 2)) (rr + srepl 2) rr) :.$ ZIS)) testFooBuild21S :: Assertion testFooBuild21S = assertEqualUpToEpsilon' 1e-10 (ringestData [2] [0.2886751345948129,0.35355339059327373]) (rev' @Double @1 (fooBuild2S @2 @'[] . sfromR) (ringestData [2] [3.0,2.0])) testFooBuild25S :: Assertion testFooBuild25S = assertEqualUpToEpsilon' 1e-10 (ringestData [2,2,1,2,2] [0.22360679774997896,0.35355339059327373,0.20412414523193154,0.5,-0.35355339059327373,500.0,1.5811388300841895,-1.118033988749895,0.1381447409988844,0.16666666666666666,0.17677669529663687,-0.25,8.574929257125441e-2,0.288948802391873,-8.703882797784893e-2,9.805806756909202e-2]) (rev' @Double @5 (fooBuild2S @2 @[2, 1, 2, 2] . sfromR) t16) fooBuildNest2S :: forall k sh target r. (ADReady target, GoodScalar r, KnownNat k, Floating (target (TKS sh r)), RealFloat r, KnownShS sh) => target (TKS (k : sh) r) -> target (TKR (1 + Rank sh) r) fooBuildNest2S v = rfromS $ sbuild1 @2 $ \ix' -> let ix :: PrimalOf target (TKS '[] Int64) ix = sfromR $ rfromK ix' in ifH (ix - (sunNest @_ @'[] @'[] . sprimalPart . snest knownShS . sfloor) (ssum0 @[5,12,11,9,4] @(TKScalar r) $ sreplicate0N @[5,12,11,9,4] (ssum0 v)) - srepl 10001 >=. srepl 0 &&* ix - (sprimalPart . sfloor) (ssum0 @[5,12,11,9,4] @(TKScalar r) $ sreplicate0N @[5,12,11,9,4] (ssum0 v)) - srepl 10001 <=. srepl 1) -- TODO: (sindex v (ShapedList.singletonIndex (ix - (sprimalPart . sfloor) (ssum0 @[5,12,11,9,4] @r $ sunNest $ treplicate (SNat @5) knownSTK $ snest (knownShS @[12,11]) (sindex v ((kfromR $ rfromS $ ix - (sprimalPart . sfloor) (ssum0 @[5,12,11,9,4] @(TKScalar r) @target $ sunNest $ tproject2 $ tfromPrimal knownSTK $ tpair tunit (sprimalPart $ snest (knownShS @[5,12,11]) $ sreplicate0N @[5,12,11,9,4] (ssum0 v))) - srepl 10001) :.$ ZIS)) -- index out of bounds; also fine -- TODO: (sunNest @_ @'[] @sh $ tlet (snest (knownShS @'[]) $ (sfromPrimal ix - sfloor (ssum0 v) - srepl 10001) `remH` srepl 2) $ \rr -> snest (knownShS @'[]) $ sqrt $ abs $ sindex v (ShapedList.singletonIndex (ifH (signum (sprimalPart (sunNest rr)) ==. negate (signum $ srepl 2)) (sprimalPart (sunNest rr) + srepl 2) (sprimalPart (sunNest rr))))) (sunNest @_ @'[] @sh $ tlet ((sfromPrimal ix - sfloor (ssum0 v) - srepl 10001) `remH` srepl 2) $ \rr -> snest (knownShS @'[]) $ sqrt $ abs $ sindex v ((kfromS $ ifH (signum (sprimalPart rr) ==. negate (signum $ srepl 2)) (sprimalPart rr + srepl 2) (sprimalPart rr)) :.$ ZIS)) testFooBuildNest21S :: Assertion testFooBuildNest21S = assertEqualUpToEpsilon' 1e-10 (ringestData [2] [0.2886751345948129,0.35355339059327373]) (rev' @Double @1 (fooBuildNest2S @2 @'[] . sfromR) (ringestData [2] [3.0,2.0])) testFooBuildNest25S :: Assertion testFooBuildNest25S = assertEqualUpToEpsilon' 1e-10 (ringestData [2,2,1,2,2] [0.22360679774997896,0.35355339059327373,0.20412414523193154,0.5,-0.35355339059327373,500.0,1.5811388300841895,-1.118033988749895,0.1381447409988844,0.16666666666666666,0.17677669529663687,-0.25,8.574929257125441e-2,0.288948802391873,-8.703882797784893e-2,9.805806756909202e-2]) (rev' @Double @5 (fooBuildNest2S @2 @[2, 1, 2, 2] . sfromR) t16) fooBuild3 :: forall target r n. ( ADReady target, GoodScalar r, KnownNat n, RealFloatH (target (TKR n r)) ) => target (TKR (1 + n) r) -> target (TKR (1 + n) r) fooBuild3 v = rbuild1 22 $ \ix -> bar ( rreplicate0N (shrTail $ rshape v) (rscalar 1) , rindex v [minH 1 (ix + 1)] ) -- index not out of bounds testFooBuild3 :: Assertion testFooBuild3 = assertEqualUpToEpsilon' 1e-10 (ringestData [2,2,1,2,2] [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,423.72976235076516,-260.41676627885636,-17.60047532855961,151.18955028869385,-1059.9668424433578,-65.00898015327623,-21.49245448729951,743.7622427949768]) (rev' @Double @5 fooBuild3 t16) fooBuild5 :: forall target r n. ( ADReady target, GoodScalar r, KnownNat n, RealFloatH (target (TKR n r)) ) => target (TKR (1 + n) r) -> target (TKR (1 + n) r) fooBuild5 v = let r = rsum v v' = rreplicate0N (shrTail $ rshape v) $ rminimum $ rflatten v in rbuild1 2 $ \ix -> r * foo ( rreplicate0N (shrTail $ rshape v) (rscalar 3) , rrepl (rshape r) 5 * r , r * v') + bar (r, rindex v [minH 1 (ix + 1)]) -- index not out of bounds testFooBuildDt :: Assertion testFooBuildDt = assertEqualUpToEpsilon 1e-5 (rconcrete $ Nested.rfromListPrimLinear [2,2,1,2,2] [1.1033568028244503e7,74274.22833989389,-5323238.2765011545,253074.03394016018,4.14744804041263e7,242643.98750578283,-1.922371592087736e7,2.730274503834733e7,1.135709425204681e7,6924.195066252549,-5345004.080027547,255679.51406100337,3.8870981856703006e7,241810.92121468345,-1.9380955730171032e7,2.877024321777493e7]) (vjp @_ @(TKR 5 Double) fooBuild5 t16 (rreplicate0N [2, 2, 1, 2, 2] (rscalar 42))) testFooBuildDt2 :: Assertion testFooBuildDt2 = assertEqualUpToEpsilon 1e-5 (rconcrete $ Nested.rfromListPrimLinear [2,2,1,2,2] [2.206713605648901e7,148548.45667978778,-1.0646476553002307e7,506148.0678803204,8.294896080825263e7,485287.9750115657,-3.844743184175473e7,5.460549007669466e7,2.271418850409362e7,13848.390132505112,-1.0690008160055092e7,511359.0281220066,7.774196371340603e7,483621.8424293669,-3.876191146034207e7,5.754048643554987e7]) (vjp @_ @(TKProduct (TKR 5 Double) (TKR 5 Double)) (\x -> let y = fooBuild5 x in tpair y y) t16 (let dt = rreplicate0N [2, 2, 1, 2, 2] (rscalar 42) in tpair dt dt)) testFooBuild5 :: Assertion testFooBuild5 = assertEqualUpToEpsilon' 1e-5 (ringestData [3,1,2,2,1,2,2] [-613291.6547530327,571164.2201603781,-1338602.6247083102,528876.2566682736,1699442.2143691683,2874891.369778316,-3456754.605470273,3239487.8744244366,554916.1344235454,-775449.1803684114,3072.200583200206,1165767.8436804386,-1.0686356667942494e7,-6606976.194539241,-6457671.748790982,4791868.42112978,-615556.7946425928,569660.3506343022,-1348678.1169100606,534886.9366492515,1696036.143341285,2883992.9672165257,-3456212.5353846983,3240296.690514803,629047.8398075115,-794389.5797803313,-1143.8025173051583,1177448.8083517442,-1.15145721735623e7,-6618648.839812404,-6462386.031613377,5358224.852822481,-613291.6547530327,571164.2201603781,-1338602.6247083102,528876.2566682736,1699442.2143691683,2874891.369778316,-3456754.605470273,3239487.8744244366,554916.1344235454,-775449.1803684114,3072.200583200206,1165767.8436804386,-1.0686356667942494e7,-6606976.194539241,-6457671.748790982,4791868.42112978]) (rev' @Double @7 fooBuild5 t48) fooBuild1 :: forall target r n. ( ADReady target, GoodScalar r, KnownNat n, RealFloatH (target (TKR n r)) ) => target (TKR (1 + n) r) -> target (TKR (1 + n) r) fooBuild1 v = let r = rsum v tk = rreplicate0N (shrTail $ rshape v) v' = tk $ rminimum $ rflatten v in rbuild1 3 $ \ix -> r * foo ( tk (rscalar 3) , tk (rscalar 5) * r , r * v') + bar (r, rindex v [minH 1 (ix + 1)]) testFooBuild1 :: Assertion testFooBuild1 = assertEqualUpToEpsilon' 1e-8 (ringestData [2,2,1,2,2] [394056.00100873224,2652.651012139068,-190115.65273218407,9038.358355005721,1481231.4430045108,8665.8566966351,-686561.2828884773,975098.0370838332,405610.50900167174,247.29268093759174,-190893.00285812665,9131.411216464405,1388249.3520251075,8636.104329095837,-692176.9903632513,1027508.6863491047]) (rev' @Double @5 fooBuild1 t16) fooMap1 :: (ADReady target, GoodScalar r, KnownNat n, Differentiable r) => IShR (1 + n) -> target (TKR 0 r) -> target (TKR (1 + n) r) fooMap1 sh r = let v = fooBuild1 $ rreplicate0N sh (r * r) in rmap0N (\x -> x * r + rscalar 5) v testFooMap :: Assertion testFooMap = assertEqualUpToEpsilon' 1e-3 (rscalar 2.7518227) (rev' @Float @1 (fooMap1 [130]) (rscalar 0.1)) -- Reduced test, because this takes forever with Ast but without vectorization. testFooMap1 :: Assertion testFooMap1 = assertEqualUpToEpsilon 1e-6 (rscalar 3901.312463734578) (grad (kfromR @_ @Double . rsum0 @7 . fooMap1 [4, 3, 2, 3, 4, 5, 3]) (rscalar 0.1)) fooNoGo :: forall target r n. ( ADReady target, GoodScalar r, KnownNat n, Differentiable r ) => target (TKR (1 + n) r) -> target (TKR (1 + n) r) fooNoGo v = let r = rsum v r0 = rsum0 v shTail = shrTail (rshape v) in rbuild1 3 (\ix -> bar ( rreplicate0N shTail (rscalar 3.14) , bar ( rrepl shTail 3.14 , rindex v [ix]) ) + ifH (rindex v (ix * 2 :.: ZIR) <=. rreplicate0N shTail (rscalar 0) &&* 6 >. abs ix) r (rreplicate0N shTail (rscalar 5) * r)) / rslice 1 3 (rmap0N (\x -> ifH (x >. r0) r0 x) v) * rbuild1 3 (const $ rrepl shTail 1) testFooNoGo :: Assertion testFooNoGo = assertEqualUpToEpsilon' 1e-6 (ringestData [5] [344.3405885672822,-396.1811403813819,7.735358041386672,-0.8403418295960372,5.037878787878787]) (rev' @Double @1 fooNoGo (ringestData [5] [1.1 :: Double, 2.2, 3.3, 4, 5])) testFooNoGo10 :: Assertion testFooNoGo10 = assertEqualUpToEpsilon 1e-10 (ringestData [5, 3, 1, 2, 2, 1, 2, 2] [8.096867407436072e-8,9.973025492756426e-8,9.976696178938985e-8,5.614458707681111e-8,-1.8338500573636686e-7,-2.144970334428336e-7,7.354143606421902e-7,-1.8140041785503643e-7,8.096867407436072e-8,9.973025492756426e-8,9.976696178938985e-8,5.614458707681111e-8,-2.01381292700262e-7,-2.221588091014473e-7,7.354143606421902e-7,-1.9951065225263367e-7,1.7230532848112822e-7,4.5426218104870796e-7,1.430886696893587e-7,9.354993295163118e-7,-5.225515010723883e-7,1.019433073376504e-6,9.64067025472343e-6,-4.872227980305747e-6,8.089200625992941e-8,9.924319994964371e-8,1.092480101004153e-7,-2.8478802468285825e-7,9.641049518625974e-8,2.9624147815716037e-7,-1.950868158558337e-7,9.547754822865364e-8,4.5426218104870796e-7,4.5426218104870796e-7,4.5426218104870796e-7,4.5426218104870796e-7,-4.872227980305747e-6,-4.872227980305747e-6,-4.872227980305747e-6,-4.872227980305747e-6,9.361277121832246e-8,-4.872227980305747e-6,-4.872227980305747e-6,-4.872227980305747e-6,9.361277121832246e-8,9.361277121832246e-8,9.361277121832246e-8,9.361277121832246e-8,-5.488572216677945e-7,-1.8496203182958057e-7,-1.4603644180845103e-7,-1.2145268106051633e-7,-2.817402689957553e-7,-2.9913537180597976e-7,6.272804203945257e-7,-2.3697344464172694e-7,-5.488572216677945e-7,-1.8496203182958057e-7,-1.4603644180845103e-7,-1.2145268106051633e-7,-2.613973017956691e-7,-3.0013408634207794e-7,6.272804203945257e-7,-2.916736028401805e-7,-7.0114505846358575e-6,-4.303381366239431e-5,-4.897282418246382e-6,-1.710952247892854e-4,-4.2040039667393255e-5,-2.0204742564752248e-4,-1.7017980671040968e-2,-4.247008401789142e-3,-1.056090348050961e-6,-2.210187184450231e-6,-2.7842041329045203e-6,-1.0402806498987974e-5,-1.2967382896879757e-7,-1.9315601705070884e-5,-2.40087090725031e-7,-2.4419692405172046e-7,-4.303381366239431e-5,-4.303381366239431e-5,-4.303381366239431e-5,-4.303381366239431e-5,-4.247008401789142e-3,-4.247008401789142e-3,-4.247008401789142e-3,-4.247008401789142e-3,-2.683138631810477e-7,-4.247008401789142e-3,-4.247008401789142e-3,-4.247008401789142e-3,-2.683138631810477e-7,-2.683138631810477e-7,-2.683138631810477e-7,-2.683138631810477e-7,-5.488572216677945e-7,-1.8496203182958057e-7,-1.4603644180845103e-7,-1.2145268106051633e-7,-2.817402689957553e-7,-2.9913537180597976e-7,6.272804203945257e-7,-2.3697344464172694e-7,-5.488572216677945e-7,-1.8496203182958057e-7,-1.4603644180845103e-7,-1.2145268106051633e-7,-2.613973017956691e-7,-3.0013408634207794e-7,6.272804203945257e-7,-2.916736028401805e-7,-7.0114505846358575e-6,-4.303381366239431e-5,-4.897282418246382e-6,-1.710952247892854e-4,-4.2040039667393255e-5,-2.0204742564752248e-4,-1.7017980671040968e-2,-4.247008401789142e-3,-1.056090348050961e-6,-2.210187184450231e-6,-2.7842041329045203e-6,-1.0402806498987974e-5,-1.2967382896879757e-7,-1.9315601705070884e-5,-2.40087090725031e-7,-2.4419692405172046e-7,-4.303381366239431e-5,-4.303381366239431e-5,-4.303381366239431e-5,-4.303381366239431e-5,-4.247008401789142e-3,-4.247008401789142e-3,-4.247008401789142e-3,-4.247008401789142e-3,-2.683138631810477e-7,-4.247008401789142e-3,-4.247008401789142e-3,-4.247008401789142e-3,-2.683138631810477e-7,-2.683138631810477e-7,-2.683138631810477e-7,-2.683138631810477e-7,-5.469529675653596e-7,-2.331458950045675e-7,-1.9907443163522408e-7,-1.4019078434680374e-7,-6.95091094132346e-8,-5.685763846730528e-8,-9.268594848659335e-8,-3.010367762029461e-8,-5.469529675653596e-7,-2.331458950045675e-7,-1.9907443163522408e-7,-1.4019078434680374e-7,-3.415394012988984e-8,-5.069973314807702e-8,-9.268594848659335e-8,-6.380451815099858e-8,-6.883755913116986e-6,-4.273807584344302e-5,-4.79037108793574e-6,-1.705307241188017e-4,-4.2267488166320864e-5,-2.0143642393829028e-4,-1.701262134129569e-2,-4.2496361738088365e-3,-1.0224785375169973e-6,-2.1427637177332083e-6,-2.705952143004936e-6,-1.0493018474305117e-5,-1.819666770962338e-7,-1.911089472080586e-5,-9.045482032374276e-8,-2.819821645880664e-7,-4.273807584344302e-5,-4.273807584344302e-5,-4.273807584344302e-5,-4.273807584344302e-5,-4.2496361738088365e-3,-4.2496361738088365e-3,-4.2496361738088365e-3,-4.2496361738088365e-3,-3.019273543907303e-7,-4.2496361738088365e-3,-4.2496361738088365e-3,-4.2496361738088365e-3,-3.019273543907303e-7,-3.019273543907303e-7,-3.019273543907303e-7,-3.019273543907303e-7,8.287292817679557e-8,5.154639175257732e-8,4.672897196261682e-8,3.740648379052369e-8,2.884615384615385e-8,2.7780699895840894e-8,1.5447991761071065e-8,2.546934916639589e-8,8.287292817679557e-8,5.154639175257732e-8,4.672897196261682e-8,3.740648379052369e-8,2.5862068965517245e-8,2.7275544092553562e-8,1.5447991761071065e-8,2.8358432436548274e-8,3.0000000000000004e-7,7.500000000000001e-7,2.5000000000000004e-7,1.5000000000000002e-6,-7.500000000000001e-7,1.6304347826086957e-6,1.5000000000000002e-5,-7.500000000000001e-6,1.1450381679389314e-7,1.6666666666666668e-7,1.8750000000000003e-7,-3.7500000000000006e-7,4.411764705882353e-8,5.00948462422186e-7,-4.545454545454546e-8,5.76923076923077e-8,7.500000000000001e-7,7.500000000000001e-7,7.500000000000001e-7,7.500000000000001e-7,-7.500000000000001e-6,-7.500000000000001e-6,-7.500000000000001e-6,-7.500000000000001e-6,5.999928000863991e-8,-7.500000000000001e-6,-7.500000000000001e-6,-7.500000000000001e-6,5.999928000863991e-8,5.999928000863991e-8,5.999928000863991e-8,5.999928000863991e-8]) (grad (kfromR @_ @Double . rsum0 @8 . rmap0N (* rscalar 0.000000001) . fooNoGo) (rmap0N (* rscalar 0.01) $ rreplicate 5 t48)) nestedBuildMap :: forall target n r. (ADReady target, GoodScalar r, n <= 6, KnownNat n, Differentiable r) => target (TKR 0 r) -> target (TKR (1 + n) r) nestedBuildMap r = let w x = rreplicate0N [4] x :: target (TKR 1 r) v' = rreplicate0N (177 :$: ZSR) r nestedMap x = rmap0N (x /) (w x) variableLengthBuild iy = rbuild1 7 (\ix -> rindex v' (ix + iy :.: ZIR)) doublyBuild = rbuild1 3 (rreplicate0N (shrTake @n @(6 - n) $ 2 :$: 4 :$: 2 :$: 1 :$: 3 :$: 2 :$: ZSR) . rminimum . variableLengthBuild) in rmap0N (\x -> x * rsum0 (rbuild1 3 (\ix -> bar (x, rindex v' [ix])) + fooBuild1 (nestedMap x) / fooMap1 [3] x) ) doublyBuild testNestedBuildMap1 :: Assertion testNestedBuildMap1 = assertEqualUpToEpsilon' 1e-8 (rscalar 22.673212907588812) (rev' @Double @1 nestedBuildMap (rscalar 0.6)) testNestedBuildMap10 :: Assertion testNestedBuildMap10 = assertEqualUpToEpsilon 1e-8 (map rscalar [109.62086996459126,106.70290239773645,103.05843225947055,98.11825678264942,67.8014491889543,22.67321290758882,-163.40832575807545,376.4240286600336,-1996.9068313949347,249.28292226561257, 109.62086996459126,106.70290239773645,103.05843225947055,98.11825678264942,67.8014491889543,22.67321290758882,-163.40832575807545,376.4240286600336,-1996.9068313949347,249.28292226561257, 109.62086996459126,106.70290239773645,103.05843225947055,98.11825678264942,67.8014491889543,22.67321290758882,-163.40832575807545,376.4240286600336,-1996.9068313949347,249.28292226561257]) (map (cgrad (kfromR . rsum0 @1 @(TKScalar Double) . nestedBuildMap)) (map (Concrete . Nested.rscalar) $ [0.1, 0.2 .. 1] ++ [0.1, 0.2 .. 1] ++ [0.1, 0.2 .. 1])) testNestedBuildMap11 :: Assertion testNestedBuildMap11 = assertEqualUpToEpsilon 1e-8 (map rscalar [109.62086996459126,106.70290239773645,103.05843225947055,98.11825678264942,67.8014491889543,22.67321290758882,-163.40832575807545,376.4240286600336,-1996.9068313949347,249.28292226561257, 109.62086996459126,106.70290239773645,103.05843225947055,98.11825678264942,67.8014491889543,22.67321290758882,-163.40832575807545,376.4240286600336,-1996.9068313949347,249.28292226561257, 109.62086996459126,106.70290239773645,103.05843225947055,98.11825678264942,67.8014491889543,22.67321290758882,-163.40832575807545,376.4240286600336,-1996.9068313949347,249.28292226561257]) (map (grad (kfromR . rsum0 @1 @(TKScalar Double) . nestedBuildMap)) (map (Concrete . Nested.rscalar) $ [0.1, 0.2 .. 1] ++ [0.1, 0.2 .. 1] ++ [0.1, 0.2 .. 1])) {- testNestedBuildMap7 :: Assertion testNestedBuildMap7 = assertEqualUpToEpsilon' 1e-8 (rscalar 2176.628439128524) (rev' @Double @7 nestedBuildMap (rscalar 0.6)) -} -- The n <= 4 is necessary despite what GHC claims. Applying @(2 + n) -- to nestedBuildMap doesn't help. nestedSumBuild :: forall target n r. (ADReady target, GoodScalar r, n <= 4, KnownNat n, Differentiable r) => target (TKR n r) -> target (TKR (2 + n) r) nestedSumBuild v = rbuild1 13 $ \ix1 -> rbuild1 4 $ \ix2 -> ifH (ix2 >. ix1) (rmap0N ((* rscalar (-0.00000003)) . sqrt . abs) $ nestedBuildMap (rsum0 v) `rindex` (ix2 `remH` 3 :.: minH 1 ix1 :.: minH ix1 3 :.: ZIR)) (nestedBuildMap (rscalar 0.00042) `rindex` (ix2 `remH` 3 :.: minH 1 ix1 :.: minH ix1 3 :.: ZIR)) testNestedSumBuild1 :: Assertion testNestedSumBuild1 = assertEqualUpToEpsilon 1e-6 (ringestData [5] [5.738943380972744e-6,5.738943380972744e-6,5.738943380972744e-6,5.738943380972744e-6,5.738943380972744e-6]) (grad (kfromR . rsum0 @3 @(TKScalar Double) . nestedSumBuild) (ringestData [5] [1.1, 2.2, 3.3, 4, -5.22])) {- testNestedSumBuild5 :: Assertion testNestedSumBuild5 = assertEqualUpToEpsilon' 1e-6 (ringestData [1,2,2] [3.5330436757054903e-3,3.5330436757054903e-3,3.5330436757054903e-3,3.5330436757054903e-3]) (rev' @Double @5 nestedSumBuild (rsum (rsum t16))) -} nestedSumBuildB :: forall target n r. (ADReady target, GoodScalar r, KnownNat n) => target (TKR (1 + n) r) -> target (TKR 3 r) nestedSumBuildB v = rbuild @2 [13, 4, 2] $ \case [ix, ix2] -> flip rindex [ix2] (rfromList [ rbuild1 2 rfromIndex0 , rsum $ rbuild [9, 2] $ const $ rfromIndex0 ix , rindex v (fromList $ replicate (rlength v - 1) (maxH 0 $ minH 1 $ ix2 `quotH` 2 + ix `quotH` 4 - 1)) , rbuild1 2 (\_ -> rsum0 v) , rsum (rbuild1 7 (\ix7 -> rreplicate 2 (rfromIndex0 ix7))) ]) _ -> error "nestedSumBuildB: impossible pattern needlessly required" testNestedSumBuildB :: Assertion testNestedSumBuildB = assertEqualUpToEpsilon' 1e-8 (ringestData [2,3,2,2,2] [30.0,30.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,35.0,35.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0,26.0]) (rev' @Double @3 nestedSumBuildB (rsum $ rsum $ rtranspose [1, 4, 2, 0, 3] t48)) nestedBuildIndex :: forall target r. (ADReady target, GoodScalar r) => target (TKR 5 r) -> target (TKR 3 r) nestedBuildIndex v = rbuild1 2 $ \ix2 -> rindex (rbuild1 3 $ \ix3 -> rindex (rbuild1 3 $ \ix4 -> rindex v (ix4 `remH` 2 :.: ix2 :.: 0 :.: ZIR)) [ix3]) (ix2 :.: ZIR) testNestedBuildIndex :: Assertion testNestedBuildIndex = assertEqualUpToEpsilon' 1e-10 (ringestData [2,2,1,2,2] [1.0,1.0,1.0,1.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1.0,1.0,1.0,1.0]) (rev' @Double @3 nestedBuildIndex t16) barRelu :: ( ADReady target, GoodScalar r, KnownNat n, Differentiable r ) => target (TKR n r) -> target (TKR n r) barRelu x = let t = rreplicate0N (rshape x) (rscalar 0.001) * x in relu $ bar (t, relu t) testBarReluADValDt :: Assertion testBarReluADValDt = assertEqualUpToEpsilon 1e-6 (rconcrete $ Nested.rfromListPrimLinear [2,2,1,2,2] [1.2916050471365906e-2,1.2469757606504572e-2,1.3064120086501589e-2,1.2320300700062944e-2,0.0,1.217049789428711e-2,1.2185494267265312e-2,0.0,1.4105363649830907e-2,1.3506236503127638e-2,1.3359213691150671e-2,0.0,1.7066665416485535e-2,1.2618022646204737e-2,0.0,1.595161947206668e-2]) (vjp @_ @(TKR 5 Double) barRelu t16 (rreplicate0N [ 2 , 2 , 1 , 2 , 2 ] (rscalar 42.2))) testBarReluADValDt2 :: Assertion testBarReluADValDt2 = assertEqualUpToEpsilon 1e-6 (rconcrete $ Nested.rfromListPrimLinear [2,2,1,2,2] [84.42583210117625,84.42493951543845,84.4261282404092,84.42464060162287,84.4,84.42434099465609,84.4243709887547,84.4,84.42821072755468,84.42701247325044,84.42671842762383,84.4,84.43413333114152,84.42523604552053,84.4,84.43190323923253]) (vjp @_ @(TKProduct (TKR 4 Double) (TKProduct (TKR 5 Float) (TKS [2,2,1,2,2] Double))) (\x -> tpair (rsum x) (tpair (rcast $ barRelu x) (sfromR $ barRelu x))) t16 (let dt = rreplicate0N [ 2 , 2 , 1 , 2 , 2 ] (rscalar 42.2) in tpair (rsum dt) (tpair (rcast dt) (sfromR dt)))) testBarReluADVal :: Assertion testBarReluADVal = assertEqualUpToEpsilon' 1e-10 (ringestData [3,1,2,2,1,2,2] [3.513740871835189e-4,3.8830416352632824e-4,3.981974371104471e-4,4.2420226755643853e-4,4.6186212581292275e-4,4.6805323209889415e-4,5.933633926875981e-4,4.8311739820100107e-4,3.513740871835189e-4,3.8830416352632824e-4,3.981974371104471e-4,4.2420226755643853e-4,4.803836032226148e-4,4.7114455958615145e-4,5.933633926875981e-4,4.6464270870595213e-4,3.060675467148428e-4,2.954918864100193e-4,3.095763053673437e-4,2.9195025355591045e-4,0.0,2.9166656928452994e-4,2.887557883241243e-4,0.0,3.342503234557057e-4,3.2005299770444394e-4,3.165690448140097e-4,0.0,4.0442335110155446e-4,2.990052759764126e-4,0.0,3.780004614233832e-4,2.954918864100193e-4,2.954918864100193e-4,2.954918864100193e-4,2.954918864100193e-4,0.0,0.0,0.0,0.0,3.7466025157760897e-4,0.0,0.0,0.0,3.7466025157760897e-4,3.7466025157760897e-4,3.7466025157760897e-4,3.7466025157760897e-4]) (rev' @Double @7 barRelu t48) testBarReluADVal3 :: Assertion testBarReluADVal3 = assertEqualUpToEpsilon' 1e-10 (ringestData [3,1,2,2,1,2,2] [2.8846476339094805e-4,2.885038541771792e-4,2.885145151321922e-4,2.8854294397024206e-4,2.885852309100301e-4,2.885923176600045e-4,2.887454843457817e-4,2.886097295122454e-4,2.8846476339094805e-4,2.885038541771792e-4,2.885145151321922e-4,2.8854294397024206e-4,2.8860655161315664e-4,2.88595871110374e-4,2.887454843457817e-4,2.885884088500461e-4,2.884182085399516e-4,2.884075468755327e-4,2.8842176240868867e-4,2.8840399312321096e-4,0.0,2.8840370860416445e-4,2.884007943794131e-4,0.0,2.884469945274759e-4,2.8843242392031246e-4,2.884288700806792e-4,0.0,2.885212670262263e-4,2.884110805753153e-4,0.0,2.8849283778617973e-4,2.884075468755327e-4,2.884075468755327e-4,2.884075468755327e-4,2.884075468755327e-4,0.0,0.0,0.0,0.0,2.884892851579934e-4,0.0,0.0,0.0,2.884892851579934e-4,2.884892851579934e-4,2.884892851579934e-4,2.884892851579934e-4]) (rev' @Double @7 barRelu (rmap0N (* rscalar 0.001) t48)) braidedBuilds :: forall target n r. (ADReady target, GoodScalar r, KnownNat n, Differentiable r) => target (TKR (1 + n) r) -> target (TKR 2 r) braidedBuilds r = rbuild1 3 (\ix1 -> rbuild1 4 (\ix2 -> rindex (rfromList [rfromIndex0 ix2, rscalar 7, rsum0 (rslice 1 1 r), rscalar (-0.2)]) (ix1 :.: ZIR))) testBraidedBuilds :: Assertion testBraidedBuilds = assertEqualUpToEpsilon' 1e-10 (ringestData [4] [0.0,4.0,0.0,0.0]) (rev' @Double @2 (braidedBuilds @_ @0) (rreplicate0N [4] (rscalar 3.4))) testBraidedBuilds1 :: Assertion testBraidedBuilds1 = assertEqualUpToEpsilon' 1e-10 (ringestData [2,2,1,2,2] [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0]) (rev' @Double @2 braidedBuilds t16) recycled :: (ADReady target, GoodScalar r, KnownNat n) => target (TKR n r) -> target (TKR 7 r) recycled r = rbuild1 2 $ \_ -> rbuild1 4 $ \_ -> rbuild1 2 $ \_ -> rbuild1 3 $ \_ -> nestedSumBuildB (rreplicate 4 r) testRecycled :: Assertion testRecycled = assertEqualUpToEpsilon' 1e-6 (rrepl [2] 5616) (rev' @Double @7 (recycled @_ @_ @1) (rreplicate0N [2] (rscalar 1.0001))) {- testRecycled1 :: Assertion testRecycled1 = assertEqualUpToEpsilon' 1e-6 (ringestData [5, 4, 2] [5184.0,5184.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,5424.0,5424.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0,4992.0]) (rev' @Double @7 (recycled @_ @_ @3) (rreplicate0N [5, 4, 2] (rscalar 0.0002))) -} concatBuild :: forall target r n. (ADReady target, GoodScalar r, KnownNat n, Differentiable r) => target (TKR (1 + n) r) -> target (TKR (3 + n) r) concatBuild r = rbuild1 7 (\i -> rconcat [ rbuild1 5 (const r) , rbuild1 1 (\j -> rmap0N (* rfromIndex0 (j - i)) r) , rbuild1 11 (\j -> rmap0N (* (rfromIndex0 (kfromR (rprimalPart @target (rscalar 125)) * (j `remH` (abs (signum i + abs i) + 1)) + maxH j (i `quotH` (j + 1)) * (kfromR . rprimalPart . rfloor) (rsum0 r) - ifH (r <=. r &&* i <. j) (kfromR $ rprimalPart $ rminIndex (rflatten r)) ((kfromR . rprimalPart . rfloor) $ rsum0 $ r ! ((i * j) `remH` 7 :.: ZIR))))) r) , rbuild1 13 (\_k -> rsum $ rtr $ rreplicate (rwidth r) (rslice 0 1 r)) ]) testConcatBuild0 :: Assertion testConcatBuild0 = assertEqualUpToEpsilon' 1e-10 (ringestData [7] [16917.0,16280.0,16280.0,16280.0,16280.0,16280.0,16280.0]) (rev' @Double @3 concatBuild (ringestData [7] [0.651,0.14,0.3414,-0.14,0.0014,0.0020014,0.9999])) testConcatBuild1 :: Assertion testConcatBuild1 = assertEqualUpToEpsilon 1e-10 (ringestData [3,1,2,2,1,2,2] [1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4816999999999999e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3,1.4544e-3]) (grad (kfromR . rsum0 @9 @(TKScalar Double) . concatBuild . rmap0N (* rscalar 1e-7)) t48) concatBuildm :: forall target r n. (ADReady target, GoodScalar r, KnownNat n, Differentiable r) => target (TKR (1 + n) r) -> target (TKR (2 + n) r) concatBuildm r = rbuild1 7 (\i -> rmap0N (* (rfromIndex0 ((kfromR . rprimalPart . rfloor) $ rsum0 $ r ! (i :.: ZIR)))) r) testConcatBuild0m :: Assertion testConcatBuild0m = assertEqualUpToEpsilon' 1e-10 (ringestData [7] [-1.0,-1.0,-1.0,-1.0,-1.0,-1.0,-1.0]) (rev' @Double @2 concatBuildm (ringestData [7] [0.651,0.14,0.3414,-0.14,0.0014,0.0020014,0.9999])) testConcatBuild1m :: Assertion testConcatBuild1m = assertEqualUpToEpsilon' 1e-10 (ringestData [3,1,2,2,1,2,2] [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]) (rev' @Double @8 (concatBuildm . rmap0N (* rscalar 1e-7)) t48) concatBuild2 :: (ADReady target, GoodScalar r, KnownNat n) => target (TKR (1 + n) r) -> target (TKR (3 + n) r) concatBuild2 r = rbuild1 5 (\i -> rbuild1 2 (\j -> rmap0N (* rfromIndex0 (maxH j (i `quotH` (j + 1)))) r)) testConcatBuild2 :: Assertion testConcatBuild2 = assertEqualUpToEpsilon' 1e-10 (ringestData [3] [16.0,16.0,16.0]) (rev' @Double @3 concatBuild2 (ringestData [3] [0.651,0.14,0.3414])) testConcatBuild22 :: Assertion testConcatBuild22 = assertEqualUpToEpsilon' 1e-10 (ringestData [3,1,2,2,1,2,2] [16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0,16.0]) (rev' @Double @9 concatBuild2 t48) concatBuild3 :: (ADReady target, GoodScalar r) => target (TKR 1 r) -> target (TKR 2 r) concatBuild3 _r = rbuild1 5 (\i -> rbuild1 2 (\j -> rfromIndex0 (maxH j (i `quotH` (j + 1))))) testConcatBuild3 :: Assertion testConcatBuild3 = assertEqualUpToEpsilon' 1e-10 (ringestData [0] []) (rev' @Double @2 concatBuild3 (ringestData [0] [])) testLogistic0 :: Assertion testLogistic0 = assertEqualUpToEpsilon' 1e-10 (rscalar 4.5176659730912e-2) (rev' @Double @0 logistic (rscalar 3)) testLogistic5 :: Assertion testLogistic5 = assertEqualUpToEpsilon' 1e-10 (rfromListLinear [2,2,1,2,2] [6.648056670790033e-3,0.10499358540350662,2.466509291359931e-3,0.19661193324148185,0.1049935854035065,0.2499999999999375,0.24937604019289197,0.24751657271185995,2.0452222584760427e-6,1.2337934976493025e-4,3.3523767075636815e-4,1.7662706213291118e-2,1.7763568394002473e-15,4.540945566439111e-2,4.6588861451033536e-15,5.109024314693943e-12]) (rev' @Double @5 logistic t16) testLogistic52 :: Assertion testLogistic52 = assertEqualUpToEpsilon' 1e-10 (rfromListLinear [2,2,1,2,2] [1.3111246391159124e-3,2.1750075272657612e-2,4.8549901151740267e-4,4.312916016242333e-2,2.6155373652699744e-2,5.8750924453083504e-2,5.8238333583278255e-2,5.8847120842749026e-2,4.0211548220008027e-7,2.425923569564766e-5,6.592194028602285e-5,4.415319450597324e-3,3.4925295232121135e-16,9.12284655835676e-3,1.1647215362758384e-15,1.0044951474920845e-12]) (rev' @Double @5 (logistic . logistic) t16) logisticOld :: forall target r n. ( BaseTensor target, LetTensor target , BaseTensor (PrimalOf target), KnownNat n, GoodScalar r , Floating (PrimalOf target (TKR n r)) ) => target (TKR n r) -> target (TKR n r) logisticOld d0 = tlet d0 $ \d -> -- used in rprimalPart and in tdualPart let sh = rshape d y0 = recip (rrepl sh 1 + exp (- rprimalPart @target d)) in tlet (rfromPrimal @target y0) $ \y1 -> let y = rprimalPart @target y1 in tD knownSTK y (rScale @target (y * (rrepl sh 1 - y)) $ rdualPart @target d) testLogistic0Old :: Assertion testLogistic0Old = assertEqualUpToEpsilon' 1e-10 (rscalar 4.5176659730912e-2) (rev' @Double @0 logisticOld (rscalar 3)) testLogistic5Old :: Assertion testLogistic5Old = assertEqualUpToEpsilon' 1e-10 (rfromListLinear [2,2,1,2,2] [6.648056670790033e-3,0.10499358540350662,2.466509291359931e-3,0.19661193324148185,0.1049935854035065,0.2499999999999375,0.24937604019289197,0.24751657271185995,2.0452222584760427e-6,1.2337934976493025e-4,3.3523767075636815e-4,1.7662706213291118e-2,1.7763568394002473e-15,4.540945566439111e-2,4.6588861451033536e-15,5.109024314693943e-12]) (rev' @Double @5 logisticOld t16) testLogistic52Old :: Assertion testLogistic52Old = assertEqualUpToEpsilon' 1e-10 (rfromListLinear [2,2,1,2,2] [1.3111246391159124e-3,2.1750075272657612e-2,4.8549901151740267e-4,4.312916016242333e-2,2.6155373652699744e-2,5.8750924453083504e-2,5.8238333583278255e-2,5.8847120842749026e-2,4.0211548220008027e-7,2.425923569564766e-5,6.592194028602285e-5,4.415319450597324e-3,3.4925295232121135e-16,9.12284655835676e-3,1.1647215362758384e-15,1.0044951474920845e-12]) (rev' @Double @5 (logisticOld . logistic) t16) logisticA :: forall target r n. ( BaseTensor target, LetTensor target , BaseTensor (PrimalOf target), KnownNat n, GoodScalar r , Floating (PrimalOf target (TKR n r)) ) => target (TKR n r) -> target (TKR n r) logisticA d0 = tlet d0 $ \d -> -- used in rprimalPart and in tdualPart let sh = rshape d y0 = recip (rrepl sh 1 + exp (- rprimalPart @target d)) in tlet (rfromPrimal @target y0) $ \y1 -> let y = rprimalPart @target y1 in rfromPrimal y + rfromDual (rScale @target (y * (rrepl sh 1 - y)) $ rdualPart @target d) testLogisticA0 :: Assertion testLogisticA0 = assertEqualUpToEpsilon' 1e-10 (rscalar 4.5176659730912e-2) (rev' @Double @0 logisticA (rscalar 3)) logisticB :: forall target r n. ( BaseTensor target, LetTensor target , BaseTensor (PrimalOf target), KnownNat n, GoodScalar r , Floating (PrimalOf target (TKR n r)) ) => target (TKR n r) -> target (TKR n r) logisticB d0 = tlet d0 $ \d -> -- used in rprimalPart and in tdualPart let sh = rshape d y0 = recip (rrepl sh 1 + exp (- rprimalPart @target d)) in tlet (rfromPrimal @target y0) $ \y1 -> let y = rprimalPart @target y1 in rfromPrimal y + rfromDual (rdualPart @target d) testLogisticB0 :: Assertion testLogisticB0 = assertEqualUpToEpsilon' 1e-10 (rscalar 1) (rev' @Double @0 logisticB (rscalar 3)) logisticC :: forall target r n. ( BaseTensor target, LetTensor target , KnownNat n, GoodScalar r ) => target (TKR n r) -> target (TKR n r) logisticC d0 = tlet d0 $ \d -> -- used in rprimalPart and in tdualPart let y0 = rprimalPart @target d in rfromPrimal @target y0 testLogisticC0 :: Assertion testLogisticC0 = assertEqualUpToEpsilon' 1e-10 (rscalar 0) (rev' @Double @0 logisticC (rscalar 3))