💣 Machine learning which might blow up in your face 💣
1
fork

Configure Feed

Select the types of activity you want to include in your feed.

Remove primes on shape instantiations

Add singletons for Shape and remove hacks on recurrent nets

Add Recurrent Nets

+1782 -300
+1 -1
README.md
··· 52 52 To perform back propagation, one can call the eponymous function 53 53 ```haskell 54 54 backPropagate :: forall input target shapes layers. (Head shapes ~ input, Last shapes ~ target) 55 - => Network layers shapes -> S' input -> S' target -> Gradients layers 55 + => Network layers shapes -> S input -> S target -> Gradients layers 56 56 ``` 57 57 which takes a network, appropriate input and target data, and returns the 58 58 back propagated gradients for the network. The shapes of the gradients are
+13
cbits/gradient_decent.c
··· 1 + #include "gradient_decent.h" 2 + 3 + void decend_cpu(int len, double rate, double momentum, double regulariser, 4 + const double* weights, 5 + const double* gradient, 6 + const double* last, 7 + double* outputWeights, double* outputMomentum) { 8 + 9 + for (int i = 0; i <= len; i++) { 10 + outputMomentum[i] = momentum * last[i] - rate * gradient[i]; 11 + outputWeights[i] = weights[i] + outputMomentum[i] - (rate * regulariser) * weights[i]; 12 + } 13 + }
+9
cbits/gradient_decent.h
··· 1 + #include <stdio.h> 2 + #include <stdint.h> 3 + 4 + void decend_cpu(int len, double rate, double momentum, double regulariser, 5 + const double* weights, 6 + const double* gradient, 7 + const double* last, 8 + double* outputWeights, double* outputMomentum); 9 +
+4 -11
cbits/im2col.c
··· 1 1 #include "im2col.h" 2 2 3 - void im2col_cpu(const double* data_im, int dataOffset, const int channels, 3 + void im2col_cpu(const double* data_im, const int channels, 4 4 const int height, const int width, const int kernel_h, const int kernel_w, 5 5 const int stride_h, const int stride_w, 6 6 double* data_col) { 7 7 8 - data_im += dataOffset; 9 8 const int channel_size = height * width; 10 9 11 10 for (int fitting_height = 0; fitting_height <= (height - kernel_h); fitting_height += stride_h) { ··· 23 22 } 24 23 } 25 24 26 - void col2im_cpu(const double* data_col, int dataOffset, const int channels, 25 + void col2im_cpu(const double* data_col, const int channels, 27 26 const int height, const int width, const int kernel_h, const int kernel_w, 28 27 const int stride_h, const int stride_w, 29 28 double* data_im) { 30 29 31 30 memset(data_im, 0, height * width * channels * sizeof(double)); 32 - data_col += dataOffset; 33 31 34 32 const int channel_size = height * width; 35 33 ··· 50 48 51 49 inline double max ( double a, double b ) { return a > b ? a : b; } 52 50 53 - void pool_forwards_cpu(const double* data_im, int dataOffset, const int channels, 51 + void pool_forwards_cpu(const double* data_im, const int channels, 54 52 const int height, const int width, const int kernel_h, const int kernel_w, 55 53 const int stride_h, const int stride_w, 56 54 double* data_pooled) { 57 55 58 - data_im += dataOffset; 59 - 60 56 const int channel_size = height * width; 61 57 62 58 for (int channel = 0; channel < channels; channel++) { ··· 89 85 } 90 86 } 91 87 92 - void pool_backwards_cpu(const double* data_im, int data_im_offset, 93 - const double* data_pooled, int data_pooled_offset, 88 + void pool_backwards_cpu(const double* data_im, const double* data_pooled, 94 89 const int channels, const int height, const int width, const int kernel_h, 95 90 const int kernel_w, const int stride_h, const int stride_w, 96 91 double* data_backgrad ) { 97 92 98 - data_im += data_im_offset; 99 - data_pooled += data_pooled_offset; 100 93 memset(data_backgrad, 0, height * width * channels * sizeof(double)); 101 94 102 95 const int channel_size = height * width;
+4 -5
cbits/im2col.h
··· 2 2 #include <stdint.h> 3 3 #include <string.h> 4 4 5 - void im2col_cpu(const double* data_im, int dataOffset, const int channels, 5 + void im2col_cpu(const double* data_im, const int channels, 6 6 const int height, const int width, const int kernel_h, const int kernel_w, 7 7 const int stride_h, const int stride_w, 8 8 double* data_col); 9 9 10 - void col2im_cpu(const double* data_col, int dataOffset, const int channels, 10 + void col2im_cpu(const double* data_col, const int channels, 11 11 const int height, const int width, const int kernel_h, const int kernel_w, 12 12 const int stride_h, const int stride_w, 13 13 double* data_im); 14 14 15 - void pool_forwards_cpu(const double* data_im, int dataOffset, const int channels, 15 + void pool_forwards_cpu(const double* data_im, const int channels, 16 16 const int height, const int width, const int kernel_h, const int kernel_w, 17 17 const int stride_h, const int stride_w, 18 18 double* data_pooled); 19 19 20 - void pool_backwards_cpu(const double* data_im, int data_im_offset, 21 - const double* data_pooled, int data_pooled_offset, 20 + void pool_backwards_cpu(const double* data_im, const double* data_pooled, 22 21 const int channels, const int height, const int width, const int kernel_h, 23 22 const int kernel_w, const int stride_h, const int stride_w, 24 23 double* data_backgrad );
+71
grenade.cabal
··· 19 19 base >= 4.8 && < 5 20 20 , bytestring == 0.10.* 21 21 , async 22 + , containers 23 + , deepseq 22 24 , either == 4.4.* 23 25 , exceptions == 0.8.* 24 26 , hmatrix ··· 26 28 , mtl >= 2.2.1 && < 2.3 27 29 , parallel == 3.2.* 28 30 , primitive 31 + , reflection 29 32 , text == 1.2.* 30 33 , transformers 31 34 , singletons 35 + , vector 32 36 33 37 ghc-options: 34 38 -Wall ··· 55 59 56 60 Grenade.Layers.Internal.Convolution 57 61 Grenade.Layers.Internal.Pooling 62 + Grenade.Layers.Internal.Update 63 + 64 + Grenade.Recurrent 65 + 66 + Grenade.Recurrent.Core.Network 67 + Grenade.Recurrent.Core.Runner 68 + 69 + Grenade.Recurrent.Layers.BasicRecurrent 70 + Grenade.Recurrent.Layers.LSTM 71 + Grenade.Recurrent.Layers.Trivial 72 + 73 + Grenade.Utils.OneHot 58 74 59 75 includes: cbits/im2col.h 76 + cbits/gradient_decent.h 60 77 c-sources: cbits/im2col.c 78 + cbits/gradient_decent.c 61 79 62 80 cc-options: -std=c99 -O3 -msse4.2 -Wall -Werror -DCABAL=1 63 81 ··· 90 108 , transformers 91 109 , singletons 92 110 , MonadRandom 111 + , vector 112 + 113 + executable recurrent 114 + ghc-options: -Wall -threaded -O2 115 + main-is: main/recurrent.hs 116 + build-depends: base 117 + , grenade 118 + , attoparsec 119 + , either 120 + , optparse-applicative == 0.12.* 121 + , text == 1.2.* 122 + , mtl >= 2.2.1 && < 2.3 123 + , hmatrix >= 0.18 && < 0.19 124 + , transformers 125 + , singletons 126 + , MonadRandom 127 + 128 + 129 + executable shakespeare 130 + ghc-options: -Wall -threaded -O2 131 + main-is: main/shakespeare.hs 132 + build-depends: base 133 + , grenade 134 + , attoparsec 135 + , either 136 + , optparse-applicative == 0.12.* 137 + , text == 1.2.* 138 + , mtl >= 2.2.1 && < 2.3 139 + , hmatrix >= 0.18 && < 0.19 140 + , transformers 141 + , singletons 142 + , vector 143 + , MonadRandom 144 + , containers 93 145 94 146 95 147 test-suite test ··· 117 169 , quickcheck-instances == 0.3.* 118 170 , MonadRandom 119 171 , random 172 + , ad 173 + , reflection 120 174 121 175 122 176 benchmark bench ··· 135 189 , criterion == 1.1.* 136 190 , grenade 137 191 , hmatrix 192 + 193 + benchmark bench-lstm 194 + type: exitcode-stdio-1.0 195 + 196 + main-is: bench-lstm.hs 197 + 198 + ghc-options: -Wall -threaded -O2 199 + 200 + hs-source-dirs: 201 + bench 202 + 203 + build-depends: 204 + base >= 3 && < 5 205 + , bytestring == 0.10.* 206 + , criterion == 1.1.* 207 + , grenade 208 + , hmatrix
+4 -2
mafia
··· 1 1 #!/bin/sh -eu 2 2 3 + : ${MAFIA_HOME:=$HOME/.mafia} 4 + 3 5 fetch_latest () { 4 6 if [ -z ${MAFIA_TEST_MODE+x} ]; then 5 7 TZ=$(date +"%T") ··· 55 57 # If we can't find the mafia version, then we need to upgrade the script. 56 58 run_upgrade 57 59 else 58 - MAFIA_BIN=$HOME/.ambiata/mafia/bin 60 + MAFIA_BIN=$MAFIA_HOME/bin 59 61 MAFIA_FILE=mafia-$MAFIA_VERSION 60 62 MAFIA_PATH=$MAFIA_BIN/$MAFIA_FILE 61 63 ··· 118 120 upgrade) shift; run_upgrade "$@" ;; 119 121 *) exec_mafia "$@" 120 122 esac 121 - # Version: a1b39ee8ac1969ed2e891b9062d079be75863e99 123 + # Version: 3044e63eb472fb9e16926d4ab2ca9dd9e255829c
+10 -10
main/feedforward.hs
··· 4 4 {-# LANGUAGE TypeOperators #-} 5 5 {-# LANGUAGE TupleSections #-} 6 6 {-# LANGUAGE TypeFamilies #-} 7 - {-# LANGUAGE FlexibleContexts #-} 8 - 9 7 import Control.Monad 10 8 import Control.Monad.Random 9 + import Data.List ( foldl' ) 10 + 11 11 import GHC.TypeLits 12 12 13 13 import qualified Numeric.LinearAlgebra.Static as SA ··· 34 34 netTest rate n = do 35 35 inps <- replicateM n $ do 36 36 s <- getRandom 37 - return $ S1D' $ SA.randomVector s SA.Uniform * 2 - 1 38 - let outs = flip map inps $ \(S1D' v) -> 37 + return $ S1D $ SA.randomVector s SA.Uniform * 2 - 1 38 + let outs = flip map inps $ \(S1D v) -> 39 39 if v `inCircle` (fromRational 0.33, 0.33) || v `inCircle` (fromRational (-0.33), 0.33) 40 - then S1D' $ fromRational 1 41 - else S1D' $ fromRational 0 40 + then S1D $ fromRational 1 41 + else S1D $ fromRational 0 42 42 net0 <- randomNet 43 43 44 - let trained = foldl trainEach net0 (zip inps outs) 44 + let trained = foldl' trainEach net0 (zip inps outs) 45 45 let testIns = [ [ (x,y) | x <- [0..50] ] 46 46 | y <- [0..20] ] 47 47 48 - let outMat = fmap (fmap (\(x,y) -> (render . normx) $ runNet trained (S1D' $ SA.vector [x / 25 - 1,y / 10 - 1]))) testIns 48 + let outMat = fmap (fmap (\(x,y) -> (render . normx) $ runNet trained (S1D $ SA.vector [x / 25 - 1,y / 10 - 1]))) testIns 49 49 return $ unlines outMat 50 50 51 51 where ··· 59 59 | n' <= 0.8 = '=' 60 60 | otherwise = '#' 61 61 62 - normx :: S' ('D1 1) -> Double 63 - normx (S1D' r) = SA.mean r 62 + normx :: S ('D1 1) -> Double 63 + normx (S1D r) = SA.mean r 64 64 65 65 data FeedForwardOpts = FeedForwardOpts Int LearningParameters 66 66
+13 -12
main/mnist.hs
··· 5 5 {-# LANGUAGE TupleSections #-} 6 6 {-# LANGUAGE TypeFamilies #-} 7 7 {-# LANGUAGE FlexibleContexts #-} 8 - 9 8 import Control.Applicative 10 9 import Control.Monad 11 10 import Control.Monad.Random 12 - import Control.Monad.Trans.Class 13 11 import Control.Monad.Trans.Except 14 12 15 13 import qualified Data.Attoparsec.Text as A 14 + import Data.List ( foldl' ) 16 15 import qualified Data.Text as T 17 16 import qualified Data.Text.IO as T 17 + import qualified Data.Vector.Storable as V 18 18 19 - import Numeric.LinearAlgebra (maxIndex) 19 + import Numeric.LinearAlgebra ( maxIndex ) 20 20 import qualified Numeric.LinearAlgebra.Static as SA 21 21 22 22 import Options.Applicative 23 23 24 24 import Grenade 25 + import Grenade.Utils.OneHot 25 26 26 27 -- The definition of our convolutional neural network. 27 28 -- In the type signature, we have a type level list of shapes which are passed between the layers. ··· 49 50 trainEach rate' !network (i, o) = train rate' network i o 50 51 51 52 runIteration trainRows validateRows net i = do 52 - let trained' = foldl (trainEach ( rate { learningRate = learningRate rate * 0.9 ^ i} )) net trainRows 53 + let trained' = foldl' (trainEach ( rate { learningRate = learningRate rate * 0.9 ^ i} )) net trainRows 53 54 let res = fmap (\(rowP,rowL) -> (rowL,) $ runNet trained' rowP) validateRows 54 - let res' = fmap (\(S1D' label, S1D' prediction) -> (maxIndex (SA.extract label), maxIndex (SA.extract prediction))) res 55 + let res' = fmap (\(S1D label, S1D prediction) -> (maxIndex (SA.extract label), maxIndex (SA.extract prediction))) res 55 56 print trained' 56 57 putStrLn $ "Iteration " ++ show i ++ ": " ++ show (length (filter ((==) <$> fst <*> snd) res')) ++ " of " ++ show (length res') 57 58 return trained' ··· 61 62 mnist' :: Parser MnistOpts 62 63 mnist' = MnistOpts <$> argument str (metavar "TRAIN") 63 64 <*> argument str (metavar "VALIDATE") 64 - <*> option auto (long "iterations" <> short 'i' <> value 10) 65 + <*> option auto (long "iterations" <> short 'i' <> value 15) 65 66 <*> (LearningParameters 66 67 <$> option auto (long "train_rate" <> short 'r' <> value 0.01) 67 68 <*> option auto (long "momentum" <> value 0.9) ··· 78 79 Right () -> pure () 79 80 Left err -> putStrLn err 80 81 81 - readMNIST :: FilePath -> ExceptT String IO [(S' ('D2 28 28), S' ('D1 10))] 82 + readMNIST :: FilePath -> ExceptT String IO [(S ('D2 28 28), S ('D1 10))] 82 83 readMNIST mnist = ExceptT $ do 83 84 mnistdata <- T.readFile mnist 84 85 return $ traverse (A.parseOnly parseMNIST) (T.lines mnistdata) 85 86 86 - parseMNIST :: A.Parser (S' ('D2 28 28), S' ('D1 10)) 87 + parseMNIST :: A.Parser (S ('D2 28 28), S ('D1 10)) 87 88 parseMNIST = do 88 - lab <- A.decimal 89 - pixels <- many (A.char ',' >> A.double) 90 - let lab' = replicate lab 0 ++ [1] ++ replicate (9 - lab) 0 91 - return (S2D' $ SA.fromList pixels, S1D' $ SA.fromList lab') 89 + Just lab <- oneHot <$> A.decimal 90 + pixels <- many (A.char ',' >> A.double) 91 + image <- maybe (fail "Parsed row was of an incorrect size") pure (fromStorable . V.fromList $ pixels) 92 + return (image, lab)
+87
main/recurrent.hs
··· 1 + {-# LANGUAGE BangPatterns #-} 2 + {-# LANGUAGE DataKinds #-} 3 + {-# LANGUAGE ScopedTypeVariables #-} 4 + {-# LANGUAGE TypeOperators #-} 5 + {-# LANGUAGE TupleSections #-} 6 + {-# LANGUAGE TypeFamilies #-} 7 + 8 + import Control.Monad ( foldM ) 9 + import Control.Monad.Random ( MonadRandom, getRandomR ) 10 + 11 + import Data.List ( cycle, unfoldr ) 12 + import qualified Numeric.LinearAlgebra.Static as SA 13 + 14 + import Options.Applicative 15 + 16 + import Grenade 17 + import Grenade.Recurrent 18 + 19 + -- The defininition for our simple recurrent network. 20 + -- This file just trains a network to generate a repeating sequence 21 + -- of 0 0 1. 22 + -- 23 + -- The F and R types are Tagging types to ensure that the runner and 24 + -- creation function know how to treat the layers. 25 + type F = FeedForward 26 + type R = Recurrent 27 + 28 + type RecNet = RecurrentNetwork '[ R (LSTM 1 4), R (LSTM 4 1), F Trivial] 29 + '[ 'D1 1, 'D1 4, 'D1 1, 'D1 1 ] 30 + 31 + type RecInput = RecurrentInputs '[ R (LSTM 1 4), R (LSTM 4 1), F Trivial] 32 + 33 + randomNet :: MonadRandom m => m (RecNet, RecInput) 34 + randomNet = randomRecurrent 35 + 36 + netTest :: MonadRandom m => RecNet -> RecInput -> LearningParameters -> Int -> m (RecNet, RecInput) 37 + netTest net0 i0 rate iterations = 38 + foldM trainIteration (net0,i0) [1..iterations] 39 + where 40 + trainingCycle = cycle [c 0, c 0, c 1] 41 + 42 + trainIteration (net, io) _ = do 43 + dropping <- getRandomR (0, 2) 44 + count <- getRandomR (5, 30) 45 + let t = drop dropping trainingCycle 46 + let example = ((,Nothing) <$> take count t) ++ [(t !! count, Just $ t !! (count + 1))] 47 + return $ trainEach net io example 48 + 49 + trainEach !nt !io !ex = trainRecurrent rate nt io ex 50 + 51 + data FeedForwardOpts = FeedForwardOpts Int LearningParameters 52 + 53 + feedForward' :: Parser FeedForwardOpts 54 + feedForward' = FeedForwardOpts <$> option auto (long "examples" <> short 'e' <> value 20000) 55 + <*> (LearningParameters 56 + <$> option auto (long "train_rate" <> short 'r' <> value 0.01) 57 + <*> option auto (long "momentum" <> value 0.9) 58 + <*> option auto (long "l2" <> value 0.0005) 59 + ) 60 + 61 + generateRecurrent :: RecNet -> RecInput -> S ('D1 1) -> [Int] 62 + generateRecurrent n s i = 63 + unfoldr go (s, i) 64 + where 65 + go (x, y) = 66 + do let (ns, o) = runRecurrent n x y 67 + o' = heat o 68 + Just (o', (ns, fromIntegral o')) 69 + 70 + heat :: S ('D1 1) -> Int 71 + heat x = case x of 72 + (S1D v) -> round (SA.mean v) 73 + 74 + main :: IO () 75 + main = do 76 + FeedForwardOpts examples rate <- execParser (info (feedForward' <**> helper) idm) 77 + putStrLn "Training network..." 78 + 79 + (net0, i0) <- randomNet 80 + (trained, bestInput) <- netTest net0 i0 rate examples 81 + 82 + let results = generateRecurrent trained bestInput (c 1) 83 + 84 + print . take 50 . drop 100 $ results 85 + 86 + c :: Double -> S ('D1 1) 87 + c = S1D . SA.konst
+156
main/shakespeare.hs
··· 1 + {-# LANGUAGE BangPatterns #-} 2 + {-# LANGUAGE RecordWildCards #-} 3 + {-# LANGUAGE DataKinds #-} 4 + {-# LANGUAGE ScopedTypeVariables #-} 5 + {-# LANGUAGE TypeOperators #-} 6 + {-# LANGUAGE TupleSections #-} 7 + {-# LANGUAGE TypeFamilies #-} 8 + {-# LANGUAGE LambdaCase #-} 9 + 10 + import Control.Monad.Random 11 + import Control.Monad.Trans.Except 12 + 13 + import Data.Char ( isUpper, toUpper, toLower ) 14 + import Data.List ( unfoldr, foldl' ) 15 + import Data.Maybe ( fromMaybe ) 16 + 17 + import qualified Data.Vector as V 18 + import Data.Vector ( Vector ) 19 + 20 + import qualified Data.Map as M 21 + import Data.Proxy ( Proxy (..) ) 22 + 23 + 24 + import Data.Singletons.Prelude 25 + import GHC.TypeLits 26 + 27 + import Numeric.LinearAlgebra.Static ( konst ) 28 + 29 + import Options.Applicative 30 + 31 + import Grenade 32 + import Grenade.Recurrent 33 + import Grenade.Utils.OneHot 34 + 35 + -- The defininition for our natural language recurrent network. 36 + -- This network is able to learn and generate simple words in 37 + -- about an hour. 38 + -- 39 + -- This is a first class recurrent net, although it's similar to 40 + -- an unrolled graph. 41 + -- 42 + -- The F and R types are tagging types to ensure that the runner and 43 + -- creation function know how to treat the layers. 44 + -- 45 + -- As an example, here's a short sequence generated. 46 + -- 47 + -- > the see and and the sir, and and the make and the make and go the make and go the make and the 48 + -- 49 + type F = FeedForward 50 + type R = Recurrent 51 + 52 + -- The definition of our network 53 + type Shakespeare = RecurrentNetwork '[ R (LSTM 40 40), F (FullyConnected 40 40), F Logit] 54 + '[ 'D1 40, 'D1 40, 'D1 40, 'D1 40 ] 55 + 56 + -- The definition of the "sideways" input, which the network if fed recurrently. 57 + type Shakespearian = RecurrentInputs '[ R (LSTM 40 40), F (FullyConnected 40 40), F Logit] 58 + 59 + randomNet :: MonadRandom m => m (Shakespeare, Shakespearian) 60 + randomNet = randomRecurrent 61 + 62 + -- | Load the data files and prepare a map of characters to a compressed int representation. 63 + loadShakespeare :: FilePath -> ExceptT String IO (Vector Int, M.Map Char Int, Vector Char) 64 + loadShakespeare path = do 65 + contents <- lift $ readFile path 66 + let annotated = annotateCapitals contents 67 + (m,cs) <- ExceptT . return . note "Couldn't fit data in hotMap" $ hotMap (Proxy :: Proxy 40) annotated 68 + hot <- ExceptT . return . note "Couldn't generate hot values" $ traverse (`M.lookup` m) annotated 69 + return (V.fromList hot, m, cs) 70 + 71 + trainSlice :: LearningParameters -> Shakespeare -> Shakespearian -> Vector Int -> Int -> Int -> (Shakespeare, Shakespearian) 72 + trainSlice !rate !net !recIns input offset size = 73 + let e = fmap (x . oneHot) . V.toList $ V.slice offset size input 74 + in case reverse e of 75 + (o : l : xs) -> 76 + let examples = reverse $ (l, Just o) : ((,Nothing) <$> xs) 77 + in trainRecurrent rate net recIns examples 78 + _ -> error "Not enough input" 79 + where 80 + x = fromMaybe (error "Hot variable didn't fit.") 81 + 82 + runShakespeare :: ShakespeareOpts -> ExceptT String IO () 83 + runShakespeare ShakespeareOpts {..} = do 84 + (shakespeare, oneHotMap, oneHotDictionary) <- loadShakespeare trainingFile 85 + (net0, i0) <- lift randomNet 86 + lift $ foldM_ (\(!net, !io) size -> do 87 + xs <- take (iterations `div` 15) <$> getRandomRs (0, length shakespeare - size - 1) 88 + let (!trained, !bestInput) = foldl' (\(!n, !i) offset -> trainSlice rate n i shakespeare offset size) (net, io) xs 89 + let results = take 100 $ generateParagraph trained bestInput oneHotMap oneHotDictionary ( S1D $ konst 0) 90 + putStrLn ("TRAINING STEP WITH SIZE: " ++ show size) 91 + putStrLn (unAnnotateCapitals results) 92 + return (trained, bestInput) 93 + ) (net0, i0) [10,10,15,15,20,20,25,25,30,30,35,35,40,40,50 :: Int] 94 + 95 + generateParagraph :: forall layers shapes n a. (Last shapes ~ 'D1 n, Head shapes ~ 'D1 n, KnownNat n, Ord a) 96 + => RecurrentNetwork layers shapes 97 + -> RecurrentInputs layers 98 + -> M.Map a Int 99 + -> Vector a 100 + -> S ('D1 n) 101 + -> [a] 102 + generateParagraph n s hotmap hotdict i = 103 + unfoldr go (s, i) 104 + where 105 + go (x, y) = 106 + do let (ns, o) = runRecurrent n x y 107 + un <- unHot hotdict o 108 + re <- makeHot hotmap un 109 + Just (un, (ns, re)) 110 + 111 + data ShakespeareOpts = ShakespeareOpts { 112 + trainingFile :: FilePath 113 + , iterations :: Int 114 + , rate :: LearningParameters 115 + } 116 + 117 + shakespeare' :: Parser ShakespeareOpts 118 + shakespeare' = ShakespeareOpts <$> argument str (metavar "TRAIN") 119 + <*> option auto (long "examples" <> short 'e' <> value 1000000) 120 + <*> (LearningParameters 121 + <$> option auto (long "train_rate" <> short 'r' <> value 0.01) 122 + <*> option auto (long "momentum" <> value 0.95) 123 + <*> option auto (long "l2" <> value 0.000001) 124 + ) 125 + 126 + main :: IO () 127 + main = do 128 + shopts <- execParser (info (shakespeare' <**> helper) idm) 129 + res <- runExceptT $ runShakespeare shopts 130 + case res of 131 + Right () -> pure () 132 + Left err -> putStrLn err 133 + 134 + 135 + -- Replace capitals with an annotation and the lower case letter 136 + -- http://fastml.com/one-weird-trick-for-training-char-rnns/ 137 + annotateCapitals :: String -> String 138 + annotateCapitals (x : rest) 139 + | isUpper x 140 + = '^' : toLower x : annotateCapitals rest 141 + | otherwise 142 + = x : annotateCapitals rest 143 + annotateCapitals [] 144 + = [] 145 + 146 + unAnnotateCapitals :: String -> String 147 + unAnnotateCapitals ('^' : x : rest) 148 + = toUpper x : unAnnotateCapitals rest 149 + unAnnotateCapitals (x : rest) 150 + = x : unAnnotateCapitals rest 151 + unAnnotateCapitals [] 152 + = [] 153 + 154 + -- | Tag the 'Nothing' value of a 'Maybe' 155 + note :: a -> Maybe b -> Either a b 156 + note a = maybe (Left a) Right
+30 -15
src/Grenade/Core/Network.hs
··· 1 1 {-# LANGUAGE DataKinds #-} 2 2 {-# LANGUAGE GADTs #-} 3 - {-# LANGUAGE KindSignatures #-} 4 - {-# LANGUAGE ScopedTypeVariables #-} 5 3 {-# LANGUAGE TypeOperators #-} 6 4 {-# LANGUAGE TypeFamilies #-} 7 - {-# LANGUAGE PolyKinds #-} 8 5 {-# LANGUAGE MultiParamTypeClasses #-} 9 6 {-# LANGUAGE FlexibleContexts #-} 10 7 {-# LANGUAGE FlexibleInstances #-} 11 - {-# LANGUAGE LambdaCase #-} 8 + {-| 9 + Module : Grenade.Core.Network 10 + Description : Core definition a simple neural etwork 11 + Copyright : (c) Huw Campbell, 2016-2017 12 + License : BSD2 13 + Stability : experimental 14 + 15 + This module defines the core data type for the simplest 16 + Neural network we support. 12 17 18 + -} 13 19 module Grenade.Core.Network ( 14 20 Layer (..) 15 21 , Network (..) ··· 20 26 ) where 21 27 22 28 import Control.Monad.Random (MonadRandom) 23 - 29 + import Data.List ( foldl' ) 30 + import Data.Singletons 24 31 25 32 import Grenade.Core.Shape 26 33 34 + -- | Learning parameters for stochastic gradient descent. 27 35 data LearningParameters = LearningParameters { 28 36 learningRate :: Double 29 37 , learningMomentum :: Double ··· 33 41 -- | Class for updating a layer. All layers implement this, and it is 34 42 -- shape independent. 35 43 class Show x => UpdateLayer x where 44 + {-# MINIMAL runUpdate, createRandom #-} 36 45 -- | The type for the gradient for this layer. 37 46 -- Unit if there isn't a gradient to pass back. 38 47 type Gradient x :: * 39 48 -- | Update a layer with its gradient and learning parameters 40 49 runUpdate :: LearningParameters -> x -> Gradient x -> x 50 + 41 51 -- | Create a random layer, many layers will use pure 42 52 createRandom :: MonadRandom m => m x 43 53 54 + -- | Update a layer with many Gradients 55 + runUpdates :: LearningParameters -> x -> [Gradient x] -> x 56 + runUpdates rate = foldl' (runUpdate rate) 57 + 44 58 -- | Class for a layer. All layers implement this, however, they don't 45 59 -- need to implement it for all shapes, only ones which are appropriate. 46 60 class UpdateLayer x => Layer x (i :: Shape) (o :: Shape) where 47 61 -- | Used in training and scoring. Take the input from the previous 48 62 -- layer, and give the output from this layer. 49 - runForwards :: x -> S' i -> S' o 63 + runForwards :: x -> S i -> S o 50 64 -- | Back propagate a step. Takes the current layer, the input that the 51 65 -- layer gave from the input and the back propagated derivatives from 52 66 -- the layer above. 53 67 -- Returns the gradient layer and the derivatives to push back further. 54 - runBackwards :: x -> S' i -> S' o -> (Gradient x, S' i) 68 + runBackwards :: x -> S i -> S o -> (Gradient x, S i) 55 69 56 70 -- | Type of a network. 57 - -- The [*] type specifies the types of the layers. This is needed for parallel 58 - -- running and being all the gradients beck together. 71 + -- 72 + -- The [*] type specifies the types of the layers. 73 + -- 59 74 -- The [Shape] type specifies the shapes of data passed between the layers. 60 - -- Could be considered to be a heterogeneous list of layers which are able to 75 + -- 76 + -- Can be considered to be a heterogeneous list of layers which are able to 61 77 -- transform the data shapes of the network. 62 78 data Network :: [*] -> [Shape] -> * where 63 - O :: Layer x i o => !x -> Network '[x] '[i, o] 64 - (:~>) :: Layer x i h => !x -> !(Network xs (h ': hs)) -> Network (x ': xs) (i ': h ': hs) 79 + O :: (SingI i, SingI o, Layer x i o) => !x -> Network '[x] '[i, o] 80 + (:~>) :: (SingI i, SingI h, Layer x i h) => !x -> !(Network xs (h ': hs)) -> Network (x ': xs) (i ': h ': hs) 65 81 infixr 5 :~> 66 82 67 83 instance Show (Network l h) where ··· 74 90 OG :: UpdateLayer x => Gradient x -> Gradients '[x] 75 91 (:/>) :: UpdateLayer x => Gradient x -> Gradients xs -> Gradients (x ': xs) 76 92 77 - 78 93 -- | A network can easily be created by hand with (:~>), but an easy way to initialise a random 79 94 -- network is with the randomNetwork. 80 95 class CreatableNetwork (xs :: [*]) (ss :: [Shape]) where 81 96 -- | Create a network of the types requested 82 97 randomNetwork :: MonadRandom m => m (Network xs ss) 83 98 84 - instance Layer x i o => CreatableNetwork (x ': '[]) (i ': o ': '[]) where 99 + instance (SingI i, SingI o, Layer x i o) => CreatableNetwork (x ': '[]) (i ': o ': '[]) where 85 100 randomNetwork = O <$> createRandom 86 101 87 - instance (Layer x i o, CreatableNetwork xs (o ': r ': rs)) => CreatableNetwork (x ': xs) (i ': o ': r ': rs) where 102 + instance (SingI i, SingI o, Layer x i o, CreatableNetwork xs (o ': r ': rs)) => CreatableNetwork (x ': xs) (i ': o ': r ': rs) where 88 103 randomNetwork = (:~>) <$> createRandom <*> randomNetwork
+30 -20
src/Grenade/Core/Runner.hs
··· 4 4 {-# LANGUAGE ScopedTypeVariables #-} 5 5 {-# LANGUAGE TypeOperators #-} 6 6 {-# LANGUAGE TypeFamilies #-} 7 + {-| 8 + Module : Grenade.Core.Shape 9 + Description : Core definition of the Shapes of data we understand 10 + Copyright : (c) Huw Campbell, 2016-2017 11 + License : BSD2 12 + Stability : experimental 7 13 14 + This module defines simple back propagation and training functions 15 + for a network. 16 + -} 8 17 module Grenade.Core.Runner ( 9 18 train 10 19 , backPropagate ··· 16 25 import Grenade.Core.Network 17 26 import Grenade.Core.Shape 18 27 19 - -- | Drive and network and collect its back propogated gradients. 20 - backPropagate :: forall input output shapes layers. (Head shapes ~ input, Last shapes ~ output) 21 - => Network layers shapes -> S' input -> S' output -> Gradients layers 28 + -- | Perform reverse automatic differentiation on the network 29 + -- for the current input and expected output. 30 + -- 31 + -- /Note:/ The loss function pushed backwards is appropriate 32 + -- for both regression and classification as a squared loss 33 + -- or log-loss respectively. Other loss functions are not yet 34 + -- implemented. 35 + backPropagate :: forall shapes layers. 36 + Network layers shapes -> S (Head shapes) -> S (Last shapes) -> Gradients layers 22 37 backPropagate network input target = 23 38 fst $ go input network 24 39 where 25 - go :: forall j js sublayers. (Head js ~ j, Last js ~ output) 26 - => S' j -- ^ input vector 40 + go :: forall js sublayers. (Last js ~ Last shapes) 41 + => S (Head js) -- ^ input vector 27 42 -> Network sublayers js -- ^ network to train 28 - -> (Gradients sublayers, S' j) 43 + -> (Gradients sublayers, S (Head js)) 29 44 -- handle input from the beginning, feeding upwards. 30 45 go !x (layer :~> n) 31 46 = let y = runForwards layer x ··· 44 59 45 60 in (OG layer', dWs) 46 61 47 - -- | Update a network with new weights after training with an instance. 48 - train :: forall input output shapes layers. (Head shapes ~ input, Last shapes ~ output) 49 - => LearningParameters -- ^ learning rate 50 - -> Network layers shapes -- ^ network to train 51 - -> S' input -> S' output -- ^ target vector 52 - -> Network layers shapes 53 - train rate network input output = 54 - let grads = backPropagate network input output 55 - in applyUpdate rate network grads 56 - 62 + -- | Apply one step of stochastic gradient decent across the network. 57 63 applyUpdate :: LearningParameters -> Network ls ss -> Gradients ls -> Network ls ss 58 64 applyUpdate rate (O layer) (OG gradient) 59 65 = O (runUpdate rate layer gradient) ··· 62 68 applyUpdate _ _ _ 63 69 = error "Impossible for the gradients of a network to have a different length to the network" 64 70 65 - -- | Just forwards propagation with no training. 66 - runNet :: Network layers hs 67 - -> S' (Head hs) -- ^ input vector 68 - -> S' (Last hs) -- ^ target vector 71 + -- | Update a network with new weights after training with an instance. 72 + train :: LearningParameters -> Network layers shapes -> S (Head shapes) -> S (Last shapes) -> Network layers shapes 73 + train rate network input output = 74 + let grads = backPropagate network input output 75 + in applyUpdate rate network grads 76 + 77 + -- | Run the network with input and return the given output. 78 + runNet :: Network layers shapes -> S (Head shapes) -> S (Last shapes) 69 79 runNet (layer :~> n) !x = let y = runForwards layer x in runNet n y 70 80 runNet (O layer) !x = runForwards layer x
+140 -44
src/Grenade/Core/Shape.hs
··· 1 1 {-# LANGUAGE DataKinds #-} 2 2 {-# LANGUAGE GADTs #-} 3 3 {-# LANGUAGE KindSignatures #-} 4 - {-# LANGUAGE ScopedTypeVariables #-} 5 - {-# LANGUAGE TypeOperators #-} 6 4 {-# LANGUAGE TypeFamilies #-} 7 - {-# LANGUAGE PolyKinds #-} 8 - {-# LANGUAGE MultiParamTypeClasses #-} 5 + {-# LANGUAGE TypeOperators #-} 6 + {-# LANGUAGE StandaloneDeriving #-} 9 7 {-# LANGUAGE FlexibleContexts #-} 10 - {-# LANGUAGE FlexibleInstances #-} 8 + {-# LANGUAGE ScopedTypeVariables #-} 9 + {-# LANGUAGE RankNTypes #-} 11 10 12 - -- Ghc 8.0 gives a warning on `(+) _ _ = error ...` but ghc 7.10 fails to 11 + -- Ghc 8.0 gives a warning on `n2 _ _ = error ...` but ghc 7.10 fails to 13 12 -- compile without this default pattern. 14 13 {-# OPTIONS_GHC -fno-warn-overlapping-patterns #-} 15 14 15 + {-| 16 + Module : Grenade.Core.Shape 17 + Description : Core definition of the Shapes of data we understand 18 + Copyright : (c) Huw Campbell, 2016-2017 19 + License : BSD2 20 + Stability : experimental 21 + 22 + This module defines the core data types for the shapes of data that 23 + are understood by Grenade. 24 + -} 16 25 module Grenade.Core.Shape ( 17 26 Shape (..) 18 - , S' (..) 27 + , S (..) 28 + , randomOfShape 29 + , fromStorable 19 30 ) where 20 31 32 + import Control.DeepSeq (NFData (..)) 33 + import Control.Monad.Random ( MonadRandom, getRandom ) 34 + 35 + import Data.Singletons 21 36 import Data.Singletons.TypeLits 37 + import Data.Vector.Storable ( Vector ) 38 + import qualified Data.Vector.Storable as V 39 + 22 40 import GHC.TypeLits 23 41 42 + import qualified Numeric.LinearAlgebra.Static as H 24 43 import Numeric.LinearAlgebra.Static 25 - 44 + import qualified Numeric.LinearAlgebra as NLA 26 45 27 46 -- | The current shapes we accept. 28 47 -- at the moment this is just one, two, and three dimensional 29 48 -- Vectors/Matricies. 30 - data Shape = 31 - D1 Nat 49 + data Shape 50 + = D1 Nat 32 51 | D2 Nat Nat 33 52 | D3 Nat Nat Nat 34 53 35 - instance Num (S' x) where 36 - (+) (S1D' x) (S1D' y) = S1D' (x + y) 37 - (+) (S2D' x) (S2D' y) = S2D' (x + y) 38 - (+) (S3D' x) (S3D' y) = S3D' (x + y) 39 - (+) _ _ = error "Impossible to have different constructors for the same shaped network" 54 + -- | Given a Shape n, these are the possible data structures with that shape. 55 + -- All shapes are held in contiguous memory. 56 + -- 3D is held in a matrix (usually row oriented) which has height depth * rows. 57 + data S (n :: Shape) where 58 + S1D :: ( KnownNat o ) => R o -> S ('D1 o) 59 + S2D :: ( KnownNat rows, KnownNat columns ) => L rows columns -> S ('D2 rows columns) 60 + S3D :: ( KnownNat rows 61 + , KnownNat columns 62 + , KnownNat depth 63 + , KnownNat (rows * depth)) => L (rows * depth) columns -> S ('D3 rows columns depth) 40 64 41 - (-) (S1D' x) (S1D' y) = S1D' (x - y) 42 - (-) (S2D' x) (S2D' y) = S2D' (x - y) 43 - (-) (S3D' x) (S3D' y) = S3D' (x - y) 44 - (-) _ _ = error "Impossible to have different constructors for the same shaped network" 65 + deriving instance Show (S n) 45 66 46 - (*) (S1D' x) (S1D' y) = S1D' (x * y) 47 - (*) (S2D' x) (S2D' y) = S2D' (x * y) 48 - (*) (S3D' x) (S3D' y) = S3D' (x * y) 49 - (*) _ _ = error "Impossible to have different constructors for the same shaped network" 67 + instance SingI x => Num (S x) where 68 + (+) = n2 (+) 69 + (-) = n2 (-) 70 + (*) = n2 (*) 71 + abs = n1 abs 72 + signum = n1 signum 73 + fromInteger x = case (sing :: Sing x) of 74 + D1Sing -> S1D (konst $ fromInteger x) 75 + D2Sing -> S2D (konst $ fromInteger x) 76 + D3Sing -> S3D (konst $ fromInteger x) 50 77 51 - abs (S1D' x) = S1D' (abs x) 52 - abs (S2D' x) = S2D' (abs x) 53 - abs (S3D' x) = S3D' (abs x) 78 + instance SingI x => Fractional (S x) where 79 + (/) = n2 (/) 80 + recip = n1 recip 81 + fromRational x = case (sing :: Sing x) of 82 + D1Sing -> S1D (konst $ fromRational x) 83 + D2Sing -> S2D (konst $ fromRational x) 84 + D3Sing -> S3D (konst $ fromRational x) 54 85 55 - signum (S1D' x) = S1D' (signum x) 56 - signum (S2D' x) = S2D' (signum x) 57 - signum (S3D' x) = S3D' (signum x) 86 + instance SingI x => Floating (S x) where 87 + pi = case (sing :: Sing x) of 88 + D1Sing -> S1D (konst pi) 89 + D2Sing -> S2D (konst pi) 90 + D3Sing -> S3D (konst pi) 91 + exp = n1 exp 92 + log = n1 log 93 + sqrt = n1 sqrt 94 + (**) = n2 (**) 95 + logBase = n2 logBase 96 + sin = n1 sin 97 + cos = n1 cos 98 + tan = n1 tan 99 + asin = n1 asin 100 + acos = n1 acos 101 + atan = n1 atan 102 + sinh = n1 sinh 103 + cosh = n1 cosh 104 + tanh = n1 tanh 105 + asinh = n1 asinh 106 + acosh = n1 acosh 107 + atanh = n1 atanh 58 108 59 - fromInteger _ = error "Unimplemented: fromInteger on Shape" 109 + -- Singletons 110 + -- These could probably be derived with template haskell, but this seems 111 + -- clear and makes adding the KnownNat constraints simple. 112 + data instance Sing (n :: Shape) where 113 + D1Sing :: KnownNat a => Sing ('D1 a) 114 + D2Sing :: (KnownNat a, KnownNat b) => Sing ('D2 a b) 115 + D3Sing :: (KnownNat a, KnownNat b, KnownNat c, KnownNat (a * c)) => Sing ('D3 a b c) 60 116 61 - -- | Given a Shape n, these are the possible data structures with that shape. 62 - -- All shapes are held in contiguous memory. 63 - -- 3D is held in a matrix (usually row oriented) which has height depth * rows. 64 - data S' (n :: Shape) where 65 - S1D' :: ( KnownNat o ) => R o -> S' ('D1 o) 66 - S2D' :: ( KnownNat rows, KnownNat columns ) => L rows columns -> S' ('D2 rows columns) 67 - S3D' :: ( KnownNat rows 68 - , KnownNat columns 69 - , KnownNat depth 70 - , KnownNat (rows * depth)) => L (rows * depth) columns -> S' ('D3 rows columns depth) 117 + instance KnownNat a => SingI ('D1 a) where 118 + sing = D1Sing 119 + instance (KnownNat a, KnownNat b) => SingI ('D2 a b) where 120 + sing = D2Sing 121 + instance (KnownNat a, KnownNat b, KnownNat c, KnownNat (a * c)) => SingI ('D3 a b c) where 122 + sing = D3Sing 71 123 72 - instance Show (S' n) where 73 - show (S1D' a) = "S1D' " ++ show a 74 - show (S2D' a) = "S2D' " ++ show a 75 - show (S3D' a) = "S3D' " ++ show a 124 + -- 125 + -- I haven't made shapes strict, as sometimes they're not needed 126 + -- (the last input gradient back for instance) 127 + -- 128 + instance NFData (S x) where 129 + rnf (S1D x) = rnf x 130 + rnf (S2D x) = rnf x 131 + rnf (S3D x) = rnf x 132 + 133 + -- | Generate random data of the desired shape 134 + randomOfShape :: forall x m. ( MonadRandom m, SingI x ) => m (S x) 135 + randomOfShape = do 136 + seed :: Int <- getRandom 137 + return $ case (sing :: Sing x) of 138 + D1Sing -> S1D (randomVector seed Uniform * 2 - 1) 139 + D2Sing -> S2D (uniformSample seed (-1) 1) 140 + D3Sing -> S3D (uniformSample seed (-1) 1) 141 + 142 + -- | Generate a shape from a Storable Vector. 143 + -- 144 + -- Returns Nothing if the vector is of the wrong size. 145 + fromStorable :: forall x. SingI x => Vector Double -> Maybe (S x) 146 + fromStorable xs = case sing :: Sing x of 147 + D1Sing -> S1D <$> H.create xs 148 + D2Sing -> S2D <$> mkL xs 149 + D3Sing -> S3D <$> mkL xs 150 + where 151 + mkL :: forall rows columns. (KnownNat rows, KnownNat columns) 152 + => Vector Double -> Maybe (L rows columns) 153 + mkL v = 154 + let rows = fromIntegral $ natVal (Proxy :: Proxy rows) 155 + columns = fromIntegral $ natVal (Proxy :: Proxy columns) 156 + in if rows * columns == V.length v 157 + then H.create $ NLA.reshape columns v 158 + else Nothing 159 + 160 + -- Helper function for creating the number instances 161 + n1 :: ( forall a. Floating a => a -> a ) -> S x -> S x 162 + n1 f (S1D x) = S1D (f x) 163 + n1 f (S2D x) = S2D (f x) 164 + n1 f (S3D x) = S3D (f x) 165 + 166 + -- Helper function for creating the number instances 167 + n2 :: ( forall a. Floating a => a -> a -> a ) -> S x -> S x -> S x 168 + n2 f (S1D x) (S1D y) = S1D (f x y) 169 + n2 f (S2D x) (S2D y) = S2D (f x y) 170 + n2 f (S3D x) (S3D y) = S3D (f x y) 171 + n2 _ _ _ = error "Impossible to have different constructors for the same shaped network"
+37
src/Grenade/Graph/GraphNetwork.hs
··· 1 + {-# LANGUAGE DataKinds #-} 2 + {-# LANGUAGE GADTs #-} 3 + {-# LANGUAGE KindSignatures #-} 4 + {-# LANGUAGE ScopedTypeVariables #-} 5 + {-# LANGUAGE TypeOperators #-} 6 + {-# LANGUAGE TypeFamilies #-} 7 + {-# LANGUAGE PolyKinds #-} 8 + {-# LANGUAGE MultiParamTypeClasses #-} 9 + {-# LANGUAGE FlexibleContexts #-} 10 + {-# LANGUAGE FlexibleInstances #-} 11 + {-# LANGUAGE LambdaCase #-} 12 + 13 + module Grenade.Graph.Network ( 14 + Layer (..) 15 + , UpdateLayer (..) 16 + ) where 17 + 18 + import Control.Monad.Random (MonadRandom) 19 + import Data.Singletons 20 + import Data.Singletons.Prelude 21 + 22 + import GHC.TypeLits 23 + 24 + import Grenade.Core.Shape 25 + import Grenade.Core.Network ( UpdateLayer (..), Layer (..) ) 26 + 27 + -- | Type of a DAG network 28 + 29 + data Fin :: Nat -> * where 30 + Fin0 :: Fin (n + 1) 31 + FinS :: Fin n -> Fin (n + 1) 32 + 33 + data Edge :: Nat -> * where 34 + Edge :: Shape -> Fin n -> Edge n 35 + 36 + data Node a n where 37 + Node :: a -> [Edge n] -> Node a n
+26 -32
src/Grenade/Layers/Convolution.hs
··· 1 - {-# LANGUAGE BangPatterns #-} 2 1 {-# LANGUAGE DataKinds #-} 3 2 {-# LANGUAGE ScopedTypeVariables #-} 4 - {-# LANGUAGE StandaloneDeriving #-} 5 3 {-# LANGUAGE RecordWildCards #-} 6 4 {-# LANGUAGE GADTs #-} 7 5 {-# LANGUAGE TypeOperators #-} ··· 9 7 {-# LANGUAGE MultiParamTypeClasses #-} 10 8 {-# LANGUAGE FlexibleInstances #-} 11 9 {-# LANGUAGE FlexibleContexts #-} 12 - {-# LANGUAGE PolyKinds #-} 13 - {-# LANGUAGE PatternGuards #-} 14 - 15 10 module Grenade.Layers.Convolution ( 16 11 Convolution (..) 17 12 , Convolution' (..) ··· 31 26 import Grenade.Core.Network 32 27 import Grenade.Core.Shape 33 28 import Grenade.Layers.Internal.Convolution 29 + import Grenade.Layers.Internal.Update 34 30 35 31 -- | A convolution layer for a neural network. 36 32 -- This uses the im2col convolution trick popularised by Caffe, which essentially turns the ··· 43 39 -- `out = (in - kernel) / stride + 1` for both dimensions. 44 40 -- 45 41 -- One probably shouldn't build their own layer, but rather use the randomConvolution function. 46 - data Convolution :: Nat -- ^ Number of channels, for the first layer this could be RGB for instance. 47 - -> Nat -- ^ Number of filters, this is the number of channels output by the layer. 48 - -> Nat -- ^ The number of rows in the kernel filter 49 - -> Nat -- ^ The number of column in the kernel filter 50 - -> Nat -- ^ The row stride of the convolution filter 51 - -> Nat -- ^ The columns stride of the convolution filter 42 + data Convolution :: Nat -- Number of channels, for the first layer this could be RGB for instance. 43 + -> Nat -- Number of filters, this is the number of channels output by the layer. 44 + -> Nat -- The number of rows in the kernel filter 45 + -> Nat -- The number of column in the kernel filter 46 + -> Nat -- The row stride of the convolution filter 47 + -> Nat -- The columns stride of the convolution filter 52 48 -> * where 53 49 Convolution :: ( KnownNat channels 54 50 , KnownNat filters ··· 58 54 , KnownNat strideColumns 59 55 , KnownNat kernelFlattened 60 56 , kernelFlattened ~ (kernelRows * kernelColumns * channels)) 61 - => !(L kernelFlattened filters) -- ^ The kernel filter weights 62 - -> !(L kernelFlattened filters) -- ^ The last kernel update (or momentum) 57 + => !(L kernelFlattened filters) -- The kernel filter weights 58 + -> !(L kernelFlattened filters) -- The last kernel update (or momentum) 63 59 -> Convolution channels filters kernelRows kernelColumns strideRows strideColumns 64 60 65 - data Convolution' :: Nat -- ^ Number of channels, for the first layer this could be RGB for instance. 66 - -> Nat -- ^ Number of filters, this is the number of channels output by the layer. 67 - -> Nat -- ^ The number of rows in the kernel filter 68 - -> Nat -- ^ The number of column in the kernel filter 69 - -> Nat -- ^ The row stride of the convolution filter 70 - -> Nat -- ^ The columns stride of the convolution filter 61 + data Convolution' :: Nat -- Number of channels, for the first layer this could be RGB for instance. 62 + -> Nat -- Number of filters, this is the number of channels output by the layer. 63 + -> Nat -- The number of rows in the kernel filter 64 + -> Nat -- The number of column in the kernel filter 65 + -> Nat -- The row stride of the convolution filter 66 + -> Nat -- The columns stride of the convolution filter 71 67 -> * where 72 68 Convolution' :: ( KnownNat channels 73 69 , KnownNat filters ··· 77 73 , KnownNat strideColumns 78 74 , KnownNat kernelFlattened 79 75 , kernelFlattened ~ (kernelRows * kernelColumns * channels)) 80 - => !(L kernelFlattened filters) -- ^ The kernel filter gradient 76 + => !(L kernelFlattened filters) -- The kernel filter gradient 81 77 -> Convolution' channels filters kernelRows kernelColumns strideRows strideColumns 82 78 83 79 instance Show (Convolution c f k k' s s') where ··· 109 105 , kernelFlattened ~ (kernelRows * kernelColumns * channels)) 110 106 => m (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) 111 107 randomConvolution = do 112 - s :: Int <- getRandom 108 + s <- getRandom 113 109 let wN = uniformSample s (-1) 1 114 110 mm = konst 0 115 111 return $ Convolution wN mm ··· 124 120 ) => UpdateLayer (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) where 125 121 type Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols) = (Convolution' channels filters kernelRows kernelCols strideRows strideCols) 126 122 runUpdate LearningParameters {..} (Convolution oldKernel oldMomentum) (Convolution' kernelGradient) = 127 - let newMomentum = konst learningMomentum * oldMomentum - konst learningRate * kernelGradient 128 - regulariser = konst (learningRegulariser * learningRate) * oldKernel 129 - newKernel = oldKernel + newMomentum - regulariser 123 + let (newKernel, newMomentum) = decendMatrix learningRate learningMomentum learningRegulariser oldKernel kernelGradient oldMomentum 130 124 in Convolution newKernel newMomentum 131 125 132 126 createRandom = randomConvolution ··· 146 140 , KnownNat (kernelRows * kernelCols * 1) 147 141 , KnownNat (outputRows * filters) 148 142 ) => Layer (Convolution 1 filters kernelRows kernelCols strideRows strideCols) ('D2 inputRows inputCols) ('D3 outputRows outputCols filters) where 149 - runForwards (Convolution kernel _) (S2D' input) = 143 + runForwards (Convolution kernel _) (S2D input) = 150 144 let ex = extract input 151 145 ek = extract kernel 152 146 kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows) ··· 159 153 mt = c LA.<> ek 160 154 r = col2vid 1 1 1 1 ox oy mt 161 155 rs = fromJust . create $ r 162 - in S3D' rs 156 + in S3D rs 163 157 164 - runBackwards (Convolution kernel _) (S2D' input) (S3D' dEdy) = 158 + runBackwards (Convolution kernel _) (S2D input) (S3D dEdy) = 165 159 let ex = extract input 166 160 ix = fromIntegral $ natVal (Proxy :: Proxy inputRows) 167 161 iy = fromIntegral $ natVal (Proxy :: Proxy inputCols) ··· 183 177 dW = vs LA.<> tr ek 184 178 185 179 xW = col2im kx ky sx sy ix iy dW 186 - in (Convolution' kN, S2D' . fromJust . create $ xW) 180 + in (Convolution' kN, S2D . fromJust . create $ xW) 187 181 188 182 189 183 -- | A three dimensional image (or 2d with many channels) can have ··· 203 197 , KnownNat (kernelRows * kernelCols * channels) 204 198 , KnownNat (outputRows * filters) 205 199 ) => Layer (Convolution channels filters kernelRows kernelCols strideRows strideCols) ('D3 inputRows inputCols channels) ('D3 outputRows outputCols filters) where 206 - runForwards (Convolution kernel _) (S3D' input) = 200 + runForwards (Convolution kernel _) (S3D input) = 207 201 let ex = extract input 208 202 ek = extract kernel 209 203 ix = fromIntegral $ natVal (Proxy :: Proxy inputRows) ··· 219 213 mt = c LA.<> ek 220 214 r = col2vid 1 1 1 1 ox oy mt 221 215 rs = fromJust . create $ r 222 - in S3D' rs 223 - runBackwards (Convolution kernel _) (S3D' input) (S3D' dEdy) = 216 + in S3D rs 217 + runBackwards (Convolution kernel _) (S3D input) (S3D dEdy) = 224 218 let ex = extract input 225 219 ix = fromIntegral $ natVal (Proxy :: Proxy inputRows) 226 220 iy = fromIntegral $ natVal (Proxy :: Proxy inputCols) ··· 243 237 dW = vs LA.<> tr ek 244 238 245 239 xW = col2vid kx ky sx sy ix iy dW 246 - in (Convolution' kN, S3D' . fromJust . create $ xW) 240 + in (Convolution' kN, S3D . fromJust . create $ xW)
+4 -8
src/Grenade/Layers/Crop.hs
··· 4 4 {-# LANGUAGE TypeOperators #-} 5 5 {-# LANGUAGE TypeFamilies #-} 6 6 {-# LANGUAGE MultiParamTypeClasses #-} 7 - {-# LANGUAGE FlexibleInstances #-} 8 - {-# LANGUAGE FlexibleContexts #-} 9 - {-# LANGUAGE PolyKinds #-} 10 - 11 7 module Grenade.Layers.Crop ( 12 8 Crop (..) 13 9 ) where ··· 50 46 , (inputRows - cropTop - cropBottom) ~ outputRows 51 47 , (inputColumns - cropLeft - cropRight) ~ outputColumns 52 48 ) => Layer (Crop cropLeft cropTop cropRight cropBottom) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where 53 - runForwards Crop (S2D' input) = 49 + runForwards Crop (S2D input) = 54 50 let cropl = fromIntegral $ natVal (Proxy :: Proxy cropLeft) 55 51 cropt = fromIntegral $ natVal (Proxy :: Proxy cropTop) 56 52 nrows = fromIntegral $ natVal (Proxy :: Proxy outputRows) 57 53 ncols = fromIntegral $ natVal (Proxy :: Proxy outputColumns) 58 54 m = extract input 59 55 r = subMatrix (cropt, cropl) (nrows, ncols) m 60 - in S2D' . fromJust . create $ r 61 - runBackwards _ _ (S2D' dEdy) = 56 + in S2D . fromJust . create $ r 57 + runBackwards _ _ (S2D dEdy) = 62 58 let cropl = fromIntegral $ natVal (Proxy :: Proxy cropLeft) 63 59 cropt = fromIntegral $ natVal (Proxy :: Proxy cropTop) 64 60 cropr = fromIntegral $ natVal (Proxy :: Proxy cropRight) 65 61 cropb = fromIntegral $ natVal (Proxy :: Proxy cropBottom) 66 62 eo = extract dEdy 67 63 vs = diagBlock [konst 0 (cropt,cropl), eo, konst 0 (cropb,cropr)] 68 - in ((), S2D' . fromJust . create $ vs) 64 + in ((), S2D . fromJust . create $ vs)
+4 -9
src/Grenade/Layers/Dropout.hs
··· 1 1 {-# LANGUAGE DataKinds #-} 2 - {-# LANGUAGE ScopedTypeVariables #-} 3 2 {-# LANGUAGE TypeOperators #-} 4 3 {-# LANGUAGE TypeFamilies #-} 5 4 {-# LANGUAGE MultiParamTypeClasses #-} 6 - {-# LANGUAGE FlexibleContexts #-} 7 - {-# LANGUAGE FlexibleInstances #-} 8 - {-# LANGUAGE LambdaCase #-} 9 - 10 5 module Grenade.Layers.Dropout ( 11 6 Dropout (..) 12 7 , randomDropout ··· 45 40 return $ Dropout xs 46 41 47 42 instance (KnownNat i) => Layer (Dropout i) ('D1 i) ('D1 i) where 48 - runForwards (Dropout drops) (S1D' x) = S1D' $ x * drops 49 - runForwards (Pass rate) (S1D' x)= S1D' $ dvmap (* (1 - rate)) x 50 - runBackwards (Dropout drops) _ (S1D' x) = ((), S1D' $ x * drops) 51 - runBackwards (Pass rate) _ (S1D' x) = ((), S1D' $ dvmap (* (1 - rate)) x) 43 + runForwards (Dropout drops) (S1D x) = S1D $ x * drops 44 + runForwards (Pass rate) (S1D x)= S1D $ dvmap (* (1 - rate)) x 45 + runBackwards (Dropout drops) _ (S1D x) = ((), S1D $ x * drops) 46 + runBackwards (Pass rate) _ (S1D x) = ((), S1D $ dvmap (* (1 - rate)) x)
+18 -11
src/Grenade/Layers/Flatten.hs
··· 1 - {-# LANGUAGE BangPatterns #-} 2 1 {-# LANGUAGE DataKinds #-} 3 - {-# LANGUAGE ScopedTypeVariables #-} 4 - {-# LANGUAGE StandaloneDeriving #-} 5 2 {-# LANGUAGE TypeOperators #-} 6 3 {-# LANGUAGE TypeFamilies #-} 7 4 {-# LANGUAGE MultiParamTypeClasses #-} 8 5 {-# LANGUAGE FlexibleContexts #-} 9 - {-# LANGUAGE FlexibleInstances #-} 10 - 11 6 module Grenade.Layers.Flatten ( 12 7 FlattenLayer (..) 13 8 ) where ··· 16 11 import GHC.TypeLits 17 12 18 13 import Numeric.LinearAlgebra.Static 19 - import Numeric.LinearAlgebra.Data as LA (flatten, toList) 14 + import Numeric.LinearAlgebra.Data as LA ( flatten ) 20 15 21 16 import Grenade.Core.Shape 22 17 import Grenade.Core.Network 23 18 19 + -- | Flatten Layer 20 + -- 21 + -- Flattens input down to D1 from either 2D or 3D data. 22 + -- 23 + -- Can also be used to turn a 3D image with only one channel into a 2D image. 24 24 data FlattenLayer = FlattenLayer 25 25 deriving Show 26 26 ··· 29 29 runUpdate _ _ _ = FlattenLayer 30 30 createRandom = return FlattenLayer 31 31 32 - 33 32 instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * z)) => Layer FlattenLayer ('D2 x y) ('D1 a) where 34 - runForwards _ (S2D' y) = S1D' . fromList . toList . flatten . extract $ y 35 - runBackwards _ _ (S1D' y) = ((), S2D' . fromList . toList . unwrap $ y) 33 + runForwards _ (S2D y) = fromJust' . fromStorable . flatten . extract $ y 34 + runBackwards _ _ (S1D y) = ((), fromJust' . fromStorable . extract $ y) 36 35 37 36 instance (KnownNat a, KnownNat x, KnownNat y, KnownNat (x * z), KnownNat z, a ~ (x * y * z)) => Layer FlattenLayer ('D3 x y z) ('D1 a) where 38 - runForwards _ (S3D' y) = S1D' . fromList . toList . flatten . extract $ y 39 - runBackwards _ _ (S1D' y) = ((), S3D' . fromList . toList . unwrap $ y) 37 + runForwards _ (S3D y) = fromJust' . fromStorable . flatten . extract $ y 38 + runBackwards _ _ (S1D y) = ((), fromJust' . fromStorable . extract $ y) 39 + 40 + instance (KnownNat y, KnownNat x, KnownNat z, z ~ 1) => Layer FlattenLayer ('D3 x y z) ('D2 x y) where 41 + runForwards _ (S3D y) = S2D y 42 + runBackwards _ _ (S2D y) = ((), S3D y) 43 + 44 + fromJust' :: Maybe x -> x 45 + fromJust' (Just x) = x 46 + fromJust' Nothing = error $ "FlattenLayer error: data shape couldn't be converted."
+9 -13
src/Grenade/Layers/FullyConnected.hs
··· 1 1 {-# LANGUAGE DataKinds #-} 2 - {-# LANGUAGE ScopedTypeVariables #-} 3 2 {-# LANGUAGE RecordWildCards #-} 4 3 {-# LANGUAGE TypeOperators #-} 5 4 {-# LANGUAGE TypeFamilies #-} 6 5 {-# LANGUAGE MultiParamTypeClasses #-} 7 - {-# LANGUAGE FlexibleInstances #-} 8 - 9 6 module Grenade.Layers.FullyConnected ( 10 7 FullyConnected (..) 11 8 , randomFullyConnected ··· 20 17 import Grenade.Core.Network 21 18 import Grenade.Core.Shape 22 19 20 + import Grenade.Layers.Internal.Update 21 + 23 22 -- | A basic fully connected (or inner product) neural network layer. 24 23 data FullyConnected i o = FullyConnected 25 24 !(R o) -- Bias neuron weights ··· 38 37 type Gradient (FullyConnected i o) = (FullyConnected' i o) 39 38 40 39 runUpdate LearningParameters {..} (FullyConnected oldBias oldBiasMomentum oldActivations oldMomentum) (FullyConnected' biasGradient activationGradient) = 41 - let newBiasMomentum = konst learningMomentum * oldBiasMomentum - konst learningRate * biasGradient 42 - newBias = oldBias + newBiasMomentum 43 - newMomentum = konst learningMomentum * oldMomentum - konst learningRate * activationGradient 44 - regulariser = konst (learningRegulariser * learningRate) * oldActivations 45 - newActivations = oldActivations + newMomentum - regulariser 40 + let (newBias, newBiasMomentum) = decendVector learningRate learningMomentum learningRegulariser oldBias biasGradient oldBiasMomentum 41 + (newActivations, newMomentum) = decendMatrix learningRate learningMomentum learningRegulariser oldActivations activationGradient oldMomentum 46 42 in FullyConnected newBias newBiasMomentum newActivations newMomentum 47 43 48 44 createRandom = randomFullyConnected 49 45 50 46 instance (KnownNat i, KnownNat o) => Layer (FullyConnected i o) ('D1 i) ('D1 o) where 51 47 -- Do a matrix vector multiplication and return the result. 52 - runForwards (FullyConnected wB _ wN _) (S1D' v) = S1D' (wB + wN #> v) 48 + runForwards (FullyConnected wB _ wN _) (S1D v) = S1D (wB + wN #> v) 53 49 54 50 -- Run a backpropogation step for a full connected layer. 55 - runBackwards (FullyConnected _ _ wN _) (S1D' x) (S1D' dEdy) = 51 + runBackwards (FullyConnected _ _ wN _) (S1D x) (S1D dEdy) = 56 52 let wB' = dEdy 57 53 mm' = dEdy `outer` x 58 54 -- calcluate derivatives for next step 59 55 dWs = tr wN #> dEdy 60 - in (FullyConnected' wB' mm', S1D' dWs) 56 + in (FullyConnected' wB' mm', S1D dWs) 61 57 62 58 randomFullyConnected :: (MonadRandom m, KnownNat i, KnownNat o) 63 59 => m (FullyConnected i o) 64 60 randomFullyConnected = do 65 - s1 :: Int <- getRandom 66 - s2 :: Int <- getRandom 61 + s1 <- getRandom 62 + s2 <- getRandom 67 63 let wB = randomVector s1 Uniform * 2 - 1 68 64 wN = uniformSample s2 (-1) 1 69 65 bm = konst 0
+2 -7
src/Grenade/Layers/Fuse.hs
··· 1 - {-# LANGUAGE BangPatterns #-} 2 1 {-# LANGUAGE DataKinds #-} 3 2 {-# LANGUAGE GADTs #-} 4 - {-# LANGUAGE KindSignatures #-} 5 3 {-# LANGUAGE ScopedTypeVariables #-} 6 4 {-# LANGUAGE TypeOperators #-} 7 5 {-# LANGUAGE TypeFamilies #-} 8 - {-# LANGUAGE PolyKinds #-} 9 6 {-# LANGUAGE MultiParamTypeClasses #-} 10 7 {-# LANGUAGE FlexibleContexts #-} 11 8 {-# LANGUAGE FlexibleInstances #-} 12 - 13 - 14 9 module Grenade.Layers.Fuse ( 15 10 Fuse (..) 16 11 ) where ··· 42 37 43 38 instance (Layer x i h, Layer y h o) => Layer (Fuse x y i h o) i o where 44 39 runForwards (x :$$ y) input = 45 - let yInput :: S' h = runForwards x input 40 + let yInput :: S h = runForwards x input 46 41 in runForwards y yInput 47 42 48 43 runBackwards (x :$$ y) input backGradient = 49 - let yInput :: S' h = runForwards x input 44 + let yInput :: S h = runForwards x input 50 45 (y', yGrad) = runBackwards y yInput backGradient 51 46 (x', xGrad) = runBackwards x input yGrad 52 47 in ((x', y'), xGrad)
+13 -11
src/Grenade/Layers/Internal/Convolution.hs
··· 6 6 , vid2col 7 7 ) where 8 8 9 - import Foreign ( mallocForeignPtrArray0, withForeignPtr ) 9 + import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 ) 10 + 11 + import Foreign ( mallocForeignPtrArray, withForeignPtr ) 10 12 import Foreign.Ptr ( Ptr ) 11 13 12 14 import Numeric.LinearAlgebra ( Matrix, flatten, rows, cols ) ··· 28 30 col2im_c channels height width kernelRows kernelColumns strideRows strideColumns dataCol = 29 31 let vec = flatten dataCol 30 32 in unsafePerformIO $ do 31 - outPtr <- mallocForeignPtrArray0 (height * width * channels) 32 - let (inPtr, inOffset, _) = U.unsafeToForeignPtr vec 33 + outPtr <- mallocForeignPtrArray (height * width * channels) 34 + let (inPtr, _) = U.unsafeToForeignPtr0 vec 33 35 34 36 withForeignPtr inPtr $ \inPtr' -> 35 37 withForeignPtr outPtr $ \outPtr' -> 36 - col2im_cpu inPtr' inOffset channels height width kernelRows kernelColumns strideRows strideColumns outPtr' 38 + col2im_cpu inPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr' 37 39 38 - let matVec = U.unsafeFromForeignPtr outPtr 0 (height * width * channels) 40 + let matVec = U.unsafeFromForeignPtr0 outPtr (height * width * channels) 39 41 return $ U.matrixFromVector U.RowMajor (height * channels) width matVec 40 42 41 43 foreign import ccall unsafe 42 44 col2im_cpu 43 - :: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO () 45 + :: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO () 44 46 45 47 vid2col :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double 46 48 vid2col kernelRows kernelColumns strideRows strideColumns height width dataVid = ··· 63 65 kernelSize = kernelRows * kernelColumns 64 66 numberOfPatches = rowOut * colOut 65 67 in unsafePerformIO $ do 66 - outPtr <- mallocForeignPtrArray0 (numberOfPatches * kernelSize * channels) 67 - let (inPtr, inOffset, _) = U.unsafeToForeignPtr vec 68 + outPtr <- mallocForeignPtrArray (numberOfPatches * kernelSize * channels) 69 + let (inPtr, _) = U.unsafeToForeignPtr0 vec 68 70 69 71 withForeignPtr inPtr $ \inPtr' -> 70 72 withForeignPtr outPtr $ \outPtr' -> 71 - im2col_cpu inPtr' inOffset channels height width kernelRows kernelColumns strideRows strideColumns outPtr' 73 + im2col_cpu inPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr' 72 74 73 - let matVec = U.unsafeFromForeignPtr outPtr 0 (numberOfPatches * kernelSize * channels) 75 + let matVec = U.unsafeFromForeignPtr0 outPtr (numberOfPatches * kernelSize * channels) 74 76 return $ U.matrixFromVector U.RowMajor numberOfPatches (kernelSize * channels) matVec 75 77 76 78 foreign import ccall unsafe 77 79 im2col_cpu 78 - :: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO () 80 + :: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
+14 -12
src/Grenade/Layers/Internal/Pooling.hs
··· 4 4 , poolBackward 5 5 ) where 6 6 7 - import Foreign ( mallocForeignPtrArray0, withForeignPtr ) 7 + import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 ) 8 + 9 + import Foreign ( mallocForeignPtrArray, withForeignPtr ) 8 10 import Foreign.Ptr ( Ptr ) 9 11 10 12 import Numeric.LinearAlgebra ( Matrix , flatten ) ··· 19 21 colOut = (width - kernelColumns) `div` strideColumns + 1 20 22 numberOfPatches = rowOut * colOut 21 23 in unsafePerformIO $ do 22 - outPtr <- mallocForeignPtrArray0 (numberOfPatches * channels) 23 - let (inPtr, inOffset, _) = U.unsafeToForeignPtr vec 24 + outPtr <- mallocForeignPtrArray (numberOfPatches * channels) 25 + let (inPtr, _) = U.unsafeToForeignPtr0 vec 24 26 25 27 withForeignPtr inPtr $ \inPtr' -> 26 28 withForeignPtr outPtr $ \outPtr' -> 27 - pool_forwards_cpu inPtr' inOffset channels height width kernelRows kernelColumns strideRows strideColumns outPtr' 29 + pool_forwards_cpu inPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr' 28 30 29 - let matVec = U.unsafeFromForeignPtr outPtr 0 (numberOfPatches * channels) 31 + let matVec = U.unsafeFromForeignPtr0 outPtr (numberOfPatches * channels) 30 32 return $ U.matrixFromVector U.RowMajor (rowOut * channels) colOut matVec 31 33 32 34 foreign import ccall unsafe 33 35 pool_forwards_cpu 34 - :: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO () 36 + :: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO () 35 37 36 38 poolBackward :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double 37 39 poolBackward channels height width kernelRows kernelColumns strideRows strideColumns dataIm dataGrad = 38 40 let vecIm = flatten dataIm 39 41 vecGrad = flatten dataGrad 40 42 in unsafePerformIO $ do 41 - outPtr <- mallocForeignPtrArray0 (height * width * channels) 42 - let (imPtr, imOffset, _) = U.unsafeToForeignPtr vecIm 43 - let (gradPtr, gradOffset, _) = U.unsafeToForeignPtr vecGrad 43 + outPtr <- mallocForeignPtrArray (height * width * channels) 44 + let (imPtr, _) = U.unsafeToForeignPtr0 vecIm 45 + let (gradPtr, _) = U.unsafeToForeignPtr0 vecGrad 44 46 45 47 withForeignPtr imPtr $ \imPtr' -> 46 48 withForeignPtr gradPtr $ \gradPtr' -> 47 49 withForeignPtr outPtr $ \outPtr' -> 48 - pool_backwards_cpu imPtr' imOffset gradPtr' gradOffset channels height width kernelRows kernelColumns strideRows strideColumns outPtr' 50 + pool_backwards_cpu imPtr' gradPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr' 49 51 50 - let matVec = U.unsafeFromForeignPtr outPtr 0 (height * width * channels) 52 + let matVec = U.unsafeFromForeignPtr0 outPtr (height * width * channels) 51 53 return $ U.matrixFromVector U.RowMajor (height * channels) width matVec 52 54 53 55 foreign import ccall unsafe 54 56 pool_backwards_cpu 55 - :: Ptr Double -> Int -> Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO () 57 + :: Ptr Double -> Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
+70
src/Grenade/Layers/Internal/Update.hs
··· 1 + {-# LANGUAGE ForeignFunctionInterface #-} 2 + module Grenade.Layers.Internal.Update ( 3 + decendMatrix 4 + , decendVector 5 + ) where 6 + 7 + import Data.Maybe ( fromJust ) 8 + import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 ) 9 + 10 + import Foreign ( mallocForeignPtrArray, withForeignPtr ) 11 + import Foreign.Ptr ( Ptr ) 12 + import GHC.TypeLits 13 + 14 + import Numeric.LinearAlgebra ( Vector, flatten ) 15 + import Numeric.LinearAlgebra.Static 16 + import qualified Numeric.LinearAlgebra.Devel as U 17 + 18 + import System.IO.Unsafe ( unsafePerformIO ) 19 + 20 + decendMatrix :: (KnownNat rows, KnownNat columns) => Double -> Double -> Double -> L rows columns -> L rows columns -> L rows columns -> (L rows columns, L rows columns) 21 + decendMatrix rate momentum regulariser weights gradient lastUpdate = 22 + let (rows, cols) = size weights 23 + len = rows * cols 24 + -- Most gradients come in in ColumnMajor, 25 + -- so we'll transpose here before flattening them 26 + -- into a vector to prevent a copy. 27 + -- 28 + -- This gives ~15% speed improvement for LSTMs. 29 + weights' = flatten . tr . extract $ weights 30 + gradient' = flatten . tr . extract $ gradient 31 + lastUpdate' = flatten . tr . extract $ lastUpdate 32 + (vw, vm) = decendUnsafe len rate momentum regulariser weights' gradient' lastUpdate' 33 + 34 + -- Note that it's ColumnMajor, as we did a transpose before 35 + -- using the internal vectors. 36 + mw = U.matrixFromVector U.ColumnMajor rows cols vw 37 + mm = U.matrixFromVector U.ColumnMajor rows cols vm 38 + in (fromJust . create $ mw, fromJust . create $ mm) 39 + 40 + decendVector :: (KnownNat r) => Double -> Double -> Double -> R r -> R r -> R r -> (R r, R r) 41 + decendVector rate momentum regulariser weights gradient lastUpdate = 42 + let len = size weights 43 + weights' = extract weights 44 + gradient' = extract gradient 45 + lastUpdate' = extract lastUpdate 46 + (vw, vm) = decendUnsafe len rate momentum regulariser weights' gradient' lastUpdate' 47 + in (fromJust $ create vw, fromJust $ create vm) 48 + 49 + decendUnsafe :: Int -> Double -> Double -> Double -> Vector Double -> Vector Double -> Vector Double -> (Vector Double, Vector Double) 50 + decendUnsafe len rate momentum regulariser weights gradient lastUpdate = 51 + unsafePerformIO $ do 52 + outWPtr <- mallocForeignPtrArray len 53 + outMPtr <- mallocForeignPtrArray len 54 + let (wPtr, _) = U.unsafeToForeignPtr0 weights 55 + let (gPtr, _) = U.unsafeToForeignPtr0 gradient 56 + let (lPtr, _) = U.unsafeToForeignPtr0 lastUpdate 57 + 58 + withForeignPtr wPtr $ \wPtr' -> 59 + withForeignPtr gPtr $ \gPtr' -> 60 + withForeignPtr lPtr $ \lPtr' -> 61 + withForeignPtr outWPtr $ \outWPtr' -> 62 + withForeignPtr outMPtr $ \outMPtr' -> 63 + decend_cpu len rate momentum regulariser wPtr' gPtr' lPtr' outWPtr' outMPtr' 64 + 65 + return (U.unsafeFromForeignPtr0 outWPtr len, U.unsafeFromForeignPtr0 outMPtr len) 66 + 67 + foreign import ccall unsafe 68 + decend_cpu 69 + :: Int -> Double -> Double -> Double -> Ptr Double -> Ptr Double -> Ptr Double -> Ptr Double -> Ptr Double -> IO () 70 +
+6 -10
src/Grenade/Layers/Logit.hs
··· 1 1 {-# LANGUAGE DataKinds #-} 2 - {-# LANGUAGE ScopedTypeVariables #-} 3 2 {-# LANGUAGE TypeOperators #-} 4 3 {-# LANGUAGE TypeFamilies #-} 5 4 {-# LANGUAGE MultiParamTypeClasses #-} 6 - {-# LANGUAGE FlexibleInstances #-} 7 - 8 5 module Grenade.Layers.Logit ( 9 6 Logit (..) 10 7 ) where ··· 27 24 createRandom = return Logit 28 25 29 26 instance (KnownNat i) => Layer Logit ('D1 i) ('D1 i) where 30 - runForwards _ (S1D' y) = S1D' (logistic y) 31 - runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (logistic' y * dEdy)) 27 + runForwards _ (S1D y) = S1D (logistic y) 28 + runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (logistic' y * dEdy)) 32 29 33 30 instance (KnownNat i, KnownNat j) => Layer Logit ('D2 i j) ('D2 i j) where 34 - runForwards _ (S2D' y) = S2D' (logistic y) 35 - runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (logistic' y * dEdy)) 31 + runForwards _ (S2D y) = S2D (logistic y) 32 + runBackwards _ (S2D y) (S2D dEdy) = ((), S2D (logistic' y * dEdy)) 36 33 37 34 instance (KnownNat i, KnownNat j, KnownNat k) => Layer Logit ('D3 i j k) ('D3 i j k) where 38 - runForwards _ (S3D' y) = S3D' (logistic y) 39 - runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (logistic' y * dEdy)) 40 - 35 + runForwards _ (S3D y) = S3D (logistic y) 36 + runBackwards _ (S3D y) (S3D dEdy) = ((), S3D (logistic' y * dEdy)) 41 37 42 38 logistic :: Floating a => a -> a 43 39 logistic x = 1 / (1 + exp (-x))
+4 -8
src/Grenade/Layers/Pad.hs
··· 4 4 {-# LANGUAGE TypeOperators #-} 5 5 {-# LANGUAGE TypeFamilies #-} 6 6 {-# LANGUAGE MultiParamTypeClasses #-} 7 - {-# LANGUAGE FlexibleInstances #-} 8 - {-# LANGUAGE FlexibleContexts #-} 9 - {-# LANGUAGE PolyKinds #-} 10 - 11 7 module Grenade.Layers.Pad ( 12 8 Pad (..) 13 9 ) where ··· 50 46 , (inputRows + padTop + padBottom) ~ outputRows 51 47 , (inputColumns + padLeft + padRight) ~ outputColumns 52 48 ) => Layer (Pad padLeft padTop padRight padBottom) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where 53 - runForwards Pad (S2D' input) = 49 + runForwards Pad (S2D input) = 54 50 let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft) 55 51 padt = fromIntegral $ natVal (Proxy :: Proxy padTop) 56 52 padr = fromIntegral $ natVal (Proxy :: Proxy padRight) 57 53 padb = fromIntegral $ natVal (Proxy :: Proxy padBottom) 58 54 m = extract input 59 55 r = diagBlock [konst 0 (padt,padl), m, konst 0 (padb,padr)] 60 - in S2D' . fromJust . create $ r 61 - runBackwards Pad _ (S2D' dEdy) = 56 + in S2D . fromJust . create $ r 57 + runBackwards Pad _ (S2D dEdy) = 62 58 let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft) 63 59 padt = fromIntegral $ natVal (Proxy :: Proxy padTop) 64 60 nrows = fromIntegral $ natVal (Proxy :: Proxy inputRows) 65 61 ncols = fromIntegral $ natVal (Proxy :: Proxy inputColumns) 66 62 m = extract dEdy 67 63 vs = subMatrix (padt, padl) (nrows, ncols) m 68 - in ((), S2D' . fromJust . create $ vs) 64 + in ((), S2D . fromJust . create $ vs)
+8 -12
src/Grenade/Layers/Pooling.hs
··· 1 - {-# LANGUAGE BangPatterns #-} 2 1 {-# LANGUAGE DataKinds #-} 3 2 {-# LANGUAGE ScopedTypeVariables #-} 4 3 {-# LANGUAGE StandaloneDeriving #-} ··· 6 5 {-# LANGUAGE TypeOperators #-} 7 6 {-# LANGUAGE TypeFamilies #-} 8 7 {-# LANGUAGE MultiParamTypeClasses #-} 9 - {-# LANGUAGE FlexibleInstances #-} 10 8 {-# LANGUAGE FlexibleContexts #-} 11 - {-# LANGUAGE PolyKinds #-} 12 - 13 9 module Grenade.Layers.Pooling ( 14 10 Pooling (..) 15 11 ) where ··· 55 51 , ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows) 56 52 , ((outputColumns - 1) * strideColumns) ~ (inputColumns - kernelColumns) 57 53 ) => Layer (Pooling kernelRows kernelColumns strideRows strideColumns) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where 58 - runForwards Pooling (S2D' input) = 54 + runForwards Pooling (S2D input) = 59 55 let height = fromIntegral $ natVal (Proxy :: Proxy inputRows) 60 56 width = fromIntegral $ natVal (Proxy :: Proxy inputColumns) 61 57 kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows) ··· 65 61 ex = extract input 66 62 r = poolForward 1 height width kx ky sx sy ex 67 63 rs = fromJust . create $ r 68 - in S2D' $ rs 69 - runBackwards Pooling (S2D' input) (S2D' dEdy) = 64 + in S2D $ rs 65 + runBackwards Pooling (S2D input) (S2D dEdy) = 70 66 let height = fromIntegral $ natVal (Proxy :: Proxy inputRows) 71 67 width = fromIntegral $ natVal (Proxy :: Proxy inputColumns) 72 68 kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows) ··· 76 72 ex = extract input 77 73 eo = extract dEdy 78 74 vs = poolBackward 1 height width kx ky sx sy ex eo 79 - in ((), S2D' . fromJust . create $ vs) 75 + in ((), S2D . fromJust . create $ vs) 80 76 81 77 82 78 -- | A three dimensional image can be pooled on each layer. ··· 93 89 , ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows) 94 90 , ((outputColumns - 1) * strideColumns) ~ (inputColumns - kernelColumns) 95 91 ) => Layer (Pooling kernelRows kernelColumns strideRows strideColumns) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) where 96 - runForwards Pooling (S3D' input) = 92 + runForwards Pooling (S3D input) = 97 93 let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows) 98 94 iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns) 99 95 kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows) ··· 104 100 ex = extract input 105 101 r = poolForward ch ix iy kx ky sx sy ex 106 102 rs = fromJust . create $ r 107 - in S3D' rs 108 - runBackwards Pooling (S3D' input) (S3D' dEdy) = 103 + in S3D rs 104 + runBackwards Pooling (S3D input) (S3D dEdy) = 109 105 let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows) 110 106 iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns) 111 107 kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows) ··· 116 112 ex = extract input 117 113 eo = extract dEdy 118 114 vs = poolBackward ch ix iy kx ky sx sy ex eo 119 - in ((), S3D' . fromJust . create $ vs) 115 + in ((), S3D . fromJust . create $ vs)
+6 -9
src/Grenade/Layers/Relu.hs
··· 1 1 {-# LANGUAGE DataKinds #-} 2 - {-# LANGUAGE ScopedTypeVariables #-} 3 2 {-# LANGUAGE TypeOperators #-} 4 3 {-# LANGUAGE TypeFamilies #-} 5 4 {-# LANGUAGE MultiParamTypeClasses #-} 6 - {-# LANGUAGE FlexibleInstances #-} 7 - 8 5 module Grenade.Layers.Relu ( 9 6 Relu (..) 10 7 ) where ··· 27 24 createRandom = return Relu 28 25 29 26 instance ( KnownNat i) => Layer Relu ('D1 i) ('D1 i) where 30 - runForwards _ (S1D' y) = S1D' (relu y) 27 + runForwards _ (S1D y) = S1D (relu y) 31 28 where 32 29 relu = LAS.dvmap (\a -> if a <= 0 then 0 else a) 33 - runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (relu' y * dEdy)) 30 + runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (relu' y * dEdy)) 34 31 where 35 32 relu' = LAS.dvmap (\a -> if a <= 0 then 0 else 1) 36 33 37 34 instance (KnownNat i, KnownNat j) => Layer Relu ('D2 i j) ('D2 i j) where 38 - runForwards _ (S2D' y) = S2D' (relu y) 35 + runForwards _ (S2D y) = S2D (relu y) 39 36 where 40 37 relu = LAS.dmmap (\a -> if a <= 0 then 0 else a) 41 - runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (relu' y * dEdy)) 38 + runBackwards _ (S2D y) (S2D dEdy) = ((), S2D (relu' y * dEdy)) 42 39 where 43 40 relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1) 44 41 45 42 instance (KnownNat i, KnownNat j, KnownNat k) => Layer Relu ('D3 i j k) ('D3 i j k) where 46 - runForwards _ (S3D' y) = S3D' (relu y) 43 + runForwards _ (S3D y) = S3D (relu y) 47 44 where 48 45 relu = LAS.dmmap (\a -> if a <= 0 then 0 else a) 49 - runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (relu' y * dEdy)) 46 + runBackwards _ (S3D y) (S3D dEdy) = ((), S3D (relu' y * dEdy)) 50 47 where 51 48 relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
+6 -9
src/Grenade/Layers/Tanh.hs
··· 1 1 {-# LANGUAGE DataKinds #-} 2 - {-# LANGUAGE ScopedTypeVariables #-} 3 2 {-# LANGUAGE TypeOperators #-} 4 3 {-# LANGUAGE TypeFamilies #-} 5 4 {-# LANGUAGE MultiParamTypeClasses #-} 6 - {-# LANGUAGE FlexibleInstances #-} 7 - 8 5 module Grenade.Layers.Tanh ( 9 6 Tanh (..) 10 7 ) where ··· 24 21 createRandom = return Tanh 25 22 26 23 instance KnownNat i => Layer Tanh ('D1 i) ('D1 i) where 27 - runForwards _ (S1D' y) = S1D' (tanh y) 28 - runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (tanh' y * dEdy)) 24 + runForwards _ (S1D y) = S1D (tanh y) 25 + runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (tanh' y * dEdy)) 29 26 30 27 instance (KnownNat i, KnownNat j) => Layer Tanh ('D2 i j) ('D2 i j) where 31 - runForwards _ (S2D' y) = S2D' (tanh y) 32 - runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (tanh' y * dEdy)) 28 + runForwards _ (S2D y) = S2D (tanh y) 29 + runBackwards _ (S2D y) (S2D dEdy) = ((), S2D (tanh' y * dEdy)) 33 30 34 31 instance (KnownNat i, KnownNat j, KnownNat k) => Layer Tanh ('D3 i j k) ('D3 i j k) where 35 - runForwards _ (S3D' y) = S3D' (tanh y) 36 - runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (tanh' y * dEdy)) 32 + runForwards _ (S3D y) = S3D (tanh y) 33 + runBackwards _ (S3D y) (S3D dEdy) = ((), S3D (tanh' y * dEdy)) 37 34 38 35 tanh' :: (Floating a) => a -> a 39 36 tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
+9
src/Grenade/Recurrent.hs
··· 1 + module Grenade.Recurrent ( 2 + module X 3 + ) where 4 + 5 + import Grenade.Recurrent.Core.Network as X 6 + import Grenade.Recurrent.Core.Runner as X 7 + import Grenade.Recurrent.Layers.BasicRecurrent as X 8 + import Grenade.Recurrent.Layers.LSTM as X 9 + import Grenade.Recurrent.Layers.Trivial as X
+98
src/Grenade/Recurrent/Core/Network.hs
··· 1 + {-# LANGUAGE DataKinds #-} 2 + {-# LANGUAGE GADTs #-} 3 + {-# LANGUAGE TypeOperators #-} 4 + {-# LANGUAGE TypeFamilies #-} 5 + {-# LANGUAGE MultiParamTypeClasses #-} 6 + {-# LANGUAGE FlexibleContexts #-} 7 + {-# LANGUAGE FlexibleInstances #-} 8 + {-# LANGUAGE EmptyDataDecls #-} 9 + module Grenade.Recurrent.Core.Network ( 10 + Recurrent 11 + , FeedForward 12 + , RecurrentLayer (..) 13 + , RecurrentUpdateLayer (..) 14 + , RecurrentNetwork (..) 15 + , RecurrentInputs (..) 16 + , CreatableRecurrent (..) 17 + ) where 18 + 19 + 20 + import Control.Monad.Random ( MonadRandom ) 21 + import Data.Singletons ( SingI ) 22 + 23 + import Grenade.Core.Shape 24 + import Grenade.Core.Network 25 + 26 + 27 + -- | Witness type to say indicate we're building up with a normal feed 28 + -- forward layer. 29 + data FeedForward :: * -> * 30 + -- | Witness type to say indicate we're building up with a recurrent layer. 31 + data Recurrent :: * -> * 32 + 33 + -- | Class for a recurrent layer. 34 + -- It's quite similar to a normal layer but for the input and output 35 + -- of an extra recurrent data shape. 36 + class UpdateLayer x => RecurrentUpdateLayer x where 37 + -- | Shape of data that is passed between each subsequent run of the layer 38 + type RecurrentShape x :: Shape 39 + 40 + class (RecurrentUpdateLayer x, SingI (RecurrentShape x)) => RecurrentLayer x (i :: Shape) (o :: Shape) where 41 + -- | Used in training and scoring. Take the input from the previous 42 + -- layer, and give the output from this layer. 43 + runRecurrentForwards :: x -> S (RecurrentShape x) -> S i -> (S (RecurrentShape x), S o) 44 + -- | Back propagate a step. Takes the current layer, the input that the 45 + -- layer gave from the input and the back propagated derivatives from 46 + -- the layer above. 47 + -- Returns the gradient layer and the derivatives to push back further. 48 + runRecurrentBackwards :: x -> S (RecurrentShape x) -> S i -> S (RecurrentShape x) -> S o -> (Gradient x, S (RecurrentShape x), S i) 49 + 50 + data RecurrentNetwork :: [*] -> [Shape] -> * where 51 + OR :: (SingI i, SingI o, Layer x i o) => !x -> RecurrentNetwork '[FeedForward x] '[i, o] 52 + (:~~>) :: (SingI i, Layer x i h) => !x -> !(RecurrentNetwork xs (h ': hs)) -> RecurrentNetwork (FeedForward x ': xs) (i ': h ': hs) 53 + (:~@>) :: (SingI i, RecurrentLayer x i h) => !x -> !(RecurrentNetwork xs (h ': hs)) -> RecurrentNetwork (Recurrent x ': xs) (i ': h ': hs) 54 + infixr 5 :~~> 55 + infixr 5 :~@> 56 + 57 + instance Show (RecurrentNetwork l h) where 58 + show (OR a) = "OR " ++ show a 59 + show (i :~~> o) = show i ++ "\n:~~>\n" ++ show o 60 + show (i :~@> o) = show i ++ "\n:~@>\n" ++ show o 61 + 62 + 63 + -- | Recurrent inputs (sideways shapes on an imaginary unrolled graph) 64 + -- Parameterised on the layers of a Network. 65 + data RecurrentInputs :: [*] -> * where 66 + ORS :: UpdateLayer x 67 + => () -> RecurrentInputs '[FeedForward x] 68 + (:~~+>) :: UpdateLayer x 69 + => () -> !(RecurrentInputs xs) -> RecurrentInputs (FeedForward x ': xs) 70 + (:~@+>) :: (SingI (RecurrentShape x), RecurrentUpdateLayer x) 71 + => !(S (RecurrentShape x)) -> !(RecurrentInputs xs) -> RecurrentInputs (Recurrent x ': xs) 72 + infixr 5 :~~+> 73 + infixr 5 :~@+> 74 + 75 + -- | A network can easily be created by hand with (:~~>) and (:~@>), but an easy way to initialise a random 76 + -- recurrent network and a set of random inputs for it is with the randomRecurrent. 77 + class CreatableRecurrent (xs :: [*]) (ss :: [Shape]) where 78 + -- | Create a network of the types requested 79 + randomRecurrent :: MonadRandom m => m (RecurrentNetwork xs ss, RecurrentInputs xs) 80 + 81 + instance (SingI i, SingI o, Layer x i o) => CreatableRecurrent (FeedForward x ': '[]) (i ': o ': '[]) where 82 + randomRecurrent = do 83 + thisLayer <- createRandom 84 + return (OR thisLayer, ORS ()) 85 + 86 + instance (SingI i, Layer x i o, CreatableRecurrent xs (o ': r ': rs)) => CreatableRecurrent (FeedForward x ': xs) (i ': o ': r ': rs) where 87 + randomRecurrent = do 88 + thisLayer <- createRandom 89 + (rest, resti) <- randomRecurrent 90 + return (thisLayer :~~> rest, () :~~+> resti) 91 + 92 + instance (SingI i, RecurrentLayer x i o, CreatableRecurrent xs (o ': r ': rs)) => CreatableRecurrent (Recurrent x ': xs) (i ': o ': r ': rs) where 93 + randomRecurrent = do 94 + thisLayer <- createRandom 95 + thisShape <- randomOfShape 96 + (rest, resti) <- randomRecurrent 97 + return (thisLayer :~@> rest, thisShape :~@+> resti) 98 +
+144
src/Grenade/Recurrent/Core/Runner.hs
··· 1 + {-# LANGUAGE BangPatterns #-} 2 + {-# LANGUAGE GADTs #-} 3 + {-# LANGUAGE DataKinds #-} 4 + {-# LANGUAGE ScopedTypeVariables #-} 5 + {-# LANGUAGE TypeOperators #-} 6 + {-# LANGUAGE TypeFamilies #-} 7 + {-# LANGUAGE FlexibleContexts #-} 8 + {-# LANGUAGE RankNTypes #-} 9 + {-# LANGUAGE RecordWildCards #-} 10 + module Grenade.Recurrent.Core.Runner ( 11 + trainRecurrent 12 + , runRecurrent 13 + ) where 14 + 15 + import Data.Singletons.Prelude 16 + import Grenade.Core.Network 17 + import Grenade.Core.Shape 18 + 19 + import Grenade.Recurrent.Core.Network 20 + 21 + -- | Drive and network and collect its back propogated gradients. 22 + trainRecurrent :: forall shapes layers. SingI (Last shapes) 23 + => LearningParameters 24 + -> RecurrentNetwork layers shapes 25 + -> RecurrentInputs layers 26 + -> [(S (Head shapes), Maybe (S (Last shapes)))] 27 + -> (RecurrentNetwork layers shapes, RecurrentInputs layers) 28 + trainRecurrent rate network recinputs examples = 29 + updateBack $ go inputs network recinputs 30 + where 31 + inputs = fst <$> examples 32 + targets = snd <$> examples 33 + updateBack (a,recgrad,_) = (a,updateRecInputs rate recinputs recgrad) 34 + 35 + go :: forall js sublayers. (Last js ~ Last shapes) 36 + => [S (Head js)] -- ^ input vector 37 + -> RecurrentNetwork sublayers js -- ^ network to train 38 + -> RecurrentInputs sublayers 39 + -> (RecurrentNetwork sublayers js, RecurrentInputs sublayers, [S (Head js)]) 40 + 41 + -- This is a simple non-recurrent layer, just map it forwards 42 + -- Note we're doing training here, we could just return a list of gradients 43 + -- (and probably will in future). 44 + go !xs (layer :~~> n) (() :~~+> nIn) 45 + = let ys = runForwards layer <$> xs 46 + -- recursively run the rest of the network, and get the gradients from above. 47 + (newFN, ig, grads) = go ys n nIn 48 + -- calculate the gradient for this layer to pass down, 49 + back = uncurry (runBackwards layer) <$> zip (reverse xs) grads 50 + -- the new trained layer. 51 + newlayer = runUpdates rate layer (fst <$> back) 52 + 53 + in (newlayer :~~> newFN, () :~~+> ig, snd <$> back) 54 + 55 + -- This is a recurrent layer, so we need to do a scan, first input to last, providing 56 + -- the recurrent shape output to the next layer. 57 + go !xs (layer :~@> n) (g :~@+> nIn) 58 + = let ys = scanlFrom layer g xs 59 + 60 + (newFN, ig, grads) = go (snd <$> ys) n nIn 61 + 62 + backExamples = zip3 (fst <$> reverse ys) (reverse xs) grads 63 + 64 + (rg, back) = myscanbackward layer backExamples 65 + -- the new trained layer. 66 + newlayer = runUpdates rate layer (fst <$> back) 67 + in (newlayer :~@> newFN, rg :~@+> ig, snd <$> back) 68 + 69 + -- Handle the output layer, bouncing the derivatives back down. 70 + -- We may not have a target for each example, so when we don't use 0 gradient. 71 + go !xs (OR layer) (ORS ()) 72 + = let ys = runForwards layer <$> xs 73 + -- recursively run the rest of the network, and get the gradients from above. 74 + back = uncurry (runBackwards layer) <$> zip xs (zipWith makeError ys targets) 75 + -- the new trained layer. 76 + newlayer = runUpdates rate layer (reverse $ fst <$> back) 77 + in (OR newlayer, ORS (), reverse (snd <$> back)) 78 + 79 + go _ _ _ = 80 + error "Impossible for network and recurrent inputs to have different shapes" 81 + 82 + 83 + makeError :: S (Last shapes) -> Maybe (S (Last shapes)) -> S (Last shapes) 84 + makeError _ Nothing = 0 85 + makeError y (Just t) = y - t 86 + 87 + updateRecInputs :: forall sublayers. 88 + LearningParameters 89 + -> RecurrentInputs sublayers 90 + -> RecurrentInputs sublayers 91 + -> RecurrentInputs sublayers 92 + 93 + updateRecInputs l@LearningParameters {..} (() :~~+> xs) (() :~~+> ys) 94 + = () :~~+> updateRecInputs l xs ys 95 + 96 + updateRecInputs l@LearningParameters {..} (x :~@+> xs) (y :~@+> ys) 97 + = (x - realToFrac learningRate * y) :~@+> updateRecInputs l xs ys 98 + 99 + updateRecInputs _ (ORS ()) (ORS ()) 100 + = ORS () 101 + updateRecInputs _ _ _ 102 + = error "Impossible for updateRecInputs to have different shapes" 103 + 104 + scanlFrom :: forall x i o. RecurrentLayer x i o 105 + => x -- ^ the layer 106 + -> S (RecurrentShape x) -- ^ place to start 107 + -> [S i] -- ^ list of inputs to scan through 108 + -> [(S (RecurrentShape x), S o)] -- ^ list of scan inputs and outputs 109 + scanlFrom !layer !recShape (x:xs) = 110 + let (lerec, lepush) = runRecurrentForwards layer recShape x 111 + in (recShape, lepush) : scanlFrom layer lerec xs 112 + scanlFrom _ _ [] = [] 113 + 114 + myscanbackward :: forall x i o. RecurrentLayer x i o 115 + => x -- ^ the layer 116 + -> [(S (RecurrentShape x), S i, S o)] -- ^ the list of inputs and output to scan over 117 + -> (S (RecurrentShape x), [(Gradient x, S i)]) -- ^ list of gradients to fold and inputs to backprop 118 + myscanbackward layer = 119 + goX 0 120 + where 121 + goX :: S (RecurrentShape x) -> [(S (RecurrentShape x), S i, S o)] -> (S (RecurrentShape x), [(Gradient x, S i)]) 122 + goX !lastback ((recShape, lastin, backgrad):xs) = 123 + let (layergrad, recgrad, ingrad) = runRecurrentBackwards layer recShape lastin lastback backgrad 124 + (pushedback, ll) = goX recgrad xs 125 + in (pushedback, (layergrad, ingrad) : ll) 126 + goX !lastback [] = (lastback, []) 127 + 128 + -- | Just forwards propagation with no training. 129 + runRecurrent :: RecurrentNetwork layers shapes 130 + -> RecurrentInputs layers -> S (Head shapes) 131 + -> (RecurrentInputs layers, S (Last shapes)) 132 + runRecurrent (layer :~~> n) (() :~~+> nr) !x 133 + = let ys = runForwards layer x 134 + (nr', o) = runRecurrent n nr ys 135 + in (() :~~+> nr', o) 136 + runRecurrent (layer :~@> n) (recin :~@+> nr) !x 137 + = let (recin', y) = runRecurrentForwards layer recin x 138 + (nr', o) = runRecurrent n nr y 139 + in (recin' :~@+> nr', o) 140 + runRecurrent (OR layer) (ORS ()) !x 141 + = (ORS (), runForwards layer x) 142 + 143 + runRecurrent _ _ _ 144 + = error "Impossible for the gradients of a network to have a different length or shape to the network"
+92
src/Grenade/Recurrent/Layers/BasicRecurrent.hs
··· 1 + {-# LANGUAGE DataKinds #-} 2 + {-# LANGUAGE GADTs #-} 3 + {-# LANGUAGE RecordWildCards #-} 4 + {-# LANGUAGE TypeOperators #-} 5 + {-# LANGUAGE TypeFamilies #-} 6 + {-# LANGUAGE MultiParamTypeClasses #-} 7 + {-# LANGUAGE FlexibleContexts #-} 8 + {-# LANGUAGE UndecidableInstances #-} 9 + module Grenade.Recurrent.Layers.BasicRecurrent ( 10 + BasicRecurrent (..) 11 + , randomBasicRecurrent 12 + ) where 13 + 14 + import Control.Monad.Random ( MonadRandom, getRandom ) 15 + 16 + import Data.Singletons.TypeLits 17 + 18 + import Numeric.LinearAlgebra.Static 19 + 20 + import GHC.TypeLits 21 + 22 + import Grenade.Core.Network 23 + import Grenade.Core.Shape 24 + import Grenade.Recurrent.Core.Network 25 + 26 + data BasicRecurrent :: Nat -- Input layer size 27 + -> Nat -- Output layer size 28 + -> * where 29 + BasicRecurrent :: ( KnownNat input 30 + , KnownNat output 31 + , KnownNat matrixCols 32 + , matrixCols ~ (input + output)) 33 + => !(R output) -- Bias neuron weights 34 + -> !(R output) -- Bias neuron momentum 35 + -> !(L output matrixCols) -- Activation 36 + -> !(L output matrixCols) -- Momentum 37 + -> BasicRecurrent input output 38 + 39 + data BasicRecurrent' :: Nat -- Input layer size 40 + -> Nat -- Output layer size 41 + -> * where 42 + BasicRecurrent' :: ( KnownNat input 43 + , KnownNat output 44 + , KnownNat matrixCols 45 + , matrixCols ~ (input + output)) 46 + => !(R output) -- Bias neuron gradients 47 + -> !(L output matrixCols) 48 + -> BasicRecurrent' input output 49 + 50 + instance Show (BasicRecurrent i o) where 51 + show BasicRecurrent {} = "BasicRecurrent" 52 + 53 + instance (KnownNat i, KnownNat o, KnownNat (i + o)) => UpdateLayer (BasicRecurrent i o) where 54 + type Gradient (BasicRecurrent i o) = (BasicRecurrent' i o) 55 + 56 + runUpdate LearningParameters {..} (BasicRecurrent oldBias oldBiasMomentum oldActivations oldMomentum) (BasicRecurrent' biasGradient activationGradient) = 57 + let newBiasMomentum = konst learningMomentum * oldBiasMomentum - konst learningRate * biasGradient 58 + newBias = oldBias + newBiasMomentum 59 + newMomentum = konst learningMomentum * oldMomentum - konst learningRate * activationGradient 60 + regulariser = konst (learningRegulariser * learningRate) * oldActivations 61 + newActivations = oldActivations + newMomentum - regulariser 62 + in BasicRecurrent newBias newBiasMomentum newActivations newMomentum 63 + 64 + createRandom = randomBasicRecurrent 65 + 66 + instance (KnownNat i, KnownNat o, KnownNat (i + o), i <= (i + o), o ~ ((i + o) - i)) => RecurrentUpdateLayer (BasicRecurrent i o) where 67 + type RecurrentShape (BasicRecurrent i o) = 'D1 o 68 + 69 + instance (KnownNat i, KnownNat o, KnownNat (i + o), i <= (i + o), o ~ ((i + o) - i)) => RecurrentLayer (BasicRecurrent i o) ('D1 i) ('D1 o) where 70 + -- Do a matrix vector multiplication and return the result. 71 + runRecurrentForwards (BasicRecurrent wB _ wN _) (S1D lastOutput) (S1D thisInput) = 72 + let thisOutput = S1D $ wB + wN #> (thisInput # lastOutput) 73 + in (thisOutput, thisOutput) 74 + 75 + -- Run a backpropogation step for a full connected layer. 76 + runRecurrentBackwards (BasicRecurrent _ _ wN _) (S1D lastOutput) (S1D thisInput) (S1D dRec) (S1D dEdy) = 77 + let biasGradient = (dRec + dEdy) 78 + layerGrad = (dRec + dEdy) `outer` (thisInput # lastOutput) 79 + -- calcluate derivatives for next step 80 + (backGrad, recGrad) = split $ tr wN #> (dRec + dEdy) 81 + in (BasicRecurrent' biasGradient layerGrad, S1D recGrad, S1D backGrad) 82 + 83 + randomBasicRecurrent :: (MonadRandom m, KnownNat i, KnownNat o, KnownNat x, x ~ (i + o)) 84 + => m (BasicRecurrent i o) 85 + randomBasicRecurrent = do 86 + seed1 <- getRandom 87 + seed2 <- getRandom 88 + let wB = randomVector seed1 Uniform * 2 - 1 89 + wN = uniformSample seed2 (-1) 1 90 + bm = konst 0 91 + mm = konst 0 92 + return $ BasicRecurrent wB bm wN mm
+244
src/Grenade/Recurrent/Layers/LSTM.hs
··· 1 + {-# LANGUAGE BangPatterns #-} 2 + {-# LANGUAGE DataKinds #-} 3 + {-# LANGUAGE GADTs #-} 4 + {-# LANGUAGE RankNTypes #-} 5 + {-# LANGUAGE RecordWildCards #-} 6 + {-# LANGUAGE TypeOperators #-} 7 + {-# LANGUAGE TypeFamilies #-} 8 + {-# LANGUAGE MultiParamTypeClasses #-} 9 + {-# LANGUAGE FlexibleContexts #-} 10 + {-# LANGUAGE ViewPatterns #-} 11 + module Grenade.Recurrent.Layers.LSTM ( 12 + LSTM (..) 13 + , LSTMWeights (..) 14 + , randomLSTM 15 + ) where 16 + 17 + import Control.Monad.Random ( MonadRandom, getRandom ) 18 + 19 + -- import Data.List ( foldl1' ) 20 + import Data.Singletons.TypeLits 21 + 22 + import Numeric.LinearAlgebra.Static 23 + 24 + import Grenade.Core.Network 25 + import Grenade.Core.Shape 26 + 27 + import Grenade.Layers.Internal.Update 28 + 29 + import Grenade.Recurrent.Core.Network 30 + 31 + -- | Long Short Term Memory Recurrent unit 32 + -- 33 + -- This is a Peephole formulation, so the recurrent shape is 34 + -- just the cell state, the previous output is not held or used 35 + -- at all. 36 + data LSTM :: Nat -> Nat -> * where 37 + LSTM :: ( KnownNat input 38 + , KnownNat output 39 + ) => !(LSTMWeights input output) -- Weights 40 + -> !(LSTMWeights input output) -- Momentums 41 + -> LSTM input output 42 + 43 + data LSTMWeights :: Nat -> Nat -> * where 44 + LSTMWeights :: ( KnownNat input 45 + , KnownNat output 46 + ) => { 47 + lstmWf :: !(L output input) -- Weight Forget (W_f) 48 + , lstmUf :: !(L output output) -- Cell State Forget (U_f) 49 + , lstmBf :: !(R output) -- Bias Forget (b_f) 50 + , lstmWi :: !(L output input) -- Weight Input (W_i) 51 + , lstmUi :: !(L output output) -- Cell State Input (U_i) 52 + , lstmBi :: !(R output) -- Bias Input (b_i) 53 + , lstmWo :: !(L output input) -- Weight Output (W_o) 54 + , lstmUo :: !(L output output) -- Cell State Output (U_o) 55 + , lstmBo :: !(R output) -- Bias Output (b_o) 56 + , lstmWc :: !(L output input) -- Weight Cell (W_c) 57 + , lstmBc :: !(R output) -- Bias Cell (b_c) 58 + } -> LSTMWeights input output 59 + 60 + instance Show (LSTM i o) where 61 + show LSTM {} = "LSTM" 62 + 63 + instance (KnownNat i, KnownNat o) => UpdateLayer (LSTM i o) where 64 + -- The gradients are the same shape as the weights and momentum 65 + -- This seems to be a general pattern, maybe it should be enforced. 66 + type Gradient (LSTM i o) = (LSTMWeights i o) 67 + 68 + -- Run the update function for each group matrix/vector of weights, momentums and gradients. 69 + -- Hmm, maybe the function should be used instead of passing in the learning parameters. 70 + runUpdate LearningParameters {..} (LSTM w m) g = 71 + let (wf, wf') = u lstmWf w m g 72 + (uf, uf') = u lstmUf w m g 73 + (bf, bf') = v lstmBf w m g 74 + (wi, wi') = u lstmWi w m g 75 + (ui, ui') = u lstmUi w m g 76 + (bi, bi') = v lstmBi w m g 77 + (wo, wo') = u lstmWo w m g 78 + (uo, uo') = u lstmUo w m g 79 + (bo, bo') = v lstmBo w m g 80 + (wc, wc') = u lstmWc w m g 81 + (bc, bc') = v lstmBc w m g 82 + in LSTM (LSTMWeights wf uf bf wi ui bi wo uo bo wc bc) (LSTMWeights wf' uf' bf' wi' ui' bi' wo' uo' bo' wc' bc') 83 + where 84 + -- Utility function for updating with the momentum, gradients, and weights. 85 + u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> x -> ((L out ix), (L out ix)) 86 + u e (e -> weights) (e -> momentum) (e -> gradient) = 87 + decendMatrix learningRate learningMomentum learningRegulariser weights gradient momentum 88 + 89 + v :: forall x ix. (KnownNat ix) => (x -> (R ix)) -> x -> x -> x -> ((R ix), (R ix)) 90 + v e (e -> weights) (e -> momentum) (e -> gradient) = 91 + decendVector learningRate learningMomentum learningRegulariser weights gradient momentum 92 + 93 + -- There's a lot of updates here, so to try and minimise the number of data copies 94 + -- we'll create a mutable bucket for each. 95 + -- runUpdates rate lstm gs = 96 + -- let combinedGradient = foldl1' uu gs 97 + -- in runUpdate rate lstm combinedGradient 98 + -- where 99 + -- uu :: (KnownNat i, KnownNat o) => LSTMWeights i o -> LSTMWeights i o -> LSTMWeights i o 100 + -- uu a b = 101 + -- let wf = u lstmWf a b 102 + -- uf = u lstmUf a b 103 + -- bf = v lstmBf a b 104 + -- wi = u lstmWi a b 105 + -- ui = u lstmUi a b 106 + -- bi = v lstmBi a b 107 + -- wo = u lstmWo a b 108 + -- uo = u lstmUo a b 109 + -- bo = v lstmBo a b 110 + -- wc = u lstmWc a b 111 + -- bc = v lstmBc a b 112 + -- in LSTMWeights wf uf bf wi ui bi wo uo bo wc bc 113 + -- u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> L out ix 114 + -- u e (e -> a) (e -> b) = tr $ tr a + tr b 115 + 116 + -- v :: forall x ix. (x -> (R ix)) -> x -> x -> R ix 117 + -- v e (e -> a) (e -> b) = a + b 118 + 119 + createRandom = randomLSTM 120 + 121 + instance (KnownNat i, KnownNat o) => RecurrentUpdateLayer (LSTM i o) where 122 + -- The recurrent shape is the same size as the output. 123 + -- It's actually the cell state however, as this is a peephole variety LSTM. 124 + type RecurrentShape (LSTM i o) = 'D1 o 125 + 126 + instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) where 127 + -- Forward propagation for the LSTM layer. 128 + -- The size of the cell state is also the size of the output. 129 + runRecurrentForwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) = 130 + let -- Forget state vector 131 + f_t = sigmoid $ lstmBf + lstmWf #> input + lstmUf #> cell 132 + -- Input state vector 133 + i_t = sigmoid $ lstmBi + lstmWi #> input + lstmUi #> cell 134 + -- Output state vector 135 + o_t = sigmoid $ lstmBo + lstmWo #> input + lstmUo #> cell 136 + -- Cell input state vector 137 + c_x = tanh $ lstmBc + lstmWc #> input 138 + -- Cell state 139 + c_t = f_t * cell + i_t * c_x 140 + -- Output (it's sometimes recommended to use tanh c_t) 141 + h_t = o_t * c_t 142 + in (S1D c_t, S1D h_t) 143 + 144 + -- Run a backpropogation step for an LSTM layer. 145 + -- We're doing all the derivatives by hand here, so one should 146 + -- be extra careful when changing this. 147 + -- 148 + -- There's a test version using the AD library without hmatrix in the test 149 + -- suite. These should match always. 150 + runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) (S1D cellGrad) (S1D h_t') = 151 + -- We're not keeping the Wengert tape during the forward pass, 152 + -- so we're duplicating some work here. 153 + -- 154 + -- If I was being generous, I'd call it checkpointing. 155 + -- 156 + -- Maybe think about better ways to store some intermediate states. 157 + let -- Forget state vector 158 + f_s = lstmBf + lstmWf #> input + lstmUf #> cell 159 + f_t = sigmoid f_s 160 + -- Input state vector 161 + i_s = lstmBi + lstmWi #> input + lstmUi #> cell 162 + i_t = sigmoid i_s 163 + -- Output state vector 164 + o_s = lstmBo + lstmWo #> input + lstmUo #> cell 165 + o_t = sigmoid o_s 166 + -- Cell input state vector 167 + c_s = lstmBc + lstmWc #> input 168 + c_x = tanh c_s 169 + -- Cell state 170 + c_t = f_t * cell + i_t * c_x 171 + 172 + -- Reverse Mode AD Derivitives 173 + c_t' = h_t' * o_t + cellGrad 174 + 175 + f_t' = c_t' * cell 176 + f_s' = sigmoid' f_s * f_t' 177 + 178 + o_t' = h_t' * c_t 179 + o_s' = sigmoid' o_s * o_t' 180 + 181 + i_t' = c_t' * c_x 182 + i_s' = sigmoid' i_s * i_t' 183 + 184 + c_x' = c_t' * i_t 185 + c_s' = tanh' c_s * c_x' 186 + 187 + -- The derivatives to pass sideways (recurrent) and downwards 188 + cell' = tr lstmUf #> f_s' + tr lstmUo #> o_s' + tr lstmUi #> i_s' + c_t' * f_t 189 + input' = tr lstmWf #> f_s' + tr lstmWo #> o_s' + tr lstmWi #> i_s' + tr lstmWc #> c_s' 190 + 191 + -- Calculate the gradient Matricies for the input 192 + lstmWf' = f_s' `outer` input 193 + lstmWi' = i_s' `outer` input 194 + lstmWo' = o_s' `outer` input 195 + lstmWc' = c_s' `outer` input 196 + 197 + -- Calculate the gradient Matricies for the cell 198 + lstmUf' = f_s' `outer` cell 199 + lstmUi' = i_s' `outer` cell 200 + lstmUo' = o_s' `outer` cell 201 + 202 + -- The biases just get the values, but we'll write it so it's obvious 203 + lstmBf' = f_s' 204 + lstmBi' = i_s' 205 + lstmBo' = o_s' 206 + lstmBc' = c_s' 207 + 208 + gradients = LSTMWeights lstmWf' lstmUf' lstmBf' lstmWi' lstmUi' lstmBi' lstmWo' lstmUo' lstmBo' lstmWc' lstmBc' 209 + in (gradients, S1D cell', S1D input') 210 + 211 + -- | Generate an LSTM layer with random Weights 212 + -- one can also just call createRandom from UpdateLayer 213 + -- 214 + -- Has forget gate biases set to 1 to encourage early learning. 215 + -- 216 + -- https://github.com/karpathy/char-rnn/commit/0dfeaa454e687dd0278f036552ea1e48a0a408c9 217 + -- 218 + randomLSTM :: forall m i o. (MonadRandom m, KnownNat i, KnownNat o) 219 + => m (LSTM i o) 220 + randomLSTM = do 221 + let w = (\s -> uniformSample s (-1) 1 ) <$> getRandom 222 + u = (\s -> uniformSample s (-1) 1 ) <$> getRandom 223 + v = (\s -> randomVector s Uniform * 2 - 1) <$> getRandom 224 + 225 + w0 = konst 0 226 + u0 = konst 0 227 + v0 = konst 0 228 + 229 + LSTM <$> (LSTMWeights <$> w <*> u <*> pure (konst 1) <*> w <*> u <*> v <*> w <*> u <*> v <*> w <*> v) 230 + <*> pure (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0) 231 + 232 + -- | Maths 233 + -- 234 + -- TODO: move to not here 235 + sigmoid :: Floating a => a -> a 236 + sigmoid x = 1 / (1 + exp (-x)) 237 + 238 + sigmoid' :: Floating a => a -> a 239 + sigmoid' x = logix * (1 - logix) 240 + where 241 + logix = sigmoid x 242 + 243 + tanh' :: (Floating a) => a -> a 244 + tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
+23
src/Grenade/Recurrent/Layers/Trivial.hs
··· 1 + {-# LANGUAGE DataKinds #-} 2 + {-# LANGUAGE TypeOperators #-} 3 + {-# LANGUAGE TypeFamilies #-} 4 + {-# LANGUAGE MultiParamTypeClasses #-} 5 + {-# LANGUAGE FlexibleInstances #-} 6 + module Grenade.Recurrent.Layers.Trivial ( 7 + Trivial (..) 8 + ) where 9 + 10 + import Grenade.Core.Network 11 + 12 + -- | A trivial layer. 13 + data Trivial = Trivial 14 + deriving Show 15 + 16 + instance UpdateLayer Trivial where 17 + type Gradient Trivial = () 18 + runUpdate _ _ _ = Trivial 19 + createRandom = return Trivial 20 + 21 + instance (a ~ b) => Layer Trivial a b where 22 + runForwards _ = id 23 + runBackwards _ _ y = ((), y)
+84
src/Grenade/Utils/OneHot.hs
··· 1 + {-# LANGUAGE DataKinds #-} 2 + {-# LANGUAGE GADTs #-} 3 + {-# LANGUAGE TypeFamilies #-} 4 + {-# LANGUAGE TypeOperators #-} 5 + {-# LANGUAGE FlexibleContexts #-} 6 + {-# LANGUAGE ScopedTypeVariables #-} 7 + {-# LANGUAGE RankNTypes #-} 8 + 9 + module Grenade.Utils.OneHot ( 10 + oneHot 11 + , hotMap 12 + , makeHot 13 + , unHot 14 + ) where 15 + 16 + import Data.List ( group, sort ) 17 + 18 + import Data.Map ( Map ) 19 + import qualified Data.Map as M 20 + 21 + import Data.Proxy 22 + import Data.Singletons.TypeLits 23 + 24 + import Data.Vector ( Vector ) 25 + import qualified Data.Vector as V 26 + 27 + import Numeric.LinearAlgebra ( maxIndex ) 28 + import Numeric.LinearAlgebra.Devel 29 + import Numeric.LinearAlgebra.Static 30 + 31 + import Grenade.Core.Shape 32 + 33 + -- | From an int which is hot, create a 1D Shape 34 + -- with one index hot (1) with the rest 0. 35 + -- Rerurns Nothing if the hot number is larger 36 + -- than the length of the vector. 37 + oneHot :: forall n. (KnownNat n) 38 + => Int -> Maybe (S ('D1 n)) 39 + oneHot hot = 40 + let len = fromIntegral $ natVal (Proxy :: Proxy n) 41 + in if hot < len 42 + then 43 + fmap S1D . create $ runSTVector $ do 44 + vec <- newVector 0 len 45 + writeVector vec hot 1 46 + return vec 47 + else Nothing 48 + 49 + -- | Create a one hot map from any enumerable. 50 + -- Returns a map, and the ordered list for the reverse transformation 51 + hotMap :: (Ord a, KnownNat n) => Proxy n -> [a] -> Maybe (Map a Int, Vector a) 52 + hotMap n as = 53 + let len = fromIntegral $ natVal n 54 + uniq = [ c | (c:_) <- group $ sort as] 55 + hotl = length uniq 56 + in if hotl <= len 57 + then 58 + Just (M.fromList $ zip uniq [0..], V.fromList uniq) 59 + else Nothing 60 + 61 + -- | From a map and value, create a 1D Shape 62 + -- with one index hot (1) with the rest 0. 63 + -- Rerurns Nothing if the hot number is larger 64 + -- than the length of the vector or the map 65 + -- doesn't contain the value. 66 + makeHot :: forall a n. (Ord a, KnownNat n) 67 + => Map a Int -> a -> Maybe (S ('D1 n)) 68 + makeHot m x = do 69 + hot <- M.lookup x m 70 + let len = fromIntegral $ natVal (Proxy :: Proxy n) 71 + if hot < len 72 + then 73 + fmap S1D . create $ runSTVector $ do 74 + vec <- newVector 0 len 75 + writeVector vec hot 1 76 + return vec 77 + else Nothing 78 + 79 + unHot :: forall a n. (KnownNat n) 80 + => Vector a -> (S ('D1 n)) -> Maybe a 81 + unHot v (S1D xs) 82 + = (V.!?) v 83 + $ maxIndex (extract xs) 84 +
+20 -10
test/Test/Grenade/Layers/Convolution.hs
··· 1 - {-# LANGUAGE TemplateHaskell #-} 1 + {-# LANGUAGE TemplateHaskell #-} 2 2 {-# LANGUAGE DataKinds #-} 3 - {-# LANGUAGE GADTs #-} 4 - {-# LANGUAGE ScopedTypeVariables #-} 5 - {-# LANGUAGE KindSignatures #-} 6 - {-# LANGUAGE ConstraintKinds #-} 3 + {-# LANGUAGE GADTs #-} 4 + {-# LANGUAGE ScopedTypeVariables #-} 5 + {-# LANGUAGE KindSignatures #-} 6 + {-# LANGUAGE ConstraintKinds #-} 7 7 {-# LANGUAGE TypeOperators #-} 8 - 9 8 {-# OPTIONS_GHC -fno-warn-missing-signatures #-} 10 9 module Test.Grenade.Layers.Convolution where 11 10 ··· 30 29 instance Show OpaqueConvolution where 31 30 show (OpaqueConvolution n) = show n 32 31 32 + genConvolution :: ( KnownNat channels 33 + , KnownNat filters 34 + , KnownNat kernelRows 35 + , KnownNat kernelColumns 36 + , KnownNat strideRows 37 + , KnownNat strideColumns 38 + , KnownNat kernelFlattened 39 + , kernelFlattened ~ (kernelRows * kernelColumns * channels) 40 + ) => Jack (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) 41 + genConvolution = Convolution <$> uniformSample <*> uniformSample 42 + 33 43 genOpaqueOpaqueConvolution :: Jack OpaqueConvolution 34 44 genOpaqueOpaqueConvolution = do 35 45 Just channels <- someNatVal <$> choose (1, 10) ··· 46 56 p2 = natDict pkc 47 57 p3 = natDict pch 48 58 in case p1 %* p2 %* p3 of 49 - Dict -> OpaqueConvolution <$> (Convolution <$> uniformSample <*> uniformSample :: Jack (Convolution ch fl kr kc sr sc)) 59 + Dict -> OpaqueConvolution <$> (genConvolution :: Jack (Convolution ch fl kr kc sr sc)) 50 60 51 61 prop_conv_net_witness = 52 62 gamble genOpaqueOpaqueConvolution $ \onet -> ··· 80 90 , (unsafeCoerce (Dict :: Dict ()) :: Dict (((outRows - 1) * strideRows) ~ (inRows - kernelRows))) 81 91 , (unsafeCoerce (Dict :: Dict ()) :: Dict (((outCols - 1) * strideCols) ~ (inCols - kernelCols)))) of 82 92 (Dict, Dict, Dict, Dict) -> 83 - gamble (S3D' <$> uniformSample) $ \(input :: S' ('D3 inRows inCols channels)) -> 84 - let output :: S' ('D3 outRows outCols filters) = runForwards convLayer input 85 - backed :: (Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols), S' ('D3 inRows inCols channels)) 93 + gamble (S3D <$> uniformSample) $ \(input :: S ('D3 inRows inCols channels)) -> 94 + let output :: S ('D3 outRows outCols filters) = runForwards convLayer input 95 + backed :: (Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols), S ('D3 inRows inCols channels)) 86 96 = runBackwards convLayer input output 87 97 in backed `seq` True 88 98 ) :: Property
+4 -4
test/Test/Grenade/Layers/FullyConnected.hs
··· 44 44 prop_fully_connected_forwards :: Property 45 45 prop_fully_connected_forwards = 46 46 gamble genOpaqueFullyConnected $ \(OpaqueFullyConnected (fclayer :: FullyConnected i o)) -> 47 - gamble (S1D' <$> randomVector) $ \(input :: S' ('D1 i)) -> 48 - let output :: S' ('D1 o) = runForwards fclayer input 49 - backed :: (Gradient (FullyConnected i o), S' ('D1 i)) 50 - = runBackwards fclayer input output 47 + gamble (S1D <$> randomVector) $ \(input :: S ('D1 i)) -> 48 + let output :: S ('D1 o) = runForwards fclayer input 49 + backed :: (Gradient (FullyConnected i o), S ('D1 i)) 50 + = runBackwards fclayer input output 51 51 in backed `seq` True 52 52 53 53 return []
+5 -5
test/Test/Grenade/Layers/Pooling.hs
··· 1 - {-# LANGUAGE TemplateHaskell #-} 2 - {-# LANGUAGE DataKinds #-} 3 - {-# LANGUAGE KindSignatures #-} 4 - {-# LANGUAGE GADTs #-} 5 - {-# LANGUAGE ScopedTypeVariables #-} 1 + {-# LANGUAGE TemplateHaskell #-} 2 + {-# LANGUAGE DataKinds #-} 3 + {-# LANGUAGE KindSignatures #-} 4 + {-# LANGUAGE GADTs #-} 5 + {-# LANGUAGE ScopedTypeVariables #-} 6 6 {-# OPTIONS_GHC -fno-warn-missing-signatures #-} 7 7 module Test.Grenade.Layers.Pooling where 8 8
+101
test/Test/Grenade/Recurrent/Layers/LSTM.hs
··· 1 + {-# LANGUAGE TemplateHaskell #-} 2 + {-# LANGUAGE DataKinds #-} 3 + {-# LANGUAGE GADTs #-} 4 + {-# LANGUAGE ScopedTypeVariables #-} 5 + {-# LANGUAGE ConstraintKinds #-} 6 + {-# LANGUAGE TypeOperators #-} 7 + {-# LANGUAGE FlexibleContexts #-} 8 + {-# LANGUAGE RankNTypes #-} 9 + 10 + {-# OPTIONS_GHC -fno-warn-missing-signatures #-} 11 + module Test.Grenade.Recurrent.Layers.LSTM where 12 + 13 + import Disorder.Jack 14 + 15 + import Data.Foldable ( toList ) 16 + import Data.Singletons.TypeLits 17 + 18 + import Grenade 19 + import Grenade.Recurrent 20 + 21 + import qualified Numeric.LinearAlgebra as H 22 + import qualified Numeric.LinearAlgebra.Static as S 23 + 24 + 25 + import qualified Test.Grenade.Recurrent.Layers.LSTM.Reference as Reference 26 + import Test.Jack.Hmatrix 27 + 28 + genLSTM :: forall i o. (KnownNat i, KnownNat o) => Jack (LSTM i o) 29 + genLSTM = do 30 + let w = uniformSample 31 + u = uniformSample 32 + v = randomVector 33 + 34 + w0 = S.konst 0 35 + u0 = S.konst 0 36 + v0 = S.konst 0 37 + 38 + LSTM <$> (LSTMWeights <$> w <*> u <*> v <*> w <*> u <*> v <*> w <*> u <*> v <*> w <*> v) 39 + <*> pure (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0) 40 + 41 + prop_lstm_reference_forwards = 42 + gamble randomVector $ \(input :: S.R 3) -> 43 + gamble randomVector $ \(cell :: S.R 2) -> 44 + gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) -> 45 + let actual = runRecurrentForwards net (S1D cell) (S1D input) 46 + in case actual of 47 + ((S1D cellOut) :: S ('D1 2), (S1D output) :: S ('D1 2)) -> 48 + let cellOut' = Reference.Vector . H.toList . S.extract $ cellOut 49 + output' = Reference.Vector . H.toList . S.extract $ output 50 + refNet = Reference.lstmToReference lstmWeights 51 + refCell = Reference.Vector . H.toList . S.extract $ cell 52 + refInput = Reference.Vector . H.toList . S.extract $ input 53 + (refCO, refO) = Reference.runLSTM refNet refCell refInput 54 + in toList refCO ~~~ toList cellOut' .&&. toList refO ~~~ toList output' 55 + 56 + prop_lstm_reference_backwards = 57 + gamble randomVector $ \(input :: S.R 3) -> 58 + gamble randomVector $ \(cell :: S.R 2) -> 59 + gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) -> 60 + let actualBacks = runRecurrentBackwards net (S1D cell) (S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2)) 61 + in case actualBacks of 62 + (actualGradients, _, _) -> 63 + let refNet = Reference.lstmToReference lstmWeights 64 + refCell = Reference.Vector . H.toList . S.extract $ cell 65 + refInput = Reference.Vector . H.toList . S.extract $ input 66 + refGradients = Reference.runLSTMback refCell refInput refNet 67 + in toList refGradients ~~~ toList (Reference.lstmToReference actualGradients) 68 + 69 + prop_lstm_reference_backwards_input = 70 + gamble randomVector $ \(input :: S.R 3) -> 71 + gamble randomVector $ \(cell :: S.R 2) -> 72 + gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) -> 73 + let actualBacks = runRecurrentBackwards net (S1D cell) (S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2)) 74 + in case actualBacks of 75 + (_, _, S1D actualGradients) -> 76 + let refNet = Reference.lstmToReference lstmWeights 77 + refCell = Reference.Vector . H.toList . S.extract $ cell 78 + refInput = Reference.Vector . H.toList . S.extract $ input 79 + refGradients = Reference.runLSTMbackOnInput refCell refNet refInput 80 + in toList refGradients ~~~ H.toList (S.extract actualGradients) 81 + 82 + prop_lstm_reference_backwards_cell = 83 + gamble randomVector $ \(input :: S.R 3) -> 84 + gamble randomVector $ \(cell :: S.R 2) -> 85 + gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) -> 86 + let actualBacks = runRecurrentBackwards net (S1D cell) (S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2)) 87 + in case actualBacks of 88 + (_, S1D actualGradients, _) -> 89 + let refNet = Reference.lstmToReference lstmWeights 90 + refCell = Reference.Vector . H.toList . S.extract $ cell 91 + refInput = Reference.Vector . H.toList . S.extract $ input 92 + refGradients = Reference.runLSTMbackOnCell refInput refNet refCell 93 + in toList refGradients ~~~ (H.toList . S.extract $ actualGradients) 94 + 95 + 96 + (~~~) as bs = all (< 1e-8) (zipWith (-) as bs) 97 + infix 4 ~~~ 98 + 99 + return [] 100 + tests :: IO Bool 101 + tests = $quickCheckAll
+149
test/Test/Grenade/Recurrent/Layers/LSTM/Reference.hs
··· 1 + {-# LANGUAGE DataKinds #-} 2 + {-# LANGUAGE GADTs #-} 3 + {-# LANGUAGE ScopedTypeVariables #-} 4 + {-# LANGUAGE ConstraintKinds #-} 5 + {-# LANGUAGE TypeOperators #-} 6 + {-# LANGUAGE DeriveFunctor #-} 7 + {-# LANGUAGE DeriveFoldable #-} 8 + {-# LANGUAGE DeriveTraversable #-} 9 + {-# LANGUAGE RecordWildCards #-} 10 + {-# LANGUAGE FlexibleContexts #-} 11 + {-# LANGUAGE RankNTypes #-} 12 + 13 + {-# OPTIONS_GHC -fno-warn-missing-signatures #-} 14 + module Test.Grenade.Recurrent.Layers.LSTM.Reference where 15 + 16 + import Data.Reflection 17 + import Numeric.AD.Mode.Reverse 18 + import Numeric.AD.Internal.Reverse ( Tape ) 19 + 20 + import qualified Grenade.Recurrent.Layers.LSTM as LSTM 21 + import qualified Numeric.LinearAlgebra.Static as S 22 + import qualified Numeric.LinearAlgebra as H 23 + 24 + -- 25 + -- This module contains a set of list only versions of 26 + -- an LSTM layer which can be used with the AD library. 27 + -- 28 + -- Using this, we can check to make sure that our fast 29 + -- back propagation implementation is correct. 30 + -- 31 + 32 + -- | List only matrix deriving functor 33 + data Matrix a = Matrix { 34 + matrixWeights :: [[a]] 35 + } deriving (Functor, Foldable, Traversable, Eq, Show) 36 + 37 + -- | List only vector deriving functor 38 + data Vector a = Vector { 39 + vectorWeights :: [a] 40 + } deriving (Functor, Foldable, Traversable, Eq, Show) 41 + 42 + -- | List only LSTM weights 43 + data RefLSTM a = RefLSTM 44 + { refLstmWf :: Matrix a -- Weight Forget (W_f) 45 + , refLstmUf :: Matrix a -- Cell State Forget (U_f) 46 + , refLstmBf :: Vector a -- Bias Forget (b_f) 47 + , refLstmWi :: Matrix a -- Weight Input (W_i) 48 + , refLstmUi :: Matrix a -- Cell State Input (U_i) 49 + , refLstmBi :: Vector a -- Bias Input (b_i) 50 + , refLstmWo :: Matrix a -- Weight Output (W_o) 51 + , refLstmUo :: Matrix a -- Cell State Output (U_o) 52 + , refLstmBo :: Vector a -- Bias Output (b_o) 53 + , refLstmWc :: Matrix a -- Weight Cell (W_c) 54 + , refLstmBc :: Vector a -- Bias Cell (b_c) 55 + } deriving (Functor, Foldable, Traversable, Eq, Show) 56 + 57 + lstmToReference :: LSTM.LSTMWeights a b -> RefLSTM Double 58 + lstmToReference LSTM.LSTMWeights {..} = 59 + let refLstmWf = Matrix . H.toLists . S.extract $ lstmWf -- Weight Forget (W_f) 60 + refLstmUf = Matrix . H.toLists . S.extract $ lstmUf -- Cell State Forget (U_f) 61 + refLstmBf = Vector . H.toList . S.extract $ lstmBf -- Bias Forget (b_f) 62 + refLstmWi = Matrix . H.toLists . S.extract $ lstmWi -- Weight Input (W_i) 63 + refLstmUi = Matrix . H.toLists . S.extract $ lstmUi -- Cell State Input (U_i) 64 + refLstmBi = Vector . H.toList . S.extract $ lstmBi -- Bias Input (b_i) 65 + refLstmWo = Matrix . H.toLists . S.extract $ lstmWo -- Weight Output (W_o) 66 + refLstmUo = Matrix . H.toLists . S.extract $ lstmUo -- Cell State Output (U_o) 67 + refLstmBo = Vector . H.toList . S.extract $ lstmBo -- Bias Output (b_o) 68 + refLstmWc = Matrix . H.toLists . S.extract $ lstmWc -- Weight Cell (W_c) 69 + refLstmBc = Vector . H.toList . S.extract $ lstmBc -- Bias Cell (b_c) 70 + in RefLSTM {..} 71 + 72 + runLSTM :: Floating a => RefLSTM a -> Vector a -> Vector a -> (Vector a, Vector a) 73 + runLSTM RefLSTM {..} cell input = 74 + let -- Forget state vector 75 + f_t = sigmoid $ refLstmBf #+ refLstmWf #> input #+ refLstmUf #> cell 76 + -- Input state vector 77 + i_t = sigmoid $ refLstmBi #+ refLstmWi #> input #+ refLstmUi #> cell 78 + -- Output state vector 79 + o_t = sigmoid $ refLstmBo #+ refLstmWo #> input #+ refLstmUo #> cell 80 + -- Cell input state vector 81 + c_x = fmap tanh $ refLstmBc #+ refLstmWc #> input 82 + -- Cell state 83 + c_t = f_t #* cell #+ i_t #* c_x 84 + -- Output (it's sometimes recommended to use tanh c_t) 85 + h_t = o_t #* c_t 86 + in (c_t, h_t) 87 + 88 + runLSTMback :: forall a. Floating a => Vector a -> Vector a -> RefLSTM a -> RefLSTM a 89 + runLSTMback cell input = 90 + grad f 91 + where 92 + f :: forall s. Reifies s Tape => RefLSTM (Reverse s a) -> Reverse s a 93 + f net = 94 + let cell' = fmap auto cell 95 + input' = fmap auto input 96 + (cells, forwarded) = runLSTM net cell' input' 97 + in sum forwarded + sum cells 98 + 99 + runLSTMbackOnInput :: forall a. Floating a => Vector a -> RefLSTM a -> Vector a -> Vector a 100 + runLSTMbackOnInput cell net = 101 + grad f 102 + where 103 + f :: forall s. Reifies s Tape => Vector (Reverse s a) -> Reverse s a 104 + f input = 105 + let cell' = fmap auto cell 106 + net' = fmap auto net 107 + (cells, forwarded) = runLSTM net' cell' input 108 + in sum forwarded + sum cells 109 + 110 + runLSTMbackOnCell :: forall a. Floating a => Vector a -> RefLSTM a -> Vector a -> Vector a 111 + runLSTMbackOnCell input net = 112 + grad f 113 + where 114 + f :: forall s. Reifies s Tape => Vector (Reverse s a) -> Reverse s a 115 + f cell = 116 + let input' = fmap auto input 117 + net' = fmap auto net 118 + (cells, forwarded) = runLSTM net' cell input' 119 + in sum forwarded + sum cells 120 + 121 + -- | Helper to multiply a matrix by a vector 122 + matMult :: Num a => Matrix a -> Vector a -> Vector a 123 + matMult (Matrix m) (Vector v) = Vector result 124 + where 125 + lrs = map length m 126 + l = length v 127 + result = if all (== l) lrs 128 + then map (\r -> sum $ zipWith (*) r v) m 129 + else error $ "Matrix has rows of length " ++ show lrs ++ 130 + " but vector is of length " ++ show l 131 + 132 + (#>) :: Num a => Matrix a -> Vector a -> Vector a 133 + (#>) = matMult 134 + infixr 8 #> 135 + 136 + (#+) :: Num a => Vector a -> Vector a -> Vector a 137 + (#+) (Vector as) (Vector bs) = Vector $ zipWith (+) as bs 138 + infixl 6 #+ 139 + 140 + (#-) :: Num a => Vector a -> Vector a -> Vector a 141 + (#-) (Vector as) (Vector bs) = Vector $ zipWith (-) as bs 142 + infixl 6 #- 143 + 144 + (#*) :: Num a => Vector a -> Vector a -> Vector a 145 + (#*) (Vector as) (Vector bs) = Vector $ zipWith (*) as bs 146 + infixl 7 #* 147 + 148 + sigmoid :: (Functor f, Floating a) => f a -> f a 149 + sigmoid xs = (\x -> 1 / (1 + exp (-x))) <$> xs
+2 -5
test/Test/Jack/Hmatrix.hs
··· 4 4 5 5 module Test.Jack.Hmatrix where 6 6 7 - import Data.Proxy 8 7 import Disorder.Jack 9 8 10 9 import GHC.TypeLits ··· 12 11 import qualified Numeric.LinearAlgebra.Static as HStatic 13 12 14 13 randomVector :: forall n. KnownNat n => Jack (HStatic.R n) 15 - randomVector = HStatic.fromList <$> vectorOf (fromInteger (natVal (Proxy :: Proxy n))) sizedRealFrac 14 + randomVector = (\s -> HStatic.randomVector s HStatic.Uniform * 2 - 1) <$> sizedNat 16 15 17 16 uniformSample :: forall m n. (KnownNat m, KnownNat n) => Jack (HStatic.L m n) 18 - uniformSample = HStatic.fromList 19 - <$> vectorOf (fromInteger (natVal (Proxy :: Proxy m)) * fromInteger (natVal (Proxy :: Proxy n))) 20 - sizedRealFrac 17 + uniformSample = (\s -> HStatic.uniformSample s (-1) 1 ) <$> sizedNat
+8 -5
test/test.hs
··· 1 1 import Disorder.Core.Main 2 2 3 - import qualified Test.Grenade.Layers.Pooling as Test.Grenade.Layers.Pooling 4 - import qualified Test.Grenade.Layers.Convolution as Test.Grenade.Layers.Convolution 5 - import qualified Test.Grenade.Layers.FullyConnected as Test.Grenade.Layers.FullyConnected 3 + import qualified Test.Grenade.Layers.Pooling 4 + import qualified Test.Grenade.Layers.Convolution 5 + import qualified Test.Grenade.Layers.FullyConnected 6 6 7 - import qualified Test.Grenade.Layers.Internal.Convolution as Test.Grenade.Layers.Internal.Convolution 8 - import qualified Test.Grenade.Layers.Internal.Pooling as Test.Grenade.Layers.Internal.Pooling 7 + import qualified Test.Grenade.Layers.Internal.Convolution 8 + import qualified Test.Grenade.Layers.Internal.Pooling 9 9 10 + import qualified Test.Grenade.Recurrent.Layers.LSTM 10 11 11 12 main :: IO () 12 13 main = ··· 17 18 18 19 , Test.Grenade.Layers.Internal.Convolution.tests 19 20 , Test.Grenade.Layers.Internal.Pooling.tests 21 + 22 + , Test.Grenade.Recurrent.Layers.LSTM.tests 20 23 ]