{-# LANGUAGE AllowAmbiguousTypes, OverloadedLists #-} {-# OPTIONS_GHC -fno-cse #-} {-# OPTIONS_GHC -Wno-incomplete-uni-patterns #-} -- | Tests of convolution and disparity cost volume defined using the build -- operation of ranked tensors. module TestConvSimplified (testTrees) where import Prelude import Control.Exception.Assert.Sugar import GHC.Exts (IsList (..)) import GHC.TypeLits (KnownNat) import Test.Tasty import Test.Tasty.HUnit hiding (assert) import Data.Array.Nested qualified as Nested import Data.Array.Nested.Ranked.Shape import HordeAd import HordeAd.Core.AstEnv import HordeAd.Core.AstFreshId (resetVarCounter) import HordeAd.Core.AstInterpret import HordeAd.Core.AstTools import HordeAd.Core.CarriersAst import HordeAd.Core.Delta import HordeAd.Core.Ops import HordeAd.Core.OpsAst import CrossTesting import EqEpsilon testTrees :: [TestTree] testTrees = [ testCase "KonstG0Rev" testKonstG0Rev , testCase "KonstG0Tiny1" testKonstG0Tiny1 , testCase "KonstG0TinyS" testKonstG0TinyS , testCase "KonstG0TinyA" testKonstG0TinyA , testCase "KonstG0LittleA" testKonstG0LittleA , testCase "Replicate0Rev" testReplicate0Rev , testCase "Replicate0Tiny1" testReplicate0Tiny1 , testCase "Replicate0TinyS" testReplicate0TinyS , testCase "Replicate0TinyA" testReplicate0TinyA , testCase "Replicate0LittleA" testReplicate0LittleA , testCase "Konst5LittleB" testKonst5LittleB , testCase "Konst5LittleC" testKonst5LittleC , testCase "Konst5BigB" testKonst5BigB , testCase "KonstNotBigB" testKonstNotBigB , testCase "Konst5BigC" testKonst5BigC , testCase "KonstNotBigC" testKonstNotBigC , testCase "Konst5LittleB128b" testKonst5LittleB128b , testCase "Konst5LittleC128b" testKonst5LittleC128b , testCase "Konst5BigB128b" testKonst5BigB128b , testCase "KonstNotBigB128b" testKonstNotBigB128b , testCase "Konst5BigC128b" testKonst5BigC128b , testCase "KonstNotBigC128b" testKonstNotBigC128b , testCase "Konst5LittleB128c" testKonst5LittleB128c , testCase "Konst5LittleC128c" testKonst5LittleC128c , testCase "Konst5BigB128c" testKonst5BigB128c , testCase "KonstNotBigB128c" testKonstNotBigB128c , testCase "Konst5BigC128c" testKonst5BigC128c , testCase "KonstNotBigC128c" testKonstNotBigC128c -- , testCase "Konst5LittleB128bc" testKonst5LittleB128bc -- , testCase "Konst5LittleC128bc" testKonst5LittleC128bc -- , testCase "Konst5BigB128bc" testKonst5BigB128bc -- , testCase "KonstNotBigB128cb" testKonstNotBigB128cb -- , testCase "Konst5BigC128cb" testKonst5BigC128cb -- , testCase "KonstNotBigC128cb" testKonstNotBigC128cb , testCase "Replicate0RevLaborious" testReplicate0RevLaborious , testCase "Replicate0Tiny1Laborious" testReplicate0Tiny1Laborious , testCase "Replicate0TinySLaborious" testReplicate0TinySLaborious , testCase "Replicate0TinyALaborious" testReplicate0TinyALaborious , testCase "Replicate0LittleALaborious" testReplicate0LittleALaborious , testCase "Konst5LittleBLaborious" testKonst5LittleBLaborious , testCase "Konst5LittleCLaborious" testKonst5LittleCLaborious , testCase "Konst5BigBLaborious" testKonst5BigBLaborious , testCase "KonstNotBigBLaborious" testKonstNotBigBLaborious , testCase "Konst5BigCLaborious" testKonst5BigCLaborious , testCase "KonstNotBigCLaborious" testKonstNotBigCLaborious , testCase "Konst5LittleBLaborious128b" testKonst5LittleBLaborious128b , testCase "Konst5LittleCLaborious128b" testKonst5LittleCLaborious128b -- , testCase "Konst5BigBLaborious128b" testKonst5BigBLaborious128b -- , testCase "KonstNotBigBLaborious128b" testKonstNotBigBLaborious128b -- , testCase "Konst5BigCLaborious128b" testKonst5BigCLaborious128b -- , testCase "KonstNotBigCLaborious128b" testKonstNotBigCLaborious128b , testCase "Konst5LittleBLaborious128c" testKonst5LittleBLaborious128c , testCase "Konst5LittleCLaborious128c" testKonst5LittleCLaborious128c -- , testCase "Konst5BigBLaborious128c" testKonst5BigBLaborious128c -- , testCase "KonstNotBigBLaborious128c" testKonstNotBigBLaborious128c -- , testCase "Konst5BigCLaborious128c" testKonst5BigCLaborious128c -- , testCase "KonstNotBigCLaborious128c" testKonstNotBigCLaborious128c -- , testCase "Konst5LittleBLaborious128bc" testKonst5LittleBLaborious128bc -- , testCase "Konst5LittleCLaborious128bc" testKonst5LittleCLaborious128bc -- , testCase "Konst5BigBLaborious128bc" testKonst5BigBLaborious128bc -- , testCase "KonstNotBigBLaborious128cb" testKonstNotBigBLaborious128cb -- , testCase "Konst5BigCLaborious128cb" testKonst5BigCLaborious128cb -- , testCase "KonstNotBigCLaborious128cb" testKonstNotBigCLaborious128cb -- , testCase "Replicate0RevPadded" testReplicate0RevPadded , testCase "Replicate0Tiny1Padded" testReplicate0Tiny1Padded , testCase "Replicate0TinySPadded" testReplicate0TinySPadded , testCase "Replicate0TinyAPadded" testReplicate0TinyAPadded , testCase "Replicate0LittleAPadded" testReplicate0LittleAPadded -- , testCase "Konst5LittleBPadded" testKonst5LittleBPadded -- , testCase "Konst5LittleCPadded" testKonst5LittleCPadded -- , testCase "Konst5BigBPadded" testKonst5BigBPadded -- , testCase "KonstNotBigBPadded" testKonstNotBigBPadded -- , testCase "Konst5BigCPadded" testKonst5BigCPadded -- , testCase "KonstNotBigCPadded" testKonstNotBigCPadded -- , testCase "Konst5LittleBPadded128b" testKonst5LittleBPadded128b -- , testCase "Konst5LittleCPadded128b" testKonst5LittleCPadded128b -- , testCase "Konst5BigBPadded128b" testKonst5BigBPadded128b -- , testCase "KonstNotBigBPadded128b" testKonstNotBigBPadded128b -- , testCase "Konst5BigCPadded128b" testKonst5BigCPadded128b -- , testCase "KonstNotBigCPadded128b" testKonstNotBigCPadded128b -- , testCase "Konst5LittleBPadded128c" testKonst5LittleBPadded128c -- , testCase "Konst5LittleCPadded128c" testKonst5LittleCPadded128c -- , testCase "Konst5BigBPadded128c" testKonst5BigBPadded128c -- , testCase "KonstNotBigBPadded128c" testKonstNotBigBPadded128c -- , testCase "Konst5BigCPadded128c" testKonst5BigCPadded128c -- , testCase "KonstNotBigCPadded128c" testKonstNotBigCPadded128c -- , testCase "Konst5LittleBPadded128bc" testKonst5LittleBPadded128bc -- , testCase "Konst5LittleCPadded128bc" testKonst5LittleCPadded128bc -- , testCase "Konst5BigBPadded128bc" testKonst5BigBPadded128bc -- , testCase "KonstNotBigBPadded128cb" testKonstNotBigBPadded128cb -- , testCase "Konst5BigCPadded128cb" testKonst5BigCPadded128cb -- , testCase "KonstNotBigCPadded128cb" testKonstNotBigCPadded128cb , testCase "disparityKonst" test_disparityKonst , testCase "disparityKonst2" test_disparityKonst2 , testCase "disparitySmall" test_disparitySmall , testCase "ConvTomsSliceRev" testTomsSliceRev , testCase "ConvTomsSlice" testTomsSlice , testCase "ConvTomsSlicePP" testTomsSlicePP , testCase "minimizedCNNOPP0c" testCNNOPP0c , testCase "minimizedCNNOPP0b" testCNNOPP0b , testCase "minimizedCNNOPP1e" testCNNOPP1e , testCase "minimizedCNNOPP2" testCNNOPP2 , testCase "minimizedCNNOPP2b" testCNNOPP2b -- , testCase "minimizedCNNOPP3" testCNNOPP3 , testCase "minimizedCNNOPP3b" testCNNOPP3b , testCase "minimizedCNNOPP4" testCNNOPP4 , testCase "minimizedCNNOPP4b" testCNNOPP4b , testCase "minimizedCNNOPP5" testCNNOPP5 , testCase "minimizedCNNOPP5b" testCNNOPP5b , testCase "minimizedCNNOPP6" testCNNOPP6 , testCase "minimizedCNNOPP6b" testCNNOPP6b , testCase "minimizedCNNOPP7" testCNNOPP7 , testCase "minimizedCNNOPP7b" testCNNOPP7b -- , testCase "minimizedPaddedCNNOPP0c" testPaddedCNNOPP0c -- , testCase "minimizedPaddedCNNOPP0b" testPaddedCNNOPP0b -- , testCase "minimizedPaddedCNNOPP1e" testPaddedCNNOPP1e , testCase "minimizedPaddedCNNOPP1b" testPaddedCNNOPP1b , testCase "minimizedPaddedCNNOPPLet" testPaddedCNNOPPLet , testCase "minimizedPaddedCNNOPPLet2" testPaddedCNNOPPLet2 -- , testCase "minimizedPaddedCNNOPP2" testPaddedCNNOPP2 , testCase "minimizedCNNOPP0cW" testCNNOPP0cW , testCase "minimizedCNNOPP0bW" testCNNOPP0bW , testCase "minimizedCNNOPP1bW" testCNNOPP1bW , testCase "minimizedCNNOPP4bW" testCNNOPP4bW , testCase "minimizedCNNOPP4bD" testCNNOPP4bD , testCase "minimizedCNNOPP5aW" testCNNOPP5aW , testCase "minimizedCNNOPP5bW" testCNNOPP5bW , testCase "minimizedCNNOPP5cW" testCNNOPP5cW , testCase "minimizedCNNOPP5dW" testCNNOPP5dW ] -- The examples reproduced and transformed in this file are borrowed -- from https://github.com/benl23x5/adops. -- Here they are defined using ranked tensors. -- * A non-laborious version (depends on indexing OOB giving 0 consistently) conv2d1 :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2d1 = conv2dUnpadded $ rconcrete $ Nested.rfromListPrimLinear (fromList [1, 1, 1, 1]) [-0.2] conv2dA :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2dA = conv2dUnpadded $ rconcrete $ Nested.rfromListPrimLinear (fromList [1, 2, 1, 1]) [-0.2, 25.0003] conv2dB :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2dB = conv2dUnpadded (rconcrete $ unConcrete t16b) testKonstG0Rev :: Assertion testKonstG0Rev = assertEqualUpToEpsilon 1e-4 (rconcrete $ Nested.rfromListPrimLinear [2, 2, 2, 2] [18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001]) (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dB) (rrepl [2, 2, 2, 2] 0)) testKonstG0Tiny1 :: Assertion testKonstG0Tiny1 = assertEqualUpToEpsilon' 1e-10 (ringestData [1, 1, 1, 1] [-0.2]) (rev' @Double @4 conv2d1 (rrepl [1, 1, 1, 1] 0)) testKonstG0TinyS :: Assertion testKonstG0TinyS = assertEqualUpToEpsilon' 1e-10 (ringestData [1, 1, 1, 1] [582665.99432]) (rev' @Double @4 (conv2dUnpadded $ rreplicate0N [1, 1, 1, 1] (rsum0 (rconcrete $ unConcrete t16b))) (ringestData [1, 1, 1, 1] [0])) testKonstG0TinyA :: Assertion testKonstG0TinyA = assertEqualUpToEpsilon' 1e-10 (ringestData [1, 2, 1, 1] [-0.2,25.0003]) (rev' @Double @4 conv2dA (rrepl [1, 2, 1, 1] 0)) testKonstG0LittleA :: Assertion testKonstG0LittleA = assertEqualUpToEpsilon' 1e-10 (ringestData [2, 2, 2, 2] [-0.2,-0.2,-0.2,-0.2,25.0003,25.0003,25.0003,25.0003,-0.2,-0.2,-0.2,-0.2,25.0003,25.0003,25.0003,25.0003]) (rev' @Double @4 conv2dA (rrepl [2, 2, 2, 2] 0)) conv2dC :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2dC = flip conv2dUnpadded (rconcrete $ unConcrete t16b) conv2dB128b :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2dB128b = conv2dUnpadded (rconcrete $ unConcrete t128b) conv2dC128b :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2dC128b = flip conv2dUnpadded (rconcrete $ unConcrete t128b) conv2dB128c :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2dB128c = conv2dUnpadded (rconcrete $ unConcrete t128c) conv2dC128c :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2dC128c = flip conv2dUnpadded (rconcrete $ unConcrete t128c) testReplicate0Rev :: Assertion testReplicate0Rev = assertEqualUpToEpsilon 1e-4 (rconcrete $ Nested.rfromListPrimLinear [2, 2, 2, 2] [18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001]) (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dB) (rrepl [2, 2, 2, 2] 0)) testReplicate0Tiny1 :: Assertion testReplicate0Tiny1 = assertEqualUpToEpsilon' 1e-10 (ringestData [1, 1, 1, 1] [-0.2]) (rev' @Double @4 conv2d1 (rrepl [1, 1, 1, 1] 0)) testReplicate0TinyS :: Assertion testReplicate0TinyS = assertEqualUpToEpsilon' 1e-10 (ringestData [1, 1, 1, 1] [582665.99432]) (rev' @Double @4 (conv2dUnpadded $ rreplicate0N [1, 1, 1, 1] (rsum0 (rconcrete $ unConcrete t16b))) (ringestData [1, 1, 1, 1] [0])) testReplicate0TinyA :: Assertion testReplicate0TinyA = assertEqualUpToEpsilon' 1e-10 (ringestData [1, 2, 1, 1] [-0.2,25.0003]) (rev' @Double @4 conv2dA (rrepl [1, 2, 1, 1] 0)) testReplicate0LittleA :: Assertion testReplicate0LittleA = assertEqualUpToEpsilon' 1e-10 (ringestData [2, 2, 2, 2] [-0.2,-0.2,-0.2,-0.2,25.0003,25.0003,25.0003,25.0003,-0.2,-0.2,-0.2,-0.2,25.0003,25.0003,25.0003,25.0003]) (rev' @Double @4 conv2dA (rrepl [2, 2, 2, 2] 0)) -- with data t16 testKonst5LittleB :: Assertion testKonst5LittleB = assertEqualUpToEpsilon' 1e-8 (ringestData [2, 2, 2, 2] [18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001]) (rev' @Double @4 conv2dB (rreplicate0N [2, 2, 2, 2] (rscalar 5))) testKonst5LittleC :: Assertion testKonst5LittleC = assertEqualUpToEpsilon' 1e-8 (ringestData [2, 2, 2, 2] [40.1,8.0,11.0,-3.0,582625.89432,28.79432,-309.09999999999997,25.8,40.1,8.0,11.0,-3.0,582625.89432,28.79432,-309.09999999999997,25.8]) (rev' @Double @4 conv2dC (rreplicate0N [2, 2, 2, 2] (rscalar 5))) testKonst5BigB :: Assertion testKonst5BigB = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001]) (rev' @Double @4 conv2dB (rreplicate0N [3, 2, 4, 2] (rscalar 5))) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigB :: Assertion testKonstNotBigB = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001]) (rev' @Double @4 conv2dB (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10]))) testKonst5BigC :: Assertion testKonst5BigC = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0,40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0,40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0]) (rev' @Double @4 conv2dC (rreplicate0N [3, 2, 4, 2] (rscalar 5))) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigC :: Assertion testKonstNotBigC = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0,40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0,40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0]) (rev' @Double @4 conv2dC (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10]))) -- with data t128b testKonst5LittleB128b :: Assertion testKonst5LittleB128b = assertEqualUpToEpsilon' 1e-8 (ringestData [2, 2, 2, 2] [112.3003,251.5006,209.49462,482.69492000000014,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,112.3003,251.5006,209.49462,482.69492000000014,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004]) (rev' @Double @4 conv2dB128b (rreplicate0N [2, 2, 2, 2] (rscalar 5))) testKonst5LittleC128b :: Assertion testKonst5LittleC128b = assertEqualUpToEpsilon' 1e-8 (ringestData [2, 2, 2, 2] [1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987]) (rev' @Double @4 conv2dC128b (rreplicate0N [2, 2, 2, 2] (rscalar 5))) testKonst5BigB128b :: Assertion testKonst5BigB128b = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993]) (rev' @Double @4 conv2dB128b (rreplicate0N [3, 2, 4, 2] (rscalar 5))) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigB128b :: Assertion testKonstNotBigB128b = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993]) (rev' @Double @4 conv2dB128b (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10]))) testKonst5BigC128b :: Assertion testKonst5BigC128b = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001]) (rev' @Double @4 conv2dC128b (rreplicate0N [3, 2, 4, 2] (rscalar 5))) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigC128b :: Assertion testKonstNotBigC128b = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001]) (rev' @Double @4 conv2dC128b (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10]))) -- with data t128c testKonst5LittleB128c :: Assertion testKonst5LittleB128c = assertEqualUpToEpsilon' 1e-8 (ringestData [2, 2, 2, 2] [54.100300000000004,111.20060000000001,119.09462,270.29492000000005,58.2,140.3,90.4,212.4,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,58.2,140.3,90.4,212.4]) (rev' @Double @4 conv2dB128c (rreplicate0N [2, 2, 2, 2] (rscalar 5))) testKonst5LittleC128c :: Assertion testKonst5LittleC128c = assertEqualUpToEpsilon' 1e-8 (ringestData [2, 2, 2, 2] [2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992]) (rev' @Double @4 conv2dC128c (rreplicate0N [2, 2, 2, 2] (rscalar 5))) testKonst5BigB128c :: Assertion testKonst5BigB128c = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005]) (rev' @Double @4 conv2dB128c (rreplicate0N [3, 2, 4, 2] (rscalar 5))) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigB128c :: Assertion testKonstNotBigB128c = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005]) (rev' @Double @4 conv2dB128c (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10]))) testKonst5BigC128c :: Assertion testKonst5BigC128c = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002]) (rev' @Double @4 conv2dC128c (rreplicate0N [3, 2, 4, 2] (rscalar 5))) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigC128c :: Assertion testKonstNotBigC128c = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002]) (rev' @Double @4 conv2dC128c (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10]))) -- with data t128b and t128c {- testKonst5LittleB128bc :: Assertion testKonst5LittleB128bc = assertEqualUpToEpsilon' 1e-8 (ringestData [2,2,8,4] [112.3003,251.5006,417.79492000000005,494.89491000000015,209.49462,482.69492000000014,778.9778800001002,952.0721900001002,229.49462000000003,610.5892400000002,1113.1722000001,1412.1551500001997,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,3.000000000000032,65.90000000000003,106.90000000000003,173.90000000000006,164.10000000000002,365.89432000010004,593.1946200001,821.2892400002004,667.2003000000001,1060.8778800002,1500.2781800001994,1870.0614400004986,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,112.3003,251.5006,417.79492000000005,494.89491000000015,209.49462,482.69492000000014,778.9778800001002,952.0721900001002,229.49462000000003,610.5892400000002,1113.1722000001,1412.1551500001997,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,3.000000000000032,65.90000000000003,106.90000000000003,173.90000000000006,164.10000000000002,365.89432000010004,593.1946200001,821.2892400002004,667.2003000000001,1060.8778800002,1500.2781800001994,1870.0614400004986,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898]) (rev' @Double @4 conv2dB128b t128c) testKonst5LittleC128bc :: Assertion testKonst5LittleC128bc = assertEqualUpToEpsilon' 1e-8 (ringestData [2,2,8,4] [1627.8210700004993,1571.2321300004994,1047.1431900004002,393.6715900002,1132.9261600005002,1188.6375200005,803.7488800004002,316.57160000019996,675.7488800003999,828.6545600004001,577.7659200003001,220.57728000019998,215.6659200003,388.5716000003,245.5772800002,94.68864000010001,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2725.0393200008984,1831.7390200008983,1259.3728000004999,568.6722000005001,2551.139320000898,1660.8390200008987,1151.3728000005,501.6722000005,1903.750080000699,1174.5497800006997,803.9778800004001,340.5775800004001,854.9778800004001,628.8778800004001,450.1892400002,198.8889400002,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1627.8210700004993,1571.2321300004994,1047.1431900004002,393.6715900002,1132.9261600005002,1188.6375200005,803.7488800004002,316.57160000019996,675.7488800003999,828.6545600004001,577.7659200003001,220.57728000019998,215.6659200003,388.5716000003,245.5772800002,94.68864000010001,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2725.0393200008984,1831.7390200008983,1259.3728000004999,568.6722000005001,2551.139320000898,1660.8390200008987,1151.3728000005,501.6722000005,1903.750080000699,1174.5497800006997,803.9778800004001,340.5775800004001,854.9778800004001,628.8778800004001,450.1892400002,198.8889400002,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]) (rev' @Double @4 conv2dC128b t128c) testKonst5BigB128bc :: Assertion testKonst5BigB128bc = assertEqualUpToEpsilon' 1e-8 (ringestData [2,2,8,4] [112.3003,251.5006,417.79492000000005,494.89491000000015,209.49462,482.69492000000014,778.9778800001002,952.0721900001002,229.49462000000003,610.5892400000002,1113.1722000001,1412.1551500001997,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,3.000000000000032,65.90000000000003,106.90000000000003,173.90000000000006,164.10000000000002,365.89432000010004,593.1946200001,821.2892400002004,667.2003000000001,1060.8778800002,1500.2781800001994,1870.0614400004986,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,112.3003,251.5006,417.79492000000005,494.89491000000015,209.49462,482.69492000000014,778.9778800001002,952.0721900001002,229.49462000000003,610.5892400000002,1113.1722000001,1412.1551500001997,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,3.000000000000032,65.90000000000003,106.90000000000003,173.90000000000006,164.10000000000002,365.89432000010004,593.1946200001,821.2892400002004,667.2003000000001,1060.8778800002,1500.2781800001994,1870.0614400004986,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898]) (rev' @Double @4 conv2dB128b t128c) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigB128cb :: Assertion testKonstNotBigB128cb = assertEqualUpToEpsilon' 1e-8 (ringestData [4,2,4,4] [54.100300000000004,111.20060000000001,191.4006,228.4006,119.09462,270.29492000000005,435.28356000009995,519.1778800001,109.09462000000002,318.19492,563.3835600001,687.2778800001003,174.08894000000004,477.28924000000006,774.2665200002001,931.9551600002003,58.2,140.3,226.39432,266.49431000000004,90.4,212.4,343.69432000000006,432.89431000000013,120.4,292.39432000000005,549.78864,724.8772700001001,-117.5,103.38864000010005,459.88296000009996,695.8659100003002,54.100300000000004,111.20060000000001,191.4006,228.4006,119.09462,270.29492000000005,435.28356000009995,519.1778800001,109.09462000000002,318.19492,563.3835600001,687.2778800001003,174.08894000000004,477.28924000000006,774.2665200002001,931.9551600002003,58.2,140.3,226.39432,266.49431000000004,90.4,212.4,343.69432000000006,432.89431000000013,120.4,292.39432000000005,549.78864,724.8772700001001,-117.5,103.38864000010005,459.88296000009996,695.8659100003002,54.100300000000004,111.20060000000001,191.4006,228.4006,119.09462,270.29492000000005,435.28356000009995,519.1778800001,109.09462000000002,318.19492,563.3835600001,687.2778800001003,174.08894000000004,477.28924000000006,774.2665200002001,931.9551600002003,58.2,140.3,226.39432,266.49431000000004,90.4,212.4,343.69432000000006,432.89431000000013,120.4,292.39432000000005,549.78864,724.8772700001001,-117.5,103.38864000010005,459.88296000009996,695.8659100003002,54.100300000000004,111.20060000000001,191.4006,228.4006,119.09462,270.29492000000005,435.28356000009995,519.1778800001,109.09462000000002,318.19492,563.3835600001,687.2778800001003,174.08894000000004,477.28924000000006,774.2665200002001,931.9551600002003,58.2,140.3,226.39432,266.49431000000004,90.4,212.4,343.69432000000006,432.89431000000013,120.4,292.39432000000005,549.78864,724.8772700001001,-117.5,103.38864000010005,459.88296000009996,695.8659100003002]) (rev' @Double @4 conv2dB128c t128b) testKonst5BigC128cb :: Assertion testKonst5BigC128cb = assertEqualUpToEpsilon' 1e-8 (ringestData [4,2,4,4] [2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003]) (rev' @Double @4 conv2dC128c t128b) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigC128cb :: Assertion testKonstNotBigC128cb = assertEqualUpToEpsilon' 1e-8 (ringestData [4,2,4,4] [2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003]) (rev' @Double @4 conv2dC128c t128b) -} -- * A laborious version (meaning, out of bounds indexing is handled explicitly) -- | Unpadded full convolution, -- where the output size is the same as the input size. -- -- It guards the out of bounds indexing behind a conditional -- to prevent changed values after vectorization, -- but the guarding is no longer needed, so this is only for testing. -- -- BTW, the indexing lower bounds in the code are spurious, -- so they get simplified away in the resulting AST program. conv2dUnpaddedL :: (ADReady target, GoodScalar r) => target (TKR 4 r) -> target (TKR 4 r) -> target (TKR 4 r) conv2dUnpaddedL arrK arrA = let [nImgs, nCinpA, nAh, nAw] = rshape arrA [nCoutK, nCinpK, nKh, nKw] = rshape arrK nCinp = assert (nCinpA == nCinpK `blame` (nCinpA, nCinpK)) nCinpA shB = [nImgs, nCoutK, nAh, nAw] shK1 = [1, nCinp, nKh, nKw] in rbuild shB $ \case [iImg, iCout, iBh, iBw] -> let arrAt = slicezL shK1 arrA [iImg, 0, iBh, iBw] arrKt = slicezL shK1 arrK [iCout, 0, 0, 0] in rdot0 arrAt arrKt _ -> error "conv2dUnpaddedL: impossible pattern needlessly required" -- | Slice a section out of a tensor, -- given a base offset and shape of the section. -- -- If the slice extends out side the source array then the corresponding -- elements are set to zero. slicezL :: (ADReady target, GoodScalar r, KnownNat n) => IShR n -> target (TKR n r) -> IxROf target n -> target (TKR n r) slicezL shOut d ixBase = rbuild shOut $ \ixResult -> indexz0L d (ixrZipWith (+) ixBase ixResult) -- | Retrieve the element at the given index, -- returning zero for out of range indices. -- -- Warning: this uses ix twice and within0 again uses it twice, -- so this variant without tlet should be used only when it's known -- that ix is of small constant size (e.g., if it contains conditionals -- that compare big tensors or their minimal elements, it likely is not, -- unless the tensors are under tlet and only variables representing them -- are used). indexz0L :: forall target r n. (ADReady target, GoodScalar r, KnownNat n) => target (TKR n r) -> IxROf target n -> target (TKR 0 r) indexz0L d ix = ifH (within0 @target (rshape @target d) ix) (d ! ix) (rscalar 0) -- | Given an index and shape, check if the index is fully within the shape. -- Note that @ix@ is used twice, so should be shared outside. within0 :: forall target n. (ADReady target, KnownNat n) => IShR n -> IxROf target n -> BoolOf target within0 sh ix = let within :: IntOf target -> IntOf target -> BoolOf target within i dim = 0 <=. i &&* dim >. i in foldr (&&*) true $ zipWith within (toList ix) (map fromIntegral $ toList sh) conv2d1Laborious :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2d1Laborious = conv2dUnpaddedL $ rconcrete $ Nested.rfromListPrimLinear (fromList [1, 1, 1, 1]) [-0.2] conv2dALaborious :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2dALaborious = conv2dUnpaddedL $ rconcrete $ Nested.rfromListPrimLinear (fromList [1, 2, 1, 1]) [-0.2, 25.0003] conv2dBLaborious :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2dBLaborious = conv2dUnpaddedL (rconcrete $ unConcrete t16b) conv2dCLaborious :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2dCLaborious = flip conv2dUnpaddedL (rconcrete $ unConcrete t16b) conv2dBLaborious128b :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2dBLaborious128b = conv2dUnpaddedL (rconcrete $ unConcrete t128b) conv2dCLaborious128b :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2dCLaborious128b = flip conv2dUnpaddedL (rconcrete $ unConcrete t128b) conv2dBLaborious128c :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2dBLaborious128c = conv2dUnpaddedL (rconcrete $ unConcrete t128c) conv2dCLaborious128c :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2dCLaborious128c = flip conv2dUnpaddedL (rconcrete $ unConcrete t128c) testReplicate0RevLaborious :: Assertion testReplicate0RevLaborious = assertEqualUpToEpsilon 1e-4 (rconcrete $ Nested.rfromListPrimLinear [2, 2, 2, 2] [18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001]) (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dBLaborious) (rrepl [2, 2, 2, 2] 0)) testReplicate0Tiny1Laborious :: Assertion testReplicate0Tiny1Laborious = assertEqualUpToEpsilon' 1e-10 (ringestData [1, 1, 1, 1] [-0.2]) (rev' @Double @4 conv2d1Laborious (rrepl [1, 1, 1, 1] 0)) testReplicate0TinySLaborious :: Assertion testReplicate0TinySLaborious = assertEqualUpToEpsilon' 1e-10 (ringestData [1, 1, 1, 1] [582665.99432]) (rev' @Double @4 (conv2dUnpaddedL $ rreplicate0N [1, 1, 1, 1] (rsum0 (rconcrete $ unConcrete t16b))) (ringestData [1, 1, 1, 1] [0])) testReplicate0TinyALaborious :: Assertion testReplicate0TinyALaborious = assertEqualUpToEpsilon' 1e-10 (ringestData [1, 2, 1, 1] [-0.2,25.0003]) (rev' @Double @4 conv2dALaborious (rrepl [1, 2, 1, 1] 0)) testReplicate0LittleALaborious :: Assertion testReplicate0LittleALaborious = assertEqualUpToEpsilon' 1e-10 (ringestData [2, 2, 2, 2] [-0.2,-0.2,-0.2,-0.2,25.0003,25.0003,25.0003,25.0003,-0.2,-0.2,-0.2,-0.2,25.0003,25.0003,25.0003,25.0003]) (rev' @Double @4 conv2dALaborious (rrepl [2, 2, 2, 2] 0)) -- with data t16 testKonst5LittleBLaborious :: Assertion testKonst5LittleBLaborious = assertEqualUpToEpsilon' 1e-8 (ringestData [2, 2, 2, 2] [18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001]) (rev' @Double @4 conv2dBLaborious (rreplicate0N [2, 2, 2, 2] (rscalar 5))) testKonst5LittleCLaborious :: Assertion testKonst5LittleCLaborious = assertEqualUpToEpsilon' 1e-8 (ringestData [2, 2, 2, 2] [40.1,8.0,11.0,-3.0,582625.89432,28.79432,-309.09999999999997,25.8,40.1,8.0,11.0,-3.0,582625.89432,28.79432,-309.09999999999997,25.8]) (rev' @Double @4 conv2dCLaborious (rreplicate0N [2, 2, 2, 2] (rscalar 5))) testKonst5BigBLaborious :: Assertion testKonst5BigBLaborious = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001]) (rev' @Double @4 conv2dBLaborious (rreplicate0N [3, 2, 4, 2] (rscalar 5))) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigBLaborious :: Assertion testKonstNotBigBLaborious = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,32.1,40.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,582597.1,582625.8943200001,582597.1,582625.8943200001]) (rev' @Double @4 conv2dBLaborious (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10]))) testKonst5BigCLaborious :: Assertion testKonst5BigCLaborious = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0,40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0,40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0]) (rev' @Double @4 conv2dCLaborious (rreplicate0N [3, 2, 4, 2] (rscalar 5))) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigCLaborious :: Assertion testKonstNotBigCLaborious = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0,40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0,40.1,8.0,11.0,-3.0,0.0,0.0,0.0,0.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,0.0,0.0,0.0,0.0]) (rev' @Double @4 conv2dCLaborious (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10]))) -- with data t128b testKonst5LittleBLaborious128b :: Assertion testKonst5LittleBLaborious128b = assertEqualUpToEpsilon' 1e-8 (ringestData [2, 2, 2, 2] [112.3003,251.5006,209.49462,482.69492000000014,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,112.3003,251.5006,209.49462,482.69492000000014,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004]) (rev' @Double @4 conv2dBLaborious128b (rreplicate0N [2, 2, 2, 2] (rscalar 5))) testKonst5LittleCLaborious128b :: Assertion testKonst5LittleCLaborious128b = assertEqualUpToEpsilon' 1e-8 (ringestData [2, 2, 2, 2] [1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987]) (rev' @Double @4 conv2dCLaborious128b (rreplicate0N [2, 2, 2, 2] (rscalar 5))) {- testKonst5BigBLaborious128b :: Assertion testKonst5BigBLaborious128b = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993]) (rev' @Double @4 conv2dBLaborious128b (rreplicate0N [3, 2, 4, 2] (rscalar 5))) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigBLaborious128b :: Assertion testKonstNotBigBLaborious128b = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993]) (rev' @Double @4 conv2dBLaborious128b (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10]))) testKonst5BigCLaborious128b :: Assertion testKonst5BigCLaborious128b = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001]) (rev' @Double @4 conv2dCLaborious128b (rreplicate0N [3, 2, 4, 2] (rscalar 5))) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigCLaborious128b :: Assertion testKonstNotBigCLaborious128b = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001]) (rev' @Double @4 conv2dCLaborious128b (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10]))) -} -- with data t128c testKonst5LittleBLaborious128c :: Assertion testKonst5LittleBLaborious128c = assertEqualUpToEpsilon' 1e-8 (ringestData [2, 2, 2, 2] [54.100300000000004,111.20060000000001,119.09462,270.29492000000005,58.2,140.3,90.4,212.4,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,58.2,140.3,90.4,212.4]) (rev' @Double @4 conv2dBLaborious128c (rreplicate0N [2, 2, 2, 2] (rscalar 5))) testKonst5LittleCLaborious128c :: Assertion testKonst5LittleCLaborious128c = assertEqualUpToEpsilon' 1e-8 (ringestData [2, 2, 2, 2] [2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992]) (rev' @Double @4 conv2dCLaborious128c (rreplicate0N [2, 2, 2, 2] (rscalar 5))) {- testKonst5BigBLaborious128c :: Assertion testKonst5BigBLaborious128c = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005]) (rev' @Double @4 conv2dBLaborious128c (rreplicate0N [3, 2, 4, 2] (rscalar 5))) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigBLaborious128c :: Assertion testKonstNotBigBLaborious128c = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005]) (rev' @Double @4 conv2dBLaborious128c (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10]))) testKonst5BigCLaborious128c :: Assertion testKonst5BigCLaborious128c = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002]) (rev' @Double @4 conv2dCLaborious128c (rreplicate0N [3, 2, 4, 2] (rscalar 5))) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigCLaborious128c :: Assertion testKonstNotBigCLaborious128c = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002]) (rev' @Double @4 conv2dCLaborious128c (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10]))) -} -- with data t128b and t128c {- testKonst5LittleBLaborious128bc :: Assertion testKonst5LittleBLaborious128bc = assertEqualUpToEpsilon' 1e-8 (ringestData [2,2,8,4] [112.3003,251.5006,417.79492000000005,494.89491000000015,209.49462,482.69492000000014,778.9778800001002,952.0721900001002,229.49462000000003,610.5892400000002,1113.1722000001,1412.1551500001997,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,3.000000000000032,65.90000000000003,106.90000000000003,173.90000000000006,164.10000000000002,365.89432000010004,593.1946200001,821.2892400002004,667.2003000000001,1060.8778800002,1500.2781800001994,1870.0614400004986,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,112.3003,251.5006,417.79492000000005,494.89491000000015,209.49462,482.69492000000014,778.9778800001002,952.0721900001002,229.49462000000003,610.5892400000002,1113.1722000001,1412.1551500001997,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,3.000000000000032,65.90000000000003,106.90000000000003,173.90000000000006,164.10000000000002,365.89432000010004,593.1946200001,821.2892400002004,667.2003000000001,1060.8778800002,1500.2781800001994,1870.0614400004986,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898]) (rev' @Double @4 conv2dBLaborious128b t128c) testKonst5LittleCLaborious128bc :: Assertion testKonst5LittleCLaborious128bc = assertEqualUpToEpsilon' 1e-8 (ringestData [2,2,8,4] [1627.8210700004993,1571.2321300004994,1047.1431900004002,393.6715900002,1132.9261600005002,1188.6375200005,803.7488800004002,316.57160000019996,675.7488800003999,828.6545600004001,577.7659200003001,220.57728000019998,215.6659200003,388.5716000003,245.5772800002,94.68864000010001,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2725.0393200008984,1831.7390200008983,1259.3728000004999,568.6722000005001,2551.139320000898,1660.8390200008987,1151.3728000005,501.6722000005,1903.750080000699,1174.5497800006997,803.9778800004001,340.5775800004001,854.9778800004001,628.8778800004001,450.1892400002,198.8889400002,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,1627.8210700004993,1571.2321300004994,1047.1431900004002,393.6715900002,1132.9261600005002,1188.6375200005,803.7488800004002,316.57160000019996,675.7488800003999,828.6545600004001,577.7659200003001,220.57728000019998,215.6659200003,388.5716000003,245.5772800002,94.68864000010001,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,2725.0393200008984,1831.7390200008983,1259.3728000004999,568.6722000005001,2551.139320000898,1660.8390200008987,1151.3728000005,501.6722000005,1903.750080000699,1174.5497800006997,803.9778800004001,340.5775800004001,854.9778800004001,628.8778800004001,450.1892400002,198.8889400002,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]) (rev' @Double @4 conv2dCLaborious128b t128c) testKonst5BigBLaborious128bc :: Assertion testKonst5BigBLaborious128bc = assertEqualUpToEpsilon' 1e-8 (ringestData [2,2,8,4] [112.3003,251.5006,417.79492000000005,494.89491000000015,209.49462,482.69492000000014,778.9778800001002,952.0721900001002,229.49462000000003,610.5892400000002,1113.1722000001,1412.1551500001997,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,3.000000000000032,65.90000000000003,106.90000000000003,173.90000000000006,164.10000000000002,365.89432000010004,593.1946200001,821.2892400002004,667.2003000000001,1060.8778800002,1500.2781800001994,1870.0614400004986,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,112.3003,251.5006,417.79492000000005,494.89491000000015,209.49462,482.69492000000014,778.9778800001002,952.0721900001002,229.49462000000003,610.5892400000002,1113.1722000001,1412.1551500001997,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,56.58894000000004,580.6778800001001,1234.1494800003,1627.8210700004993,3.000000000000032,65.90000000000003,106.90000000000003,173.90000000000006,164.10000000000002,365.89432000010004,593.1946200001,821.2892400002004,667.2003000000001,1060.8778800002,1500.2781800001994,1870.0614400004986,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898,893.3003,1465.6665200003993,2156.3671200003987,2725.039320000898]) (rev' @Double @4 conv2dBLaborious128b t128c) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigBLaborious128cb :: Assertion testKonstNotBigBLaborious128cb = assertEqualUpToEpsilon' 1e-8 (ringestData [4,2,4,4] [54.100300000000004,111.20060000000001,191.4006,228.4006,119.09462,270.29492000000005,435.28356000009995,519.1778800001,109.09462000000002,318.19492,563.3835600001,687.2778800001003,174.08894000000004,477.28924000000006,774.2665200002001,931.9551600002003,58.2,140.3,226.39432,266.49431000000004,90.4,212.4,343.69432000000006,432.89431000000013,120.4,292.39432000000005,549.78864,724.8772700001001,-117.5,103.38864000010005,459.88296000009996,695.8659100003002,54.100300000000004,111.20060000000001,191.4006,228.4006,119.09462,270.29492000000005,435.28356000009995,519.1778800001,109.09462000000002,318.19492,563.3835600001,687.2778800001003,174.08894000000004,477.28924000000006,774.2665200002001,931.9551600002003,58.2,140.3,226.39432,266.49431000000004,90.4,212.4,343.69432000000006,432.89431000000013,120.4,292.39432000000005,549.78864,724.8772700001001,-117.5,103.38864000010005,459.88296000009996,695.8659100003002,54.100300000000004,111.20060000000001,191.4006,228.4006,119.09462,270.29492000000005,435.28356000009995,519.1778800001,109.09462000000002,318.19492,563.3835600001,687.2778800001003,174.08894000000004,477.28924000000006,774.2665200002001,931.9551600002003,58.2,140.3,226.39432,266.49431000000004,90.4,212.4,343.69432000000006,432.89431000000013,120.4,292.39432000000005,549.78864,724.8772700001001,-117.5,103.38864000010005,459.88296000009996,695.8659100003002,54.100300000000004,111.20060000000001,191.4006,228.4006,119.09462,270.29492000000005,435.28356000009995,519.1778800001,109.09462000000002,318.19492,563.3835600001,687.2778800001003,174.08894000000004,477.28924000000006,774.2665200002001,931.9551600002003,58.2,140.3,226.39432,266.49431000000004,90.4,212.4,343.69432000000006,432.89431000000013,120.4,292.39432000000005,549.78864,724.8772700001001,-117.5,103.38864000010005,459.88296000009996,695.8659100003002]) (rev' @Double @4 conv2dBLaborious128c t128b) testKonst5BigCLaborious128cb :: Assertion testKonst5BigCLaborious128cb = assertEqualUpToEpsilon' 1e-8 (ringestData [4,2,4,4] [2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003]) (rev' @Double @4 conv2dCLaborious128c t128b) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigCLaborious128cb :: Assertion testKonstNotBigCLaborious128cb = assertEqualUpToEpsilon' 1e-8 (ringestData [4,2,4,4] [2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003,2640.8154000007976,1836.3264600007988,1163.4488800005001,483.7716000003,2412.414800000798,1662.026160000799,1046.2488800005003,446.7716000003,2121.6375200006987,1436.2432000006995,914.5659200004003,399.8772800003,1953.5375200006988,1258.1432000006998,794.3659200004003,359.8772800003,1712.044990000598,1566.644690000599,1143.0671100004001,478.5721900004001,1445.5506800005985,1358.3503800005992,1016.8728000004002,438.47220000040005,1279.150680000599,1224.1503800005996,922.5728000004001,389.3722000004,987.1677200004992,962.1674200005002,710.5841600003,303.48356000030003]) (rev' @Double @4 conv2dCLaborious128c t128b) -} -- * A padded version (out of bounds indexing is not possible) -- | Full convolution with just enough extra external zero padding -- to ensure that the output size is the same as the input size -- and all input points are read the same number of times. -- -- The same result could be accomplished by tweaking indexes slightly -- in conv2dUnpadded, but here additionally all bounds checks in the code -- are spurious and will be simplified away in the resulting AST program. conv2dPadded :: forall target r. (ADReady target, GoodScalar r) => target (TKR 4 r) -> target (TKR 4 r) -> target (TKR 4 r) conv2dPadded arrK arrA = let [nImgs, nCinpA, nAh, nAw] = rshape arrA [nCoutK, nCinpK, nKh, nKw] = rshape arrK shAPadded = [nImgs, nCinpA, nAh + nKh, nAw + nKw] arrAPadded = rbuild @4 @0 @(TKScalar r) @target shAPadded $ \case [iImg, iCinp, iPh, iPw] -> ifH (iPh <. fromIntegral (nKh `div` 2) ||* iPw <. fromIntegral (nKw `div` 2) ||* iPh >=. fromIntegral (nAh + nKh `div` 2) ||* iPw >=. fromIntegral (nAw + nKw `div` 2)) (rscalar 0) (arrA ! [ iImg , iCinp , iPh - fromIntegral (nKh `div` 2) , iPw - fromIntegral (nKw `div` 2) ]) nCinp = assert (nCinpA == nCinpK `blame` (nCinpA, nCinpK)) nCinpA shB = [nImgs, nCoutK, nAh, nAw] shK1 = [1, nCinp, nKh, nKw] in rbuild shB $ \case [iImg, iCout, iBh, iBw] -> let arrAt = slicezL shK1 arrAPadded [iImg, 0, iBh, iBw] arrKt = slicezL shK1 arrK [iCout, 0, 0, 0] in rdot0 arrAt arrKt _ -> error "conv2dPadded: impossible pattern needlessly required" conv2d1Padded :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2d1Padded = conv2dPadded $ rconcrete $ Nested.rfromListPrimLinear (fromList [1, 1, 1, 1]) [-0.2] conv2dAPadded :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2dAPadded = conv2dPadded $ rconcrete $ Nested.rfromListPrimLinear (fromList [1, 2, 1, 1]) [-0.2, 25.0003] conv2dBPadded :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2dBPadded = conv2dPadded (rconcrete $ unConcrete t16b) conv2dCPadded :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2dCPadded = flip conv2dPadded (rconcrete $ unConcrete t16b) conv2dBPadded128b :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2dBPadded128b = conv2dPadded (rconcrete $ unConcrete t128b) conv2dCPadded128b :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) conv2dCPadded128b = flip conv2dPadded (rconcrete $ unConcrete t128b) _conv2dBPadded128c :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) _conv2dBPadded128c = conv2dPadded (rconcrete $ unConcrete t128c) _conv2dCPadded128c :: (ADReady target, GoodScalar r, Differentiable r) => target (TKR 4 r) -> target (TKR 4 r) _conv2dCPadded128c = flip conv2dPadded (rconcrete $ unConcrete t128c) -- TODO: OOMs _testReplicate0RevPadded :: Assertion _testReplicate0RevPadded = assertEqualUpToEpsilon 1e-4 (rconcrete $ Nested.rfromListPrimLinear [2, 2, 2, 2] [40.1,8.0,11.0,-3.0,582625.89432,28.79432,-309.09999999999997,25.8,40.1,8.0,11.0,-3.0,582625.89432,28.79432,-309.09999999999997,25.8]) (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dBPadded) (rrepl [2, 2, 2, 2] 0)) testReplicate0Tiny1Padded :: Assertion testReplicate0Tiny1Padded = assertEqualUpToEpsilon 1e-10 (ringestData [1, 1, 1, 1] [-0.2]) (cgrad (kfromR . rsum0 @4 @(TKScalar Double) . conv2d1Padded) (rrepl [1, 1, 1, 1] 0)) testReplicate0TinySPadded :: Assertion testReplicate0TinySPadded = assertEqualUpToEpsilon 1e-10 (ringestData [1, 1, 1, 1] [582665.99432]) (grad (kfromR . rsum0 @4 @(TKScalar Double) . (conv2dPadded $ rreplicate0N [1, 1, 1, 1] (rsum0 (rconcrete $ unConcrete t16b)))) (ringestData [1, 1, 1, 1] [0])) testReplicate0TinyAPadded :: Assertion testReplicate0TinyAPadded = assertEqualUpToEpsilon 1e-10 (ringestData [1, 2, 1, 1] [-0.2,25.0003]) (cgrad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dAPadded) (rrepl [1, 2, 1, 1] 0)) testReplicate0LittleAPadded :: Assertion testReplicate0LittleAPadded = assertEqualUpToEpsilon 1e-10 (ringestData [2, 2, 2, 2] [-0.2,-0.2,-0.2,-0.2,25.0003,25.0003,25.0003,25.0003,-0.2,-0.2,-0.2,-0.2,25.0003,25.0003,25.0003,25.0003]) (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dAPadded) (rrepl [2, 2, 2, 2] 0)) -- with data t16 -- TODO: OOMs _testKonst5LittleBPadded :: Assertion _testKonst5LittleBPadded = assertEqualUpToEpsilon 1e-8 (ringestData [2, 2, 2, 2] [40.1,8.0,11.0,-3.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,40.1,8.0,11.0,-3.0,582625.8943200001,28.794320000000003,-309.09999999999997,25.8]) (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dBPadded) (rreplicate0N [2, 2, 2, 2] (rscalar 5))) -- TODO: OOMs _testKonst5LittleCPadded :: Assertion _testKonst5LittleCPadded = assertEqualUpToEpsilon 1e-8 (ringestData [2, 2, 2, 2] [18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001,18.1,29.1,32.1,40.1,582932.0,582934.99432,582597.1,582625.8943200001]) (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dCPadded) (rreplicate0N [2, 2, 2, 2] (rscalar 5))) -- TODO: OOMs _testKonst5BigBPadded :: Assertion _testKonst5BigBPadded = assertEqualUpToEpsilon 1e-8 (ringestData [3, 2, 4, 2] [40.1,8.0,40.1,8.0,40.1,8.0,11.0,-3.0,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,40.1,8.0,40.1,8.0,40.1,8.0,11.0,-3.0,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,40.1,8.0,40.1,8.0,40.1,8.0,11.0,-3.0,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,-309.09999999999997,25.8]) (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dBPadded) (rreplicate0N [3, 2, 4, 2] (rscalar 5))) -- TODO: OOMs -- The gradient is the same as above, because one argument is the same -- and convolution is linear. _testKonstNotBigBPadded :: Assertion _testKonstNotBigBPadded = assertEqualUpToEpsilon 1e-8 (ringestData [3, 2, 4, 2] [40.1,8.0,40.1,8.0,40.1,8.0,11.0,-3.0,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,40.1,8.0,40.1,8.0,40.1,8.0,11.0,-3.0,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,-309.09999999999997,25.8,40.1,8.0,40.1,8.0,40.1,8.0,11.0,-3.0,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,582625.8943200001,28.794320000000003,-309.09999999999997,25.8]) (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dBPadded) (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10]))) -- TODO: OOMs _testKonst5BigCPadded :: Assertion _testKonst5BigCPadded = assertEqualUpToEpsilon 1e-8 (ringestData [3, 2, 4, 2] [0.0,0.0,18.1,29.1,32.1,40.1,14.0,11.0,0.0,0.0,582932.0,582934.99432,582597.1,582625.8943200001,-334.9,-309.09999999999997,0.0,0.0,18.1,29.1,32.1,40.1,14.0,11.0,0.0,0.0,582932.0,582934.99432,582597.1,582625.8943200001,-334.9,-309.09999999999997,0.0,0.0,18.1,29.1,32.1,40.1,14.0,11.0,0.0,0.0,582932.0,582934.99432,582597.1,582625.8943200001,-334.9,-309.09999999999997]) (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dCPadded) (rreplicate0N [3, 2, 4, 2] (rscalar 5))) -- TODO: OOMs -- The gradient is the same as above, because one argument is the same -- and convolution is linear. _testKonstNotBigCPadded :: Assertion _testKonstNotBigCPadded = assertEqualUpToEpsilon 1e-8 (ringestData [3, 2, 4, 2] [0.0,0.0,18.1,29.1,32.1,40.1,14.0,11.0,0.0,0.0,582932.0,582934.99432,582597.1,582625.8943200001,-334.9,-309.09999999999997,0.0,0.0,18.1,29.1,32.1,40.1,14.0,11.0,0.0,0.0,582932.0,582934.99432,582597.1,582625.8943200001,-334.9,-309.09999999999997,0.0,0.0,18.1,29.1,32.1,40.1,14.0,11.0,0.0,0.0,582932.0,582934.99432,582597.1,582625.8943200001,-334.9,-309.09999999999997]) (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dCPadded) (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10]))) -- with data t128b -- TODO: OOMs _testKonst5LittleBPadded128b :: Assertion _testKonst5LittleBPadded128b = assertEqualUpToEpsilon 1e-8 (ringestData [2, 2, 2, 2] [578.1829600001,558.1716000002,608.0772800002001,577.7659200003001,729.1778800002002,701.1835600003001,833.9722000003002,803.9778800004001,578.1829600001,558.1716000002,608.0772800002001,577.7659200003001,729.1778800002002,701.1835600003001,833.9722000003002,803.9778800004001]) (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dBPadded128b) (rreplicate0N [2, 2, 2, 2] (rscalar 5))) -- TODO: OOMs _testKonst5LittleCPadded128b :: Assertion _testKonst5LittleCPadded128b = assertEqualUpToEpsilon 1e-8 (ringestData [2, 2, 2, 2] [1113.1722000001,1412.1551500001997,1234.1494800003002,1627.8210700004993,1500.2781800001994,1870.0614400004986,2156.3671200003987,2725.0393200008984,1113.1722000001,1412.1551500001997,1234.1494800003002,1627.8210700004993,1500.2781800001994,1870.0614400004986,2156.3671200003987,2725.0393200008984]) (grad (kfromR . rsum0 @4 @(TKScalar Double) . conv2dCPadded128b) (rreplicate0N [2, 2, 2, 2] (rscalar 5))) {- testKonst5BigBPadded128b :: Assertion testKonst5BigBPadded128b = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993]) (rev' @Double @4 conv2dBPadded128b (rreplicate0N [3, 2, 4, 2] (rscalar 5))) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigBPadded128b :: Assertion testKonstNotBigBPadded128b = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993,112.3003,251.5006,209.49462,482.69492000000014,229.49462000000003,610.5892400000002,56.58894000000004,580.6778800001001,3.000000000000032,65.90000000000003,164.10000000000002,365.89432000010004,667.2003000000001,1060.8778800002,893.3003,1465.6665200003993]) (rev' @Double @4 conv2dBPadded128b (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10]))) testKonst5BigCPadded128b :: Assertion testKonst5BigCPadded128b = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001]) (rev' @Double @4 conv2dCPadded128b (rreplicate0N [3, 2, 4, 2] (rscalar 5))) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigCPadded128b :: Assertion testKonstNotBigCPadded128b = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001,1627.8210700004993,1571.2321300004994,1132.9261600005002,1188.6375200005,675.7488800003999,828.6545600004001,215.6659200003,388.5716000003,2725.0393200008984,1831.7390200008983,2551.139320000898,1660.8390200008987,1903.750080000699,1174.5497800006997,854.9778800004001,628.8778800004001]) (rev' @Double @4 conv2dCPadded128b (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10]))) -- with data t128c testKonst5LittleBPadded128c :: Assertion testKonst5LittleBPadded128c = assertEqualUpToEpsilon' 1e-8 (ringestData [2, 2, 2, 2] [186.7886400001,121.7829600001,269.09432000009997,261.3943200001,210.9943200001,231.79432000010002,160.00030000000004,194.00060000000005,186.7886400001,121.7829600001,269.09432000009997,261.3943200001,210.9943200001,231.79432000010002,160.00030000000004,194.00060000000005]) (rev' @Double @4 conv2dBPadded128c (rreplicate0N [2, 2, 2, 2] (rscalar 5))) testKonst5LittleCPadded128c :: Assertion testKonst5LittleCPadded128c = assertEqualUpToEpsilon' 1e-8 (ringestData [2, 2, 2, 2] [1772.649480000399,2138.4267600005987,2157.0438000004983,2640.8154000007976,961.7781800001002,1359.4557500003987,1233.4728000001987,1712.044990000598,1772.649480000399,2138.4267600005987,2157.0438000004983,2640.8154000007976,961.7781800001002,1359.4557500003987,1233.4728000001987,1712.044990000598]) (rev' @Double @4 conv2dCPadded128c (rreplicate0N [2, 2, 2, 2] (rscalar 5))) testKonst5BigBPadded128c :: Assertion testKonst5BigBPadded128c = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005]) (rev' @Double @4 conv2dBPadded128c (rreplicate0N [3, 2, 4, 2] (rscalar 5))) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigBPadded128c :: Assertion testKonstNotBigBPadded128c = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005,54.100300000000004,111.20060000000001,119.09462,270.29492000000005,109.09462000000002,318.19492,174.08894000000004,477.28924000000006,58.2,140.3,90.4,212.4,120.4,292.39432000000005,-117.5,103.38864000010005]) (rev' @Double @4 conv2dBPadded128c (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10]))) testKonst5BigCPadded128c :: Assertion testKonst5BigCPadded128c = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002]) (rev' @Double @4 conv2dCPadded128c (rreplicate0N [3, 2, 4, 2] (rscalar 5))) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigCPadded128c :: Assertion testKonstNotBigCPadded128c = assertEqualUpToEpsilon' 1e-8 (ringestData [3, 2, 4, 2] [2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002,2640.8154000007976,1836.3264600007988,2412.414800000798,1662.026160000799,2121.6375200006987,1436.2432000006995,1953.5375200006988,1258.1432000006998,1712.044990000598,1566.644690000599,1445.5506800005985,1358.3503800005992,1279.150680000599,1224.1503800005996,987.1677200004992,962.1674200005002]) (rev' @Double @4 conv2dCPadded128c (rfromList0N [3, 2, 4, 2] (map rscalar [37, 36 .. -10]))) -} -- with data t128b and t128c {- testKonst5LittleBPadded128bc :: Assertion testKonst5LittleBPadded128bc = assertEqualUpToEpsilon' 1e-8 (ringestData [2,2,8,4] [1113.1722000001,1412.1551500001997,1182.6605300001997,801.5659100002002,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,816.3545600003002,1132.9261600005,1188.6375200004998,803.7488800004002,455.17160000019993,675.7488800004002,828.6545600004,577.7659200003001,1500.2781800001994,1870.0614400004986,1202.8611400004997,809.1835600003001,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2049.4671200003986,2551.139320000898,1660.8390200008987,1151.3728000004999,1563.172500000299,1903.7500800006983,1174.5497800006997,803.9778800004001,1113.1722000001,1412.1551500001997,1182.6605300001997,801.5659100002002,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,816.3545600003002,1132.9261600005,1188.6375200004998,803.7488800004002,455.17160000019993,675.7488800004002,828.6545600004,577.7659200003001,1500.2781800001994,1870.0614400004986,1202.8611400004997,809.1835600003001,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2049.4671200003986,2551.139320000898,1660.8390200008987,1151.3728000004999,1563.172500000299,1903.7500800006983,1174.5497800006997,803.9778800004001]) (rev' @Double @4 conv2dBPadded128b t128c) testKonst5LittleCPadded128bc :: Assertion testKonst5LittleCPadded128bc = assertEqualUpToEpsilon' 1e-8 (ringestData [2,2,8,4] [0.0,0.0,0.0,0.0,251.5006,417.7949200000001,494.8949100000001,382.59461000000005,482.69492000000014,778.9778800001002,952.0721900001002,742.5775700001002,610.5892400000002,1113.1722000001,1412.1551500001997,1182.6605300002002,580.6778800001001,1234.1494800003002,1627.8210700004993,1571.2321300004994,329.17728000010004,816.3545600003002,1132.9261600005002,1188.6375200005,97.98296000010004,455.17160000020016,675.7488800003999,828.6545600004001,-29.9113599999,120.97728000019995,215.6659200003,388.5716000003,0.0,0.0,0.0,0.0,65.90000000000003,106.90000000000003,173.90000000000006,170.90000000000003,365.89432000010004,593.1946200001,821.2892400002003,657.1892400002001,1060.8778800002,1500.2781800001994,1870.0614400004986,1202.8611400005,1465.6665200003995,2156.3671200003987,2725.0393200008984,1831.7390200008983,1399.7665200003996,2049.4671200003986,2551.139320000898,1660.8390200008987,1099.7722000003,1563.1725000002994,1903.750080000699,1174.5497800006997,404.7886400002001,656.0889400002,854.9778800004001,628.8778800004001,0.0,0.0,0.0,0.0,251.5006,417.7949200000001,494.8949100000001,382.59461000000005,482.69492000000014,778.9778800001002,952.0721900001002,742.5775700001002,610.5892400000002,1113.1722000001,1412.1551500001997,1182.6605300002002,580.6778800001001,1234.1494800003002,1627.8210700004993,1571.2321300004994,329.17728000010004,816.3545600003002,1132.9261600005002,1188.6375200005,97.98296000010004,455.17160000020016,675.7488800003999,828.6545600004001,-29.9113599999,120.97728000019995,215.6659200003,388.5716000003,0.0,0.0,0.0,0.0,65.90000000000003,106.90000000000003,173.90000000000006,170.90000000000003,365.89432000010004,593.1946200001,821.2892400002003,657.1892400002001,1060.8778800002,1500.2781800001994,1870.0614400004986,1202.8611400005,1465.6665200003995,2156.3671200003987,2725.0393200008984,1831.7390200008983,1399.7665200003996,2049.4671200003986,2551.139320000898,1660.8390200008987,1099.7722000003,1563.1725000002994,1903.750080000699,1174.5497800006997,404.7886400002001,656.0889400002,854.9778800004001,628.8778800004001]) (rev' @Double @4 conv2dCPadded128b t128c) testKonst5BigBPadded128bc :: Assertion testKonst5BigBPadded128bc = assertEqualUpToEpsilon' 1e-8 (ringestData [2,2,8,4] [1113.1722000001,1412.1551500001997,1182.6605300001997,801.5659100002002,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,816.3545600003002,1132.9261600005,1188.6375200004998,803.7488800004002,455.17160000019993,675.7488800004002,828.6545600004,577.7659200003001,1500.2781800001994,1870.0614400004986,1202.8611400004997,809.1835600003001,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2049.4671200003986,2551.139320000898,1660.8390200008987,1151.3728000004999,1563.172500000299,1903.7500800006983,1174.5497800006997,803.9778800004001,1113.1722000001,1412.1551500001997,1182.6605300001997,801.5659100002002,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,1234.1494800003,1627.8210700004993,1571.2321300004992,1047.1431900004,816.3545600003002,1132.9261600005,1188.6375200004998,803.7488800004002,455.17160000019993,675.7488800004002,828.6545600004,577.7659200003001,1500.2781800001994,1870.0614400004986,1202.8611400004997,809.1835600003001,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2156.3671200003987,2725.039320000898,1831.7390200008986,1259.3728000004996,2049.4671200003986,2551.139320000898,1660.8390200008987,1151.3728000004999,1563.172500000299,1903.7500800006983,1174.5497800006997,803.9778800004001]) (rev' @Double @4 conv2dBPadded128b t128c) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigBPadded128cb :: Assertion testKonstNotBigBPadded128cb = assertEqualUpToEpsilon' 1e-8 (ringestData [4,2,4,4] [606.6659200002001,754.4545600002001,651.5659200002001,373.6659200002,720.1772800002001,917.0659200003003,749.9716000003002,467.1772800001999,1209.2659200003,1451.1488800004995,884.9545600005001,547.1716000002999,1382.7772800002997,1708.860240000599,1078.4602400006002,708.7829600003,316.58864000010004,552.3716000003001,707.9716000003,538.0829600002,328.18894000010005,579.9722000003001,735.8722000002999,565.9835600002,411.9895400001,634.5784800003,706.4781800003,507.58924000020005,773.5898400001,1016.1790800003,753.2787800002999,550.5898400002,606.6659200002001,754.4545600002001,651.5659200002001,373.6659200002,720.1772800002001,917.0659200003003,749.9716000003002,467.1772800001999,1209.2659200003,1451.1488800004995,884.9545600005001,547.1716000002999,1382.7772800002997,1708.860240000599,1078.4602400006002,708.7829600003,316.58864000010004,552.3716000003001,707.9716000003,538.0829600002,328.18894000010005,579.9722000003001,735.8722000002999,565.9835600002,411.9895400001,634.5784800003,706.4781800003,507.58924000020005,773.5898400001,1016.1790800003,753.2787800002999,550.5898400002,606.6659200002001,754.4545600002001,651.5659200002001,373.6659200002,720.1772800002001,917.0659200003003,749.9716000003002,467.1772800001999,1209.2659200003,1451.1488800004995,884.9545600005001,547.1716000002999,1382.7772800002997,1708.860240000599,1078.4602400006002,708.7829600003,316.58864000010004,552.3716000003001,707.9716000003,538.0829600002,328.18894000010005,579.9722000003001,735.8722000002999,565.9835600002,411.9895400001,634.5784800003,706.4781800003,507.58924000020005,773.5898400001,1016.1790800003,753.2787800002999,550.5898400002,606.6659200002001,754.4545600002001,651.5659200002001,373.6659200002,720.1772800002001,917.0659200003003,749.9716000003002,467.1772800001999,1209.2659200003,1451.1488800004995,884.9545600005001,547.1716000002999,1382.7772800002997,1708.860240000599,1078.4602400006002,708.7829600003,316.58864000010004,552.3716000003001,707.9716000003,538.0829600002,328.18894000010005,579.9722000003001,735.8722000002999,565.9835600002,411.9895400001,634.5784800003,706.4781800003,507.58924000020005,773.5898400001,1016.1790800003,753.2787800002999,550.5898400002]) (rev' @Double @4 conv2dBPadded128c t128b) testKonst5BigCPadded128cb :: Assertion testKonst5BigCPadded128cb = assertEqualUpToEpsilon' 1e-8 (ringestData [4,2,4,4] [720.1835600001002,1155.4608400002999,1436.2438000003995,1150.0548600004,1222.1722000002,1772.649480000399,2138.4267600005987,1463.1378200005997,1477.3665200002997,2157.0438000004983,2640.8154000007976,1836.3264600007988,1366.1659200002998,1965.6432000004988,2412.414800000798,1662.026160000799,226.3886400001001,671.8832600001001,1012.8665100002999,1078.3665100003,419.3835600001001,961.7781800001002,1359.4557500003987,1310.9554500003997,568.9778800002001,1233.4728000001987,1712.044990000598,1566.644690000599,428.6778800002001,1007.0784800001993,1445.5506800005985,1358.3503800005992,720.1835600001002,1155.4608400002999,1436.2438000003995,1150.0548600004,1222.1722000002,1772.649480000399,2138.4267600005987,1463.1378200005997,1477.3665200002997,2157.0438000004983,2640.8154000007976,1836.3264600007988,1366.1659200002998,1965.6432000004988,2412.414800000798,1662.026160000799,226.3886400001001,671.8832600001001,1012.8665100002999,1078.3665100003,419.3835600001001,961.7781800001002,1359.4557500003987,1310.9554500003997,568.9778800002001,1233.4728000001987,1712.044990000598,1566.644690000599,428.6778800002001,1007.0784800001993,1445.5506800005985,1358.3503800005992,720.1835600001002,1155.4608400002999,1436.2438000003995,1150.0548600004,1222.1722000002,1772.649480000399,2138.4267600005987,1463.1378200005997,1477.3665200002997,2157.0438000004983,2640.8154000007976,1836.3264600007988,1366.1659200002998,1965.6432000004988,2412.414800000798,1662.026160000799,226.3886400001001,671.8832600001001,1012.8665100002999,1078.3665100003,419.3835600001001,961.7781800001002,1359.4557500003987,1310.9554500003997,568.9778800002001,1233.4728000001987,1712.044990000598,1566.644690000599,428.6778800002001,1007.0784800001993,1445.5506800005985,1358.3503800005992,720.1835600001002,1155.4608400002999,1436.2438000003995,1150.0548600004,1222.1722000002,1772.649480000399,2138.4267600005987,1463.1378200005997,1477.3665200002997,2157.0438000004983,2640.8154000007976,1836.3264600007988,1366.1659200002998,1965.6432000004988,2412.414800000798,1662.026160000799,226.3886400001001,671.8832600001001,1012.8665100002999,1078.3665100003,419.3835600001001,961.7781800001002,1359.4557500003987,1310.9554500003997,568.9778800002001,1233.4728000001987,1712.044990000598,1566.644690000599,428.6778800002001,1007.0784800001993,1445.5506800005985,1358.3503800005992]) (rev' @Double @4 conv2dCPadded128c t128b) -- The gradient is the same as above, because one argument is the same -- and convolution is linear. testKonstNotBigCPadded128cb :: Assertion testKonstNotBigCPadded128cb = assertEqualUpToEpsilon' 1e-8 (ringestData [4,2,4,4] [720.1835600001002,1155.4608400002999,1436.2438000003995,1150.0548600004,1222.1722000002,1772.649480000399,2138.4267600005987,1463.1378200005997,1477.3665200002997,2157.0438000004983,2640.8154000007976,1836.3264600007988,1366.1659200002998,1965.6432000004988,2412.414800000798,1662.026160000799,226.3886400001001,671.8832600001001,1012.8665100002999,1078.3665100003,419.3835600001001,961.7781800001002,1359.4557500003987,1310.9554500003997,568.9778800002001,1233.4728000001987,1712.044990000598,1566.644690000599,428.6778800002001,1007.0784800001993,1445.5506800005985,1358.3503800005992,720.1835600001002,1155.4608400002999,1436.2438000003995,1150.0548600004,1222.1722000002,1772.649480000399,2138.4267600005987,1463.1378200005997,1477.3665200002997,2157.0438000004983,2640.8154000007976,1836.3264600007988,1366.1659200002998,1965.6432000004988,2412.414800000798,1662.026160000799,226.3886400001001,671.8832600001001,1012.8665100002999,1078.3665100003,419.3835600001001,961.7781800001002,1359.4557500003987,1310.9554500003997,568.9778800002001,1233.4728000001987,1712.044990000598,1566.644690000599,428.6778800002001,1007.0784800001993,1445.5506800005985,1358.3503800005992,720.1835600001002,1155.4608400002999,1436.2438000003995,1150.0548600004,1222.1722000002,1772.649480000399,2138.4267600005987,1463.1378200005997,1477.3665200002997,2157.0438000004983,2640.8154000007976,1836.3264600007988,1366.1659200002998,1965.6432000004988,2412.414800000798,1662.026160000799,226.3886400001001,671.8832600001001,1012.8665100002999,1078.3665100003,419.3835600001001,961.7781800001002,1359.4557500003987,1310.9554500003997,568.9778800002001,1233.4728000001987,1712.044990000598,1566.644690000599,428.6778800002001,1007.0784800001993,1445.5506800005985,1358.3503800005992,720.1835600001002,1155.4608400002999,1436.2438000003995,1150.0548600004,1222.1722000002,1772.649480000399,2138.4267600005987,1463.1378200005997,1477.3665200002997,2157.0438000004983,2640.8154000007976,1836.3264600007988,1366.1659200002998,1965.6432000004988,2412.414800000798,1662.026160000799,226.3886400001001,671.8832600001001,1012.8665100002999,1078.3665100003,419.3835600001001,961.7781800001002,1359.4557500003987,1310.9554500003997,568.9778800002001,1233.4728000001987,1712.044990000598,1566.644690000599,428.6778800002001,1007.0784800001993,1445.5506800005985,1358.3503800005992]) (rev' @Double @4 conv2dCPadded128c t128b) -} -- * Disparity and misc -- | Disparity cost volume. -- -- Take two arrays of multi channel 2d images, where the first contains -- left views of the scene and the second contains right views. -- -- For each pair of images, slice the right image over the left image, -- and for each offset produce the L1 distance indicating how well -- correponding -- multi-channel image elements in the right image match those in the left. -- -- Described in: -- Anytime Stereo Image Depth Estimation on Mobile Devices -- Wang, Lai et al, ICRA 2019 -- https://arxiv.org/abs/1810.11408 -- Section III b). -- costVolume :: forall r target. (ADReady target, GoodScalar r) => Int -> Int -> target (TKR 4 r) -> target (TKR 4 r) -> target (TKR 4 r) costVolume iStart nCount arrL arrR = let [nImgs, nChas, nRows, nCols] = rshape arrL shO = [nImgs, nCount, nRows, nCols] in rbuild shO $ \[iImg, iDisp, iRow, iCol] -> let arrVecL = rbuild (nChas :$: ZSR) $ \[iCha] -> rindex0 arrL [iImg, iCha, iRow, iCol] iSrc = iCol - fromIntegral iStart - iDisp arrVecR = rbuild [nChas] $ \[iCha] -> rindex0 arrR [iImg, iCha, iRow, iSrc] in rsum0 $ rzipWith1 (\xL xR -> abs (xL - xR)) arrVecL arrVecR test_disparityKonst :: Assertion test_disparityKonst = do let arrL :: ADReady target => target (TKR 4 Double) arrL = rreplicate0N [1, 2, 4, 6] (rscalar (-0.2)) arrR :: ADReady target => target (TKR 4 Double) arrR = rreplicate0N [1, 2, 4, 6] (rscalar 0.3) arrO = costVolume @Double 0 4 arrL arrR arrDL = vjp (\aL -> costVolume 0 4 aL (rfromPrimal arrR)) arrL arrO arrDR = vjp (\aR -> costVolume 0 4 (rfromPrimal arrL) aR) arrR arrO assertEqualUpToEpsilon 1e-7 (rconcrete $ Nested.rfromListPrimLinear [1,4,4,6] [1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,1.0,0.4,1.0,1.0,1.0,1.0,1.0,0.4,1.0,1.0,1.0,1.0,1.0,0.4,1.0,1.0,1.0,1.0,1.0,0.4,1.0,1.0,1.0,1.0,1.0,0.4,0.4,1.0,1.0,1.0,1.0,0.4,0.4,1.0,1.0,1.0,1.0,0.4,0.4,1.0,1.0,1.0,1.0,0.4,0.4,1.0,1.0,1.0,1.0,0.4,0.4,0.4,1.0,1.0,1.0,0.4,0.4,0.4,1.0,1.0,1.0,0.4,0.4,0.4,1.0,1.0,1.0,0.4,0.4,0.4,1.0,1.0,1.0]) arrO assertEqualUpToEpsilon 1e-7 (rconcrete $ Nested.rfromListPrimLinear [1,2,4,6] [-2.2,-2.8,-3.4,-4.0,-4.0,-4.0,-2.2,-2.8,-3.4,-4.0,-4.0,-4.0,-2.2,-2.8,-3.4,-4.0,-4.0,-4.0,-2.2,-2.8,-3.4,-4.0,-4.0,-4.0,-2.2,-2.8,-3.4,-4.0,-4.0,-4.0,-2.2,-2.8,-3.4,-4.0,-4.0,-4.0,-2.2,-2.8,-3.4,-4.0,-4.0,-4.0,-2.2,-2.8,-3.4,-4.0,-4.0,-4.0]) arrDL assertEqualUpToEpsilon 1e-7 (rconcrete $ Nested.rfromListPrimLinear [1,2,4,6] [4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0]) arrDR assertEqualUpToEpsilon' 1e-7 (ringestData [1,2,4,6] [4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0,4.0,4.0,4.0,3.0,2.0,1.0]) (rev' @Double @4 (costVolume 0 4 arrL) arrR) assertEqualUpToEpsilon' 1e-7 (ringestData [1,2,4,6] [-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0,-2.0]) (rev' @Double @4 (\aL -> costVolume 0 2 aL arrR) arrL) assertEqualUpToEpsilon' 1e-7 (ringestData [1,2,4,6] [2.0,2.0,2.0,2.0,2.0,1.0,2.0,2.0,2.0,2.0,2.0,1.0,2.0,2.0,2.0,2.0,2.0,1.0,2.0,2.0,2.0,2.0,2.0,1.0,2.0,2.0,2.0,2.0,2.0,1.0,2.0,2.0,2.0,2.0,2.0,1.0,2.0,2.0,2.0,2.0,2.0,1.0,2.0,2.0,2.0,2.0,2.0,1.0]) (rev' @Double @4 (costVolume 0 2 arrL) arrR) test_disparityKonst2 :: Assertion test_disparityKonst2 = do let arrL :: (BaseTensor target, GoodScalar r, Differentiable r) => target (TKR 4 r) arrL = ringestData [1, 2, 4, 6] [0.4,0.4,0.4,1.0,1.0,1.0,0.4,0.4,0.4,1.0,1.0,1.0,0.4,0.4,0.4,1.0,1.0,1.0, 1.7041241452319316,1.21999,0.21355339059327375,0.7867666666666666,0.7331698975466578,0.6964466094067263,1.1,1.1041141452319316,0.42000000000000004,0.3536533905932737,0.78,1.253169897546658,1.1,0.50001,0.42000000000000004,0.2801,0.78,1.3,1.1,0.50001,0.42000000000000004,0.2801,0.78,1.3,2.808238290463863,1.21999,-0.5672067811865474,0.7867666666666666,1.986339795093316,0.6964466094067263] arrR :: (BaseTensor target, GoodScalar r, Differentiable r) => target (TKR 4 r) arrR = ringestData [1, 2, 4, 6] [0.2, 0.5, -0.2, 0.0001, 0.44, 0.9, -0.9, 0.00001, -0.22, -0.28, -0.34, -0.40, -0.40,-0.22,-0.28,-0.34, 0.22360679774997896,0.35355339059327373,0.20412414523193154,0.5, -0.35355339059327373,0.16666666666666666,0.17677669529663687,-0.25, -2.808238290463863,-1.21999,-0.5672067811865474,-0.7867666666666666,-1.986339795093316,-0.6964466094067263,2.808238290463863,1.21999,-0.5672067811865474,0.7867666666666666,0.6964466094067263,0.42000000000000004,0.3536533905932737,0.78,1.253169897546658,0.50001,0.42000000000000004,0.2801,0.78,1.1,0.50001,0.42000000000000004,0.2801,0.78] arrO = rreplicate0N [1, 4, 4, 6] (rscalar (1 :: Double)) res1 = rconcrete $ Nested.rfromListPrimLinear [1,2,4,6] [4.0,2.0,2.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,2.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,4.0,2.0,0.0,0.0,-2.0,0.0,4.0,4.0,2.0,0.0,-4.0,1.0,4.0,4.0,4.0,-4.0,2.0,4.0,2.0] res2 = rconcrete $ Nested.rfromListPrimLinear [1,2,4,6] [-4.0,0.0,-4.0,-3.0,-2.0,-1.0,-4.0,-4.0,-4.0,-3.0,-2.0,-1.0,-4.0,-4.0,-4.0,-3.0,-2.0,-1.0,-4.0,-2.0,-4.0,-3.0,-2.0,-1.0,-4.0,-4.0,-4.0,-3.0,-2.0,-1.0,4.0,4.0,-4.0,1.0,-2.0,-1.0,-2.0,3.0,2.0,-1.0,-2.0,-1.0,-2.0,0.0,-2.0,-3.0,-2.0,1.0] arrDL :: Concrete (TKR 4 Double) arrDL = vjp (\aL -> costVolume 0 4 aL (rfromPrimal arrR)) arrL arrO arrDR :: Concrete (TKR 4 Double) arrDR = vjp (costVolume 0 4 (rfromPrimal arrL)) arrR arrO assertEqualUpToEpsilon 1e-7 res1 arrDL assertEqualUpToEpsilon 1e-7 res2 arrDR assertEqualUpToEpsilon' 1e-7 res1 (rev' @Double @4 (\aL -> costVolume 0 4 aL (rfromPrimal arrR)) arrL) assertEqualUpToEpsilon' 1e-7 res2 (rev' @Double @4 (costVolume 0 4 (rfromPrimal arrL)) arrR) test_disparitySmall :: Assertion test_disparitySmall = do let arrL :: ADReady target => target (TKR 4 Double) arrL = ringestData [1, 2, 3, 2] [0.2, 0.5, -0.2, 0.0001, 0.44, 0.9, -0.9, 0.00001, -0.22, -0.28, -0.34, -0.40] arrR :: ADReady target => target (TKR 4 Double) arrR = ringestData [1, 2, 3, 2] [-0.40,-0.22,-0.28,-0.34, 0.22360679774997896,0.35355339059327373,0.20412414523193154,0.5, -0.35355339059327373,0.16666666666666666,0.17677669529663687,-0.25] arrO = costVolume @Double 0 4 arrL arrR arrDL = vjp (\aL -> costVolume 0 4 aL (rfromPrimal arrR)) arrL arrO arrDR = vjp (\aR -> costVolume 0 4 (rfromPrimal arrL) aR) arrR arrO assertEqualUpToEpsilon 1e-7 (rconcrete $ Nested.rfromListPrimLinear [1,4,3,2] [1.7041241452319316,1.21999,0.21355339059327375,0.7867666666666666,0.7331698975466578,0.6964466094067263,1.1,1.1041141452319316,0.42000000000000004,0.3536533905932737,0.78,1.253169897546658,1.1,0.50001,0.42000000000000004,0.2801,0.78,1.3,1.1,0.50001,0.42000000000000004,0.2801,0.78,1.3]) arrO assertEqualUpToEpsilon' 1e-7 (ringestData [1,2,3,2] [-2.0,-1.0,-2.0,-1.0,-2.0,-1.0,2.0,1.0,-2.0,1.0,2.0,1.0]) (rev' @Double @4 (costVolume 0 4 arrL) arrR) assertEqualUpToEpsilon 1e-7 (rconcrete $ Nested.rfromListPrimLinear [1,2,3,2] [5.004124145231932,3.3241241452319317,-1.0464466094067264,1.7006200572599404,3.0731698975466575,4.5496165069533845,-5.004124145231932,-1.3240841452319316,-1.0464466094067264,-0.9933132760733929,-3.0731698975466575,-4.5496165069533845]) arrDL assertEqualUpToEpsilon 1e-7 (rconcrete $ Nested.rfromListPrimLinear [1,2,3,2] [-2.808238290463863,-1.21999,-0.5672067811865474,-0.7867666666666666,-1.986339795093316,-0.6964466094067263,2.808238290463863,1.21999,-0.5672067811865474,0.7867666666666666,1.986339795093316,0.6964466094067263]) arrDR assertEqualUpToEpsilon' 1e-7 (ringestData [1,2,3,2] [-1.0,0.0,-1.0,0.0,-1.0,0.0,1.0,0.0,-1.0,0.0,1.0,0.0]) (rev' @Double @4 (costVolume 1 4 arrL) arrR) assertEqualUpToEpsilon' 1e-7 (ringestData [1,2,3,2] [2.0,2.0,-2.0,2.0,2.0,2.0,-2.0,2.0,-2.0,-2.0,-2.0,-2.0]) (rev' @Double @4 (\aL -> costVolume 2 2 aL arrR) arrL) assertEqualUpToEpsilon' 1e-7 (ringestData [1,2,3,2] [-1.0,0.0,-1.0,0.0,-1.0,0.0,1.0,0.0,-1.0,0.0,1.0,0.0]) (rev' @Double @4 (costVolume 1 2 arrL) arrR) codeTomsSlice :: ADReady target => target (TKR 2 Double) -> target (TKR 0 Double) codeTomsSlice a = let (n, m) = case rshape a of [n', m'] -> (n', m') _ -> error "codeTomsSlice" a1 = rbuild @2 @0 [n,m-1] (\[i',j'] -> rindex0 a [i',j']) a2 = rbuild [n,m-1] (\[i',j'] -> rindex0 a [i',j' + 1]) in rsum0 @2 $ rbuild [n,m] $ \[i, _j] -> rfromIndex0 i * rsum0 (a1 * a2) testTomsSliceRev :: Assertion testTomsSliceRev = do assertEqualUpToEpsilon 1e-5 (ringestData [32,4] [63686.39999999999,137292.80000000002,121222.4,79558.40000000002,192646.40000000005,223971.0617601984,228556.80000000005,116846.33088019838,63686.39999999999,137292.80000000002,127174.4,79558.40000000002,192646.40000000005,158499.06176019844,202566.40000000005,51374.330880198424,11904.0,5952.0,7936.0,1984.0,116846.33088019838,385292.8000000001,227740.66176039676,192646.40000000005,116846.33088019838,228556.80000000005,174580.73088019836,35910.399999999994,79558.40000000002,127372.79999999997,143244.80000000002,63686.39999999999,105152.0,186683.13088000007,105151.98016,107124.73088000003,-396.79999999999995,26188.8,17459.2,25990.399999999998,-7936.0,73408.0,-1995.2691200000017,57536.0,51584.0,-660672.0,55552.0,3968.0,3968.0,3571.2,3571.2,-396.79999999999995,-396.79999999999995,49203.79519999998,49203.79519999998,49600.59519999998,49600.59519999998,49203.79519999998,49203.79519999998,-396.79999999999995,-396.79999999999995,49203.79519999998,49203.79519999998,49600.59519999998,49600.59519999998,129158.9952,65472.59519999998,79558.40000000002,-5952.0,73198.33087999995,51175.930880000036,51374.33087999995,51187.20000000001,1984.0000000000146,67059.20000000001,79558.40000000002,-5952.0,73198.33087999995,51175.930880000036,51374.33087999995,51187.20000000001,-21823.99999999993,108921.6,16070.400000000005,79558.40000000002,127372.79999999997,159116.80000000005,63686.39999999999,107124.73088000003,771974.4,218019.0617601984,192646.40000000005,170414.3308801984,385292.8000000001,340828.6617603968,192646.40000000005,57734.399999999994,99596.79999999999,137292.80000000002,63686.39999999999,79558.40000000002,127372.79999999997,159116.80000000005,63686.39999999999,107124.73088000003,236294.40000000005,271587.0617601984,192646.40000000005,45422.33088019842,385292.8000000001,162268.6617603968,192646.40000000005,57734.399999999994,99596.79999999999,137292.80000000002,63686.39999999999,79558.40000000002,127372.79999999997,159116.80000000005,63686.39999999999,107124.73088000003,369222.4,220003.0617601984,192646.40000000005,104942.33088019838,385292.8000000001,215836.66176039676,192646.40000000005]) (grad (kfromR . codeTomsSlice) (rreshape [32, 4] t128)) testTomsSlice :: Assertion testTomsSlice = do assertEqualUpToEpsilon' 1e-5 (ringestData [32,4] [63686.39999999999,137292.80000000002,121222.4,79558.40000000002,192646.40000000005,223971.0617601984,228556.80000000005,116846.33088019838,63686.39999999999,137292.80000000002,127174.4,79558.40000000002,192646.40000000005,158499.06176019844,202566.40000000005,51374.330880198424,11904.0,5952.0,7936.0,1984.0,116846.33088019838,385292.8000000001,227740.66176039676,192646.40000000005,116846.33088019838,228556.80000000005,174580.73088019836,35910.399999999994,79558.40000000002,127372.79999999997,143244.80000000002,63686.39999999999,105152.0,186683.13088000007,105151.98016,107124.73088000003,-396.79999999999995,26188.8,17459.2,25990.399999999998,-7936.0,73408.0,-1995.2691200000017,57536.0,51584.0,-660672.0,55552.0,3968.0,3968.0,3571.2,3571.2,-396.79999999999995,-396.79999999999995,49203.79519999998,49203.79519999998,49600.59519999998,49600.59519999998,49203.79519999998,49203.79519999998,-396.79999999999995,-396.79999999999995,49203.79519999998,49203.79519999998,49600.59519999998,49600.59519999998,129158.9952,65472.59519999998,79558.40000000002,-5952.0,73198.33087999995,51175.930880000036,51374.33087999995,51187.20000000001,1984.0000000000146,67059.20000000001,79558.40000000002,-5952.0,73198.33087999995,51175.930880000036,51374.33087999995,51187.20000000001,-21823.99999999993,108921.6,16070.400000000005,79558.40000000002,127372.79999999997,159116.80000000005,63686.39999999999,107124.73088000003,771974.4,218019.0617601984,192646.40000000005,170414.3308801984,385292.8000000001,340828.6617603968,192646.40000000005,57734.399999999994,99596.79999999999,137292.80000000002,63686.39999999999,79558.40000000002,127372.79999999997,159116.80000000005,63686.39999999999,107124.73088000003,236294.40000000005,271587.0617601984,192646.40000000005,45422.33088019842,385292.8000000001,162268.6617603968,192646.40000000005,57734.399999999994,99596.79999999999,137292.80000000002,63686.39999999999,79558.40000000002,127372.79999999997,159116.80000000005,63686.39999999999,107124.73088000003,369222.4,220003.0617601984,192646.40000000005,104942.33088019838,385292.8000000001,215836.66176039676,192646.40000000005]) (rev' codeTomsSlice (rreshape [32, 4] t128)) -- * PP Tests testTomsSlicePP :: Assertion testTomsSlicePP = do resetVarCounter let artifactRev = revArtifactAdapt UseIncomingCotangent codeTomsSlice (FTKR [32, 4] FTKScalar) printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\m1 -> rfromS (sscalar 4.0 * sdot0 (sconcrete (sfromListLinear [32] [0.0,1.0,2.0,3.0,4.0,5.0,6.0,7.0,8.0,9.0,10.0,11.0,12.0,13.0,14.0,15.0,16.0,17.0,18.0,19.0,20.0,21.0,22.0,23.0,24.0,25.0,26.0,27.0,28.0,29.0,30.0,31.0])) (sreplicate @32 (sdot0 (sslice (SNat @0) (SNat @3) (str (sfromR m1))) (sslice (SNat @1) (SNat @3) (str (sfromR m1))))))" printArtifactPrimalPretty artifactRev @?= "\\m1 -> let v8 = sreplicate @32 (ssum @96 (sreshape @[96] (str (sslice (SNat @0) (SNat @3) (str (sfromR m1))) * str (sslice (SNat @1) (SNat @3) (str (sfromR m1)))))) in rfromS (ssum @128 (sreshape @[128] (str (sreplicate @4 (siota (SNat @32) * v8)))))" printArtifactPretty artifactRev @?= "\\dret m1 -> let m10 = sreshape @[32,3] (sreplicate @96 (ssum @32 (siota (SNat @32) * ssum @4 (str (sreshape @[32,4] (sreplicate @128 (sfromR dret))))))) in rfromS (str (sappend (sconcrete (sfromListLinear [0,32] [])) (sappend (str (str (sslice (SNat @1) (SNat @3) (str (sfromR m1))) * m10)) (sconcrete (sreplicate [1,32] 0.0)))) + str (sappend (sconcrete (sreplicate [1,32] 0.0)) (sappend (str (str (sslice (SNat @0) (SNat @3) (str (sfromR m1))) * m10)) (sconcrete (sfromListLinear [0,32] [])))))" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret m1 -> rfromS (let x10 = sdot0 (sconcrete (sfromListLinear [32] [0.0,4.0,8.0,12.0,16.0,20.0,24.0,28.0,32.0,36.0,40.0,44.0,48.0,52.0,56.0,60.0,64.0,68.0,72.0,76.0,80.0,84.0,88.0,92.0,96.0,100.0,104.0,108.0,112.0,116.0,120.0,124.0])) (sreplicate @32 (sfromR dret)) in str (sappend (sslice (SNat @1) (SNat @3) (str (sfromR m1)) * sreplicate @3 (sreplicate @32 x10)) (sconcrete (sreplicate [1,32] 0.0))) + str (sappend (sconcrete (sreplicate [1,32] 0.0)) (sslice (SNat @0) (SNat @3) (str (sfromR m1)) * sreplicate @3 (sreplicate @32 x10))))" testCNNOPP0c :: Assertion testCNNOPP0c = do resetVarCounter let artifactRev = revArtifactAdapt UseIncomingCotangent conv2dCLaborious (FTKR [2, 2, 2, 2] (FTKScalar @Double)) printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[4,0,1,5,2,3] (sgather (sfromVector (fromList [stranspose @[3,0,5,1,2,4] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i86, i88] -> [i86 + i88]))) (\\[i40, i41] -> [i40 + i41])), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i42, i43, i44, i45] -> [ifH (notB (2 <=. i42 + i44) &&* notB (2 <=. i43 + i45)) 0 1, i42, i43, i44, i45])))))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> let w46 = str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[4,0,1,5,2,3] (sgather (sfromVector (fromList [stranspose @[3,0,5,1,2,4] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i38, i39] -> [i38 + i39]))) (\\[i40, i41] -> [i40 + i41])), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i42, i43, i44, i45] -> [ifH (notB (2 <=. i42 + i44) &&* notB (2 <=. i43 + i45)) 0 1, i42, i43, i44, i45])))))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (w46 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))))" printArtifactPretty artifactRev @?= "\\dret u1 -> let w46 = str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[4,0,1,5,2,3] (sgather (sfromVector (fromList [stranspose @[3,0,5,1,2,4] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i38, i39] -> [i38 + i39]))) (\\[i40, i41] -> [i40 + i41])), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i42, i43, i44, i45] -> [ifH (notB (2 <=. i42 + i44) &&* notB (2 <=. i43 + i45)) 0 1, i42, i43, i44, i45])))))) in rfromS (ssum @1 (str (ssum @2 (str (ssum @2 (str (ssum @2 (w46 * sreshape @[2,2,2,2,1,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret)))))))))))" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> rfromS (ssum @2 (ssum @2 (sdot1In (stranspose @[0,1,2,5,3,4] (sgather (sfromVector (fromList [stranspose @[3,0,5,1,4,2] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i112, i114] -> [i112 + i114]))) (\\[i40, i41] -> [i40 + i41])), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i96, i97, i98, i103, i104] -> [ifH (notB (2 <=. i96 + i103) &&* notB (2 <=. i97 + i104)) 0 1, i96, i97, i103, i104]))) (stranspose @[4,2,3,1,5,6,7,0] (sreshape @[2,2,2,2,1,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret)))) !$ [0]))))" testCNNOPP0b :: Assertion testCNNOPP0b = do resetVarCounter let artifactRev = revArtifactAdapt UseIncomingCotangent conv2dBLaborious (FTKR [2, 2, 2, 2] (FTKScalar @Double)) printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (sconcrete (sfromListLinear [2,2,2,2,1,2,2,2] [5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0]) * str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[4,0,1,5,2,3] (sgather (sfromVector (fromList [stranspose @[3,0,5,1,2,4] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i108, i110] -> [i108 + i110]))) (\\[i47, i48] -> [i47 + i48])), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i49, i50, i51, i52] -> [ifH (notB (2 <=. i49 + i51) &&* notB (2 <=. i50 + i52)) 0 1, i49, i50, i51, i52]))))))))))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> let w53 = str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[4,0,1,5,2,3] (sgather (sfromVector (fromList [stranspose @[3,0,5,1,2,4] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i45, i46] -> [i45 + i46]))) (\\[i47, i48] -> [i47 + i48])), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i49, i50, i51, i52] -> [ifH (notB (2 <=. i49 + i51) &&* notB (2 <=. i50 + i52)) 0 1, i49, i50, i51, i52])))))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (sconcrete (sfromListLinear [2,2,2,2,1,2,2,2] [5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0]) * w53))))" printArtifactPretty artifactRev @?= "\\dret u1 -> let w59 = sscatter (stranspose @[1,2,4,5,0,3] (ssum @1 (stranspose @[3,0,1,2] (ssum @2 (str (sconcrete (sfromListLinear [2,2,2,2,1,2,2,2] [5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0]) * sreshape @[2,2,2,2,1,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret))))))))) (\\[i55, i56, i57, i58] -> [ifH (notB (2 <=. i55 + i57) &&* notB (2 <=. i56 + i58)) 0 1, i55, i56, i57, i58]) in rfromS (stranspose @[1,2,0] (sscatter (stranspose @[2,4,1,3,0] (sscatter (stranspose @[1,3,4,0,5,2] (w59 !$ [0])) (\\[i60, i61] -> [i60 + i61]))) (\\[i62, i63] -> [i62 + i63])))" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> rfromS (stranspose @[1,2,0] (sscatter (stranspose @[2,4,1,3,0] (sscatter (stranspose @[0,2,4,5,1,6,3] (sscatter (sdot1In (sconcrete (sfromListLinear [2,2,2,2,2,2,2] [5.0,13.1,-2.0,582934.0,5.0,13.1,-2.0,582934.0,2.0,9.0,0.0,2.99432,2.0,9.0,0.0,2.99432,6.0,8.0,0.1,-335.0,6.0,8.0,0.1,-335.0,1.0,-4.0,-0.2,26.0,1.0,-4.0,-0.2,26.0,5.0,13.1,-2.0,582934.0,5.0,13.1,-2.0,582934.0,2.0,9.0,0.0,2.99432,2.0,9.0,0.0,2.99432,6.0,8.0,0.1,-335.0,6.0,8.0,0.1,-335.0,1.0,-4.0,-0.2,26.0,1.0,-4.0,-0.2,26.0,5.0,13.1,-2.0,582934.0,5.0,13.1,-2.0,582934.0,2.0,9.0,0.0,2.99432,2.0,9.0,0.0,2.99432,6.0,8.0,0.1,-335.0,6.0,8.0,0.1,-335.0,1.0,-4.0,-0.2,26.0,1.0,-4.0,-0.2,26.0,5.0,13.1,-2.0,582934.0,5.0,13.1,-2.0,582934.0,2.0,9.0,0.0,2.99432,2.0,9.0,0.0,2.99432,6.0,8.0,0.1,-335.0,6.0,8.0,0.1,-335.0,1.0,-4.0,-0.2,26.0,1.0,-4.0,-0.2,26.0])) (stranspose @[4,2,3,6,7,0,5,1] (sreshape @[2,2,2,2,1,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret)))) !$ [0])) (\\[i55, i56, i57, i58] -> [ifH (notB (2 <=. i55 + i57) &&* notB (2 <=. i56 + i58)) 0 1, i55, i56, i57, i58])) !$ [0]) (\\[i60, i61] -> [i60 + i61]))) (\\[i62, i63] -> [i62 + i63])))" testCNNOPP1e :: Assertion testCNNOPP1e = do resetVarCounter let f :: AstTensor AstMethodLet FullSpan (TKProduct (TKR 4 Double) (TKR 4 Double)) -> AstTensor AstMethodLet FullSpan (TKR 4 Double) f v = conv2dUnpaddedL (tproject1 v) (tproject2 v) ftk = FTKProduct (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar) (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar) (artifactRev, _) = revArtifactFromForwardPass UseIncomingCotangent (forwardPassByInterpretation f emptyEnv) ftk printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (str (sreplicate @2 (stranspose @[4,0,1,5,2,3] (sgather (sfromVector (fromList [stranspose @[3,0,5,1,2,4] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i93, i95] -> [i93 + i95]))) (\\[i30, i31] -> [i30 + i31])), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i32, i33, i34, i35] -> [ifH (notB (2 <=. i32 + i34) &&* notB (2 <=. i33 + i35)) 0 1, i32, i33, i34, i35])))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1))))))))))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> let w36 = str (sreplicate @2 (stranspose @[4,0,1,5,2,3] (sgather (sfromVector (fromList [stranspose @[3,0,5,1,2,4] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i28, i29] -> [i28 + i29]))) (\\[i30, i31] -> [i30 + i31])), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i32, i33, i34, i35] -> [ifH (notB (2 <=. i32 + i34) &&* notB (2 <=. i33 + i35)) 0 1, i32, i33, i34, i35])))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (w36 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1))))))))))" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> tconvert (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX))) (STKProduct (STKS [2,2,2,2] STKScalar) (STKS [2,2,2,2] STKScalar)) (let w38 = sreshape @[2,2,2,2,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret))) in tpair (ssum @2 (ssum @2 (sdot1In (stranspose @[2,3,0,4,5,6,1] (sreplicate @2 (stranspose @[4,0,1,5,2,3] (sgather (sfromVector (fromList [stranspose @[3,0,5,1,2,4] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i124, i126] -> [i124 + i126]))) (\\[i30, i31] -> [i30 + i31])), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i32, i33, i34, i35] -> [ifH (notB (2 <=. i32 + i34) &&* notB (2 <=. i33 + i35)) 0 1, i32, i33, i34, i35]))))) (stranspose @[2,3,1,4,5,6,0] w38)))) (stranspose @[1,2,0] (sscatter (stranspose @[2,4,1,3,0] (sscatter (stranspose @[0,2,4,5,1,6,3] (sscatter (sdot1In (stranspose @[2,3,5,6,0,4,1] (sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1)))))))) (stranspose @[2,3,5,6,0,4,1] w38)) (\\[i39, i40, i41, i42] -> [ifH (notB (2 <=. i39 + i41) &&* notB (2 <=. i40 + i42)) 0 1, i39, i40, i41, i42])) !$ [0]) (\\[i44, i45] -> [i44 + i45]))) (\\[i46, i47] -> [i46 + i47]))))" testCNNOPP2 :: Assertion testCNNOPP2 = do resetVarCounter let t = maxPool2dUnpadded2 (rconcrete $ Nested.rreplicateScal (1 :$: 1 :$: 2 :$: 2 :$: ZSR) 1) printAstPretty (simplifyInlineContract t) @?= "rfromS (sreplicate @2 (sreplicate @2 (stranspose @[2,3,1,0] (sappend (sreplicate @1 (sgather (sreplicate @1 (stranspose @[2,0,1] (sgather (sconcrete (sfromListLinear [2,2] [1.0,1.0,1.0,1.0])) (\\[i68, i69] -> [i69 + i68])))) (\\[i44, i35, i8] -> [i8, i8, i8, 2 * i44 + i35]))) (sconcrete (sreplicate [1,2,2,2] 0.0))) !$ [0, 0])))" printAstPretty t @?= "rfromS (sreplicate @2 (sreplicate @2 (let u36 = let u41 = sgather (sgather (sreplicate @1 (let w32 = sgather (stranspose @[3,2,0,1] (sgather (sconcrete (sfromListLinear [2,3,2] [1.0,1.0,0.0,0.0,0.0,0.0,1.0,1.0,0.0,0.0,0.0,0.0])) (\\[i26, i15] -> [i26 + i15]))) (\\[i22, i16] -> [i22 + i16]) in stranspose @[1,2,3,0] (sappend (sreplicate @1 (stranspose @[2,0,4,1,3] w32 !$ [0])) (sconcrete (sreplicate [2,2,2,2,2] 0.0))))) (\\[i20] -> [i20, i20, i20, 0])) (\\[i44, i39, i35, i8] -> [2 * i39 + i8, i39, 2 * i44 + i35]) in str (sappend (sreplicate @1 (str u41 !$ [0])) (sconcrete (sreplicate [1,2,2,2] 0.0))) in stranspose @[2,3,0,1] u36 !$ [0, 0])))" testCNNOPP2b :: Assertion testCNNOPP2b = do resetVarCounter let artifactRev = revArtifactAdapt UseIncomingCotangent maxPool2dUnpadded2 (FTKR [1, 1, 2, 2] (FTKScalar @Double)) printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (sreplicate @2 (sreplicate @2 (stranspose @[2,3,1,0] (sappend (sreplicate @1 (sgather (sreplicate @1 (stranspose @[2,0,1] (sgather (sfromR u1 !$ [0, 0]) (\\[i92, i93] -> [i93 + i92])))) (\\[i94, i95, i96] -> [i96, i96, i96, 2 * i94 + i95]))) (sconcrete (sreplicate [1,2,2,2] 0.0))) !$ [0, 0])))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> rfromS (sreplicate @2 (sreplicate @2 (stranspose @[2,3,1,0] (sappend (sreplicate @1 (sgather (sreplicate @1 (stranspose @[2,0,1] (sgather (sfromR u1 !$ [0, 0]) (\\[i92, i93] -> [i93 + i92])))) (\\[i94, i95, i96] -> [i96, i96, i96, 2 * i94 + i95]))) (sconcrete (sreplicate [1,2,2,2] 0.0))) !$ [0, 0])))" printArtifactPretty artifactRev @?= "\\dret u1 -> let u98 = stranspose @[3,2,0,1] (soneHot (ssum @2 (ssum @2 (sfromR dret))) [0, 0]) in rfromS (soneHot (sscatter (stranspose @[1,2,0] (ssum @1 (sscatter (ssum @1 (sslice (SNat @0) (SNat @1) u98)) (\\[i99, i100, i101] -> [i101, i101, i101, 2 * i99 + i100])))) (\\[i102, i103] -> [i103 + i102])) [0, 0])" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> rfromS (sreplicate @1 (sreplicate @1 (sscatter (sscatter (stranspose @[3,2,0,1] (soneHot (ssum @2 (ssum @2 (sfromR dret))) [0, 0]) !$ [0]) (\\[i99, i100, i101] -> [i101, i101, 2 * i99 + i100, i101]) !$ [0]) (\\[i102, i103] -> [i103 + i102]))))" maxPool2dUnpadded2 :: (target ~ AstTensor AstMethodLet FullSpan, r ~ Double) => target (TKR 4 r) -> target (TKR 4 r) maxPool2dUnpadded2 a = rbuild [2, 2, 2, 2] $ \case [_, _, iBh, iBw] -> let arrt = slicez2 (conv2dUnpadded2 a) [iBw, 1, 2 * iBh, 2 * iBw] in rmaximum2 arrt _ -> error "maxPool2dUnpadded2: impossible pattern needlessly required" conv2dUnpadded2 :: (target ~ AstTensor AstMethodLet FullSpan, r ~ Double) => target (TKR 4 r) -> target (TKR 4 r) conv2dUnpadded2 a = rbuild [3, 3, 2, 2] $ \case [iImg, _, iBh, iBw] -> let arrAt = slicez2 a [iImg, 0, iBh, iBw] in rindex0 arrAt [0, iBw, iBw, 0] _ -> error "conv2dUnpadded2: impossible pattern needlessly required" slicez2 :: (target ~ AstTensor AstMethodLet FullSpan, r ~ Double, n ~ 4) => target (TKR n r) -> IxROf target n -> target (TKR n r) slicez2 d ixBase = rbuild [1, 1, 2, 2] $ \ixResult -> indexz02 d (ixrZipWith (+) ixBase ixResult) indexz02 :: forall target r n. (target ~ AstTensor AstMethodLet FullSpan, r ~ Double, n ~ 4) => target (TKR n r) -> IxROf target n -> target (TKR 0 r) indexz02 d ix = ifH (1 >. (toList ix !! 0)) (d ! ix) (rscalar 0) rmaximum2 :: (target ~ AstTensor AstMethodLet FullSpan, r ~ Double) => target (TKR 4 r) -> target (TKR 0 r) rmaximum2 t0 = tlet t0 $ \t -> rindex0 t [0, 0, 0, 0] {- TODO: divergent result; bring back when GHC 9.10 dropped: testCNNOPP3 :: Assertion testCNNOPP3 = do resetVarCounter let blackGlyph :: AstTensor AstMethodLet FullSpan (TKR 4 Double) blackGlyph = AstFromPrimal $ AstReplicate (SNat @2) knownSTK $ AstReplicate (SNat @2) knownSTK $ AstReplicate (SNat @2) knownSTK $ AstReplicate (SNat @2) knownSTK (rconcrete $ Nested.rscalar 7 :: AstTensor AstMethodLet PrimalSpan (TKR 0 Double)) afcnn2T :: AstTensor AstMethodLet FullSpan (TKR 4 Double) afcnn2T = maxPool2dUnpadded33 $ conv2dUnpadded3 blackGlyph printAstPretty (simplifyInlineContract afcnn2T) @?= "rfromS (sreplicate @2 (sgather (stranspose @[2,1,0,4,3] (sappend (sreplicate @1 (sgather (sconcrete (sfromListLinear [2] [7.0,0.0])) (\\[i18, i22, i17, i15] -> [ifH (notB (2 <=. remH i22 4 + i18) &&* (notB (2 <=. i22 + i17) &&* notB (2 <=. i22 + i15))) 0 1]))) (sconcrete (sfromListLinear [1,2,2,2,2] [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])))) (\\[i52, i51] -> [remH i51 4, i52, i52, remH i51 4])))" printAstPretty afcnn2T @?= "rfromS (let w30 = sgather (sfromVector (fromList [stranspose @[4,0,1,2,5,3] (sgather (stranspose @[1,2,4,5,0,3] (sgather (sappend (sreplicate @1 (sgather (sconcrete (sfromListLinear [2] [7.0,0.0])) (\\[i18, i22, i17, i15] -> [ifH (notB (2 <=. remH i22 4 + i18) &&* (notB (2 <=. i22 + i17) &&* notB (2 <=. i22 + i15))) 0 1]))) (sreplicate @1 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sscalar 0.0))))))) (\\[i43, i38, i29, i7] -> [i43 + i7, i43 + i7, remH i38 4 + i29]))) (\\[i37, i33, i28, i8] -> [i37, i28, i33 + i8, remH i37 4 + i28])), sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sscalar 0.0))))))])) (\\[i46, i41, i36, i32, i27, i26, i24, i23] -> [ifH (notB (2 <=. remH i36 4 + i27) &&* (notB (2 <=. i46 + i26) &&* (notB (2 <=. i41 + i24) &&* notB (2 <=. i32 + i23)))) 0 1, i41, i36, i32, i27, i24, i23]) in stranspose @[4,5,6,7,0,1,2,3] w30 !$ [0, 0, 0, 0])" -} testCNNOPP3b :: Assertion testCNNOPP3b = do resetVarCounter let artifactRev = revArtifactAdapt UseIncomingCotangent (maxPool2dUnpadded33 . conv2dUnpadded3) (FTKR [2, 2, 2, 2] (FTKScalar @Double)) printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (stranspose @[2,1,0] (sgather (sgather (sappend (sreplicate @1 (stranspose @[0,4,5,1,2,3] (sgather (sfromVector (fromList [stranspose @[5,2,3,4,0,1] (sreplicate @2 (sreplicate @2 (stranspose @[2,3,1,0] (sgather (stranspose @[3,0,2,1] (sgather (stranspose @[3,0,1,2] (sgather (stranspose @[3,0,2,1] (sfromR u1) !$ [1]) (\\[i191, i193] -> [remH i191 4 + i193]))) (\\[i195, i197] -> [i195 + i197, i195]))) (\\[i128, i129] -> [i128 + i129, i128]))))), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i130, i131, i132, i133] -> [ifH (notB (2 <=. remH i130 4 + i131) &&* (notB (2 <=. i130 + i132) &&* notB (2 <=. i130 + i133))) 0 1, i130, i131, i132, i133])))) (sconcrete (sreplicate [1,2,2,2,2,2,2] 0.0))) (\\[i134, i135, i136, i137] -> [i135, i134, i135, i137, i135, i137, i134])) (\\[i138] -> [remH i138 4])))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> rfromS (stranspose @[2,1,0] (sgather (sgather (sappend (sreplicate @1 (stranspose @[0,4,5,1,2,3] (sgather (sfromVector (fromList [stranspose @[5,2,3,4,0,1] (sreplicate @2 (sreplicate @2 (stranspose @[2,3,1,0] (sgather (stranspose @[3,0,2,1] (sgather (stranspose @[3,0,1,2] (sgather (stranspose @[3,0,2,1] (sfromR u1) !$ [1]) (\\[i124, i125] -> [remH i124 4 + i125]))) (\\[i126, i127] -> [i126 + i127, i126]))) (\\[i128, i129] -> [i128 + i129, i128]))))), sconcrete (sreplicate [2,2,2,2,2,2] 0.0)])) (\\[i130, i131, i132, i133] -> [ifH (notB (2 <=. remH i130 4 + i131) &&* (notB (2 <=. i130 + i132) &&* notB (2 <=. i130 + i133))) 0 1, i130, i131, i132, i133])))) (sconcrete (sreplicate [1,2,2,2,2,2,2] 0.0))) (\\[i134, i135, i136, i137] -> [i135, i134, i135, i137, i135, i137, i134])) (\\[i138] -> [remH i138 4])))" printArtifactPretty artifactRev @?= "\\dret u1 -> let w145 = sscatter (sscatter (stranspose @[2,1,0] (sfromR dret)) (\\[i140] -> [remH i140 4])) (\\[i141, i142, i143, i144] -> [i142, i141, i142, i144, i142, i144, i141]) ; w150 = sscatter (stranspose @[0,3,4,5,1,2] (ssum @1 (sslice (SNat @0) (SNat @1) w145))) (\\[i146, i147, i148, i149] -> [ifH (notB (2 <=. remH i146 4 + i147) &&* (notB (2 <=. i146 + i148) &&* notB (2 <=. i146 + i149))) 0 1, i146, i147, i148, i149]) in rfromS (stranspose @[1,3,2,0] (soneHot (sscatter (stranspose @[1,2,3,0] (sscatter (stranspose @[1,3,2,0] (sscatter (stranspose @[3,2,0,1] (ssum @2 (ssum @2 (stranspose @[4,5,1,2,3,0] (w150 !$ [0]))))) (\\[i151, i152] -> [i151 + i152, i151]))) (\\[i153, i154] -> [i153 + i154, i153]))) (\\[i155, i156] -> [remH i155 4 + i156])) [1]))" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> rfromS (stranspose @[1,3,2,0] (soneHot (sscatter (stranspose @[1,2,3,0] (sscatter (stranspose @[1,3,2,0] (sscatter (ssum @2 (ssum @2 (stranspose @[0,5,6,1,4,2,3] (sscatter (sscatter (sscatter (stranspose @[2,1,0] (sfromR dret)) (\\[i140] -> [remH i140 4])) (\\[i141, i142, i143, i144] -> [i142, i141, i142, i144, i141, i142, i144]) !$ [0]) (\\[i146, i147, i148, i149] -> [ifH (notB (2 <=. remH i146 4 + i147) &&* (notB (2 <=. i146 + i148) &&* notB (2 <=. i146 + i149))) 0 1, i146, i147, i148, i149])) !$ [0]))) (\\[i151, i152] -> [i151 + i152, i151]))) (\\[i153, i154] -> [i153 + i154, i153]))) (\\[i155, i156] -> [remH i155 4 + i156])) [1]))" maxPool2dUnpadded3 :: (ADReady target, GoodScalar r) => target (TKR 4 r) -> target (TKR 4 r) maxPool2dUnpadded3 arr = rbuild [2, 2, 2, 2] $ \case [aa, bb, iBh, iBw] -> let arrt = slicez3 [2, 2, 2, 2] arr [iBh `quotH` 4, aa, bb, iBw] in rmaximum3 arrt _ -> error "maxPool2dUnpadded3: impossible pattern needlessly required" maxPool2dUnpadded33 :: (ADReady target, GoodScalar r) => target (TKR 4 r) -> target (TKR 4 r) maxPool2dUnpadded33 arr = rbuild [2, 2, 2, 2] $ \case [aa, bb, iBh, iBw] -> let arrt = slicez33 [2, 2, 2, 2] arr [iBh `remH` 4, aa, bb, iBw] in rmaximum3 arrt _ -> error "maxPool2dUnpadded33: impossible pattern needlessly required" conv2dUnpadded3 :: (ADReady target, GoodScalar r) => target (TKR 4 r) -> target (TKR 4 r) conv2dUnpadded3 arrA = let shB = [2, 2, 2, 2] in rbuild shB $ \case [iImg, _, iBh, iBw] -> let arrAt = slicez33 shB arrA [iImg `remH` 4, iImg, iImg, 1] in rindex0 arrAt [iBh, iBw, iImg, iBh] _ -> error "conv2dUnpadded3: impossible pattern needlessly required" slicez3 :: (ADReady target, GoodScalar r, KnownNat n) => IShR n -> target (TKR n r) -> IxROf target n -> target (TKR n r) slicez3 shOut d ixBase = rbuild shOut $ \_ -> indexz03 d (ixrZipWith (+) ixBase ixBase) slicez33 :: (ADReady target, GoodScalar r, KnownNat n) => IShR n -> target (TKR n r) -> IxROf target n -> target (TKR n r) slicez33 shOut d ixBase = rbuild shOut $ \ixResult -> indexz03 d (ixrZipWith (+) ixBase ixResult) indexz03 :: forall target r n. (ADReady target, GoodScalar r, KnownNat n) => target (TKR n r) -> IxROf target n -> target (TKR 0 r) indexz03 d ix = ifH (within0 @target (rshape @target d) ix) (d ! ix) (rscalar 0) rmaximum3 :: (BaseTensor target, LetTensor target, KnownNat n, GoodScalar r) => target (TKR n r) -> target (TKR 0 r) rmaximum3 t0 = tlet t0 $ \t -> rindex0 t [0, 0, 0, 0] testCNNOPP4 :: Assertion testCNNOPP4 = do resetVarCounter let blackGlyph :: AstTensor AstMethodLet FullSpan (TKR 4 Double) blackGlyph = AstFromPrimal $ AstReplicate (SNat @3) knownSTK $ AstReplicate (SNat @3) knownSTK $ AstReplicate (SNat @3) knownSTK $ AstReplicate (SNat @3) knownSTK (rconcrete $ Nested.rscalar 7 :: AstTensor AstMethodLet PrimalSpan (TKR 0 Double)) afcnn2T :: AstTensor AstMethodLet FullSpan (TKR 4 Double) afcnn2T = maxPool2dUnpadded4 blackGlyph printAstPretty (simplifyInlineContract afcnn2T) @?= "rfromS (str (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sreplicate @1 (stranspose @[2,1,0] (sgather (stranspose @[3,4,5,2,6,1,0] (sgather (stranspose @[4,1,3,0,2] (sgather (stranspose @[3,0,4,1,2] (sgather (sconcrete (sreplicate [2,3,3,3] 7.0)) (\\[i61, i64] -> [i61 + i64]))) (\\[i67, i69] -> [3 + (negate i69 + i67), i69]))) (\\[i71, i73, i76] -> [i71 * i73 + i76])) !$ [1, 0, 0, 0]) (\\[i84] -> [2 * i84]))))))" -- TODO: was once "rfromS (sconcrete (sfromListLinear [2,2,2,2] [0.0,0.0,0.0,0.0,7.0,7.0,7.0,7.0,0.0,0.0,0.0,0.0,7.0,7.0,7.0,7.0]))" printAstPretty afcnn2T @?= "rfromS (let w19 = sgather (sfromVector (fromList [stranspose @[3,0,5,6,1,2,4] (sgather (stranspose @[6,0,3,1,4,5,2] (sgather (stranspose @[3,0,2,1] (sgather (stranspose @[0,2,1] (sgather (sconcrete (sreplicate [2,3,3,3] 7.0)) (\\[i32, i5] -> [i32 + i5]))) (\\[i31, i6] -> [i31, 3 + (negate i31 + i6)]))) (\\[i36, i26, i7] -> [i36 * i26 + i7]))) (\\[i22, i8] -> [2 * i22 + i8])), sconcrete (sreplicate [2,2,2,2,2,2,2,2] 0.0)])) (\\[i28, i21, i15, i12, i9] -> [ifH (notB (2 <=. i28 + i15) &&* (notB (0 <=. negate i28 + i12) &&* notB (3 <=. 2 * i21 + i9))) 0 1, i28, i21, i15, i12, i9]) in stranspose @[2,3,4,7,5,0,6,1] w19 !$ [0, 0, 0, 0])" testCNNOPP4b :: Assertion testCNNOPP4b = do resetVarCounter let artifactRev = revArtifactAdapt UseIncomingCotangent maxPool2dUnpadded4 (FTKR [3, 3, 3, 3] (FTKScalar @Double)) printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (str (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sreplicate @1 (stranspose @[2,1,0] (sgather (stranspose @[2,3,0,1] (sgather (stranspose @[1,0,3,2] (sreplicate @2 (stranspose @[2,3,0,1] (sreplicate @2 (stranspose @[2,1,0] (sreplicate @2 (sfromR u1 !$ [2, 2]))))))) (\\[i194, i195] -> [i195 * i194, i194, i195]))) (\\[i125] -> [i125, 2 * i125]))))))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> rfromS (str (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sreplicate @1 (stranspose @[2,1,0] (sgather (stranspose @[3,5,6,2,4,7,1,0] (sgather (stranspose @[3,4,7,1,5,6,0,2] (sgather (stranspose @[6,0,7,4,3,2,1,5] (sgather (sslice (SNat @1) (SNat @2) (stranspose @[0,2,1] (sfromR u1))) (\\[i115, i116, i117, i118, i119] -> [i115 + i116]))) (\\[i120, i121] -> [3 + (negate i121 + i120), i121]))) (\\[i122, i123, i124] -> [i122, i123, i122 * i123 + i124])) !$ [1, 0, 0, 0]) (\\[i125] -> [i125, 2 * i125]))))))" printArtifactPretty artifactRev @?= "\\dret u1 -> rfromS (stranspose @[0,2,1] (sappend (sconcrete (sreplicate [1,3,3,3] 0.0)) (sappend (sscatter (stranspose @[1,6,5,4,3,7,0,2] (sscatter (stranspose @[6,3,7,0,1,4,5,2] (sscatter (stranspose @[7,6,3,0,4,1,2,5] (soneHot (sscatter (stranspose @[2,1,0] (ssum @1 (sslice (SNat @1) (SNat @1) (str (sfromR dret))))) (\\[i127] -> [i127, 2 * i127])) [1, 0, 0, 0])) (\\[i128, i129, i130] -> [i128, i129, i128 * i129 + i130]))) (\\[i131, i132] -> [3 + (negate i132 + i131), i132]))) (\\[i133, i134, i135, i136, i137] -> [i133 + i134])) (sconcrete (sfromListLinear [0,3,3,3] [])))))" -- TODO: was once "\\dret u1 -> rfromS (soneHot (sscatter (ssum @1 (sslice (SNat @1) (SNat @1) (str (sfromR dret)))) (\\[i86, i87, i88] -> [i86 * i87, 2 * i88])) [2, 2])" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> rfromS (sappend (sconcrete (sreplicate [1,3,3,3] 0.0)) (stranspose @[0,2,1] (sscatter (stranspose @[1,6,5,4,3,7,0,2] (sscatter (stranspose @[6,3,7,0,1,4,5,2] (sscatter (stranspose @[7,6,3,0,4,1,2,5] (soneHot (sscatter (stranspose @[1,3,2,0] (sfromR dret) !$ [1]) (\\[i127] -> [i127, 2 * i127])) [1, 0, 0, 0])) (\\[i128, i129, i130] -> [i128, i129, i128 * i129 + i130]))) (\\[i131, i132] -> [3 + (negate i132 + i131), i132]))) (\\[i133, i134, i135, i136, i137] -> [i133 + i134]))))" -- TODO: was once "\\dret u1 -> rfromS (soneHot (sscatter (str (sfromR dret) !$ [1]) (\\[i86, i87, i88] -> [i86 * i87, 2 * i88])) [2, 2])" testCNNOPP5 :: Assertion testCNNOPP5 = do resetVarCounter let blackGlyph :: AstTensor AstMethodLet FullSpan (TKR 4 Double) blackGlyph = AstFromPrimal $ AstReplicate (SNat @6) knownSTK $ AstReplicate (SNat @6) knownSTK $ AstReplicate (SNat @6) knownSTK $ AstReplicate (SNat @6) knownSTK (rconcrete $ Nested.rscalar 7 :: AstTensor AstMethodLet PrimalSpan (TKR 0 Double)) afcnn2T :: AstTensor AstMethodLet FullSpan (TKR 4 Double) afcnn2T = conv2dUnpadded4 blackGlyph printAstPretty (simplifyInlineContract afcnn2T) @?= "rfromS (sconcrete (sreplicate [1,1,2,2] 7.0))" printAstPretty afcnn2T @?= "rfromS (sconcrete (sreplicate [1,1,2,2] 7.0))" testCNNOPP5b :: Assertion testCNNOPP5b = do resetVarCounter let artifactRev = revArtifactAdapt UseIncomingCotangent conv2dUnpadded4 (FTKR [5, 5, 5, 5] (FTKScalar @Double)) printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (sreplicate @1 (sreplicate @1 (str (sslice (SNat @0) (SNat @2) (str (sslice (SNat @0) (SNat @2) (sfromR u1 !$ [0, 0])))))))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> rfromS (sreplicate @1 (sreplicate @1 (str (sslice (SNat @0) (SNat @2) (str (sslice (SNat @0) (SNat @2) (sfromR u1 !$ [0, 0])))))))" printArtifactPretty artifactRev @?= "\\dret u1 -> rfromS (soneHot (sappend (sconcrete (sfromListLinear [0,5] [])) (sappend (str (sappend (sconcrete (sfromListLinear [0,2] [])) (sappend (str (ssum @1 (ssum @1 (sfromR dret)))) (sconcrete (sreplicate [3,2] 0.0))))) (sconcrete (sreplicate [3,5] 0.0)))) [0, 0])" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> rfromS (soneHot (sappend (str (sappend (stranspose @[0,1,3,2] (sfromR dret) !$ [0, 0]) (sconcrete (sreplicate [3,2] 0.0)))) (sconcrete (sreplicate [3,5] 0.0))) [0, 0])" maxPool2dUnpadded4 :: (ADReady target, GoodScalar r) => target (TKR 4 r) -> target (TKR 4 r) maxPool2dUnpadded4 arr = rbuild [2, 2, 2, 2] $ \case [aa, bb, iBh, iBw] -> let arrt = slicez4 [2, 2, 2, 2] arr [bb + 1, 3 - bb, aa * iBh, 2 * iBw] in rmaximum3 arrt _ -> error "maxPool2dUnpadded4: impossible pattern needlessly required" conv2dUnpadded4 :: (ADReady target, GoodScalar r) => target (TKR 4 r) -> target (TKR 4 r) conv2dUnpadded4 arrA = let shB = [1, 1, 2, 2] in rbuild shB $ \case [iImg, _, iBh, iBw] -> let arrAt = slicez4 shB arrA [iImg, 0, iBh, iBw] in rindex0 arrAt [0, 0, 0, 0] _ -> error "conv2dUnpadded4: impossible pattern needlessly required" slicez4 :: (ADReady target, GoodScalar r, KnownNat n) => IShR n -> target (TKR n r) -> IxROf target n -> target (TKR n r) slicez4 shOut d ixBase = rbuild shOut $ \ixResult -> indexz03 d (ixrZipWith (+) ixBase ixResult) testCNNOPP6 :: Assertion testCNNOPP6 = do resetVarCounter let blackGlyph :: AstTensor AstMethodLet FullSpan (TKR 4 Double) blackGlyph = AstFromPrimal $ AstReplicate (SNat @2) knownSTK $ AstReplicate (SNat @2) knownSTK $ AstReplicate (SNat @2) knownSTK $ AstReplicate (SNat @2) knownSTK (rconcrete $ Nested.rscalar 7 :: AstTensor AstMethodLet PrimalSpan (TKR 0 Double)) afcnn2T :: AstTensor AstMethodLet FullSpan (TKR 4 Double) afcnn2T = maxPool2dUnpadded3 $ conv2dUnpadded3z blackGlyph printAstPretty (simplifyInlineContract afcnn2T) @?= "rfromS (sconcrete (sfromListLinear [2,2,2,2] [7.0,0.0,7.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]))" printAstPretty afcnn2T @?= "rfromS (stranspose @[1,2,0] (sreplicate @2 (let t30 = sgather (stranspose @[2,1,0] (sgather (str (sgather (sreplicate @2 (str (sreplicate @2 (let m21 = sgather (str (sgather (sconcrete (sreplicate [2,2,2,2] 7.0)) (\\[i9] -> [2 * i9, 2 * i9, 2 * i9]))) (\\[i12] -> [2 * i12]) in sappend (sreplicate @1 (sappend (sreplicate @1 (m21 !$ [0, 0])) (sconcrete (sfromListLinear [1] [0.0])))) (sconcrete (sreplicate [1,2] 0.0)))))) (\\[i1] -> [2 * i1, 0]))) (\\[i2] -> [2 * i2]))) (\\[i4] -> [2 * i4]) in sappend (sreplicate @1 (sappend (sreplicate @1 (sappend (sreplicate @1 (t30 !$ [0, 0, 0])) (sconcrete (sfromListLinear [1] [0.0])))) (sconcrete (sreplicate [1,2] 0.0)))) (sconcrete (sreplicate [1,2,2] 0.0)))))" testCNNOPP6b :: Assertion testCNNOPP6b = do resetVarCounter let artifactRev = revArtifactAdapt UseIncomingCotangent (maxPool2dUnpadded3 . conv2dUnpadded3z) (FTKR [2, 2, 2, 2] (FTKScalar @Double)) printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (stranspose @[1,2,0] (sreplicate @2 (sappend (sreplicate @1 (sappend (sreplicate @1 (sappend (sreplicate @1 (sfromR u1 !$ [0, 0, 0, 0])) (sconcrete (sfromListLinear [1] [0.0])))) (sconcrete (sreplicate [1,2] 0.0)))) (sconcrete (sreplicate [1,2,2] 0.0)))))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> rfromS (stranspose @[1,2,0] (sreplicate @2 (sappend (sreplicate @1 (sappend (sreplicate @1 (sappend (sreplicate @1 (sfromR u1 !$ [0, 0, 0, 0])) (sconcrete (sfromListLinear [1] [0.0])))) (sconcrete (sreplicate [1,2] 0.0)))) (sconcrete (sreplicate [1,2,2] 0.0)))))" printArtifactPretty artifactRev @?= "\\dret u1 -> let t34 = ssum @2 (stranspose @[2,0,1] (sfromR dret)) in rfromS (soneHot (ssum @1 (sslice (SNat @0) (SNat @1) (ssum @1 (sslice (SNat @0) (SNat @1) (ssum @1 (sslice (SNat @0) (SNat @1) t34)))))) [0, 0, 0, 0])" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> rfromS (soneHot (ssum0 (stranspose @[0,1,3,2] (sfromR dret) !$ [0, 0, 0])) [0, 0, 0, 0])" conv2dUnpadded3z :: (ADReady target, GoodScalar r) => target (TKR 4 r) -> target (TKR 4 r) conv2dUnpadded3z arrA = let shB = [2, 2, 2, 2] in rbuild shB $ \case [iImg, _, iBh, iBw] -> let arrAt = slicez3 shB arrA [iImg, iImg, iImg, iBw] in rindex0 arrAt [iBh, iBw, iImg, iBh] _ -> error "conv2dUnpadded3z: impossible pattern needlessly required" testCNNOPP7 :: Assertion testCNNOPP7 = do resetVarCounter let blackGlyph :: AstTensor AstMethodLet FullSpan (TKR 4 Double) blackGlyph = AstFromPrimal $ AstReplicate (SNat @2) knownSTK $ AstReplicate (SNat @2) knownSTK $ AstReplicate (SNat @2) knownSTK $ AstReplicate (SNat @2) knownSTK (rconcrete $ Nested.rscalar 7 :: AstTensor AstMethodLet PrimalSpan (TKR 0 Double)) afcnn2T :: AstTensor AstMethodLet FullSpan (TKR 4 Double) afcnn2T = maxPool2dUnpadded3y $ conv2dUnpadded3y blackGlyph printAstPretty (simplifyInlineContract afcnn2T) @?= "rfromS (sconcrete (sfromListLinear [2,2,2,2] [7.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0]))" printAstPretty afcnn2T @?= "rfromS (let u27 = sgather (stranspose @[3,2,0,1] (sgather (stranspose @[1,2,0] (sgather (sreplicate @2 (stranspose @[1,2,0] (sreplicate @2 (let m21 = sgather (str (sgather (sconcrete (sreplicate [2,2,2,2] 7.0)) (\\[i9] -> [2 * i9, 2 * i9, 2 * i9]))) (\\[i11] -> [2 * i11]) in sappend (sreplicate @1 (sappend (sreplicate @1 (m21 !$ [0, 0])) (sconcrete (sfromListLinear [1] [0.0])))) (sconcrete (sreplicate [1,2] 0.0)))))) (\\[i1] -> [2 * i1]))) (\\[i31, i3] -> [2 * i3, 2 * i31]))) (\\[i4] -> [2 * i4]) in stranspose @[1,2,0] (sappend (sreplicate @1 (sappend (sreplicate @1 (sappend (sreplicate @1 (sappend (sreplicate @1 (u27 !$ [0, 0, 0, 0])) (sconcrete (sfromListLinear [1] [0.0])))) (sconcrete (sreplicate [1,2] 0.0)))) (sconcrete (sreplicate [1,2,2] 0.0)))) (sconcrete (sreplicate [1,2,2,2] 0.0))))" testCNNOPP7b :: Assertion testCNNOPP7b = do resetVarCounter let artifactRev = revArtifactAdapt UseIncomingCotangent (maxPool2dUnpadded3y . conv2dUnpadded3y) (FTKR [2, 2, 2, 2] (FTKScalar @Double)) printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (stranspose @[1,2,0] (sappend (sreplicate @1 (sappend (sreplicate @1 (sappend (sreplicate @1 (sappend (sreplicate @1 (sfromR u1 !$ [0, 0, 0, 0])) (sconcrete (sfromListLinear [1] [0.0])))) (sconcrete (sreplicate [1,2] 0.0)))) (sconcrete (sreplicate [1,2,2] 0.0)))) (sconcrete (sreplicate [1,2,2,2] 0.0))))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> rfromS (stranspose @[1,2,0] (sappend (sreplicate @1 (sappend (sreplicate @1 (sappend (sreplicate @1 (sappend (sreplicate @1 (sfromR u1 !$ [0, 0, 0, 0])) (sconcrete (sfromListLinear [1] [0.0])))) (sconcrete (sreplicate [1,2] 0.0)))) (sconcrete (sreplicate [1,2,2] 0.0)))) (sconcrete (sreplicate [1,2,2,2] 0.0))))" printArtifactPretty artifactRev @?= "\\dret u1 -> rfromS (soneHot (ssum @1 (sslice (SNat @0) (SNat @1) (ssum @1 (sslice (SNat @0) (SNat @1) (ssum @1 (sslice (SNat @0) (SNat @1) (ssum @1 (sslice (SNat @0) (SNat @1) (stranspose @[2,0,1] (sfromR dret)))))))))) [0, 0, 0, 0])" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> rfromS (soneHot (sfromR dret !$ [0, 0, 0, 0]) [0, 0, 0, 0])" maxPool2dUnpadded3y :: (ADReady target, GoodScalar r) => target (TKR 4 r) -> target (TKR 4 r) maxPool2dUnpadded3y arr = rbuild [2, 2, 2, 2] $ \case [aa, bb, iBh, iBw] -> let arrt = slicez3 [2, 2, 2, 2] arr [iBh, aa, bb, iBw] in rmaximum3 arrt _ -> error "maxPool2dUnpadded3y: impossible pattern needlessly required" conv2dUnpadded3y :: (ADReady target, GoodScalar r) => target (TKR 4 r) -> target (TKR 4 r) conv2dUnpadded3y arrA = let shB = [2, 2, 2, 2] in rbuild shB $ \case [iImg, _, iBh, iBw] -> let arrAt = slicez3 shB arrA [iImg, iImg, iImg, iBh] in rindex0 arrAt [iBh, iBw, iImg, iBh] _ -> error "conv2dUnpadded3y: impossible pattern needlessly required" -- TODO: OOMs _testPaddedCNNOPP0c :: Assertion _testPaddedCNNOPP0c = do resetVarCounter let artifactRev = revArtifactAdapt UseIncomingCotangent conv2dCPadded (FTKR [2, 2, 2, 2] (FTKScalar @Double)) printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[5,0,1,4,2,3] (sgather (sconcrete (sfromListLinear [4,4,2,2] [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,5.0,13.1,-2.0,582934.0,2.0,9.0,0.0,2.99432,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,6.0,8.0,0.1,-335.0,1.0,-4.0,-0.2,26.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])) (\\[i37, i38, i39, i40] -> [i37 + i39, i38 + i40])))))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> let w41 = str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[5,0,1,4,2,3] (sgather (sconcrete (sfromListLinear [4,4,2,2] [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,5.0,13.1,-2.0,582934.0,2.0,9.0,0.0,2.99432,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,6.0,8.0,0.1,-335.0,1.0,-4.0,-0.2,26.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])) (\\[i37, i38, i39, i40] -> [i37 + i39, i38 + i40])))))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (w41 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))))" printArtifactPretty artifactRev @?= "\\dret u1 -> let w41 = str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[5,0,1,4,2,3] (sgather (sconcrete (sfromListLinear [4,4,2,2] [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,5.0,13.1,-2.0,582934.0,2.0,9.0,0.0,2.99432,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,6.0,8.0,0.1,-335.0,1.0,-4.0,-0.2,26.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])) (\\[i37, i38, i39, i40] -> [i37 + i39, i38 + i40])))))) in rfromS (ssum @1 (str (ssum @2 (str (ssum @2 (str (ssum @2 (w41 * sreshape @[2,2,2,2,1,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret)))))))))))" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> rfromS (ssum @2 (ssum @2 (sdot1In (stranspose @[0,1,2,6,3,4,5] (sgather (sconcrete (sfromListLinear [4,4,2,2] [0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,5.0,-2.0,13.1,582934.0,2.0,0.0,9.0,2.99432,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,6.0,0.1,8.0,-335.0,1.0,-0.2,-4.0,26.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0,0.0])) (\\[i66, i67, i68, i73, i74] -> [i66 + i73, i67 + i74]))) (stranspose @[4,2,3,1,5,6,7,0] (sreshape @[2,2,2,2,1,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret)))) !$ [0]))))" -- TODO: OOMs _testPaddedCNNOPP0b :: Assertion _testPaddedCNNOPP0b = do resetVarCounter let artifactRev = revArtifactAdapt UseIncomingCotangent conv2dBPadded (FTKR [2, 2, 2, 2] (FTKScalar @Double)) printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (sconcrete (sfromListLinear [2,2,2,2,1,2,2,2] [5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0]) * str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (sgather (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)) (stranspose @[9,2,5,1,3,4,6,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0)) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,0,7,8,6] (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR u1)) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0)))))))))))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0)))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0))))))) (\\[i50, i51, i52, i53, i54, i55] -> [i51 + i54, i51, i54, i50, i52, i53, i55, i50, i53, i52 + i55])))))))))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> let w56 = str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (sgather (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)) (stranspose @[9,2,5,1,3,4,6,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0)) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,0,7,8,6] (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR u1)) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0)))))))))))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0)))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0))))))) (\\[i50, i51, i52, i53, i54, i55] -> [i51 + i54, i51, i54, i50, i52, i53, i55, i50, i53, i52 + i55]))))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (sconcrete (sfromListLinear [2,2,2,2,1,2,2,2] [5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0]) * w56))))" printArtifactPretty artifactRev @?= "\\dret u1 -> let w64 = sscatter (ssum @1 (stranspose @[3,0,1,2] (ssum @2 (str (sconcrete (sfromListLinear [2,2,2,2,1,2,2,2] [5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,5.0,2.0,6.0,1.0,-2.0,0.0,0.1,-0.2,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0,13.1,9.0,8.0,-4.0,582934.0,2.99432,-335.0,26.0]) * sreshape @[2,2,2,2,1,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret)))))))) (\\[i58, i59, i60, i61, i62, i63] -> [i59 + i62, i59, i62, i58, i60, i61, i63, i58, i61, i60 + i63]) ; u65 = ssum @2 (ssum @2 (ssum @2 (ssum @2 (ssum @2 (ssum @2 (stranspose @[6,1,2,3,4,5,9,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)) (sappend (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0)) (sappend (sslice (SNat @0) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @0) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @1) (SNat @3) (stranspose @[9,3,1,4,5,2,6,7,8,0] (sslice (SNat @1) (SNat @3) w64))))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)))))))))) in rfromS (stranspose @[1,2,0] (sslice (SNat @0) (SNat @2) (sslice (SNat @1) (SNat @3) (stranspose @[3,1,2,0] (sslice (SNat @0) (SNat @2) (sslice (SNat @1) (SNat @3) u65))))))" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> rfromS (stranspose @[1,2,0] (sslice (SNat @1) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @1) (SNat @2) (ssum @2 (ssum @2 (ssum @2 (ssum @2 (ssum @2 (ssum @2 (stranspose @[6,1,2,3,4,5,9,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)) (sappend (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0)) (sappend (sslice (SNat @0) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @0) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @1) (SNat @3) (stranspose @[9,3,1,4,5,2,6,7,8,0] (sslice (SNat @1) (SNat @3) (sscatter (sdot1In (sconcrete (sfromListLinear [2,2,2,2,2,2,2] [5.0,13.1,2.0,9.0,6.0,8.0,1.0,-4.0,-2.0,582934.0,0.0,2.99432,0.1,-335.0,-0.2,26.0,5.0,13.1,2.0,9.0,6.0,8.0,1.0,-4.0,-2.0,582934.0,0.0,2.99432,0.1,-335.0,-0.2,26.0,5.0,13.1,2.0,9.0,6.0,8.0,1.0,-4.0,-2.0,582934.0,0.0,2.99432,0.1,-335.0,-0.2,26.0,5.0,13.1,2.0,9.0,6.0,8.0,1.0,-4.0,-2.0,582934.0,0.0,2.99432,0.1,-335.0,-0.2,26.0,5.0,13.1,2.0,9.0,6.0,8.0,1.0,-4.0,-2.0,582934.0,0.0,2.99432,0.1,-335.0,-0.2,26.0,5.0,13.1,2.0,9.0,6.0,8.0,1.0,-4.0,-2.0,582934.0,0.0,2.99432,0.1,-335.0,-0.2,26.0,5.0,13.1,2.0,9.0,6.0,8.0,1.0,-4.0,-2.0,582934.0,0.0,2.99432,0.1,-335.0,-0.2,26.0,5.0,13.1,2.0,9.0,6.0,8.0,1.0,-4.0,-2.0,582934.0,0.0,2.99432,0.1,-335.0,-0.2,26.0])) (stranspose @[4,0,2,3,5,6,7,1] (sreshape @[2,2,2,2,1,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret)))) !$ [0])) (\\[i58, i59, i60, i61, i62, i63] -> [i59 + i62, i59, i62, i58, i60, i61, i63, i58, i61, i60 + i63]))))))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)))))))))))))))" -- TODO: OOMs _testPaddedCNNOPP1e :: Assertion _testPaddedCNNOPP1e = do resetVarCounter let f :: AstTensor AstMethodLet FullSpan (TKProduct (TKR 4 Double) (TKR 4 Double)) -> AstTensor AstMethodLet FullSpan (TKR 4 Double) f v = conv2dPadded (tproject1 v) (tproject2 v) ftk = FTKProduct (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar) (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar) (artifactRev, _) = revArtifactFromForwardPass UseIncomingCotangent (forwardPassByInterpretation f emptyEnv) ftk printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (str (sreplicate @2 (sgather (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)) (stranspose @[9,2,5,1,3,4,6,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0)) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,0,7,8,6] (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0)))))))))))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0)))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0))))))) (\\[i44, i75, i45, i46, i76, i47] -> [i75 + i76, i75, i76, i44, i45, i46, i47, i44, i46, i45 + i47]))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1))))))))))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> let u41 = sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0))) ; w48 = str (sreplicate @2 (stranspose @[0,4,1,2,5,3] (sgather (stranspose @[2,3,4,5,6,7,8,0,1] (sgather (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)) (stranspose @[9,2,5,1,3,4,6,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0)) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,0,7,8,6] (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 u41)))))))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0)))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0))))))) (\\[i42, i43] -> [i42 + i43, i42, i43]))) (\\[i44, i45, i46, i47] -> [i44, i45, i46, i47, i44, i46, i45 + i47])))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (w48 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1))))))))))" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> tconvert (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX))) (STKProduct (STKS [2,2,2,2] STKScalar) (STKS [2,2,2,2] STKScalar)) (let w50 = sreshape @[2,2,2,2,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret))) in tpair (ssum @2 (ssum @2 (sdot1In (stranspose @[2,3,0,4,5,6,1] (sreplicate @2 (sgather (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)) (stranspose @[9,2,5,1,3,4,6,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0)) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @1) (SNat @2) (stranspose @[9,1,2,3,4,5,0,7,8,6] (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sreplicate @2 (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0)))))))))))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0)))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,3] 0.0))))))) (\\[i44, i84, i45, i46, i85, i47] -> [i84 + i85, i84, i85, i44, i45, i46, i47, i44, i46, i45 + i47])))) (stranspose @[2,3,1,4,5,6,0] w50)))) (stranspose @[1,2,0] (sslice (SNat @1) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @1) (SNat @2) (ssum @2 (ssum @2 (ssum @2 (ssum @2 (ssum @2 (ssum @2 (stranspose @[6,1,2,3,4,5,9,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0)) (sappend (stranspose @[9,1,2,3,4,5,6,7,8,0] (sappend (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0)) (sappend (sslice (SNat @0) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @0) (SNat @2) (stranspose @[9,1,2,3,4,5,6,7,8,0] (sslice (SNat @1) (SNat @3) (stranspose @[9,3,1,4,5,2,6,7,8,0] (sslice (SNat @1) (SNat @3) (sscatter (stranspose @[7,8,0,1,2,3,4,5,6] (sscatter (sdot1In (sreplicate @2 (stranspose @[2,3,5,0,4,1] (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1))))))) (stranspose @[0,3,4,6,2,5,1] w50)) (\\[i51, i52, i53, i54] -> [i51, i52, i53, i54, i51, i53, i52 + i54]))) (\\[i55, i56] -> [i55 + i56, i55, i56]))))))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,2,2,2,2,2,2,4] 0.0))))))))))))))))" -- This is fragile due to indexing out of bounds, see above. testPaddedCNNOPP1b :: Assertion testPaddedCNNOPP1b = do resetVarCounter let f :: AstTensor AstMethodLet FullSpan (TKProduct (TKR 4 Double) (TKR 4 Double)) -> AstTensor AstMethodLet FullSpan (TKR 4 Double) f v = conv2dShrinking (tproject1 v) (tproject2 v) ftk = FTKProduct (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar) (FTKR (6 :$: 2 :$: 6 :$: 6 :$: ZSR) FTKScalar) (artifactRev, _) = revArtifactFromForwardPass UseIncomingCotangent (forwardPassByInterpretation f emptyEnv) ftk printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[6,2,4,4,8] (str (sreplicate @2 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i64, i66] -> [i64 + i66]))) (\\[i24, i25] -> [i24 + i25])))) * sreplicate @6 (str (sreplicate @4 (str (sreplicate @4 (sfromR (tproject1 u1))))))))))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> let w26 = str (sreplicate @2 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i22, i23] -> [i22 + i23]))) (\\[i24, i25] -> [i24 + i25])))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[6,2,4,4,8] (w26 * sreplicate @6 (str (sreplicate @4 (str (sreplicate @4 (sfromR (tproject1 u1))))))))))" printArtifactPretty artifactRev @?= "\\dret u1 -> let w26 = str (sreplicate @2 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i22, i23] -> [i22 + i23]))) (\\[i24, i25] -> [i24 + i25])))) ; w28 = sreshape @[6,2,4,4,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret))) in tpair (rfromS (ssum @4 (str (ssum @4 (str (ssum @6 (w26 * w28))))))) (rfromS (stranspose @[1,2,0] (sscatter (stranspose @[2,4,1,3,0] (sscatter (stranspose @[2,5,0,1,3,4] (ssum @2 (str (sreplicate @6 (str (sreplicate @4 (str (sreplicate @4 (sfromR (tproject1 u1)))))) * w28)))) (\\[i29, i30] -> [i29 + i30]))) (\\[i31, i32] -> [i31 + i32]))))" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> tconvert (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [6,2,6,6] FTKScalar)) ConvSX))) (STKProduct (STKS [2,2,2,2] STKScalar) (STKS [6,2,6,6] STKScalar)) (let w28 = sreshape @[6,2,4,4,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret))) in tpair (ssum @4 (ssum @4 (sdot1In (stranspose @[2,3,0,4,5,6,1] (sreplicate @2 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i81, i83] -> [i81 + i83]))) (\\[i24, i25] -> [i24 + i25]))))) (stranspose @[2,3,1,4,5,6,0] w28)))) (stranspose @[1,2,0] (sscatter (stranspose @[2,4,1,3,0] (sscatter (sdot1In (stranspose @[3,6,0,2,4,5,1] (sreplicate @6 (str (sreplicate @4 (str (sreplicate @4 (sfromR (tproject1 u1)))))))) (stranspose @[3,6,0,2,4,5,1] w28)) (\\[i29, i30] -> [i29 + i30]))) (\\[i31, i32] -> [i31 + i32]))))" testPaddedCNNOPPLet :: Assertion testPaddedCNNOPPLet = do resetVarCounter let f :: AstTensor AstMethodLet FullSpan (TKProduct (TKR 4 Double) (TKR 4 Double)) -> AstTensor AstMethodLet FullSpan (TKR 4 Double) f v = conv2dPaddedLet (tproject1 v) (tproject2 v) ftk = FTKProduct (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar) (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar) (artifactRev, _) = revArtifactFromForwardPass UseIncomingCotangent (forwardPassByInterpretation f emptyEnv) ftk printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (str (sreplicate @2 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,3] 0.0)) (stranspose @[0,2,3,1] (sgather (sslice (SNat @1) (SNat @3) (stranspose @[3,0,4,1,2] (sslice (SNat @1) (SNat @3) (stranspose @[3,1,2,4,0] (sfromVector (fromList [stranspose @[1,2,3,0] (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0)))), sconcrete (sreplicate [2,2,4,4] 0.0)])))))) (\\[i36, i37] -> [i36, i37, ifH (notB (2 <=. i37) &&* notB (2 <=. i36)) 0 1])))))) (\\[i81, i83] -> [i81 + i83]))) (\\[i41, i42] -> [i41 + i42])))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1))))))))))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> let u35 = sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0))) ; u38 = stranspose @[1,2,0] (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,3] 0.0)) (stranspose @[0,2,3,1] (sgather (sslice (SNat @1) (SNat @3) (stranspose @[3,0,4,1,2] (sslice (SNat @1) (SNat @3) (stranspose @[3,1,2,4,0] (sfromVector (fromList [stranspose @[1,2,3,0] u35, sconcrete (sreplicate [2,2,4,4] 0.0)])))))) (\\[i36, i37] -> [i36, i37, ifH (notB (2 <=. i37) &&* notB (2 <=. i36)) 0 1])))))) ; w43 = str (sreplicate @2 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] u38) (\\[i39, i40] -> [i39 + i40]))) (\\[i41, i42] -> [i41 + i42])))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (w43 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1))))))))))" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> tconvert (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX))) (STKProduct (STKS [2,2,2,2] STKScalar) (STKS [2,2,2,2] STKScalar)) (let w45 = sreshape @[2,2,2,2,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret))) in tpair (ssum @2 (ssum @2 (sdot1In (stranspose @[2,3,0,4,5,6,1] (sreplicate @2 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,3] 0.0)) (stranspose @[0,2,3,1] (sgather (sslice (SNat @1) (SNat @3) (stranspose @[3,0,4,1,2] (sslice (SNat @1) (SNat @3) (stranspose @[3,1,2,4,0] (sfromVector (fromList [stranspose @[1,2,3,0] (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0)))), sconcrete (sreplicate [2,2,4,4] 0.0)])))))) (\\[i36, i37] -> [i36, i37, ifH (notB (2 <=. i37) &&* notB (2 <=. i36)) 0 1])))))) (\\[i100, i102] -> [i100 + i102]))) (\\[i41, i42] -> [i41 + i42]))))) (stranspose @[2,3,1,4,5,6,0] w45)))) (stranspose @[3,1,2,0] (sslice (SNat @1) (SNat @2) (stranspose @[3,2,1,4,0] (sslice (SNat @1) (SNat @2) (stranspose @[3,2,1,0] (sappend (sconcrete (sreplicate [1,2,2,4,2] 0.0)) (stranspose @[1,3,4,0,2] (sappend (sconcrete (sreplicate [1,3,2,2,2] 0.0)) (sscatter (sslice (SNat @1) (SNat @3) (stranspose @[3,0,1,2] (sslice (SNat @1) (SNat @3) (sscatter (stranspose @[2,4,1,3,0] (sscatter (sdot1In (stranspose @[3,6,0,2,4,5,1] (sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1)))))))) (stranspose @[3,6,0,2,4,5,1] w45)) (\\[i46, i47] -> [i46 + i47]))) (\\[i48, i49] -> [i48 + i49]))))) (\\[i51, i52] -> [i51, i52, ifH (notB (2 <=. i52) &&* notB (2 <=. i51)) 0 1]))))))))) !$ [0]))" conv2dPaddedLet :: forall target r. (ADReady target, GoodScalar r) => target (TKR 4 r) -> target (TKR 4 r) -> target (TKR 4 r) conv2dPaddedLet arrK arrA = let [nImgs, nCinpA, nAh, nAw] = rshape arrA [nCoutK, nCinpK, nKh, nKw] = rshape arrK shAPadded = [nImgs, nCinpA, nAh + nKh, nAw + nKw] arrAPadded = rbuild @4 @0 @(TKScalar r) @target shAPadded $ \case [iImg, iCinp, iPh, iPw] -> ifH (iPh <. fromIntegral (nKh `div` 2) ||* iPw <. fromIntegral (nKw `div` 2) ||* iPh >=. fromIntegral (nAh + nKh `div` 2) ||* iPw >=. fromIntegral (nAw + nKw `div` 2)) (rscalar 0) (arrA ! [ iImg , iCinp , iPh - fromIntegral (nKh `div` 2) , iPw - fromIntegral (nKw `div` 2) ]) nCinp = assert (nCinpA == nCinpK `blame` (nCinpA, nCinpK)) nCinpA shB = [nImgs, nCoutK, nAh, nAw] shK1 = [1, nCinp, nKh, nKw] in tlet arrAPadded $ \arrAPadded2 -> rbuild shB $ \case [iImg, iCout, iBh, iBw] -> let arrAt = slicezL shK1 arrAPadded2 [iImg, 0, iBh, iBw] arrKt = slicezL shK1 arrK [iCout, 0, 0, 0] in rdot0 arrAt arrKt _ -> error "conv2dPaddedLet: impossible pattern needlessly required" testPaddedCNNOPPLet2 :: Assertion testPaddedCNNOPPLet2 = do resetVarCounter let f :: AstTensor AstMethodLet FullSpan (TKProduct (TKR 4 Double) (TKR 4 Double)) -> AstTensor AstMethodLet FullSpan (TKR 4 Double) f v = conv2dPaddedLet2 (tproject1 v) (tproject2 v) ftk = FTKProduct (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar) (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar) (artifactRev, _) = revArtifactFromForwardPass UseIncomingCotangent (forwardPassByInterpretation f emptyEnv) ftk printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[3,5,2,0,4,1] (sgather (stranspose @[1,2,0] (sreplicate @2 (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,3] 0.0)) (stranspose @[0,2,3,1] (sgather (sslice (SNat @1) (SNat @3) (stranspose @[3,0,4,1,2] (sslice (SNat @1) (SNat @3) (stranspose @[3,1,2,4,0] (sfromVector (fromList [stranspose @[1,2,3,0] (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0)))), sconcrete (sreplicate [2,2,4,4] 0.0)])))))) (\\[i53, i54] -> [i53, i54, ifH (notB (2 <=. i54) &&* notB (2 <=. i53)) 0 1])))))))) (\\[i180, i182] -> [i180 + i182]))) (\\[i62, i63] -> [i62, i62 + i63])))))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR (tproject1 u1))))))))))))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> let u52 = sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0))) ; u55 = stranspose @[1,2,0] (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,3] 0.0)) (stranspose @[0,2,3,1] (sgather (sslice (SNat @1) (SNat @3) (stranspose @[3,0,4,1,2] (sslice (SNat @1) (SNat @3) (stranspose @[3,1,2,4,0] (sfromVector (fromList [stranspose @[1,2,3,0] u52, sconcrete (sreplicate [2,2,4,4] 0.0)])))))) (\\[i53, i54] -> [i53, i54, ifH (notB (2 <=. i54) &&* notB (2 <=. i53)) 0 1])))))) ; w64 = str (sreplicate @2 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[3,5,2,0,4,1] (sgather (stranspose @[1,2,0] (sgather (stranspose @[1,5,0,2,3,4] (sreplicate @2 (stranspose @[3,0,1,2] (sgather (stranspose @[2,3,0,1] (sreplicate @2 u55)) (\\[i56, i57, i58] -> [i58, i56]))))) (\\[i59] -> [i59, i59]))) (\\[i60, i61] -> [i60, i60 + i61]))) (\\[i62, i63] -> [i62, i62 + i63])))))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (w64 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR (tproject1 u1))))))))))))" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> tconvert (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX))) (STKProduct (STKS [2,2,2,2] STKScalar) (STKS [2,2,2,2] STKScalar)) (let w66 = sreshape @[2,2,2,2,1,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret))) in tpair (ssum @2 (ssum @2 (sdot1In (stranspose @[5,0,1,4,3,2] (sgather (stranspose @[2,4,0,3,1] (sgather (str (sreplicate @2 (sappend (sconcrete (sreplicate [1,2,4,2] 0.0)) (stranspose @[3,2,0,1] (sappend (sconcrete (sreplicate [1,2,2,3] 0.0)) (stranspose @[0,2,3,1] (sgather (sslice (SNat @1) (SNat @3) (stranspose @[3,0,4,1,2] (sslice (SNat @1) (SNat @3) (stranspose @[3,1,2,4,0] (sfromVector (fromList [stranspose @[1,2,3,0] (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0)))), sconcrete (sreplicate [2,2,4,4] 0.0)])))))) (\\[i53, i54] -> [i53, i54, ifH (notB (2 <=. i54) &&* notB (2 <=. i53)) 0 1])))))))) (\\[i253, i255] -> [i255 + i253]))) (\\[i222, i223, i229] -> [i222, i222 + i229]))) (stranspose @[4,2,3,1,5,6,7,0] w66 !$ [0])))) (stranspose @[3,1,2,0] (sslice (SNat @1) (SNat @2) (stranspose @[3,2,1,4,0] (sslice (SNat @1) (SNat @2) (stranspose @[3,2,1,0] (sappend (sconcrete (sreplicate [1,2,2,4,2] 0.0)) (stranspose @[1,3,4,0,2] (sappend (sconcrete (sreplicate [1,3,2,2,2] 0.0)) (sscatter (sslice (SNat @1) (SNat @3) (stranspose @[3,0,1,2] (sslice (SNat @1) (SNat @3) (ssum @2 (stranspose @[2,1,3,0] (sscatter (ssum @2 (stranspose @[2,3,4,5,0,1] (sscatter (stranspose @[2,0,1] (sscatter (stranspose @[3,5,2,0,4,1] (sscatter (sdot1In (sreplicate @2 (stranspose @[2,0,1,4,3] (sreplicate @2 (sreplicate @2 (stranspose @[3,2,1,0] (sfromR (tproject1 u1))))))) (stranspose @[4,3,7,0,2,5,6,1] w66 !$ [0])) (\\[i67, i68] -> [i67, i67 + i68]))) (\\[i69, i70] -> [i69, i69 + i70]))) (\\[i71] -> [i71, i71])))) (\\[i72, i73, i74] -> [i74, i72]))))))) (\\[i76, i77] -> [i76, i77, ifH (notB (2 <=. i77) &&* notB (2 <=. i76)) 0 1]))))))))) !$ [0]))" conv2dPaddedLet2 :: forall target r. (ADReady target, GoodScalar r) => target (TKR 4 r) -> target (TKR 4 r) -> target (TKR 4 r) conv2dPaddedLet2 arrK arrA = let [nImgs, nCinpA, nAh, nAw] = rshape arrA [nCoutK, nCinpK, nKh, nKw] = rshape arrK shAPadded = [nImgs, nCinpA, nAh + nKh, nAw + nKw] arrAPadded = rbuild @4 @0 @(TKScalar r) @target shAPadded $ \case [iImg, iCinp, iPh, iPw] -> ifH (iPh <. fromIntegral (nKh `div` 2) ||* iPw <. fromIntegral (nKw `div` 2) ||* iPh >=. fromIntegral (nAh + nKh `div` 2) ||* iPw >=. fromIntegral (nAw + nKw `div` 2)) (rscalar 0) (arrA ! [ iImg , iCinp , iPh - fromIntegral (nKh `div` 2) , iPw - fromIntegral (nKw `div` 2) ]) nCinp = assert (nCinpA == nCinpK `blame` (nCinpA, nCinpK)) nCinpA shB = [nImgs, nCoutK, nAh, nAw] shK1 = [1, nCinp, nKh, nKw] in rbuild shB $ \case [iImg, iCout, iBh, iBw] -> let arrAt = tlet arrAPadded $ \arrAPadded2 -> slicezL shK1 arrAPadded2 [iImg, 0, iBh, iBw] arrKt = slicezL shK1 arrK [iCout, 0, 0, 0] in rdot0 arrAt arrKt _ -> error "conv2dPaddedLet2: impossible pattern needlessly required" -- TODO: OOMs _testPaddedCNNOPP2 :: Assertion _testPaddedCNNOPP2 = do resetVarCounter let f :: AstTensor AstMethodLet FullSpan (TKProduct (TKR 4 Double) (TKR 4 Double)) -> AstTensor AstMethodLet FullSpan (TKR 4 Double) f v = conv2dPadded2 (tproject1 v) (tproject2 v) ftk = FTKProduct (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar) (FTKR (2 :$: 2 :$: 2 :$: 2 :$: ZSR) FTKScalar) (artifactRev, _) = revArtifactFromForwardPass UseIncomingCotangent (forwardPassByInterpretation f emptyEnv) ftk printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (str (sreplicate @2 (stranspose @[5,0,1,4,2,3] (sgather (sappend (sconcrete (sreplicate [1,4,2,2] 0.0)) (sappend (stranspose @[3,0,2,1] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,4,2,2] 0.0)))) (\\[i24, i25, i27, i28] -> [i25 + i28, i24 + i27])))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1))))))))))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> let w29 = str (sreplicate @2 (sgather (sappend (sconcrete (sreplicate [1,2,2,4] 0.0)) (sappend (stranspose @[3,1,2,0] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,2,2,4] 0.0)))) (\\[i23, i24, i25, i26, i27, i28] -> [i25 + i28, i23, i26, i24 + i27]))) in rfromS (ssum @8 (stranspose @[4,0,1,2,3] (sreshape @[2,2,2,2,8] (w29 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (sfromR (tproject1 u1))))))))))" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> tconvert (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [2,2,2,2] FTKScalar)) ConvSX))) (STKProduct (STKS [2,2,2,2] STKScalar) (STKS [2,2,2,2] STKScalar)) (let w31 = sreshape @[2,2,2,2,2,2,2] (stranspose @[1,2,3,4,0] (sreplicate @8 (sfromR dret))) in tpair (ssum @2 (ssum @2 (sdot1In (stranspose @[2,3,0,4,5,6,1] (sreplicate @2 (stranspose @[5,0,1,4,2,3] (sgather (sappend (sconcrete (sreplicate [1,4,2,2] 0.0)) (sappend (stranspose @[3,0,2,1] (sappend (sconcrete (sreplicate [1,2,2,2] 0.0)) (sappend (stranspose @[2,0,1] (sfromR (tproject2 u1))) (sconcrete (sreplicate [1,2,2,2] 0.0))))) (sconcrete (sreplicate [1,4,2,2] 0.0)))) (\\[i24, i25, i27, i28] -> [i25 + i28, i24 + i27]))))) (stranspose @[2,3,1,4,5,6,0] w31)))) (stranspose @[1,2,0] (sslice (SNat @1) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @1) (SNat @2) (sscatter (sdot1In (sreplicate @2 (sreplicate @2 (sreplicate @2 (stranspose @[1,2,3,0] (sfromR (tproject1 u1)))))) (stranspose @[0,2,3,4,5,6,1] w31)) (\\[i32, i33, i34, i35, i36, i37] -> [i34 + i37, i32, i35, i33 + i36])))))))" conv2dPadded2 :: forall target r. (ADReady target, GoodScalar r) => target (TKR 4 r) -> target (TKR 4 r) -> target (TKR 4 r) conv2dPadded2 arrK arrA = let [nImgs, nCinpA, nAh, nAw] = rshape arrA [nCoutK, nCinpK, nKh, nKw] = rshape arrK shAPadded = [nImgs, nCinpA, nAh + nKh, nAw + nKw] arrAPadded = rbuild @4 @0 @(TKScalar r) @target shAPadded $ \case [iImg, iCinp, iPh, iPw] -> arrA ! [ iImg , iCinp , iPh - fromIntegral (nKh `div` 2) , iPw - fromIntegral (nKw `div` 2) ] nCinp = assert (nCinpA == nCinpK `blame` (nCinpA, nCinpK)) nCinpA shB = [nImgs, nCoutK, nAh, nAw] shK1 = [1, nCinp, nKh, nKw] in rbuild shB $ \case [iImg, iCout, iBh, iBw] -> let arrAt = slicezL shK1 arrAPadded [iImg, 0, iBh, iBw] arrKt = slicezL shK1 arrK [iCout, 0, 0, 0] in rdot0 arrAt arrKt _ -> error "conv2dPadded2: impossible pattern needlessly required" -- * Non-laborious CNN PP tests -- Convolution differentiated wrt the kernel. testCNNOPP0cW :: Assertion testCNNOPP0cW = do resetVarCounter let ftk = FTKR (7 :$: 5 :$: 7 :$: 7 :$: ZSR) (FTKScalar @Double) varName = mkAstVarName ftk Nothing . intToAstVarId $ 100000000 var = AstVar varName ftk2 = FTKR (5 :$: 5 :$: 5 :$: 5 :$: ZSR) (FTKScalar @Double) f = simplifyInline . flip conv2dUnpadded var env = extendEnv varName (dDnotShared (AstRaw var) (DeltaZero ftk)) emptyEnv (artifactRev, _) = revArtifactFromForwardPass UseIncomingCotangent (forwardPassByInterpretation f env) ftk2 "\\u0 -> " ++ printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u0 -> \\u1 -> rfromS (ssum @125 (stranspose @[4,0,1,2,3] (sreshape @[7,5,7,7,125] (str (sreplicate @5 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR u0)) (\\[i81, i83] -> [i81 + i83]))) (\\[i41, i42] -> [i41 + i42])))) * sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR u1)))))))))" "\\u0 -> " ++ printArtifactPrimalPretty artifactRev @?= "\\u0 -> \\u1 -> let w43 = str (sreplicate @5 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR u0)) (\\[i39, i40] -> [i39 + i40]))) (\\[i41, i42] -> [i41 + i42])))) in rfromS (ssum @125 (stranspose @[4,0,1,2,3] (sreshape @[7,5,7,7,125] (w43 * sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR u1)))))))))" "\\u0 -> " ++ printArtifactPretty artifactRev @?= "\\u0 -> \\dret u1 -> let w43 = str (sreplicate @5 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR u0)) (\\[i39, i40] -> [i39 + i40]))) (\\[i41, i42] -> [i41 + i42])))) ; w45 = sreshape @[7,5,7,7,5,5,5] (stranspose @[1,2,3,4,0] (sreplicate @125 (sfromR dret))) in rfromS (ssum @7 (str (ssum @7 (str (ssum @7 (w43 * w45))))))" "\\u0 -> " ++ printArtifactPretty (simplifyArtifact artifactRev) @?= "\\u0 -> \\dret u1 -> rfromS (ssum @7 (ssum @7 (sdot1In (stranspose @[2,3,0,4,5,6,1] (sreplicate @5 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR u0)) (\\[i98, i100] -> [i98 + i100]))) (\\[i41, i42] -> [i41 + i42]))))) (stranspose @[2,3,1,4,5,6,0] (sreshape @[7,5,7,7,5,5,5] (stranspose @[1,2,3,4,0] (sreplicate @125 (sfromR dret))))))))" -- Convolution differentiated wrt the data. testCNNOPP0bW :: Assertion testCNNOPP0bW = do resetVarCounter let ftk = FTKR (5 :$: 5 :$: 5 :$: 5 :$: ZSR) (FTKScalar @Double) varName = mkAstVarName ftk Nothing . intToAstVarId $ 100000000 var = AstVar varName ftk2 = FTKR (7 :$: 5 :$: 7 :$: 7 :$: ZSR) (FTKScalar @Double) f = simplifyInline . conv2dUnpadded var env = extendEnv varName (dDnotShared (AstRaw var) (DeltaZero ftk)) emptyEnv (artifactRev, _) = revArtifactFromForwardPass UseIncomingCotangent (forwardPassByInterpretation f env) ftk2 "\\u0 -> " ++ printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u0 -> \\u1 -> rfromS (ssum @125 (stranspose @[4,0,1,2,3] (sreshape @[7,5,7,7,125] (str (sreplicate @5 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i64, i66] -> [i64 + i66]))) (\\[i41, i42] -> [i41 + i42])))) * sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR u0)))))))))" "\\u0 -> " ++ printArtifactPrimalPretty artifactRev @?= "\\u0 -> \\u1 -> let w43 = str (sreplicate @5 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i39, i40] -> [i39 + i40]))) (\\[i41, i42] -> [i41 + i42])))) in rfromS (ssum @125 (stranspose @[4,0,1,2,3] (sreshape @[7,5,7,7,125] (w43 * sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR u0)))))))))" "\\u0 -> " ++ printArtifactPretty artifactRev @?= "\\u0 -> \\dret u1 -> let w45 = sreshape @[7,5,7,7,5,5,5] (stranspose @[1,2,3,4,0] (sreplicate @125 (sfromR dret))) in rfromS (stranspose @[1,2,0] (sscatter (stranspose @[2,4,1,3,0] (sscatter (stranspose @[2,5,0,1,3,4] (ssum @5 (str (sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR u0))))) * w45)))) (\\[i46, i47] -> [i46 + i47]))) (\\[i48, i49] -> [i48 + i49])))" "\\u0 -> " ++ printArtifactPretty (simplifyArtifact artifactRev) @?= "\\u0 -> \\dret u1 -> rfromS (stranspose @[1,2,0] (sscatter (stranspose @[2,4,1,3,0] (sscatter (sdot1In (stranspose @[3,6,0,2,4,5,1] (sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR u0))))))) (stranspose @[3,6,0,2,4,5,1] (sreshape @[7,5,7,7,5,5,5] (stranspose @[1,2,3,4,0] (sreplicate @125 (sfromR dret)))))) (\\[i46, i47] -> [i46 + i47]))) (\\[i48, i49] -> [i48 + i49])))" testCNNOPP1bW :: Assertion testCNNOPP1bW = do resetVarCounter let f :: AstTensor AstMethodLet FullSpan (TKProduct (TKR 4 Double) (TKR 4 Double)) -> AstTensor AstMethodLet FullSpan (TKR 4 Double) f v = simplifyInline $ conv2dUnpadded (tproject1 v) (tproject2 v) ftk = FTKProduct (FTKR (7 :$: 7 :$: 7 :$: 7 :$: ZSR) FTKScalar) (FTKR (7 :$: 7 :$: 7 :$: 7 :$: ZSR) FTKScalar) (artifactRev, _) = revArtifactFromForwardPass UseIncomingCotangent (forwardPassByInterpretation f emptyEnv) ftk printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (ssum @343 (stranspose @[4,0,1,2,3] (sreshape @[7,7,7,7,343] (str (sreplicate @7 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i81, i83] -> [i81 + i83]))) (\\[i41, i42] -> [i41 + i42])))) * sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR (tproject1 u1))))))))))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> let w43 = str (sreplicate @7 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i39, i40] -> [i39 + i40]))) (\\[i41, i42] -> [i41 + i42])))) in rfromS (ssum @343 (stranspose @[4,0,1,2,3] (sreshape @[7,7,7,7,343] (w43 * sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR (tproject1 u1))))))))))" printArtifactPretty artifactRev @?= "\\dret u1 -> let w43 = str (sreplicate @7 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i39, i40] -> [i39 + i40]))) (\\[i41, i42] -> [i41 + i42])))) ; w45 = sreshape @[7,7,7,7,7,7,7] (stranspose @[1,2,3,4,0] (sreplicate @343 (sfromR dret))) in tpair (rfromS (ssum @7 (str (ssum @7 (str (ssum @7 (w43 * w45))))))) (rfromS (stranspose @[1,2,0] (sscatter (stranspose @[2,4,1,3,0] (sscatter (stranspose @[2,5,0,1,3,4] (ssum @7 (str (sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR (tproject1 u1)))))) * w45)))) (\\[i46, i47] -> [i46 + i47]))) (\\[i48, i49] -> [i48 + i49]))))" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> tconvert (ConvT2 (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [7,7,7,7] FTKScalar)) ConvSX)) (ConvCmp (ConvXR STKScalar) (ConvCmp (ConvXX' (FTKX [7,7,7,7] FTKScalar)) ConvSX))) (STKProduct (STKS [7,7,7,7] STKScalar) (STKS [7,7,7,7] STKScalar)) (let w45 = sreshape @[7,7,7,7,7,7,7] (stranspose @[1,2,3,4,0] (sreplicate @343 (sfromR dret))) in tpair (ssum @7 (ssum @7 (sdot1In (stranspose @[2,3,0,4,5,6,1] (sreplicate @7 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (stranspose @[2,0,1] (sfromR (tproject2 u1))) (\\[i98, i100] -> [i98 + i100]))) (\\[i41, i42] -> [i41 + i42]))))) (stranspose @[2,3,1,4,5,6,0] w45)))) (stranspose @[1,2,0] (sscatter (stranspose @[2,4,1,3,0] (sscatter (sdot1In (stranspose @[3,6,0,2,4,5,1] (sreplicate @7 (str (sreplicate @7 (str (sreplicate @7 (sfromR (tproject1 u1)))))))) (stranspose @[3,6,0,2,4,5,1] w45)) (\\[i46, i47] -> [i46 + i47]))) (\\[i48, i49] -> [i48 + i49]))))" testCNNOPP4bW :: Assertion testCNNOPP4bW = do resetVarCounter let !artifactRev = revArtifactAdapt UseIncomingCotangent (maxPool2dUnpadded 4 2) (FTKR [7, 7, 7, 7] (FTKScalar @Double)) !artSimp = simplifyArtifact artifactRev let ftk1 = FTKR (7 :$: 7 :$: 7 :$: 7 :$: ZSR) (FTKScalar @Double) ftkDt = FTKR (7 :$: 7 :$: 3 :$: 3 :$: ZSR) (FTKScalar @Double) env = extendEnv (artVarDtRev artSimp) (tconcrete ftkDt (treplTarget 7 ftkDt)) $ extendEnv (artVarDomainRev artSimp) (tconcrete ftk1 (treplTarget 42 ftk1)) emptyEnv interpretAstPrimal @Concrete env (artPrimalRev artifactRev) @?= interpretAstPrimal @Concrete env (artPrimalRev artSimp) interpretAstPrimal @Concrete env (artDerivativeRev artifactRev) @?= interpretAstPrimal @Concrete env (artDerivativeRev artSimp) printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (let w47 = sreshape @[7,7,3,3,16] (stranspose @[2,3,4,0,5,1] (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i137, i138] -> [2 * i137 + i138]))) (\\[i45, i46] -> [2 * i45 + i46]))) in sgather w47 (\\[i48, i49, i50, i51] -> [i48, i49, i50, i51, kfromS (smaxIndex (w47 !$ [i48, i49, i50, i51]))]))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> let w47 = sreshape @[7,7,3,3,16] (stranspose @[2,3,4,0,5,1] (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i43, i44] -> [2 * i43 + i44]))) (\\[i45, i46] -> [2 * i45 + i46]))) in rfromS (sgather w47 (\\[i48, i49, i50, i51] -> [i48, i49, i50, i51, kfromS (smaxIndex (w47 !$ [i48, i49, i50, i51]))]))" printArtifactPretty artifactRev @?= "\\dret u1 -> let w47 = sreshape @[7,7,3,3,16] (stranspose @[2,3,4,0,5,1] (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i43, i44] -> [2 * i43 + i44]))) (\\[i45, i46] -> [2 * i45 + i46]))) in rfromS (stranspose @[1,2,0] (sscatter (stranspose @[3,4,1,2,0] (sscatter (stranspose @[3,5,0,1,2,4] (sreshape @[7,7,3,3,4,4] (sscatter (sfromR dret) (\\[i53, i54, i55, i56] -> [i53, i54, i55, i56, kfromS (smaxIndex (w47 !$ [i53, i54, i55, i56]))])))) (\\[i57, i58] -> [2 * i57 + i58]))) (\\[i59, i60] -> [2 * i59 + i60])))" -- The remH comes from the indexing of reshape rule and it looks terrible, -- but w42 looks even worse, depending on available primitives, -- so the rule is probably fine. printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> rfromS (stranspose @[1,2,0] (sscatter (stranspose @[3,4,1,2,0] (sscatter (stranspose @[3,5,0,1,2,4] (sreshape @[7,7,3,3,4,4] (sscatter (sfromR dret) (\\[i53, i54, i55, i56] -> [i53, i54, i55, i56, kfromS (smaxIndex (sgather (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i159, i160] -> [2 * i159 + i160]))) (\\[i45, i46] -> [2 * i45 + i46])) (\\[i151] -> [remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i151) 16) 3, remH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i151) 4, remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i151) 1008) 7, remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i151) 144) 7, remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i151) 48) 3, remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i151) 4) 4])))])))) (\\[i57, i58] -> [2 * i57 + i58]))) (\\[i59, i60] -> [2 * i59 + i60])))" printAstPretty (simplifyInlineContractNoExpand $ artDerivativeRev artifactRev) @?= "rfromS (stranspose @[1,2,0] (sscatter (stranspose @[3,4,1,2,0] (sscatter (stranspose @[3,5,0,1,2,4] (sreshape @[7,7,3,3,4,4] (sscatter (sfromR u52) (\\[i53, i54, i55, i56] -> [i53, i54, i55, i56, kfromS (smaxIndex (sreshape @[7,7,3,3,16] (stranspose @[2,3,4,0,5,1] (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i43, i44] -> [2 * i43 + i44]))) (\\[i45, i46] -> [2 * i45 + i46]))) !$ [i53, i54, i55, i56]))])))) (\\[i57, i58] -> [2 * i57 + i58]))) (\\[i59, i60] -> [2 * i59 + i60])))" testCNNOPP4bD :: Assertion testCNNOPP4bD = do resetVarCounter setTotalSharing True let !artifactRev = revArtifactAdapt UseIncomingCotangent (maxPool2dUnpadded 4 2) (FTKR [7, 7, 7, 7] (FTKScalar @Double)) !artSimp = simplifyArtifact artifactRev setTotalSharing False let ftk1 = FTKR (7 :$: 7 :$: 7 :$: 7 :$: ZSR) (FTKScalar @Double) ftkDt = FTKR (7 :$: 7 :$: 3 :$: 3 :$: ZSR) (FTKScalar @Double) env = extendEnv (artVarDtRev artSimp) (tconcrete ftkDt (treplTarget 7 ftkDt)) $ extendEnv (artVarDomainRev artSimp) (tconcrete ftk1 (treplTarget 42 ftk1)) emptyEnv interpretAstPrimal @Concrete env (artPrimalRev artifactRev) @?= interpretAstPrimal @Concrete env (artPrimalRev artSimp) interpretAstPrimal @Concrete env (artDerivativeRev artifactRev) @?= interpretAstPrimal @Concrete env (artDerivativeRev artSimp) printArtifactPrimalPretty artSimp @?= "\\u1 -> rfromS (let w47 = sreshape @[7,7,3,3,16] (stranspose @[2,3,4,0,5,1] (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i98, i99] -> [2 * i98 + i99]))) (\\[i45, i46] -> [2 * i45 + i46]))) in sgather w47 (\\[i48, i49, i50, i51] -> [i48, i49, i50, i51, kfromS (smaxIndex (w47 !$ [i48, i49, i50, i51]))]))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> let w47 = sreshape @[7,7,3,3,16] (stranspose @[2,3,4,0,5,1] (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i43, i44] -> [2 * i43 + i44]))) (\\[i45, i46] -> [2 * i45 + i46]))) in rfromS (sgather w47 (\\[i48, i49, i50, i51] -> [i48, i49, i50, i51, kfromS (smaxIndex (w47 !$ [i48, i49, i50, i51]))]))" printArtifactPretty artifactRev @?= "\\dret u1 -> let w47 = sreshape @[7,7,3,3,16] (stranspose @[2,3,4,0,5,1] (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i43, i44] -> [2 * i43 + i44]))) (\\[i45, i46] -> [2 * i45 + i46]))) in rfromS (stranspose @[1,2,0] (sscatter (stranspose @[3,4,1,2,0] (sscatter (stranspose @[3,5,0,1,2,4] (sreshape @[7,7,3,3,4,4] (sscatter (sfromR dret) (\\[i53, i54, i55, i56] -> [i53, i54, i55, i56, kfromS (smaxIndex (w47 !$ [i53, i54, i55, i56]))])))) (\\[i57, i58] -> [2 * i57 + i58]))) (\\[i59, i60] -> [2 * i59 + i60])))" printArtifactPretty artSimp @?= "\\dret u1 -> rfromS (stranspose @[1,2,0] (sscatter (stranspose @[3,4,1,2,0] (sscatter (stranspose @[3,5,0,1,2,4] (sreshape @[7,7,3,3,4,4] (sscatter (sfromR dret) (\\[i53, i54, i55, i56] -> [i53, i54, i55, i56, kfromS (smaxIndex (sgather (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i81, i82] -> [2 * i81 + i82]))) (\\[i45, i46] -> [2 * i45 + i46])) (\\[i73] -> [remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i73) 16) 3, remH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i73) 4, remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i73) 1008) 7, remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i73) 144) 7, remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i73) 48) 3, remH (quotH ((((1008 * i53 + 144 * i54) + 48 * i55) + 16 * i56) + i73) 4) 4])))])))) (\\[i57, i58] -> [2 * i57 + i58]))) (\\[i59, i60] -> [2 * i59 + i60])))" printAstPretty (simplifyInlineContractNoExpand $ artDerivativeRev artifactRev) @?= "rfromS (stranspose @[1,2,0] (sscatter (stranspose @[3,4,1,2,0] (sscatter (stranspose @[3,5,0,1,2,4] (sreshape @[7,7,3,3,4,4] (sscatter (sfromR u52) (\\[i53, i54, i55, i56] -> [i53, i54, i55, i56, kfromS (smaxIndex (sreshape @[7,7,3,3,16] (stranspose @[2,3,4,0,5,1] (sgather (stranspose @[4,2,3,0,1] (sgather (stranspose @[2,0,1] (sfromR u1)) (\\[i43, i44] -> [2 * i43 + i44]))) (\\[i45, i46] -> [2 * i45 + i46]))) !$ [i53, i54, i55, i56]))])))) (\\[i57, i58] -> [2 * i57 + i58]))) (\\[i59, i60] -> [2 * i59 + i60])))" testCNNOPP5aW :: Assertion testCNNOPP5aW = do resetVarCounter let artifactRev = revArtifactAdapt UseIncomingCotangent (maxPool2dUnpadded 4 2 . conv2dC) (FTKR [7, 2, 7, 7] (FTKScalar @Double)) printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (let t49 = sreshape @[2,7,16] (ssum @98 (stranspose @[4,1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] (sreshape @[2,7,2,2,98] (str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i109, i111] -> [i109 + i111]))) (\\[i46, i47] -> [i46 + i47])))))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) (sconcrete (sreplicate [2,2,7,2,98] 0.0)))) (sconcrete (sreplicate [2,2,7,4,98] 0.0))))) in stranspose @[1,2,0] (sreplicate @1 (stranspose @[1,2,0] (sreplicate @1 (sgather t49 (\\[i50, i51] -> [i50, i51, kfromS (smaxIndex (t49 !$ [i50, i51]))]))))))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> let w48 = str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i44, i45] -> [i44 + i45]))) (\\[i46, i47] -> [i46 + i47])))))) ; t49 = sreshape @[2,7,16] (ssum @98 (stranspose @[4,1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] (sreshape @[2,7,2,2,98] (w48 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) (sconcrete (sreplicate [2,2,7,2,98] 0.0)))) (sconcrete (sreplicate [2,2,7,4,98] 0.0))))) in rfromS (stranspose @[1,2,0] (sreplicate @1 (stranspose @[1,2,0] (sreplicate @1 (sgather t49 (\\[i50, i51] -> [i50, i51, kfromS (smaxIndex (t49 !$ [i50, i51]))]))))))" printArtifactPretty artifactRev @?= "\\dret u1 -> let w48 = str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i44, i45] -> [i44 + i45]))) (\\[i46, i47] -> [i46 + i47])))))) ; t49 = sreshape @[2,7,16] (ssum @98 (stranspose @[4,1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] (sreshape @[2,7,2,2,98] (w48 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) (sconcrete (sreplicate [2,2,7,2,98] 0.0)))) (sconcrete (sreplicate [2,2,7,4,98] 0.0))))) ; w55 = stranspose @[4,1,2,3,0] (sreplicate @98 (sreshape @[2,7,4,4] (sscatter (ssum @1 (stranspose @[2,0,1] (ssum @1 (stranspose @[2,0,1] (sfromR dret))))) (\\[i53, i54] -> [i53, i54, kfromS (smaxIndex (t49 !$ [i53, i54]))])))) in rfromS (ssum @1 (str (ssum @2 (str (ssum @2 (str (ssum @2 (w48 * sreshape @[2,7,2,2,1,2,7,7] (stranspose @[1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @0) (SNat @2) w55))))))))))))" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> rfromS (let w48 = sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i155, i157] -> [i155 + i157]))) (\\[i46, i47] -> [i46 + i47]) in ssum @2 (ssum @2 (sdot1In (stranspose @[1,6,0,4,5,3,2] (sreplicate @7 (stranspose @[3,2,1,4,5,0] w48))) (stranspose @[4,2,3,1,5,6,7,0] (sreshape @[2,7,2,2,1,2,7,7] (stranspose @[1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[4,1,2,3,0] (sreplicate @98 (sreshape @[2,7,4,4] (sscatter (stranspose @[2,3,0,1] (sfromR dret) !$ [0, 0]) (\\[i53, i54] -> [i53, i54, kfromS (smaxIndex (ssum @98 (str (sgather (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] (sreshape @[2,7,2,2,98] (str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] w48)))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) (sconcrete (sreplicate [2,2,7,2,98] 0.0)))) (sconcrete (sreplicate [2,2,7,4,98] 0.0))) (\\[i129] -> [remH ((112 * i53 + 16 * i54) + i129) 4, remH (quotH ((112 * i53 + 16 * i54) + i129) 112) 2, remH (quotH ((112 * i53 + 16 * i54) + i129) 16) 7, remH (quotH ((112 * i53 + 16 * i54) + i129) 4) 4])))))])))))))))) !$ [0]))))" printAstPretty (simplifyInlineContractNoExpand $ artDerivativeRev artifactRev) @?= "rfromS (let w48 = sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i44, i45] -> [i44 + i45]))) (\\[i46, i47] -> [i46 + i47]) in ssum @2 (ssum @2 (sdot1In (stranspose @[4,2,3,0,5,6,7,1] (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] w48)))) !$ [0]) (stranspose @[4,2,3,1,5,6,7,0] (sreshape @[2,7,2,2,1,2,7,7] (stranspose @[1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[4,1,2,3,0] (sreplicate @98 (sreshape @[2,7,4,4] (sscatter (stranspose @[2,3,0,1] (sfromR u52) !$ [0, 0]) (\\[i53, i54] -> [i53, i54, kfromS (smaxIndex (sreshape @[2,7,16] (ssum @98 (stranspose @[4,1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] (sreshape @[2,7,2,2,98] (str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] w48)))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) (sconcrete (sreplicate [2,2,7,2,98] 0.0)))) (sconcrete (sreplicate [2,2,7,4,98] 0.0))))) !$ [i53, i54]))])))))))))) !$ [0]))))" testCNNOPP5bW :: Assertion testCNNOPP5bW = do resetVarCounter let artifactRev = revArtifactAdapt UseIncomingCotangent (maxPool2dUnpadded 4 2 . relu) (FTKR [7, 2, 7, 7] (FTKScalar @Double)) printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (let w75 = sreshape @[7,2,3,3,16] (stranspose @[4,5,0,1,2,3] (sgather (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i61, i62, i63, i64] -> [ifH (sscalar -0.0 <=. negate (sfromR u1 !$ [i63, i64, i61, i62])) 0 1])) (\\[i65, i66, i67, i68] -> [kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i65, i67]), kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i66, i68])])) * stranspose @[4,5,0,1,2,3] (sgather (stranspose @[2,3,0,1] (sfromR u1)) (\\[i69, i70, i71, i72] -> [kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i69, i71]), kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i70, i72])]))) in sgather w75 (\\[i76, i77, i78, i79] -> [i76, i77, i78, i79, kfromS (smaxIndex (w75 !$ [i76, i77, i78, i79]))]))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> let m59 = str (sreplicate @4 (sconcrete (sreplicate [3] 2) * siota (SNat @3))) + sreplicate @3 (siota (SNat @4)) ; m60 = str (sreplicate @4 (sconcrete (sreplicate [3] 2) * siota (SNat @3))) + sreplicate @3 (siota (SNat @4)) ; w73 = stranspose @[4,5,0,1,2,3] (sgather (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i61, i62, i63, i64] -> [ifH (sscalar -0.0 <=. negate (sfromR u1 !$ [i63, i64, i61, i62])) 0 1])) (\\[i65, i66, i67, i68] -> [kfromS (m59 !$ [i65, i67]), kfromS (m60 !$ [i66, i68])])) ; w74 = stranspose @[4,5,0,1,2,3] (sgather (stranspose @[2,3,0,1] (sfromR u1)) (\\[i69, i70, i71, i72] -> [kfromS (m59 !$ [i69, i71]), kfromS (m60 !$ [i70, i72])])) ; w75 = sreshape @[7,2,3,3,16] (w73 * w74) in rfromS (sgather w75 (\\[i76, i77, i78, i79] -> [i76, i77, i78, i79, kfromS (smaxIndex (w75 !$ [i76, i77, i78, i79]))]))" printArtifactPretty artifactRev @?= "\\dret u1 -> let m59 = str (sreplicate @4 (sconcrete (sreplicate [3] 2) * siota (SNat @3))) + sreplicate @3 (siota (SNat @4)) ; m60 = str (sreplicate @4 (sconcrete (sreplicate [3] 2) * siota (SNat @3))) + sreplicate @3 (siota (SNat @4)) ; w73 = stranspose @[4,5,0,1,2,3] (sgather (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i61, i62, i63, i64] -> [ifH (sscalar -0.0 <=. negate (sfromR u1 !$ [i63, i64, i61, i62])) 0 1])) (\\[i65, i66, i67, i68] -> [kfromS (m59 !$ [i65, i67]), kfromS (m60 !$ [i66, i68])])) ; w74 = stranspose @[4,5,0,1,2,3] (sgather (stranspose @[2,3,0,1] (sfromR u1)) (\\[i69, i70, i71, i72] -> [kfromS (m59 !$ [i69, i71]), kfromS (m60 !$ [i70, i72])])) ; w75 = sreshape @[7,2,3,3,16] (w73 * w74) in rfromS (stranspose @[2,3,0,1] (sscatter (stranspose @[2,3,4,5,0,1] (w73 * sreshape @[7,2,3,3,4,4] (sscatter (sfromR dret) (\\[i81, i82, i83, i84] -> [i81, i82, i83, i84, kfromS (smaxIndex (w75 !$ [i81, i82, i83, i84]))])))) (\\[i85, i86, i87, i88] -> [kfromS (m59 !$ [i85, i87]), kfromS (m60 !$ [i86, i88])])))" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> rfromS (let w73 = sgather (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i61, i62, i63, i64] -> [ifH (sscalar -0.0 <=. negate (sfromR u1 !$ [i63, i64, i61, i62])) 0 1])) (\\[i65, i66, i67, i68] -> [kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i65, i67]), kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i66, i68])]) in stranspose @[2,3,0,1] (sscatter (w73 * stranspose @[2,3,4,5,0,1] (sreshape @[7,2,3,3,4,4] (sscatter (sfromR dret) (\\[i81, i82, i83, i84] -> [i81, i82, i83, i84, kfromS (smaxIndex (sgather (stranspose @[4,5,0,1,2,3] w73 * stranspose @[4,5,0,1,2,3] (sgather (stranspose @[2,3,0,1] (sfromR u1)) (\\[i69, i70, i71, i72] -> [kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i69, i71]), kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i70, i72])]))) (\\[i118] -> [remH (quotH ((((288 * i81 + 144 * i82) + 48 * i83) + 16 * i84) + i118) 288) 7, remH (quotH ((((288 * i81 + 144 * i82) + 48 * i83) + 16 * i84) + i118) 144) 2, remH (quotH ((((288 * i81 + 144 * i82) + 48 * i83) + 16 * i84) + i118) 48) 3, remH (quotH ((((288 * i81 + 144 * i82) + 48 * i83) + 16 * i84) + i118) 16) 3, remH (quotH ((((288 * i81 + 144 * i82) + 48 * i83) + 16 * i84) + i118) 4) 4, remH ((((288 * i81 + 144 * i82) + 48 * i83) + 16 * i84) + i118) 4])))])))) (\\[i85, i86, i87, i88] -> [kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i85, i87]), kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i86, i88])])))" printAstPretty (simplifyInlineContractNoExpand $ artDerivativeRev artifactRev) @?= "rfromS (let w73 = sgather (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i61, i62, i63, i64] -> [ifH (sscalar -0.0 <=. negate (sfromR u1 !$ [i63, i64, i61, i62])) 0 1])) (\\[i65, i66, i67, i68] -> [kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i65, i67]), kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i66, i68])]) in stranspose @[2,3,0,1] (sscatter (w73 * stranspose @[2,3,4,5,0,1] (sreshape @[7,2,3,3,4,4] (sscatter (sfromR u80) (\\[i81, i82, i83, i84] -> [i81, i82, i83, i84, kfromS (smaxIndex (sreshape @[7,2,3,3,16] (stranspose @[4,5,0,1,2,3] w73 * stranspose @[4,5,0,1,2,3] (sgather (stranspose @[2,3,0,1] (sfromR u1)) (\\[i69, i70, i71, i72] -> [kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i69, i71]), kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i70, i72])]))) !$ [i81, i82, i83, i84]))])))) (\\[i85, i86, i87, i88] -> [kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i85, i87]), kfromS (sconcrete (sfromListLinear [3,4] [0,1,2,3,2,3,4,5,4,5,6,7]) !$ [i86, i88])])))" testCNNOPP5cW :: Assertion testCNNOPP5cW = do resetVarCounter let artifactRev = revArtifactAdapt UseIncomingCotangent (relu . conv2dC) (FTKR [7, 2, 7, 7] (FTKScalar @Double)) printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (let u45 = ssum @98 (stranspose @[4,0,1,2,3] (sreshape @[2,7,2,2,98] (str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i114, i116] -> [i114 + i116]))) (\\[i42, i43] -> [i42 + i43])))))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) in sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i46, i47, i48, i49] -> [ifH (sscalar -0.0 <=. negate (u45 !$ [i46, i47, i48, i49])) 0 1]) * u45)" printArtifactPrimalPretty artifactRev @?= "\\u1 -> let w44 = str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i40, i41] -> [i40 + i41]))) (\\[i42, i43] -> [i42 + i43])))))) ; u45 = ssum @98 (stranspose @[4,0,1,2,3] (sreshape @[2,7,2,2,98] (w44 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) ; u50 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i46, i47, i48, i49] -> [ifH (sscalar -0.0 <=. negate (u45 !$ [i46, i47, i48, i49])) 0 1]) in rfromS (u50 * u45)" printArtifactPretty artifactRev @?= "\\dret u1 -> let w44 = str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i40, i41] -> [i40 + i41]))) (\\[i42, i43] -> [i42 + i43])))))) ; u45 = ssum @98 (stranspose @[4,0,1,2,3] (sreshape @[2,7,2,2,98] (w44 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) ; u50 = sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i46, i47, i48, i49] -> [ifH (sscalar -0.0 <=. negate (u45 !$ [i46, i47, i48, i49])) 0 1]) in rfromS (ssum @1 (str (ssum @2 (str (ssum @2 (str (ssum @2 (w44 * sreshape @[2,7,2,2,1,2,7,7] (stranspose @[1,2,3,4,0] (sreplicate @98 (u50 * sfromR dret)))))))))))" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> rfromS (let w44 = sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i169, i171] -> [i169 + i171]))) (\\[i42, i43] -> [i42 + i43]) in ssum @2 (ssum @2 (sdot1In (stranspose @[1,6,0,4,5,3,2] (sreplicate @7 (stranspose @[3,2,1,4,5,0] w44))) (stranspose @[4,2,3,1,5,6,7,0] (sreshape @[2,7,2,2,1,2,7,7] (stranspose @[1,2,3,4,0] (sreplicate @98 (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i46, i47, i48, i49] -> [ifH (sscalar -0.0 <=. negate (ssum0 (sgather (stranspose @[6,1,0,2,4,5,3] (sreplicate @7 (stranspose @[3,2,1,4,5,0] w44)) * sreplicate @2 (sreplicate @2 (str (sreplicate @2 (sfromR u1))))) (\\[i132] -> [remH (quotH ((((2744 * i46 + 392 * i47) + 196 * i48) + 98 * i49) + i132) 98) 2, remH (quotH ((((2744 * i46 + 392 * i47) + 196 * i48) + 98 * i49) + i132) 196) 2, remH (quotH ((((2744 * i46 + 392 * i47) + 196 * i48) + 98 * i49) + i132) 392) 7, remH (quotH ((((2744 * i46 + 392 * i47) + 196 * i48) + 98 * i49) + i132) 2744) 2, remH (quotH ((((2744 * i46 + 392 * i47) + 196 * i48) + 98 * i49) + i132) 49) 2, remH (quotH ((((2744 * i46 + 392 * i47) + 196 * i48) + 98 * i49) + i132) 7) 7, remH ((((2744 * i46 + 392 * i47) + 196 * i48) + 98 * i49) + i132) 7])))) 0 1]) * sfromR dret)))) !$ [0]))))" printAstPretty (simplifyInlineContractNoExpand $ artDerivativeRev artifactRev) @?= "rfromS (let w44 = sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i40, i41] -> [i40 + i41]))) (\\[i42, i43] -> [i42 + i43]) in ssum @2 (ssum @2 (sdot1In (stranspose @[4,2,3,0,5,6,7,1] (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] w44)))) !$ [0]) (stranspose @[4,2,3,1,5,6,7,0] (sreshape @[2,7,2,2,1,2,7,7] (stranspose @[1,2,3,4,0] (sreplicate @98 (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i46, i47, i48, i49] -> [ifH (sscalar -0.0 <=. negate (ssum0 (sreshape @[2,7,2,2,98] (str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] w44)))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))) !$ [i46, i47, i48, i49]))) 0 1]) * sfromR u51)))) !$ [0]))))" testCNNOPP5dW :: Assertion testCNNOPP5dW = do resetVarCounter let artifactRev = revArtifactAdapt UseIncomingCotangent (maxPool2dUnpadded 4 2 . relu . conv2dC) (FTKR [7, 2, 7, 7] (FTKScalar @Double)) printArtifactPrimalPretty (simplifyArtifact artifactRev) @?= "\\u1 -> rfromS (let u67 = ssum @98 (stranspose @[4,0,1,2,3] (sreshape @[2,7,2,2,98] (str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i130, i132] -> [i130 + i132]))) (\\[i64, i65] -> [i64 + i65])))))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) ; t74 = sreshape @[2,7,16] (stranspose @[1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i68, i69, i70, i71] -> [ifH (sscalar -0.0 <=. negate (u67 !$ [i69, i70, i68, i71])) 0 1])) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0))) * stranspose @[1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] u67) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0)))) in stranspose @[1,2,0] (sreplicate @1 (stranspose @[1,2,0] (sreplicate @1 (sgather t74 (\\[i75, i76] -> [i75, i76, kfromS (smaxIndex (t74 !$ [i75, i76]))]))))))" printArtifactPrimalPretty artifactRev @?= "\\u1 -> let w66 = str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i62, i63] -> [i62 + i63]))) (\\[i64, i65] -> [i64 + i65])))))) ; u67 = ssum @98 (stranspose @[4,0,1,2,3] (sreshape @[2,7,2,2,98] (w66 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) ; u72 = stranspose @[1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i68, i69, i70, i71] -> [ifH (sscalar -0.0 <=. negate (u67 !$ [i69, i70, i68, i71])) 0 1])) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0))) ; u73 = stranspose @[1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] u67) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0))) ; t74 = sreshape @[2,7,16] (u72 * u73) in rfromS (stranspose @[1,2,0] (sreplicate @1 (stranspose @[1,2,0] (sreplicate @1 (sgather t74 (\\[i75, i76] -> [i75, i76, kfromS (smaxIndex (t74 !$ [i75, i76]))]))))))" printArtifactPretty artifactRev @?= "\\dret u1 -> let w66 = str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] (sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i62, i63] -> [i62 + i63]))) (\\[i64, i65] -> [i64 + i65])))))) ; u67 = ssum @98 (stranspose @[4,0,1,2,3] (sreshape @[2,7,2,2,98] (w66 * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) ; u72 = stranspose @[1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i68, i69, i70, i71] -> [ifH (sscalar -0.0 <=. negate (u67 !$ [i69, i70, i68, i71])) 0 1])) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0))) ; u73 = stranspose @[1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] u67) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0))) ; t74 = sreshape @[2,7,16] (u72 * u73) ; u80 = stranspose @[3,0,1,2] (u72 * sreshape @[2,7,4,4] (sscatter (ssum @1 (stranspose @[2,0,1] (ssum @1 (stranspose @[2,0,1] (sfromR dret))))) (\\[i78, i79] -> [i78, i79, kfromS (smaxIndex (t74 !$ [i78, i79]))]))) in rfromS (ssum @1 (str (ssum @2 (str (ssum @2 (str (ssum @2 (w66 * sreshape @[2,7,2,2,1,2,7,7] (stranspose @[1,2,3,4,0] (sreplicate @98 (stranspose @[1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @0) (SNat @2) u80))))))))))))))" printArtifactPretty (simplifyArtifact artifactRev) @?= "\\dret u1 -> rfromS (let w66 = sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i172, i174] -> [i172 + i174]))) (\\[i64, i65] -> [i64 + i65]) ; u67 = ssum @98 (stranspose @[4,0,1,2,3] (sreshape @[2,7,2,2,98] (str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] w66)))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) ; u72 = sappend (stranspose @[3,1,2,0] (sappend (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i68, i69, i70, i71] -> [ifH (sscalar -0.0 <=. negate (u67 !$ [i69, i70, i68, i71])) 0 1])) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0)) in ssum @2 (ssum @2 (sdot1In (stranspose @[1,6,0,4,5,3,2] (sreplicate @7 (stranspose @[3,2,1,4,5,0] w66))) (stranspose @[4,2,3,1,5,6,7,0] (sreshape @[2,7,2,2,1,2,7,7] (stranspose @[1,2,3,4,0] (sreplicate @98 (stranspose @[1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @0) (SNat @2) u72))) * stranspose @[1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,0,1,2] (sreshape @[2,7,4,4] (sscatter (stranspose @[2,3,0,1] (sfromR dret) !$ [0, 0]) (\\[i78, i79] -> [i78, i79, kfromS (smaxIndex (sgather (stranspose @[1,2,3,0] u72 * stranspose @[1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] u67) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0)))) (\\[i146] -> [remH (quotH ((112 * i78 + 16 * i79) + i146) 112) 2, remH (quotH ((112 * i78 + 16 * i79) + i146) 16) 7, remH (quotH ((112 * i78 + 16 * i79) + i146) 4) 4, remH ((112 * i78 + 16 * i79) + i146) 4])))]))))))))))) !$ [0]))))" printAstPretty (simplifyInlineContractNoExpand $ artDerivativeRev artifactRev) @?= "rfromS (let w66 = sgather (stranspose @[4,2,0,3,1] (sgather (sconcrete (sfromListLinear [2,2,2,2] [5.0,2.0,-2.0,0.0,13.1,9.0,582934.0,2.99432,6.0,1.0,0.1,-0.2,8.0,-4.0,-335.0,26.0])) (\\[i62, i63] -> [i62 + i63]))) (\\[i64, i65] -> [i64 + i65]) ; u67 = ssum @98 (stranspose @[4,0,1,2,3] (sreshape @[2,7,2,2,98] (str (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] w66)))) * sreplicate @2 (str (sreplicate @2 (str (sreplicate @2 (str (sreplicate @1 (sfromR u1)))))))))) ; u72 = sappend (stranspose @[3,1,2,0] (sappend (sgather (sconcrete (sfromListLinear [2] [0.0,1.0])) (\\[i68, i69, i70, i71] -> [ifH (sscalar -0.0 <=. negate (u67 !$ [i69, i70, i68, i71])) 0 1])) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0)) in ssum @2 (ssum @2 (sdot1In (stranspose @[4,2,3,0,5,6,7,1] (sreplicate @7 (stranspose @[1,2,3,0] (sreplicate @1 (stranspose @[2,3,0,4,5,1] w66)))) !$ [0]) (stranspose @[4,2,3,1,5,6,7,0] (sreshape @[2,7,2,2,1,2,7,7] (stranspose @[1,2,3,4,0] (sreplicate @98 (stranspose @[1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @0) (SNat @2) u72))) * stranspose @[1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,1,2,0] (sslice (SNat @0) (SNat @2) (stranspose @[3,0,1,2] (sreshape @[2,7,4,4] (sscatter (stranspose @[2,3,0,1] (sfromR u77) !$ [0, 0]) (\\[i78, i79] -> [i78, i79, kfromS (smaxIndex (sreshape @[2,7,16] (stranspose @[1,2,3,0] u72 * stranspose @[1,2,3,0] (sappend (stranspose @[3,1,2,0] (sappend (stranspose @[2,0,1] u67) (sconcrete (sreplicate [2,2,7,2] 0.0)))) (sconcrete (sreplicate [2,2,7,4] 0.0)))) !$ [i78, i79]))]))))))))))) !$ [0]))))"