···5252To perform back propagation, one can call the eponymous function
5353```haskell
5454backPropagate :: forall input target shapes layers. (Head shapes ~ input, Last shapes ~ target)
5555- => Network layers shapes -> S' input -> S' target -> Gradients layers
5555+ => Network layers shapes -> S input -> S target -> Gradients layers
5656```
5757which takes a network, appropriate input and target data, and returns the
5858back propagated gradients for the network. The shapes of the gradients are
···11#include "im2col.h"
2233-void im2col_cpu(const double* data_im, int dataOffset, const int channels,
33+void im2col_cpu(const double* data_im, const int channels,
44 const int height, const int width, const int kernel_h, const int kernel_w,
55 const int stride_h, const int stride_w,
66 double* data_col) {
7788- data_im += dataOffset;
98 const int channel_size = height * width;
1091110 for (int fitting_height = 0; fitting_height <= (height - kernel_h); fitting_height += stride_h) {
···2322 }
2423}
25242626-void col2im_cpu(const double* data_col, int dataOffset, const int channels,
2525+void col2im_cpu(const double* data_col, const int channels,
2726 const int height, const int width, const int kernel_h, const int kernel_w,
2827 const int stride_h, const int stride_w,
2928 double* data_im) {
30293130 memset(data_im, 0, height * width * channels * sizeof(double));
3232- data_col += dataOffset;
33313432 const int channel_size = height * width;
3533···50485149inline double max ( double a, double b ) { return a > b ? a : b; }
52505353-void pool_forwards_cpu(const double* data_im, int dataOffset, const int channels,
5151+void pool_forwards_cpu(const double* data_im, const int channels,
5452 const int height, const int width, const int kernel_h, const int kernel_w,
5553 const int stride_h, const int stride_w,
5654 double* data_pooled) {
57555858- data_im += dataOffset;
5959-6056 const int channel_size = height * width;
61576258 for (int channel = 0; channel < channels; channel++) {
···8985 }
9086}
91879292-void pool_backwards_cpu(const double* data_im, int data_im_offset,
9393- const double* data_pooled, int data_pooled_offset,
8888+void pool_backwards_cpu(const double* data_im, const double* data_pooled,
9489 const int channels, const int height, const int width, const int kernel_h,
9590 const int kernel_w, const int stride_h, const int stride_w,
9691 double* data_backgrad ) {
97929898- data_im += data_im_offset;
9999- data_pooled += data_pooled_offset;
10093 memset(data_backgrad, 0, height * width * channels * sizeof(double));
1019410295 const int channel_size = height * width;
+4-5
cbits/im2col.h
···22#include <stdint.h>
33#include <string.h>
4455-void im2col_cpu(const double* data_im, int dataOffset, const int channels,
55+void im2col_cpu(const double* data_im, const int channels,
66 const int height, const int width, const int kernel_h, const int kernel_w,
77 const int stride_h, const int stride_w,
88 double* data_col);
991010-void col2im_cpu(const double* data_col, int dataOffset, const int channels,
1010+void col2im_cpu(const double* data_col, const int channels,
1111 const int height, const int width, const int kernel_h, const int kernel_w,
1212 const int stride_h, const int stride_w,
1313 double* data_im);
14141515-void pool_forwards_cpu(const double* data_im, int dataOffset, const int channels,
1515+void pool_forwards_cpu(const double* data_im, const int channels,
1616 const int height, const int width, const int kernel_h, const int kernel_w,
1717 const int stride_h, const int stride_w,
1818 double* data_pooled);
19192020-void pool_backwards_cpu(const double* data_im, int data_im_offset,
2121- const double* data_pooled, int data_pooled_offset,
2020+void pool_backwards_cpu(const double* data_im, const double* data_pooled,
2221 const int channels, const int height, const int width, const int kernel_h,
2322 const int kernel_w, const int stride_h, const int stride_w,
2423 double* data_backgrad );
···11#!/bin/sh -eu
2233+: ${MAFIA_HOME:=$HOME/.mafia}
44+35fetch_latest () {
46 if [ -z ${MAFIA_TEST_MODE+x} ]; then
57 TZ=$(date +"%T")
···5557 # If we can't find the mafia version, then we need to upgrade the script.
5658 run_upgrade
5759 else
5858- MAFIA_BIN=$HOME/.ambiata/mafia/bin
6060+ MAFIA_BIN=$MAFIA_HOME/bin
5961 MAFIA_FILE=mafia-$MAFIA_VERSION
6062 MAFIA_PATH=$MAFIA_BIN/$MAFIA_FILE
6163···118120upgrade) shift; run_upgrade "$@" ;;
119121*) exec_mafia "$@"
120122esac
121121-# Version: a1b39ee8ac1969ed2e891b9062d079be75863e99
123123+# Version: 3044e63eb472fb9e16926d4ab2ca9dd9e255829c
+10-10
main/feedforward.hs
···44{-# LANGUAGE TypeOperators #-}
55{-# LANGUAGE TupleSections #-}
66{-# LANGUAGE TypeFamilies #-}
77-{-# LANGUAGE FlexibleContexts #-}
88-97import Control.Monad
108import Control.Monad.Random
99+import Data.List ( foldl' )
1010+1111import GHC.TypeLits
12121313import qualified Numeric.LinearAlgebra.Static as SA
···3434netTest rate n = do
3535 inps <- replicateM n $ do
3636 s <- getRandom
3737- return $ S1D' $ SA.randomVector s SA.Uniform * 2 - 1
3838- let outs = flip map inps $ \(S1D' v) ->
3737+ return $ S1D $ SA.randomVector s SA.Uniform * 2 - 1
3838+ let outs = flip map inps $ \(S1D v) ->
3939 if v `inCircle` (fromRational 0.33, 0.33) || v `inCircle` (fromRational (-0.33), 0.33)
4040- then S1D' $ fromRational 1
4141- else S1D' $ fromRational 0
4040+ then S1D $ fromRational 1
4141+ else S1D $ fromRational 0
4242 net0 <- randomNet
43434444- let trained = foldl trainEach net0 (zip inps outs)
4444+ let trained = foldl' trainEach net0 (zip inps outs)
4545 let testIns = [ [ (x,y) | x <- [0..50] ]
4646 | y <- [0..20] ]
47474848- let outMat = fmap (fmap (\(x,y) -> (render . normx) $ runNet trained (S1D' $ SA.vector [x / 25 - 1,y / 10 - 1]))) testIns
4848+ let outMat = fmap (fmap (\(x,y) -> (render . normx) $ runNet trained (S1D $ SA.vector [x / 25 - 1,y / 10 - 1]))) testIns
4949 return $ unlines outMat
50505151 where
···5959 | n' <= 0.8 = '='
6060 | otherwise = '#'
61616262- normx :: S' ('D1 1) -> Double
6363- normx (S1D' r) = SA.mean r
6262+ normx :: S ('D1 1) -> Double
6363+ normx (S1D r) = SA.mean r
64646565data FeedForwardOpts = FeedForwardOpts Int LearningParameters
6666
+13-12
main/mnist.hs
···55{-# LANGUAGE TupleSections #-}
66{-# LANGUAGE TypeFamilies #-}
77{-# LANGUAGE FlexibleContexts #-}
88-98import Control.Applicative
109import Control.Monad
1110import Control.Monad.Random
1212-import Control.Monad.Trans.Class
1311import Control.Monad.Trans.Except
14121513import qualified Data.Attoparsec.Text as A
1414+import Data.List ( foldl' )
1615import qualified Data.Text as T
1716import qualified Data.Text.IO as T
1717+import qualified Data.Vector.Storable as V
18181919-import Numeric.LinearAlgebra (maxIndex)
1919+import Numeric.LinearAlgebra ( maxIndex )
2020import qualified Numeric.LinearAlgebra.Static as SA
21212222import Options.Applicative
23232424import Grenade
2525+import Grenade.Utils.OneHot
25262627-- The definition of our convolutional neural network.
2728-- In the type signature, we have a type level list of shapes which are passed between the layers.
···4950 trainEach rate' !network (i, o) = train rate' network i o
50515152 runIteration trainRows validateRows net i = do
5252- let trained' = foldl (trainEach ( rate { learningRate = learningRate rate * 0.9 ^ i} )) net trainRows
5353+ let trained' = foldl' (trainEach ( rate { learningRate = learningRate rate * 0.9 ^ i} )) net trainRows
5354 let res = fmap (\(rowP,rowL) -> (rowL,) $ runNet trained' rowP) validateRows
5454- let res' = fmap (\(S1D' label, S1D' prediction) -> (maxIndex (SA.extract label), maxIndex (SA.extract prediction))) res
5555+ let res' = fmap (\(S1D label, S1D prediction) -> (maxIndex (SA.extract label), maxIndex (SA.extract prediction))) res
5556 print trained'
5657 putStrLn $ "Iteration " ++ show i ++ ": " ++ show (length (filter ((==) <$> fst <*> snd) res')) ++ " of " ++ show (length res')
5758 return trained'
···6162mnist' :: Parser MnistOpts
6263mnist' = MnistOpts <$> argument str (metavar "TRAIN")
6364 <*> argument str (metavar "VALIDATE")
6464- <*> option auto (long "iterations" <> short 'i' <> value 10)
6565+ <*> option auto (long "iterations" <> short 'i' <> value 15)
6566 <*> (LearningParameters
6667 <$> option auto (long "train_rate" <> short 'r' <> value 0.01)
6768 <*> option auto (long "momentum" <> value 0.9)
···7879 Right () -> pure ()
7980 Left err -> putStrLn err
80818181-readMNIST :: FilePath -> ExceptT String IO [(S' ('D2 28 28), S' ('D1 10))]
8282+readMNIST :: FilePath -> ExceptT String IO [(S ('D2 28 28), S ('D1 10))]
8283readMNIST mnist = ExceptT $ do
8384 mnistdata <- T.readFile mnist
8485 return $ traverse (A.parseOnly parseMNIST) (T.lines mnistdata)
85868686-parseMNIST :: A.Parser (S' ('D2 28 28), S' ('D1 10))
8787+parseMNIST :: A.Parser (S ('D2 28 28), S ('D1 10))
8788parseMNIST = do
8888- lab <- A.decimal
8989- pixels <- many (A.char ',' >> A.double)
9090- let lab' = replicate lab 0 ++ [1] ++ replicate (9 - lab) 0
9191- return (S2D' $ SA.fromList pixels, S1D' $ SA.fromList lab')
8989+ Just lab <- oneHot <$> A.decimal
9090+ pixels <- many (A.char ',' >> A.double)
9191+ image <- maybe (fail "Parsed row was of an incorrect size") pure (fromStorable . V.fromList $ pixels)
9292+ return (image, lab)
+87
main/recurrent.hs
···11+{-# LANGUAGE BangPatterns #-}
22+{-# LANGUAGE DataKinds #-}
33+{-# LANGUAGE ScopedTypeVariables #-}
44+{-# LANGUAGE TypeOperators #-}
55+{-# LANGUAGE TupleSections #-}
66+{-# LANGUAGE TypeFamilies #-}
77+88+import Control.Monad ( foldM )
99+import Control.Monad.Random ( MonadRandom, getRandomR )
1010+1111+import Data.List ( cycle, unfoldr )
1212+import qualified Numeric.LinearAlgebra.Static as SA
1313+1414+import Options.Applicative
1515+1616+import Grenade
1717+import Grenade.Recurrent
1818+1919+-- The defininition for our simple recurrent network.
2020+-- This file just trains a network to generate a repeating sequence
2121+-- of 0 0 1.
2222+--
2323+-- The F and R types are Tagging types to ensure that the runner and
2424+-- creation function know how to treat the layers.
2525+type F = FeedForward
2626+type R = Recurrent
2727+2828+type RecNet = RecurrentNetwork '[ R (LSTM 1 4), R (LSTM 4 1), F Trivial]
2929+ '[ 'D1 1, 'D1 4, 'D1 1, 'D1 1 ]
3030+3131+type RecInput = RecurrentInputs '[ R (LSTM 1 4), R (LSTM 4 1), F Trivial]
3232+3333+randomNet :: MonadRandom m => m (RecNet, RecInput)
3434+randomNet = randomRecurrent
3535+3636+netTest :: MonadRandom m => RecNet -> RecInput -> LearningParameters -> Int -> m (RecNet, RecInput)
3737+netTest net0 i0 rate iterations =
3838+ foldM trainIteration (net0,i0) [1..iterations]
3939+ where
4040+ trainingCycle = cycle [c 0, c 0, c 1]
4141+4242+ trainIteration (net, io) _ = do
4343+ dropping <- getRandomR (0, 2)
4444+ count <- getRandomR (5, 30)
4545+ let t = drop dropping trainingCycle
4646+ let example = ((,Nothing) <$> take count t) ++ [(t !! count, Just $ t !! (count + 1))]
4747+ return $ trainEach net io example
4848+4949+ trainEach !nt !io !ex = trainRecurrent rate nt io ex
5050+5151+data FeedForwardOpts = FeedForwardOpts Int LearningParameters
5252+5353+feedForward' :: Parser FeedForwardOpts
5454+feedForward' = FeedForwardOpts <$> option auto (long "examples" <> short 'e' <> value 20000)
5555+ <*> (LearningParameters
5656+ <$> option auto (long "train_rate" <> short 'r' <> value 0.01)
5757+ <*> option auto (long "momentum" <> value 0.9)
5858+ <*> option auto (long "l2" <> value 0.0005)
5959+ )
6060+6161+generateRecurrent :: RecNet -> RecInput -> S ('D1 1) -> [Int]
6262+generateRecurrent n s i =
6363+ unfoldr go (s, i)
6464+ where
6565+ go (x, y) =
6666+ do let (ns, o) = runRecurrent n x y
6767+ o' = heat o
6868+ Just (o', (ns, fromIntegral o'))
6969+7070+ heat :: S ('D1 1) -> Int
7171+ heat x = case x of
7272+ (S1D v) -> round (SA.mean v)
7373+7474+main :: IO ()
7575+main = do
7676+ FeedForwardOpts examples rate <- execParser (info (feedForward' <**> helper) idm)
7777+ putStrLn "Training network..."
7878+7979+ (net0, i0) <- randomNet
8080+ (trained, bestInput) <- netTest net0 i0 rate examples
8181+8282+ let results = generateRecurrent trained bestInput (c 1)
8383+8484+ print . take 50 . drop 100 $ results
8585+8686+c :: Double -> S ('D1 1)
8787+c = S1D . SA.konst
+156
main/shakespeare.hs
···11+{-# LANGUAGE BangPatterns #-}
22+{-# LANGUAGE RecordWildCards #-}
33+{-# LANGUAGE DataKinds #-}
44+{-# LANGUAGE ScopedTypeVariables #-}
55+{-# LANGUAGE TypeOperators #-}
66+{-# LANGUAGE TupleSections #-}
77+{-# LANGUAGE TypeFamilies #-}
88+{-# LANGUAGE LambdaCase #-}
99+1010+import Control.Monad.Random
1111+import Control.Monad.Trans.Except
1212+1313+import Data.Char ( isUpper, toUpper, toLower )
1414+import Data.List ( unfoldr, foldl' )
1515+import Data.Maybe ( fromMaybe )
1616+1717+import qualified Data.Vector as V
1818+import Data.Vector ( Vector )
1919+2020+import qualified Data.Map as M
2121+import Data.Proxy ( Proxy (..) )
2222+2323+2424+import Data.Singletons.Prelude
2525+import GHC.TypeLits
2626+2727+import Numeric.LinearAlgebra.Static ( konst )
2828+2929+import Options.Applicative
3030+3131+import Grenade
3232+import Grenade.Recurrent
3333+import Grenade.Utils.OneHot
3434+3535+-- The defininition for our natural language recurrent network.
3636+-- This network is able to learn and generate simple words in
3737+-- about an hour.
3838+--
3939+-- This is a first class recurrent net, although it's similar to
4040+-- an unrolled graph.
4141+--
4242+-- The F and R types are tagging types to ensure that the runner and
4343+-- creation function know how to treat the layers.
4444+--
4545+-- As an example, here's a short sequence generated.
4646+--
4747+-- > the see and and the sir, and and the make and the make and go the make and go the make and the
4848+--
4949+type F = FeedForward
5050+type R = Recurrent
5151+5252+-- The definition of our network
5353+type Shakespeare = RecurrentNetwork '[ R (LSTM 40 40), F (FullyConnected 40 40), F Logit]
5454+ '[ 'D1 40, 'D1 40, 'D1 40, 'D1 40 ]
5555+5656+-- The definition of the "sideways" input, which the network if fed recurrently.
5757+type Shakespearian = RecurrentInputs '[ R (LSTM 40 40), F (FullyConnected 40 40), F Logit]
5858+5959+randomNet :: MonadRandom m => m (Shakespeare, Shakespearian)
6060+randomNet = randomRecurrent
6161+6262+-- | Load the data files and prepare a map of characters to a compressed int representation.
6363+loadShakespeare :: FilePath -> ExceptT String IO (Vector Int, M.Map Char Int, Vector Char)
6464+loadShakespeare path = do
6565+ contents <- lift $ readFile path
6666+ let annotated = annotateCapitals contents
6767+ (m,cs) <- ExceptT . return . note "Couldn't fit data in hotMap" $ hotMap (Proxy :: Proxy 40) annotated
6868+ hot <- ExceptT . return . note "Couldn't generate hot values" $ traverse (`M.lookup` m) annotated
6969+ return (V.fromList hot, m, cs)
7070+7171+trainSlice :: LearningParameters -> Shakespeare -> Shakespearian -> Vector Int -> Int -> Int -> (Shakespeare, Shakespearian)
7272+trainSlice !rate !net !recIns input offset size =
7373+ let e = fmap (x . oneHot) . V.toList $ V.slice offset size input
7474+ in case reverse e of
7575+ (o : l : xs) ->
7676+ let examples = reverse $ (l, Just o) : ((,Nothing) <$> xs)
7777+ in trainRecurrent rate net recIns examples
7878+ _ -> error "Not enough input"
7979+ where
8080+ x = fromMaybe (error "Hot variable didn't fit.")
8181+8282+runShakespeare :: ShakespeareOpts -> ExceptT String IO ()
8383+runShakespeare ShakespeareOpts {..} = do
8484+ (shakespeare, oneHotMap, oneHotDictionary) <- loadShakespeare trainingFile
8585+ (net0, i0) <- lift randomNet
8686+ lift $ foldM_ (\(!net, !io) size -> do
8787+ xs <- take (iterations `div` 15) <$> getRandomRs (0, length shakespeare - size - 1)
8888+ let (!trained, !bestInput) = foldl' (\(!n, !i) offset -> trainSlice rate n i shakespeare offset size) (net, io) xs
8989+ let results = take 100 $ generateParagraph trained bestInput oneHotMap oneHotDictionary ( S1D $ konst 0)
9090+ putStrLn ("TRAINING STEP WITH SIZE: " ++ show size)
9191+ putStrLn (unAnnotateCapitals results)
9292+ return (trained, bestInput)
9393+ ) (net0, i0) [10,10,15,15,20,20,25,25,30,30,35,35,40,40,50 :: Int]
9494+9595+generateParagraph :: forall layers shapes n a. (Last shapes ~ 'D1 n, Head shapes ~ 'D1 n, KnownNat n, Ord a)
9696+ => RecurrentNetwork layers shapes
9797+ -> RecurrentInputs layers
9898+ -> M.Map a Int
9999+ -> Vector a
100100+ -> S ('D1 n)
101101+ -> [a]
102102+generateParagraph n s hotmap hotdict i =
103103+ unfoldr go (s, i)
104104+ where
105105+ go (x, y) =
106106+ do let (ns, o) = runRecurrent n x y
107107+ un <- unHot hotdict o
108108+ re <- makeHot hotmap un
109109+ Just (un, (ns, re))
110110+111111+data ShakespeareOpts = ShakespeareOpts {
112112+ trainingFile :: FilePath
113113+ , iterations :: Int
114114+ , rate :: LearningParameters
115115+ }
116116+117117+shakespeare' :: Parser ShakespeareOpts
118118+shakespeare' = ShakespeareOpts <$> argument str (metavar "TRAIN")
119119+ <*> option auto (long "examples" <> short 'e' <> value 1000000)
120120+ <*> (LearningParameters
121121+ <$> option auto (long "train_rate" <> short 'r' <> value 0.01)
122122+ <*> option auto (long "momentum" <> value 0.95)
123123+ <*> option auto (long "l2" <> value 0.000001)
124124+ )
125125+126126+main :: IO ()
127127+main = do
128128+ shopts <- execParser (info (shakespeare' <**> helper) idm)
129129+ res <- runExceptT $ runShakespeare shopts
130130+ case res of
131131+ Right () -> pure ()
132132+ Left err -> putStrLn err
133133+134134+135135+-- Replace capitals with an annotation and the lower case letter
136136+-- http://fastml.com/one-weird-trick-for-training-char-rnns/
137137+annotateCapitals :: String -> String
138138+annotateCapitals (x : rest)
139139+ | isUpper x
140140+ = '^' : toLower x : annotateCapitals rest
141141+ | otherwise
142142+ = x : annotateCapitals rest
143143+annotateCapitals []
144144+ = []
145145+146146+unAnnotateCapitals :: String -> String
147147+unAnnotateCapitals ('^' : x : rest)
148148+ = toUpper x : unAnnotateCapitals rest
149149+unAnnotateCapitals (x : rest)
150150+ = x : unAnnotateCapitals rest
151151+unAnnotateCapitals []
152152+ = []
153153+154154+-- | Tag the 'Nothing' value of a 'Maybe'
155155+note :: a -> Maybe b -> Either a b
156156+note a = maybe (Left a) Right
+30-15
src/Grenade/Core/Network.hs
···11{-# LANGUAGE DataKinds #-}
22{-# LANGUAGE GADTs #-}
33-{-# LANGUAGE KindSignatures #-}
44-{-# LANGUAGE ScopedTypeVariables #-}
53{-# LANGUAGE TypeOperators #-}
64{-# LANGUAGE TypeFamilies #-}
77-{-# LANGUAGE PolyKinds #-}
85{-# LANGUAGE MultiParamTypeClasses #-}
96{-# LANGUAGE FlexibleContexts #-}
107{-# LANGUAGE FlexibleInstances #-}
1111-{-# LANGUAGE LambdaCase #-}
88+{-|
99+Module : Grenade.Core.Network
1010+Description : Core definition a simple neural etwork
1111+Copyright : (c) Huw Campbell, 2016-2017
1212+License : BSD2
1313+Stability : experimental
1414+1515+This module defines the core data type for the simplest
1616+Neural network we support.
12171818+-}
1319module Grenade.Core.Network (
1420 Layer (..)
1521 , Network (..)
···2026 ) where
21272228import Control.Monad.Random (MonadRandom)
2323-2929+import Data.List ( foldl' )
3030+import Data.Singletons
24312532import Grenade.Core.Shape
26333434+-- | Learning parameters for stochastic gradient descent.
2735data LearningParameters = LearningParameters {
2836 learningRate :: Double
2937 , learningMomentum :: Double
···3341-- | Class for updating a layer. All layers implement this, and it is
3442-- shape independent.
3543class Show x => UpdateLayer x where
4444+ {-# MINIMAL runUpdate, createRandom #-}
3645 -- | The type for the gradient for this layer.
3746 -- Unit if there isn't a gradient to pass back.
3847 type Gradient x :: *
3948 -- | Update a layer with its gradient and learning parameters
4049 runUpdate :: LearningParameters -> x -> Gradient x -> x
5050+4151 -- | Create a random layer, many layers will use pure
4252 createRandom :: MonadRandom m => m x
43535454+ -- | Update a layer with many Gradients
5555+ runUpdates :: LearningParameters -> x -> [Gradient x] -> x
5656+ runUpdates rate = foldl' (runUpdate rate)
5757+4458-- | Class for a layer. All layers implement this, however, they don't
4559-- need to implement it for all shapes, only ones which are appropriate.
4660class UpdateLayer x => Layer x (i :: Shape) (o :: Shape) where
4761 -- | Used in training and scoring. Take the input from the previous
4862 -- layer, and give the output from this layer.
4949- runForwards :: x -> S' i -> S' o
6363+ runForwards :: x -> S i -> S o
5064 -- | Back propagate a step. Takes the current layer, the input that the
5165 -- layer gave from the input and the back propagated derivatives from
5266 -- the layer above.
5367 -- Returns the gradient layer and the derivatives to push back further.
5454- runBackwards :: x -> S' i -> S' o -> (Gradient x, S' i)
6868+ runBackwards :: x -> S i -> S o -> (Gradient x, S i)
55695670-- | Type of a network.
5757--- The [*] type specifies the types of the layers. This is needed for parallel
5858--- running and being all the gradients beck together.
7171+--
7272+-- The [*] type specifies the types of the layers.
7373+--
5974-- The [Shape] type specifies the shapes of data passed between the layers.
6060--- Could be considered to be a heterogeneous list of layers which are able to
7575+--
7676+-- Can be considered to be a heterogeneous list of layers which are able to
6177-- transform the data shapes of the network.
6278data Network :: [*] -> [Shape] -> * where
6363- O :: Layer x i o => !x -> Network '[x] '[i, o]
6464- (:~>) :: Layer x i h => !x -> !(Network xs (h ': hs)) -> Network (x ': xs) (i ': h ': hs)
7979+ O :: (SingI i, SingI o, Layer x i o) => !x -> Network '[x] '[i, o]
8080+ (:~>) :: (SingI i, SingI h, Layer x i h) => !x -> !(Network xs (h ': hs)) -> Network (x ': xs) (i ': h ': hs)
6581infixr 5 :~>
66826783instance Show (Network l h) where
···7490 OG :: UpdateLayer x => Gradient x -> Gradients '[x]
7591 (:/>) :: UpdateLayer x => Gradient x -> Gradients xs -> Gradients (x ': xs)
76927777-7893-- | A network can easily be created by hand with (:~>), but an easy way to initialise a random
7994-- network is with the randomNetwork.
8095class CreatableNetwork (xs :: [*]) (ss :: [Shape]) where
8196 -- | Create a network of the types requested
8297 randomNetwork :: MonadRandom m => m (Network xs ss)
83988484-instance Layer x i o => CreatableNetwork (x ': '[]) (i ': o ': '[]) where
9999+instance (SingI i, SingI o, Layer x i o) => CreatableNetwork (x ': '[]) (i ': o ': '[]) where
85100 randomNetwork = O <$> createRandom
861018787-instance (Layer x i o, CreatableNetwork xs (o ': r ': rs)) => CreatableNetwork (x ': xs) (i ': o ': r ': rs) where
102102+instance (SingI i, SingI o, Layer x i o, CreatableNetwork xs (o ': r ': rs)) => CreatableNetwork (x ': xs) (i ': o ': r ': rs) where
88103 randomNetwork = (:~>) <$> createRandom <*> randomNetwork
+30-20
src/Grenade/Core/Runner.hs
···44{-# LANGUAGE ScopedTypeVariables #-}
55{-# LANGUAGE TypeOperators #-}
66{-# LANGUAGE TypeFamilies #-}
77+{-|
88+Module : Grenade.Core.Shape
99+Description : Core definition of the Shapes of data we understand
1010+Copyright : (c) Huw Campbell, 2016-2017
1111+License : BSD2
1212+Stability : experimental
7131414+This module defines simple back propagation and training functions
1515+for a network.
1616+-}
817module Grenade.Core.Runner (
918 train
1019 , backPropagate
···1625import Grenade.Core.Network
1726import Grenade.Core.Shape
18271919--- | Drive and network and collect its back propogated gradients.
2020-backPropagate :: forall input output shapes layers. (Head shapes ~ input, Last shapes ~ output)
2121- => Network layers shapes -> S' input -> S' output -> Gradients layers
2828+-- | Perform reverse automatic differentiation on the network
2929+-- for the current input and expected output.
3030+--
3131+-- /Note:/ The loss function pushed backwards is appropriate
3232+-- for both regression and classification as a squared loss
3333+-- or log-loss respectively. Other loss functions are not yet
3434+-- implemented.
3535+backPropagate :: forall shapes layers.
3636+ Network layers shapes -> S (Head shapes) -> S (Last shapes) -> Gradients layers
2237backPropagate network input target =
2338 fst $ go input network
2439 where
2525- go :: forall j js sublayers. (Head js ~ j, Last js ~ output)
2626- => S' j -- ^ input vector
4040+ go :: forall js sublayers. (Last js ~ Last shapes)
4141+ => S (Head js) -- ^ input vector
2742 -> Network sublayers js -- ^ network to train
2828- -> (Gradients sublayers, S' j)
4343+ -> (Gradients sublayers, S (Head js))
2944 -- handle input from the beginning, feeding upwards.
3045 go !x (layer :~> n)
3146 = let y = runForwards layer x
···44594560 in (OG layer', dWs)
46614747--- | Update a network with new weights after training with an instance.
4848-train :: forall input output shapes layers. (Head shapes ~ input, Last shapes ~ output)
4949- => LearningParameters -- ^ learning rate
5050- -> Network layers shapes -- ^ network to train
5151- -> S' input -> S' output -- ^ target vector
5252- -> Network layers shapes
5353-train rate network input output =
5454- let grads = backPropagate network input output
5555- in applyUpdate rate network grads
5656-6262+-- | Apply one step of stochastic gradient decent across the network.
5763applyUpdate :: LearningParameters -> Network ls ss -> Gradients ls -> Network ls ss
5864applyUpdate rate (O layer) (OG gradient)
5965 = O (runUpdate rate layer gradient)
···6268applyUpdate _ _ _
6369 = error "Impossible for the gradients of a network to have a different length to the network"
64706565--- | Just forwards propagation with no training.
6666-runNet :: Network layers hs
6767- -> S' (Head hs) -- ^ input vector
6868- -> S' (Last hs) -- ^ target vector
7171+-- | Update a network with new weights after training with an instance.
7272+train :: LearningParameters -> Network layers shapes -> S (Head shapes) -> S (Last shapes) -> Network layers shapes
7373+train rate network input output =
7474+ let grads = backPropagate network input output
7575+ in applyUpdate rate network grads
7676+7777+-- | Run the network with input and return the given output.
7878+runNet :: Network layers shapes -> S (Head shapes) -> S (Last shapes)
6979runNet (layer :~> n) !x = let y = runForwards layer x in runNet n y
7080runNet (O layer) !x = runForwards layer x
+140-44
src/Grenade/Core/Shape.hs
···11{-# LANGUAGE DataKinds #-}
22{-# LANGUAGE GADTs #-}
33{-# LANGUAGE KindSignatures #-}
44-{-# LANGUAGE ScopedTypeVariables #-}
55-{-# LANGUAGE TypeOperators #-}
64{-# LANGUAGE TypeFamilies #-}
77-{-# LANGUAGE PolyKinds #-}
88-{-# LANGUAGE MultiParamTypeClasses #-}
55+{-# LANGUAGE TypeOperators #-}
66+{-# LANGUAGE StandaloneDeriving #-}
97{-# LANGUAGE FlexibleContexts #-}
1010-{-# LANGUAGE FlexibleInstances #-}
88+{-# LANGUAGE ScopedTypeVariables #-}
99+{-# LANGUAGE RankNTypes #-}
11101212--- Ghc 8.0 gives a warning on `(+) _ _ = error ...` but ghc 7.10 fails to
1111+-- Ghc 8.0 gives a warning on `n2 _ _ = error ...` but ghc 7.10 fails to
1312-- compile without this default pattern.
1413{-# OPTIONS_GHC -fno-warn-overlapping-patterns #-}
15141515+{-|
1616+Module : Grenade.Core.Shape
1717+Description : Core definition of the Shapes of data we understand
1818+Copyright : (c) Huw Campbell, 2016-2017
1919+License : BSD2
2020+Stability : experimental
2121+2222+This module defines the core data types for the shapes of data that
2323+are understood by Grenade.
2424+-}
1625module Grenade.Core.Shape (
1726 Shape (..)
1818- , S' (..)
2727+ , S (..)
2828+ , randomOfShape
2929+ , fromStorable
1930 ) where
20313232+import Control.DeepSeq (NFData (..))
3333+import Control.Monad.Random ( MonadRandom, getRandom )
3434+3535+import Data.Singletons
2136import Data.Singletons.TypeLits
3737+import Data.Vector.Storable ( Vector )
3838+import qualified Data.Vector.Storable as V
3939+2240import GHC.TypeLits
23414242+import qualified Numeric.LinearAlgebra.Static as H
2443import Numeric.LinearAlgebra.Static
2525-4444+import qualified Numeric.LinearAlgebra as NLA
26452746-- | The current shapes we accept.
2847-- at the moment this is just one, two, and three dimensional
2948-- Vectors/Matricies.
3030-data Shape =
3131- D1 Nat
4949+data Shape
5050+ = D1 Nat
3251 | D2 Nat Nat
3352 | D3 Nat Nat Nat
34533535-instance Num (S' x) where
3636- (+) (S1D' x) (S1D' y) = S1D' (x + y)
3737- (+) (S2D' x) (S2D' y) = S2D' (x + y)
3838- (+) (S3D' x) (S3D' y) = S3D' (x + y)
3939- (+) _ _ = error "Impossible to have different constructors for the same shaped network"
5454+-- | Given a Shape n, these are the possible data structures with that shape.
5555+-- All shapes are held in contiguous memory.
5656+-- 3D is held in a matrix (usually row oriented) which has height depth * rows.
5757+data S (n :: Shape) where
5858+ S1D :: ( KnownNat o ) => R o -> S ('D1 o)
5959+ S2D :: ( KnownNat rows, KnownNat columns ) => L rows columns -> S ('D2 rows columns)
6060+ S3D :: ( KnownNat rows
6161+ , KnownNat columns
6262+ , KnownNat depth
6363+ , KnownNat (rows * depth)) => L (rows * depth) columns -> S ('D3 rows columns depth)
40644141- (-) (S1D' x) (S1D' y) = S1D' (x - y)
4242- (-) (S2D' x) (S2D' y) = S2D' (x - y)
4343- (-) (S3D' x) (S3D' y) = S3D' (x - y)
4444- (-) _ _ = error "Impossible to have different constructors for the same shaped network"
6565+deriving instance Show (S n)
45664646- (*) (S1D' x) (S1D' y) = S1D' (x * y)
4747- (*) (S2D' x) (S2D' y) = S2D' (x * y)
4848- (*) (S3D' x) (S3D' y) = S3D' (x * y)
4949- (*) _ _ = error "Impossible to have different constructors for the same shaped network"
6767+instance SingI x => Num (S x) where
6868+ (+) = n2 (+)
6969+ (-) = n2 (-)
7070+ (*) = n2 (*)
7171+ abs = n1 abs
7272+ signum = n1 signum
7373+ fromInteger x = case (sing :: Sing x) of
7474+ D1Sing -> S1D (konst $ fromInteger x)
7575+ D2Sing -> S2D (konst $ fromInteger x)
7676+ D3Sing -> S3D (konst $ fromInteger x)
50775151- abs (S1D' x) = S1D' (abs x)
5252- abs (S2D' x) = S2D' (abs x)
5353- abs (S3D' x) = S3D' (abs x)
7878+instance SingI x => Fractional (S x) where
7979+ (/) = n2 (/)
8080+ recip = n1 recip
8181+ fromRational x = case (sing :: Sing x) of
8282+ D1Sing -> S1D (konst $ fromRational x)
8383+ D2Sing -> S2D (konst $ fromRational x)
8484+ D3Sing -> S3D (konst $ fromRational x)
54855555- signum (S1D' x) = S1D' (signum x)
5656- signum (S2D' x) = S2D' (signum x)
5757- signum (S3D' x) = S3D' (signum x)
8686+instance SingI x => Floating (S x) where
8787+ pi = case (sing :: Sing x) of
8888+ D1Sing -> S1D (konst pi)
8989+ D2Sing -> S2D (konst pi)
9090+ D3Sing -> S3D (konst pi)
9191+ exp = n1 exp
9292+ log = n1 log
9393+ sqrt = n1 sqrt
9494+ (**) = n2 (**)
9595+ logBase = n2 logBase
9696+ sin = n1 sin
9797+ cos = n1 cos
9898+ tan = n1 tan
9999+ asin = n1 asin
100100+ acos = n1 acos
101101+ atan = n1 atan
102102+ sinh = n1 sinh
103103+ cosh = n1 cosh
104104+ tanh = n1 tanh
105105+ asinh = n1 asinh
106106+ acosh = n1 acosh
107107+ atanh = n1 atanh
581085959- fromInteger _ = error "Unimplemented: fromInteger on Shape"
109109+-- Singletons
110110+-- These could probably be derived with template haskell, but this seems
111111+-- clear and makes adding the KnownNat constraints simple.
112112+data instance Sing (n :: Shape) where
113113+ D1Sing :: KnownNat a => Sing ('D1 a)
114114+ D2Sing :: (KnownNat a, KnownNat b) => Sing ('D2 a b)
115115+ D3Sing :: (KnownNat a, KnownNat b, KnownNat c, KnownNat (a * c)) => Sing ('D3 a b c)
601166161--- | Given a Shape n, these are the possible data structures with that shape.
6262--- All shapes are held in contiguous memory.
6363--- 3D is held in a matrix (usually row oriented) which has height depth * rows.
6464-data S' (n :: Shape) where
6565- S1D' :: ( KnownNat o ) => R o -> S' ('D1 o)
6666- S2D' :: ( KnownNat rows, KnownNat columns ) => L rows columns -> S' ('D2 rows columns)
6767- S3D' :: ( KnownNat rows
6868- , KnownNat columns
6969- , KnownNat depth
7070- , KnownNat (rows * depth)) => L (rows * depth) columns -> S' ('D3 rows columns depth)
117117+instance KnownNat a => SingI ('D1 a) where
118118+ sing = D1Sing
119119+instance (KnownNat a, KnownNat b) => SingI ('D2 a b) where
120120+ sing = D2Sing
121121+instance (KnownNat a, KnownNat b, KnownNat c, KnownNat (a * c)) => SingI ('D3 a b c) where
122122+ sing = D3Sing
711237272-instance Show (S' n) where
7373- show (S1D' a) = "S1D' " ++ show a
7474- show (S2D' a) = "S2D' " ++ show a
7575- show (S3D' a) = "S3D' " ++ show a
124124+--
125125+-- I haven't made shapes strict, as sometimes they're not needed
126126+-- (the last input gradient back for instance)
127127+--
128128+instance NFData (S x) where
129129+ rnf (S1D x) = rnf x
130130+ rnf (S2D x) = rnf x
131131+ rnf (S3D x) = rnf x
132132+133133+-- | Generate random data of the desired shape
134134+randomOfShape :: forall x m. ( MonadRandom m, SingI x ) => m (S x)
135135+randomOfShape = do
136136+ seed :: Int <- getRandom
137137+ return $ case (sing :: Sing x) of
138138+ D1Sing -> S1D (randomVector seed Uniform * 2 - 1)
139139+ D2Sing -> S2D (uniformSample seed (-1) 1)
140140+ D3Sing -> S3D (uniformSample seed (-1) 1)
141141+142142+-- | Generate a shape from a Storable Vector.
143143+--
144144+-- Returns Nothing if the vector is of the wrong size.
145145+fromStorable :: forall x. SingI x => Vector Double -> Maybe (S x)
146146+fromStorable xs = case sing :: Sing x of
147147+ D1Sing -> S1D <$> H.create xs
148148+ D2Sing -> S2D <$> mkL xs
149149+ D3Sing -> S3D <$> mkL xs
150150+ where
151151+ mkL :: forall rows columns. (KnownNat rows, KnownNat columns)
152152+ => Vector Double -> Maybe (L rows columns)
153153+ mkL v =
154154+ let rows = fromIntegral $ natVal (Proxy :: Proxy rows)
155155+ columns = fromIntegral $ natVal (Proxy :: Proxy columns)
156156+ in if rows * columns == V.length v
157157+ then H.create $ NLA.reshape columns v
158158+ else Nothing
159159+160160+-- Helper function for creating the number instances
161161+n1 :: ( forall a. Floating a => a -> a ) -> S x -> S x
162162+n1 f (S1D x) = S1D (f x)
163163+n1 f (S2D x) = S2D (f x)
164164+n1 f (S3D x) = S3D (f x)
165165+166166+-- Helper function for creating the number instances
167167+n2 :: ( forall a. Floating a => a -> a -> a ) -> S x -> S x -> S x
168168+n2 f (S1D x) (S1D y) = S1D (f x y)
169169+n2 f (S2D x) (S2D y) = S2D (f x y)
170170+n2 f (S3D x) (S3D y) = S3D (f x y)
171171+n2 _ _ _ = error "Impossible to have different constructors for the same shaped network"
+37
src/Grenade/Graph/GraphNetwork.hs
···11+{-# LANGUAGE DataKinds #-}
22+{-# LANGUAGE GADTs #-}
33+{-# LANGUAGE KindSignatures #-}
44+{-# LANGUAGE ScopedTypeVariables #-}
55+{-# LANGUAGE TypeOperators #-}
66+{-# LANGUAGE TypeFamilies #-}
77+{-# LANGUAGE PolyKinds #-}
88+{-# LANGUAGE MultiParamTypeClasses #-}
99+{-# LANGUAGE FlexibleContexts #-}
1010+{-# LANGUAGE FlexibleInstances #-}
1111+{-# LANGUAGE LambdaCase #-}
1212+1313+module Grenade.Graph.Network (
1414+ Layer (..)
1515+ , UpdateLayer (..)
1616+ ) where
1717+1818+import Control.Monad.Random (MonadRandom)
1919+import Data.Singletons
2020+import Data.Singletons.Prelude
2121+2222+import GHC.TypeLits
2323+2424+import Grenade.Core.Shape
2525+import Grenade.Core.Network ( UpdateLayer (..), Layer (..) )
2626+2727+-- | Type of a DAG network
2828+2929+data Fin :: Nat -> * where
3030+ Fin0 :: Fin (n + 1)
3131+ FinS :: Fin n -> Fin (n + 1)
3232+3333+data Edge :: Nat -> * where
3434+ Edge :: Shape -> Fin n -> Edge n
3535+3636+data Node a n where
3737+ Node :: a -> [Edge n] -> Node a n
+26-32
src/Grenade/Layers/Convolution.hs
···11-{-# LANGUAGE BangPatterns #-}
21{-# LANGUAGE DataKinds #-}
32{-# LANGUAGE ScopedTypeVariables #-}
44-{-# LANGUAGE StandaloneDeriving #-}
53{-# LANGUAGE RecordWildCards #-}
64{-# LANGUAGE GADTs #-}
75{-# LANGUAGE TypeOperators #-}
···97{-# LANGUAGE MultiParamTypeClasses #-}
108{-# LANGUAGE FlexibleInstances #-}
119{-# LANGUAGE FlexibleContexts #-}
1212-{-# LANGUAGE PolyKinds #-}
1313-{-# LANGUAGE PatternGuards #-}
1414-1510module Grenade.Layers.Convolution (
1611 Convolution (..)
1712 , Convolution' (..)
···3126import Grenade.Core.Network
3227import Grenade.Core.Shape
3328import Grenade.Layers.Internal.Convolution
2929+import Grenade.Layers.Internal.Update
34303531-- | A convolution layer for a neural network.
3632-- This uses the im2col convolution trick popularised by Caffe, which essentially turns the
···4339-- `out = (in - kernel) / stride + 1` for both dimensions.
4440--
4541-- One probably shouldn't build their own layer, but rather use the randomConvolution function.
4646-data Convolution :: Nat -- ^ Number of channels, for the first layer this could be RGB for instance.
4747- -> Nat -- ^ Number of filters, this is the number of channels output by the layer.
4848- -> Nat -- ^ The number of rows in the kernel filter
4949- -> Nat -- ^ The number of column in the kernel filter
5050- -> Nat -- ^ The row stride of the convolution filter
5151- -> Nat -- ^ The columns stride of the convolution filter
4242+data Convolution :: Nat -- Number of channels, for the first layer this could be RGB for instance.
4343+ -> Nat -- Number of filters, this is the number of channels output by the layer.
4444+ -> Nat -- The number of rows in the kernel filter
4545+ -> Nat -- The number of column in the kernel filter
4646+ -> Nat -- The row stride of the convolution filter
4747+ -> Nat -- The columns stride of the convolution filter
5248 -> * where
5349 Convolution :: ( KnownNat channels
5450 , KnownNat filters
···5854 , KnownNat strideColumns
5955 , KnownNat kernelFlattened
6056 , kernelFlattened ~ (kernelRows * kernelColumns * channels))
6161- => !(L kernelFlattened filters) -- ^ The kernel filter weights
6262- -> !(L kernelFlattened filters) -- ^ The last kernel update (or momentum)
5757+ => !(L kernelFlattened filters) -- The kernel filter weights
5858+ -> !(L kernelFlattened filters) -- The last kernel update (or momentum)
6359 -> Convolution channels filters kernelRows kernelColumns strideRows strideColumns
64606565-data Convolution' :: Nat -- ^ Number of channels, for the first layer this could be RGB for instance.
6666- -> Nat -- ^ Number of filters, this is the number of channels output by the layer.
6767- -> Nat -- ^ The number of rows in the kernel filter
6868- -> Nat -- ^ The number of column in the kernel filter
6969- -> Nat -- ^ The row stride of the convolution filter
7070- -> Nat -- ^ The columns stride of the convolution filter
6161+data Convolution' :: Nat -- Number of channels, for the first layer this could be RGB for instance.
6262+ -> Nat -- Number of filters, this is the number of channels output by the layer.
6363+ -> Nat -- The number of rows in the kernel filter
6464+ -> Nat -- The number of column in the kernel filter
6565+ -> Nat -- The row stride of the convolution filter
6666+ -> Nat -- The columns stride of the convolution filter
7167 -> * where
7268 Convolution' :: ( KnownNat channels
7369 , KnownNat filters
···7773 , KnownNat strideColumns
7874 , KnownNat kernelFlattened
7975 , kernelFlattened ~ (kernelRows * kernelColumns * channels))
8080- => !(L kernelFlattened filters) -- ^ The kernel filter gradient
7676+ => !(L kernelFlattened filters) -- The kernel filter gradient
8177 -> Convolution' channels filters kernelRows kernelColumns strideRows strideColumns
82788379instance Show (Convolution c f k k' s s') where
···109105 , kernelFlattened ~ (kernelRows * kernelColumns * channels))
110106 => m (Convolution channels filters kernelRows kernelColumns strideRows strideColumns)
111107randomConvolution = do
112112- s :: Int <- getRandom
108108+ s <- getRandom
113109 let wN = uniformSample s (-1) 1
114110 mm = konst 0
115111 return $ Convolution wN mm
···124120 ) => UpdateLayer (Convolution channels filters kernelRows kernelColumns strideRows strideColumns) where
125121 type Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols) = (Convolution' channels filters kernelRows kernelCols strideRows strideCols)
126122 runUpdate LearningParameters {..} (Convolution oldKernel oldMomentum) (Convolution' kernelGradient) =
127127- let newMomentum = konst learningMomentum * oldMomentum - konst learningRate * kernelGradient
128128- regulariser = konst (learningRegulariser * learningRate) * oldKernel
129129- newKernel = oldKernel + newMomentum - regulariser
123123+ let (newKernel, newMomentum) = decendMatrix learningRate learningMomentum learningRegulariser oldKernel kernelGradient oldMomentum
130124 in Convolution newKernel newMomentum
131125132126 createRandom = randomConvolution
···146140 , KnownNat (kernelRows * kernelCols * 1)
147141 , KnownNat (outputRows * filters)
148142 ) => Layer (Convolution 1 filters kernelRows kernelCols strideRows strideCols) ('D2 inputRows inputCols) ('D3 outputRows outputCols filters) where
149149- runForwards (Convolution kernel _) (S2D' input) =
143143+ runForwards (Convolution kernel _) (S2D input) =
150144 let ex = extract input
151145 ek = extract kernel
152146 kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
···159153 mt = c LA.<> ek
160154 r = col2vid 1 1 1 1 ox oy mt
161155 rs = fromJust . create $ r
162162- in S3D' rs
156156+ in S3D rs
163157164164- runBackwards (Convolution kernel _) (S2D' input) (S3D' dEdy) =
158158+ runBackwards (Convolution kernel _) (S2D input) (S3D dEdy) =
165159 let ex = extract input
166160 ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
167161 iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
···183177 dW = vs LA.<> tr ek
184178185179 xW = col2im kx ky sx sy ix iy dW
186186- in (Convolution' kN, S2D' . fromJust . create $ xW)
180180+ in (Convolution' kN, S2D . fromJust . create $ xW)
187181188182189183-- | A three dimensional image (or 2d with many channels) can have
···203197 , KnownNat (kernelRows * kernelCols * channels)
204198 , KnownNat (outputRows * filters)
205199 ) => Layer (Convolution channels filters kernelRows kernelCols strideRows strideCols) ('D3 inputRows inputCols channels) ('D3 outputRows outputCols filters) where
206206- runForwards (Convolution kernel _) (S3D' input) =
200200+ runForwards (Convolution kernel _) (S3D input) =
207201 let ex = extract input
208202 ek = extract kernel
209203 ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
···219213 mt = c LA.<> ek
220214 r = col2vid 1 1 1 1 ox oy mt
221215 rs = fromJust . create $ r
222222- in S3D' rs
223223- runBackwards (Convolution kernel _) (S3D' input) (S3D' dEdy) =
216216+ in S3D rs
217217+ runBackwards (Convolution kernel _) (S3D input) (S3D dEdy) =
224218 let ex = extract input
225219 ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
226220 iy = fromIntegral $ natVal (Proxy :: Proxy inputCols)
···243237 dW = vs LA.<> tr ek
244238245239 xW = col2vid kx ky sx sy ix iy dW
246246- in (Convolution' kN, S3D' . fromJust . create $ xW)
240240+ in (Convolution' kN, S3D . fromJust . create $ xW)
···11-{-# LANGUAGE BangPatterns #-}
21{-# LANGUAGE DataKinds #-}
33-{-# LANGUAGE ScopedTypeVariables #-}
44-{-# LANGUAGE StandaloneDeriving #-}
52{-# LANGUAGE TypeOperators #-}
63{-# LANGUAGE TypeFamilies #-}
74{-# LANGUAGE MultiParamTypeClasses #-}
85{-# LANGUAGE FlexibleContexts #-}
99-{-# LANGUAGE FlexibleInstances #-}
1010-116module Grenade.Layers.Flatten (
127 FlattenLayer (..)
138 ) where
···1611import GHC.TypeLits
17121813import Numeric.LinearAlgebra.Static
1919-import Numeric.LinearAlgebra.Data as LA (flatten, toList)
1414+import Numeric.LinearAlgebra.Data as LA ( flatten )
20152116import Grenade.Core.Shape
2217import Grenade.Core.Network
23181919+-- | Flatten Layer
2020+--
2121+-- Flattens input down to D1 from either 2D or 3D data.
2222+--
2323+-- Can also be used to turn a 3D image with only one channel into a 2D image.
2424data FlattenLayer = FlattenLayer
2525 deriving Show
2626···2929 runUpdate _ _ _ = FlattenLayer
3030 createRandom = return FlattenLayer
31313232-3332instance (KnownNat a, KnownNat x, KnownNat y, a ~ (x * z)) => Layer FlattenLayer ('D2 x y) ('D1 a) where
3434- runForwards _ (S2D' y) = S1D' . fromList . toList . flatten . extract $ y
3535- runBackwards _ _ (S1D' y) = ((), S2D' . fromList . toList . unwrap $ y)
3333+ runForwards _ (S2D y) = fromJust' . fromStorable . flatten . extract $ y
3434+ runBackwards _ _ (S1D y) = ((), fromJust' . fromStorable . extract $ y)
36353736instance (KnownNat a, KnownNat x, KnownNat y, KnownNat (x * z), KnownNat z, a ~ (x * y * z)) => Layer FlattenLayer ('D3 x y z) ('D1 a) where
3838- runForwards _ (S3D' y) = S1D' . fromList . toList . flatten . extract $ y
3939- runBackwards _ _ (S1D' y) = ((), S3D' . fromList . toList . unwrap $ y)
3737+ runForwards _ (S3D y) = fromJust' . fromStorable . flatten . extract $ y
3838+ runBackwards _ _ (S1D y) = ((), fromJust' . fromStorable . extract $ y)
3939+4040+instance (KnownNat y, KnownNat x, KnownNat z, z ~ 1) => Layer FlattenLayer ('D3 x y z) ('D2 x y) where
4141+ runForwards _ (S3D y) = S2D y
4242+ runBackwards _ _ (S2D y) = ((), S3D y)
4343+4444+fromJust' :: Maybe x -> x
4545+fromJust' (Just x) = x
4646+fromJust' Nothing = error $ "FlattenLayer error: data shape couldn't be converted."
+9-13
src/Grenade/Layers/FullyConnected.hs
···11{-# LANGUAGE DataKinds #-}
22-{-# LANGUAGE ScopedTypeVariables #-}
32{-# LANGUAGE RecordWildCards #-}
43{-# LANGUAGE TypeOperators #-}
54{-# LANGUAGE TypeFamilies #-}
65{-# LANGUAGE MultiParamTypeClasses #-}
77-{-# LANGUAGE FlexibleInstances #-}
88-96module Grenade.Layers.FullyConnected (
107 FullyConnected (..)
118 , randomFullyConnected
···2017import Grenade.Core.Network
2118import Grenade.Core.Shape
22192020+import Grenade.Layers.Internal.Update
2121+2322-- | A basic fully connected (or inner product) neural network layer.
2423data FullyConnected i o = FullyConnected
2524 !(R o) -- Bias neuron weights
···3837 type Gradient (FullyConnected i o) = (FullyConnected' i o)
39384039 runUpdate LearningParameters {..} (FullyConnected oldBias oldBiasMomentum oldActivations oldMomentum) (FullyConnected' biasGradient activationGradient) =
4141- let newBiasMomentum = konst learningMomentum * oldBiasMomentum - konst learningRate * biasGradient
4242- newBias = oldBias + newBiasMomentum
4343- newMomentum = konst learningMomentum * oldMomentum - konst learningRate * activationGradient
4444- regulariser = konst (learningRegulariser * learningRate) * oldActivations
4545- newActivations = oldActivations + newMomentum - regulariser
4040+ let (newBias, newBiasMomentum) = decendVector learningRate learningMomentum learningRegulariser oldBias biasGradient oldBiasMomentum
4141+ (newActivations, newMomentum) = decendMatrix learningRate learningMomentum learningRegulariser oldActivations activationGradient oldMomentum
4642 in FullyConnected newBias newBiasMomentum newActivations newMomentum
47434844 createRandom = randomFullyConnected
49455046instance (KnownNat i, KnownNat o) => Layer (FullyConnected i o) ('D1 i) ('D1 o) where
5147 -- Do a matrix vector multiplication and return the result.
5252- runForwards (FullyConnected wB _ wN _) (S1D' v) = S1D' (wB + wN #> v)
4848+ runForwards (FullyConnected wB _ wN _) (S1D v) = S1D (wB + wN #> v)
53495450 -- Run a backpropogation step for a full connected layer.
5555- runBackwards (FullyConnected _ _ wN _) (S1D' x) (S1D' dEdy) =
5151+ runBackwards (FullyConnected _ _ wN _) (S1D x) (S1D dEdy) =
5652 let wB' = dEdy
5753 mm' = dEdy `outer` x
5854 -- calcluate derivatives for next step
5955 dWs = tr wN #> dEdy
6060- in (FullyConnected' wB' mm', S1D' dWs)
5656+ in (FullyConnected' wB' mm', S1D dWs)
61576258randomFullyConnected :: (MonadRandom m, KnownNat i, KnownNat o)
6359 => m (FullyConnected i o)
6460randomFullyConnected = do
6565- s1 :: Int <- getRandom
6666- s2 :: Int <- getRandom
6161+ s1 <- getRandom
6262+ s2 <- getRandom
6763 let wB = randomVector s1 Uniform * 2 - 1
6864 wN = uniformSample s2 (-1) 1
6965 bm = konst 0
+2-7
src/Grenade/Layers/Fuse.hs
···11-{-# LANGUAGE BangPatterns #-}
21{-# LANGUAGE DataKinds #-}
32{-# LANGUAGE GADTs #-}
44-{-# LANGUAGE KindSignatures #-}
53{-# LANGUAGE ScopedTypeVariables #-}
64{-# LANGUAGE TypeOperators #-}
75{-# LANGUAGE TypeFamilies #-}
88-{-# LANGUAGE PolyKinds #-}
96{-# LANGUAGE MultiParamTypeClasses #-}
107{-# LANGUAGE FlexibleContexts #-}
118{-# LANGUAGE FlexibleInstances #-}
1212-1313-149module Grenade.Layers.Fuse (
1510 Fuse (..)
1611 ) where
···42374338instance (Layer x i h, Layer y h o) => Layer (Fuse x y i h o) i o where
4439 runForwards (x :$$ y) input =
4545- let yInput :: S' h = runForwards x input
4040+ let yInput :: S h = runForwards x input
4641 in runForwards y yInput
47424843 runBackwards (x :$$ y) input backGradient =
4949- let yInput :: S' h = runForwards x input
4444+ let yInput :: S h = runForwards x input
5045 (y', yGrad) = runBackwards y yInput backGradient
5146 (x', xGrad) = runBackwards x input yGrad
5247 in ((x', y'), xGrad)
+13-11
src/Grenade/Layers/Internal/Convolution.hs
···66 , vid2col
77 ) where
8899-import Foreign ( mallocForeignPtrArray0, withForeignPtr )
99+import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 )
1010+1111+import Foreign ( mallocForeignPtrArray, withForeignPtr )
1012import Foreign.Ptr ( Ptr )
11131214import Numeric.LinearAlgebra ( Matrix, flatten, rows, cols )
···2830col2im_c channels height width kernelRows kernelColumns strideRows strideColumns dataCol =
2931 let vec = flatten dataCol
3032 in unsafePerformIO $ do
3131- outPtr <- mallocForeignPtrArray0 (height * width * channels)
3232- let (inPtr, inOffset, _) = U.unsafeToForeignPtr vec
3333+ outPtr <- mallocForeignPtrArray (height * width * channels)
3434+ let (inPtr, _) = U.unsafeToForeignPtr0 vec
33353436 withForeignPtr inPtr $ \inPtr' ->
3537 withForeignPtr outPtr $ \outPtr' ->
3636- col2im_cpu inPtr' inOffset channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
3838+ col2im_cpu inPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
37393838- let matVec = U.unsafeFromForeignPtr outPtr 0 (height * width * channels)
4040+ let matVec = U.unsafeFromForeignPtr0 outPtr (height * width * channels)
3941 return $ U.matrixFromVector U.RowMajor (height * channels) width matVec
40424143foreign import ccall unsafe
4244 col2im_cpu
4343- :: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
4545+ :: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
44464547vid2col :: Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double
4648vid2col kernelRows kernelColumns strideRows strideColumns height width dataVid =
···6365 kernelSize = kernelRows * kernelColumns
6466 numberOfPatches = rowOut * colOut
6567 in unsafePerformIO $ do
6666- outPtr <- mallocForeignPtrArray0 (numberOfPatches * kernelSize * channels)
6767- let (inPtr, inOffset, _) = U.unsafeToForeignPtr vec
6868+ outPtr <- mallocForeignPtrArray (numberOfPatches * kernelSize * channels)
6969+ let (inPtr, _) = U.unsafeToForeignPtr0 vec
68706971 withForeignPtr inPtr $ \inPtr' ->
7072 withForeignPtr outPtr $ \outPtr' ->
7171- im2col_cpu inPtr' inOffset channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
7373+ im2col_cpu inPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
72747373- let matVec = U.unsafeFromForeignPtr outPtr 0 (numberOfPatches * kernelSize * channels)
7575+ let matVec = U.unsafeFromForeignPtr0 outPtr (numberOfPatches * kernelSize * channels)
7476 return $ U.matrixFromVector U.RowMajor numberOfPatches (kernelSize * channels) matVec
75777678foreign import ccall unsafe
7779 im2col_cpu
7878- :: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
8080+ :: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
+14-12
src/Grenade/Layers/Internal/Pooling.hs
···44 , poolBackward
55 ) where
6677-import Foreign ( mallocForeignPtrArray0, withForeignPtr )
77+import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 )
88+99+import Foreign ( mallocForeignPtrArray, withForeignPtr )
810import Foreign.Ptr ( Ptr )
9111012import Numeric.LinearAlgebra ( Matrix , flatten )
···1921 colOut = (width - kernelColumns) `div` strideColumns + 1
2022 numberOfPatches = rowOut * colOut
2123 in unsafePerformIO $ do
2222- outPtr <- mallocForeignPtrArray0 (numberOfPatches * channels)
2323- let (inPtr, inOffset, _) = U.unsafeToForeignPtr vec
2424+ outPtr <- mallocForeignPtrArray (numberOfPatches * channels)
2525+ let (inPtr, _) = U.unsafeToForeignPtr0 vec
24262527 withForeignPtr inPtr $ \inPtr' ->
2628 withForeignPtr outPtr $ \outPtr' ->
2727- pool_forwards_cpu inPtr' inOffset channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
2929+ pool_forwards_cpu inPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
28302929- let matVec = U.unsafeFromForeignPtr outPtr 0 (numberOfPatches * channels)
3131+ let matVec = U.unsafeFromForeignPtr0 outPtr (numberOfPatches * channels)
3032 return $ U.matrixFromVector U.RowMajor (rowOut * channels) colOut matVec
31333234foreign import ccall unsafe
3335 pool_forwards_cpu
3434- :: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
3636+ :: Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
35373638poolBackward :: Int -> Int -> Int -> Int -> Int -> Int -> Int -> Matrix Double -> Matrix Double -> Matrix Double
3739poolBackward channels height width kernelRows kernelColumns strideRows strideColumns dataIm dataGrad =
3840 let vecIm = flatten dataIm
3941 vecGrad = flatten dataGrad
4042 in unsafePerformIO $ do
4141- outPtr <- mallocForeignPtrArray0 (height * width * channels)
4242- let (imPtr, imOffset, _) = U.unsafeToForeignPtr vecIm
4343- let (gradPtr, gradOffset, _) = U.unsafeToForeignPtr vecGrad
4343+ outPtr <- mallocForeignPtrArray (height * width * channels)
4444+ let (imPtr, _) = U.unsafeToForeignPtr0 vecIm
4545+ let (gradPtr, _) = U.unsafeToForeignPtr0 vecGrad
44464547 withForeignPtr imPtr $ \imPtr' ->
4648 withForeignPtr gradPtr $ \gradPtr' ->
4749 withForeignPtr outPtr $ \outPtr' ->
4848- pool_backwards_cpu imPtr' imOffset gradPtr' gradOffset channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
5050+ pool_backwards_cpu imPtr' gradPtr' channels height width kernelRows kernelColumns strideRows strideColumns outPtr'
49515050- let matVec = U.unsafeFromForeignPtr outPtr 0 (height * width * channels)
5252+ let matVec = U.unsafeFromForeignPtr0 outPtr (height * width * channels)
5153 return $ U.matrixFromVector U.RowMajor (height * channels) width matVec
52545355foreign import ccall unsafe
5456 pool_backwards_cpu
5555- :: Ptr Double -> Int -> Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
5757+ :: Ptr Double -> Ptr Double -> Int -> Int -> Int -> Int -> Int -> Int -> Int -> Ptr Double -> IO ()
+70
src/Grenade/Layers/Internal/Update.hs
···11+{-# LANGUAGE ForeignFunctionInterface #-}
22+module Grenade.Layers.Internal.Update (
33+ decendMatrix
44+ , decendVector
55+ ) where
66+77+import Data.Maybe ( fromJust )
88+import qualified Data.Vector.Storable as U ( unsafeToForeignPtr0, unsafeFromForeignPtr0 )
99+1010+import Foreign ( mallocForeignPtrArray, withForeignPtr )
1111+import Foreign.Ptr ( Ptr )
1212+import GHC.TypeLits
1313+1414+import Numeric.LinearAlgebra ( Vector, flatten )
1515+import Numeric.LinearAlgebra.Static
1616+import qualified Numeric.LinearAlgebra.Devel as U
1717+1818+import System.IO.Unsafe ( unsafePerformIO )
1919+2020+decendMatrix :: (KnownNat rows, KnownNat columns) => Double -> Double -> Double -> L rows columns -> L rows columns -> L rows columns -> (L rows columns, L rows columns)
2121+decendMatrix rate momentum regulariser weights gradient lastUpdate =
2222+ let (rows, cols) = size weights
2323+ len = rows * cols
2424+ -- Most gradients come in in ColumnMajor,
2525+ -- so we'll transpose here before flattening them
2626+ -- into a vector to prevent a copy.
2727+ --
2828+ -- This gives ~15% speed improvement for LSTMs.
2929+ weights' = flatten . tr . extract $ weights
3030+ gradient' = flatten . tr . extract $ gradient
3131+ lastUpdate' = flatten . tr . extract $ lastUpdate
3232+ (vw, vm) = decendUnsafe len rate momentum regulariser weights' gradient' lastUpdate'
3333+3434+ -- Note that it's ColumnMajor, as we did a transpose before
3535+ -- using the internal vectors.
3636+ mw = U.matrixFromVector U.ColumnMajor rows cols vw
3737+ mm = U.matrixFromVector U.ColumnMajor rows cols vm
3838+ in (fromJust . create $ mw, fromJust . create $ mm)
3939+4040+decendVector :: (KnownNat r) => Double -> Double -> Double -> R r -> R r -> R r -> (R r, R r)
4141+decendVector rate momentum regulariser weights gradient lastUpdate =
4242+ let len = size weights
4343+ weights' = extract weights
4444+ gradient' = extract gradient
4545+ lastUpdate' = extract lastUpdate
4646+ (vw, vm) = decendUnsafe len rate momentum regulariser weights' gradient' lastUpdate'
4747+ in (fromJust $ create vw, fromJust $ create vm)
4848+4949+decendUnsafe :: Int -> Double -> Double -> Double -> Vector Double -> Vector Double -> Vector Double -> (Vector Double, Vector Double)
5050+decendUnsafe len rate momentum regulariser weights gradient lastUpdate =
5151+ unsafePerformIO $ do
5252+ outWPtr <- mallocForeignPtrArray len
5353+ outMPtr <- mallocForeignPtrArray len
5454+ let (wPtr, _) = U.unsafeToForeignPtr0 weights
5555+ let (gPtr, _) = U.unsafeToForeignPtr0 gradient
5656+ let (lPtr, _) = U.unsafeToForeignPtr0 lastUpdate
5757+5858+ withForeignPtr wPtr $ \wPtr' ->
5959+ withForeignPtr gPtr $ \gPtr' ->
6060+ withForeignPtr lPtr $ \lPtr' ->
6161+ withForeignPtr outWPtr $ \outWPtr' ->
6262+ withForeignPtr outMPtr $ \outMPtr' ->
6363+ decend_cpu len rate momentum regulariser wPtr' gPtr' lPtr' outWPtr' outMPtr'
6464+6565+ return (U.unsafeFromForeignPtr0 outWPtr len, U.unsafeFromForeignPtr0 outMPtr len)
6666+6767+foreign import ccall unsafe
6868+ decend_cpu
6969+ :: Int -> Double -> Double -> Double -> Ptr Double -> Ptr Double -> Ptr Double -> Ptr Double -> Ptr Double -> IO ()
7070+
+6-10
src/Grenade/Layers/Logit.hs
···11{-# LANGUAGE DataKinds #-}
22-{-# LANGUAGE ScopedTypeVariables #-}
32{-# LANGUAGE TypeOperators #-}
43{-# LANGUAGE TypeFamilies #-}
54{-# LANGUAGE MultiParamTypeClasses #-}
66-{-# LANGUAGE FlexibleInstances #-}
77-85module Grenade.Layers.Logit (
96 Logit (..)
107 ) where
···2724 createRandom = return Logit
28252926instance (KnownNat i) => Layer Logit ('D1 i) ('D1 i) where
3030- runForwards _ (S1D' y) = S1D' (logistic y)
3131- runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (logistic' y * dEdy))
2727+ runForwards _ (S1D y) = S1D (logistic y)
2828+ runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (logistic' y * dEdy))
32293330instance (KnownNat i, KnownNat j) => Layer Logit ('D2 i j) ('D2 i j) where
3434- runForwards _ (S2D' y) = S2D' (logistic y)
3535- runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (logistic' y * dEdy))
3131+ runForwards _ (S2D y) = S2D (logistic y)
3232+ runBackwards _ (S2D y) (S2D dEdy) = ((), S2D (logistic' y * dEdy))
36333734instance (KnownNat i, KnownNat j, KnownNat k) => Layer Logit ('D3 i j k) ('D3 i j k) where
3838- runForwards _ (S3D' y) = S3D' (logistic y)
3939- runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (logistic' y * dEdy))
4040-3535+ runForwards _ (S3D y) = S3D (logistic y)
3636+ runBackwards _ (S3D y) (S3D dEdy) = ((), S3D (logistic' y * dEdy))
41374238logistic :: Floating a => a -> a
4339logistic x = 1 / (1 + exp (-x))
+4-8
src/Grenade/Layers/Pad.hs
···44{-# LANGUAGE TypeOperators #-}
55{-# LANGUAGE TypeFamilies #-}
66{-# LANGUAGE MultiParamTypeClasses #-}
77-{-# LANGUAGE FlexibleInstances #-}
88-{-# LANGUAGE FlexibleContexts #-}
99-{-# LANGUAGE PolyKinds #-}
1010-117module Grenade.Layers.Pad (
128 Pad (..)
139 ) where
···5046 , (inputRows + padTop + padBottom) ~ outputRows
5147 , (inputColumns + padLeft + padRight) ~ outputColumns
5248 ) => Layer (Pad padLeft padTop padRight padBottom) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where
5353- runForwards Pad (S2D' input) =
4949+ runForwards Pad (S2D input) =
5450 let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft)
5551 padt = fromIntegral $ natVal (Proxy :: Proxy padTop)
5652 padr = fromIntegral $ natVal (Proxy :: Proxy padRight)
5753 padb = fromIntegral $ natVal (Proxy :: Proxy padBottom)
5854 m = extract input
5955 r = diagBlock [konst 0 (padt,padl), m, konst 0 (padb,padr)]
6060- in S2D' . fromJust . create $ r
6161- runBackwards Pad _ (S2D' dEdy) =
5656+ in S2D . fromJust . create $ r
5757+ runBackwards Pad _ (S2D dEdy) =
6258 let padl = fromIntegral $ natVal (Proxy :: Proxy padLeft)
6359 padt = fromIntegral $ natVal (Proxy :: Proxy padTop)
6460 nrows = fromIntegral $ natVal (Proxy :: Proxy inputRows)
6561 ncols = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
6662 m = extract dEdy
6763 vs = subMatrix (padt, padl) (nrows, ncols) m
6868- in ((), S2D' . fromJust . create $ vs)
6464+ in ((), S2D . fromJust . create $ vs)
+8-12
src/Grenade/Layers/Pooling.hs
···11-{-# LANGUAGE BangPatterns #-}
21{-# LANGUAGE DataKinds #-}
32{-# LANGUAGE ScopedTypeVariables #-}
43{-# LANGUAGE StandaloneDeriving #-}
···65{-# LANGUAGE TypeOperators #-}
76{-# LANGUAGE TypeFamilies #-}
87{-# LANGUAGE MultiParamTypeClasses #-}
99-{-# LANGUAGE FlexibleInstances #-}
108{-# LANGUAGE FlexibleContexts #-}
1111-{-# LANGUAGE PolyKinds #-}
1212-139module Grenade.Layers.Pooling (
1410 Pooling (..)
1511 ) where
···5551 , ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows)
5652 , ((outputColumns - 1) * strideColumns) ~ (inputColumns - kernelColumns)
5753 ) => Layer (Pooling kernelRows kernelColumns strideRows strideColumns) ('D2 inputRows inputColumns) ('D2 outputRows outputColumns) where
5858- runForwards Pooling (S2D' input) =
5454+ runForwards Pooling (S2D input) =
5955 let height = fromIntegral $ natVal (Proxy :: Proxy inputRows)
6056 width = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
6157 kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
···6561 ex = extract input
6662 r = poolForward 1 height width kx ky sx sy ex
6763 rs = fromJust . create $ r
6868- in S2D' $ rs
6969- runBackwards Pooling (S2D' input) (S2D' dEdy) =
6464+ in S2D $ rs
6565+ runBackwards Pooling (S2D input) (S2D dEdy) =
7066 let height = fromIntegral $ natVal (Proxy :: Proxy inputRows)
7167 width = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
7268 kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
···7672 ex = extract input
7773 eo = extract dEdy
7874 vs = poolBackward 1 height width kx ky sx sy ex eo
7979- in ((), S2D' . fromJust . create $ vs)
7575+ in ((), S2D . fromJust . create $ vs)
807681778278-- | A three dimensional image can be pooled on each layer.
···9389 , ((outputRows - 1) * strideRows) ~ (inputRows - kernelRows)
9490 , ((outputColumns - 1) * strideColumns) ~ (inputColumns - kernelColumns)
9591 ) => Layer (Pooling kernelRows kernelColumns strideRows strideColumns) ('D3 inputRows inputColumns channels) ('D3 outputRows outputColumns channels) where
9696- runForwards Pooling (S3D' input) =
9292+ runForwards Pooling (S3D input) =
9793 let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
9894 iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
9995 kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
···104100 ex = extract input
105101 r = poolForward ch ix iy kx ky sx sy ex
106102 rs = fromJust . create $ r
107107- in S3D' rs
108108- runBackwards Pooling (S3D' input) (S3D' dEdy) =
103103+ in S3D rs
104104+ runBackwards Pooling (S3D input) (S3D dEdy) =
109105 let ix = fromIntegral $ natVal (Proxy :: Proxy inputRows)
110106 iy = fromIntegral $ natVal (Proxy :: Proxy inputColumns)
111107 kx = fromIntegral $ natVal (Proxy :: Proxy kernelRows)
···116112 ex = extract input
117113 eo = extract dEdy
118114 vs = poolBackward ch ix iy kx ky sx sy ex eo
119119- in ((), S3D' . fromJust . create $ vs)
115115+ in ((), S3D . fromJust . create $ vs)
+6-9
src/Grenade/Layers/Relu.hs
···11{-# LANGUAGE DataKinds #-}
22-{-# LANGUAGE ScopedTypeVariables #-}
32{-# LANGUAGE TypeOperators #-}
43{-# LANGUAGE TypeFamilies #-}
54{-# LANGUAGE MultiParamTypeClasses #-}
66-{-# LANGUAGE FlexibleInstances #-}
77-85module Grenade.Layers.Relu (
96 Relu (..)
107 ) where
···2724 createRandom = return Relu
28252926instance ( KnownNat i) => Layer Relu ('D1 i) ('D1 i) where
3030- runForwards _ (S1D' y) = S1D' (relu y)
2727+ runForwards _ (S1D y) = S1D (relu y)
3128 where
3229 relu = LAS.dvmap (\a -> if a <= 0 then 0 else a)
3333- runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (relu' y * dEdy))
3030+ runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (relu' y * dEdy))
3431 where
3532 relu' = LAS.dvmap (\a -> if a <= 0 then 0 else 1)
36333734instance (KnownNat i, KnownNat j) => Layer Relu ('D2 i j) ('D2 i j) where
3838- runForwards _ (S2D' y) = S2D' (relu y)
3535+ runForwards _ (S2D y) = S2D (relu y)
3936 where
4037 relu = LAS.dmmap (\a -> if a <= 0 then 0 else a)
4141- runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (relu' y * dEdy))
3838+ runBackwards _ (S2D y) (S2D dEdy) = ((), S2D (relu' y * dEdy))
4239 where
4340 relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
44414542instance (KnownNat i, KnownNat j, KnownNat k) => Layer Relu ('D3 i j k) ('D3 i j k) where
4646- runForwards _ (S3D' y) = S3D' (relu y)
4343+ runForwards _ (S3D y) = S3D (relu y)
4744 where
4845 relu = LAS.dmmap (\a -> if a <= 0 then 0 else a)
4949- runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (relu' y * dEdy))
4646+ runBackwards _ (S3D y) (S3D dEdy) = ((), S3D (relu' y * dEdy))
5047 where
5148 relu' = LAS.dmmap (\a -> if a <= 0 then 0 else 1)
+6-9
src/Grenade/Layers/Tanh.hs
···11{-# LANGUAGE DataKinds #-}
22-{-# LANGUAGE ScopedTypeVariables #-}
32{-# LANGUAGE TypeOperators #-}
43{-# LANGUAGE TypeFamilies #-}
54{-# LANGUAGE MultiParamTypeClasses #-}
66-{-# LANGUAGE FlexibleInstances #-}
77-85module Grenade.Layers.Tanh (
96 Tanh (..)
107 ) where
···2421 createRandom = return Tanh
25222623instance KnownNat i => Layer Tanh ('D1 i) ('D1 i) where
2727- runForwards _ (S1D' y) = S1D' (tanh y)
2828- runBackwards _ (S1D' y) (S1D' dEdy) = ((), S1D' (tanh' y * dEdy))
2424+ runForwards _ (S1D y) = S1D (tanh y)
2525+ runBackwards _ (S1D y) (S1D dEdy) = ((), S1D (tanh' y * dEdy))
29263027instance (KnownNat i, KnownNat j) => Layer Tanh ('D2 i j) ('D2 i j) where
3131- runForwards _ (S2D' y) = S2D' (tanh y)
3232- runBackwards _ (S2D' y) (S2D' dEdy) = ((), S2D' (tanh' y * dEdy))
2828+ runForwards _ (S2D y) = S2D (tanh y)
2929+ runBackwards _ (S2D y) (S2D dEdy) = ((), S2D (tanh' y * dEdy))
33303431instance (KnownNat i, KnownNat j, KnownNat k) => Layer Tanh ('D3 i j k) ('D3 i j k) where
3535- runForwards _ (S3D' y) = S3D' (tanh y)
3636- runBackwards _ (S3D' y) (S3D' dEdy) = ((), S3D' (tanh' y * dEdy))
3232+ runForwards _ (S3D y) = S3D (tanh y)
3333+ runBackwards _ (S3D y) (S3D dEdy) = ((), S3D (tanh' y * dEdy))
37343835tanh' :: (Floating a) => a -> a
3936tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
+9
src/Grenade/Recurrent.hs
···11+module Grenade.Recurrent (
22+ module X
33+ ) where
44+55+import Grenade.Recurrent.Core.Network as X
66+import Grenade.Recurrent.Core.Runner as X
77+import Grenade.Recurrent.Layers.BasicRecurrent as X
88+import Grenade.Recurrent.Layers.LSTM as X
99+import Grenade.Recurrent.Layers.Trivial as X
+98
src/Grenade/Recurrent/Core/Network.hs
···11+{-# LANGUAGE DataKinds #-}
22+{-# LANGUAGE GADTs #-}
33+{-# LANGUAGE TypeOperators #-}
44+{-# LANGUAGE TypeFamilies #-}
55+{-# LANGUAGE MultiParamTypeClasses #-}
66+{-# LANGUAGE FlexibleContexts #-}
77+{-# LANGUAGE FlexibleInstances #-}
88+{-# LANGUAGE EmptyDataDecls #-}
99+module Grenade.Recurrent.Core.Network (
1010+ Recurrent
1111+ , FeedForward
1212+ , RecurrentLayer (..)
1313+ , RecurrentUpdateLayer (..)
1414+ , RecurrentNetwork (..)
1515+ , RecurrentInputs (..)
1616+ , CreatableRecurrent (..)
1717+ ) where
1818+1919+2020+import Control.Monad.Random ( MonadRandom )
2121+import Data.Singletons ( SingI )
2222+2323+import Grenade.Core.Shape
2424+import Grenade.Core.Network
2525+2626+2727+-- | Witness type to say indicate we're building up with a normal feed
2828+-- forward layer.
2929+data FeedForward :: * -> *
3030+-- | Witness type to say indicate we're building up with a recurrent layer.
3131+data Recurrent :: * -> *
3232+3333+-- | Class for a recurrent layer.
3434+-- It's quite similar to a normal layer but for the input and output
3535+-- of an extra recurrent data shape.
3636+class UpdateLayer x => RecurrentUpdateLayer x where
3737+ -- | Shape of data that is passed between each subsequent run of the layer
3838+ type RecurrentShape x :: Shape
3939+4040+class (RecurrentUpdateLayer x, SingI (RecurrentShape x)) => RecurrentLayer x (i :: Shape) (o :: Shape) where
4141+ -- | Used in training and scoring. Take the input from the previous
4242+ -- layer, and give the output from this layer.
4343+ runRecurrentForwards :: x -> S (RecurrentShape x) -> S i -> (S (RecurrentShape x), S o)
4444+ -- | Back propagate a step. Takes the current layer, the input that the
4545+ -- layer gave from the input and the back propagated derivatives from
4646+ -- the layer above.
4747+ -- Returns the gradient layer and the derivatives to push back further.
4848+ runRecurrentBackwards :: x -> S (RecurrentShape x) -> S i -> S (RecurrentShape x) -> S o -> (Gradient x, S (RecurrentShape x), S i)
4949+5050+data RecurrentNetwork :: [*] -> [Shape] -> * where
5151+ OR :: (SingI i, SingI o, Layer x i o) => !x -> RecurrentNetwork '[FeedForward x] '[i, o]
5252+ (:~~>) :: (SingI i, Layer x i h) => !x -> !(RecurrentNetwork xs (h ': hs)) -> RecurrentNetwork (FeedForward x ': xs) (i ': h ': hs)
5353+ (:~@>) :: (SingI i, RecurrentLayer x i h) => !x -> !(RecurrentNetwork xs (h ': hs)) -> RecurrentNetwork (Recurrent x ': xs) (i ': h ': hs)
5454+infixr 5 :~~>
5555+infixr 5 :~@>
5656+5757+instance Show (RecurrentNetwork l h) where
5858+ show (OR a) = "OR " ++ show a
5959+ show (i :~~> o) = show i ++ "\n:~~>\n" ++ show o
6060+ show (i :~@> o) = show i ++ "\n:~@>\n" ++ show o
6161+6262+6363+-- | Recurrent inputs (sideways shapes on an imaginary unrolled graph)
6464+-- Parameterised on the layers of a Network.
6565+data RecurrentInputs :: [*] -> * where
6666+ ORS :: UpdateLayer x
6767+ => () -> RecurrentInputs '[FeedForward x]
6868+ (:~~+>) :: UpdateLayer x
6969+ => () -> !(RecurrentInputs xs) -> RecurrentInputs (FeedForward x ': xs)
7070+ (:~@+>) :: (SingI (RecurrentShape x), RecurrentUpdateLayer x)
7171+ => !(S (RecurrentShape x)) -> !(RecurrentInputs xs) -> RecurrentInputs (Recurrent x ': xs)
7272+infixr 5 :~~+>
7373+infixr 5 :~@+>
7474+7575+-- | A network can easily be created by hand with (:~~>) and (:~@>), but an easy way to initialise a random
7676+-- recurrent network and a set of random inputs for it is with the randomRecurrent.
7777+class CreatableRecurrent (xs :: [*]) (ss :: [Shape]) where
7878+ -- | Create a network of the types requested
7979+ randomRecurrent :: MonadRandom m => m (RecurrentNetwork xs ss, RecurrentInputs xs)
8080+8181+instance (SingI i, SingI o, Layer x i o) => CreatableRecurrent (FeedForward x ': '[]) (i ': o ': '[]) where
8282+ randomRecurrent = do
8383+ thisLayer <- createRandom
8484+ return (OR thisLayer, ORS ())
8585+8686+instance (SingI i, Layer x i o, CreatableRecurrent xs (o ': r ': rs)) => CreatableRecurrent (FeedForward x ': xs) (i ': o ': r ': rs) where
8787+ randomRecurrent = do
8888+ thisLayer <- createRandom
8989+ (rest, resti) <- randomRecurrent
9090+ return (thisLayer :~~> rest, () :~~+> resti)
9191+9292+instance (SingI i, RecurrentLayer x i o, CreatableRecurrent xs (o ': r ': rs)) => CreatableRecurrent (Recurrent x ': xs) (i ': o ': r ': rs) where
9393+ randomRecurrent = do
9494+ thisLayer <- createRandom
9595+ thisShape <- randomOfShape
9696+ (rest, resti) <- randomRecurrent
9797+ return (thisLayer :~@> rest, thisShape :~@+> resti)
9898+
+144
src/Grenade/Recurrent/Core/Runner.hs
···11+{-# LANGUAGE BangPatterns #-}
22+{-# LANGUAGE GADTs #-}
33+{-# LANGUAGE DataKinds #-}
44+{-# LANGUAGE ScopedTypeVariables #-}
55+{-# LANGUAGE TypeOperators #-}
66+{-# LANGUAGE TypeFamilies #-}
77+{-# LANGUAGE FlexibleContexts #-}
88+{-# LANGUAGE RankNTypes #-}
99+{-# LANGUAGE RecordWildCards #-}
1010+module Grenade.Recurrent.Core.Runner (
1111+ trainRecurrent
1212+ , runRecurrent
1313+ ) where
1414+1515+import Data.Singletons.Prelude
1616+import Grenade.Core.Network
1717+import Grenade.Core.Shape
1818+1919+import Grenade.Recurrent.Core.Network
2020+2121+-- | Drive and network and collect its back propogated gradients.
2222+trainRecurrent :: forall shapes layers. SingI (Last shapes)
2323+ => LearningParameters
2424+ -> RecurrentNetwork layers shapes
2525+ -> RecurrentInputs layers
2626+ -> [(S (Head shapes), Maybe (S (Last shapes)))]
2727+ -> (RecurrentNetwork layers shapes, RecurrentInputs layers)
2828+trainRecurrent rate network recinputs examples =
2929+ updateBack $ go inputs network recinputs
3030+ where
3131+ inputs = fst <$> examples
3232+ targets = snd <$> examples
3333+ updateBack (a,recgrad,_) = (a,updateRecInputs rate recinputs recgrad)
3434+3535+ go :: forall js sublayers. (Last js ~ Last shapes)
3636+ => [S (Head js)] -- ^ input vector
3737+ -> RecurrentNetwork sublayers js -- ^ network to train
3838+ -> RecurrentInputs sublayers
3939+ -> (RecurrentNetwork sublayers js, RecurrentInputs sublayers, [S (Head js)])
4040+4141+ -- This is a simple non-recurrent layer, just map it forwards
4242+ -- Note we're doing training here, we could just return a list of gradients
4343+ -- (and probably will in future).
4444+ go !xs (layer :~~> n) (() :~~+> nIn)
4545+ = let ys = runForwards layer <$> xs
4646+ -- recursively run the rest of the network, and get the gradients from above.
4747+ (newFN, ig, grads) = go ys n nIn
4848+ -- calculate the gradient for this layer to pass down,
4949+ back = uncurry (runBackwards layer) <$> zip (reverse xs) grads
5050+ -- the new trained layer.
5151+ newlayer = runUpdates rate layer (fst <$> back)
5252+5353+ in (newlayer :~~> newFN, () :~~+> ig, snd <$> back)
5454+5555+ -- This is a recurrent layer, so we need to do a scan, first input to last, providing
5656+ -- the recurrent shape output to the next layer.
5757+ go !xs (layer :~@> n) (g :~@+> nIn)
5858+ = let ys = scanlFrom layer g xs
5959+6060+ (newFN, ig, grads) = go (snd <$> ys) n nIn
6161+6262+ backExamples = zip3 (fst <$> reverse ys) (reverse xs) grads
6363+6464+ (rg, back) = myscanbackward layer backExamples
6565+ -- the new trained layer.
6666+ newlayer = runUpdates rate layer (fst <$> back)
6767+ in (newlayer :~@> newFN, rg :~@+> ig, snd <$> back)
6868+6969+ -- Handle the output layer, bouncing the derivatives back down.
7070+ -- We may not have a target for each example, so when we don't use 0 gradient.
7171+ go !xs (OR layer) (ORS ())
7272+ = let ys = runForwards layer <$> xs
7373+ -- recursively run the rest of the network, and get the gradients from above.
7474+ back = uncurry (runBackwards layer) <$> zip xs (zipWith makeError ys targets)
7575+ -- the new trained layer.
7676+ newlayer = runUpdates rate layer (reverse $ fst <$> back)
7777+ in (OR newlayer, ORS (), reverse (snd <$> back))
7878+7979+ go _ _ _ =
8080+ error "Impossible for network and recurrent inputs to have different shapes"
8181+8282+8383+ makeError :: S (Last shapes) -> Maybe (S (Last shapes)) -> S (Last shapes)
8484+ makeError _ Nothing = 0
8585+ makeError y (Just t) = y - t
8686+8787+ updateRecInputs :: forall sublayers.
8888+ LearningParameters
8989+ -> RecurrentInputs sublayers
9090+ -> RecurrentInputs sublayers
9191+ -> RecurrentInputs sublayers
9292+9393+ updateRecInputs l@LearningParameters {..} (() :~~+> xs) (() :~~+> ys)
9494+ = () :~~+> updateRecInputs l xs ys
9595+9696+ updateRecInputs l@LearningParameters {..} (x :~@+> xs) (y :~@+> ys)
9797+ = (x - realToFrac learningRate * y) :~@+> updateRecInputs l xs ys
9898+9999+ updateRecInputs _ (ORS ()) (ORS ())
100100+ = ORS ()
101101+ updateRecInputs _ _ _
102102+ = error "Impossible for updateRecInputs to have different shapes"
103103+104104+scanlFrom :: forall x i o. RecurrentLayer x i o
105105+ => x -- ^ the layer
106106+ -> S (RecurrentShape x) -- ^ place to start
107107+ -> [S i] -- ^ list of inputs to scan through
108108+ -> [(S (RecurrentShape x), S o)] -- ^ list of scan inputs and outputs
109109+scanlFrom !layer !recShape (x:xs) =
110110+ let (lerec, lepush) = runRecurrentForwards layer recShape x
111111+ in (recShape, lepush) : scanlFrom layer lerec xs
112112+scanlFrom _ _ [] = []
113113+114114+myscanbackward :: forall x i o. RecurrentLayer x i o
115115+ => x -- ^ the layer
116116+ -> [(S (RecurrentShape x), S i, S o)] -- ^ the list of inputs and output to scan over
117117+ -> (S (RecurrentShape x), [(Gradient x, S i)]) -- ^ list of gradients to fold and inputs to backprop
118118+myscanbackward layer =
119119+ goX 0
120120+ where
121121+ goX :: S (RecurrentShape x) -> [(S (RecurrentShape x), S i, S o)] -> (S (RecurrentShape x), [(Gradient x, S i)])
122122+ goX !lastback ((recShape, lastin, backgrad):xs) =
123123+ let (layergrad, recgrad, ingrad) = runRecurrentBackwards layer recShape lastin lastback backgrad
124124+ (pushedback, ll) = goX recgrad xs
125125+ in (pushedback, (layergrad, ingrad) : ll)
126126+ goX !lastback [] = (lastback, [])
127127+128128+-- | Just forwards propagation with no training.
129129+runRecurrent :: RecurrentNetwork layers shapes
130130+ -> RecurrentInputs layers -> S (Head shapes)
131131+ -> (RecurrentInputs layers, S (Last shapes))
132132+runRecurrent (layer :~~> n) (() :~~+> nr) !x
133133+ = let ys = runForwards layer x
134134+ (nr', o) = runRecurrent n nr ys
135135+ in (() :~~+> nr', o)
136136+runRecurrent (layer :~@> n) (recin :~@+> nr) !x
137137+ = let (recin', y) = runRecurrentForwards layer recin x
138138+ (nr', o) = runRecurrent n nr y
139139+ in (recin' :~@+> nr', o)
140140+runRecurrent (OR layer) (ORS ()) !x
141141+ = (ORS (), runForwards layer x)
142142+143143+runRecurrent _ _ _
144144+ = error "Impossible for the gradients of a network to have a different length or shape to the network"
+92
src/Grenade/Recurrent/Layers/BasicRecurrent.hs
···11+{-# LANGUAGE DataKinds #-}
22+{-# LANGUAGE GADTs #-}
33+{-# LANGUAGE RecordWildCards #-}
44+{-# LANGUAGE TypeOperators #-}
55+{-# LANGUAGE TypeFamilies #-}
66+{-# LANGUAGE MultiParamTypeClasses #-}
77+{-# LANGUAGE FlexibleContexts #-}
88+{-# LANGUAGE UndecidableInstances #-}
99+module Grenade.Recurrent.Layers.BasicRecurrent (
1010+ BasicRecurrent (..)
1111+ , randomBasicRecurrent
1212+ ) where
1313+1414+import Control.Monad.Random ( MonadRandom, getRandom )
1515+1616+import Data.Singletons.TypeLits
1717+1818+import Numeric.LinearAlgebra.Static
1919+2020+import GHC.TypeLits
2121+2222+import Grenade.Core.Network
2323+import Grenade.Core.Shape
2424+import Grenade.Recurrent.Core.Network
2525+2626+data BasicRecurrent :: Nat -- Input layer size
2727+ -> Nat -- Output layer size
2828+ -> * where
2929+ BasicRecurrent :: ( KnownNat input
3030+ , KnownNat output
3131+ , KnownNat matrixCols
3232+ , matrixCols ~ (input + output))
3333+ => !(R output) -- Bias neuron weights
3434+ -> !(R output) -- Bias neuron momentum
3535+ -> !(L output matrixCols) -- Activation
3636+ -> !(L output matrixCols) -- Momentum
3737+ -> BasicRecurrent input output
3838+3939+data BasicRecurrent' :: Nat -- Input layer size
4040+ -> Nat -- Output layer size
4141+ -> * where
4242+ BasicRecurrent' :: ( KnownNat input
4343+ , KnownNat output
4444+ , KnownNat matrixCols
4545+ , matrixCols ~ (input + output))
4646+ => !(R output) -- Bias neuron gradients
4747+ -> !(L output matrixCols)
4848+ -> BasicRecurrent' input output
4949+5050+instance Show (BasicRecurrent i o) where
5151+ show BasicRecurrent {} = "BasicRecurrent"
5252+5353+instance (KnownNat i, KnownNat o, KnownNat (i + o)) => UpdateLayer (BasicRecurrent i o) where
5454+ type Gradient (BasicRecurrent i o) = (BasicRecurrent' i o)
5555+5656+ runUpdate LearningParameters {..} (BasicRecurrent oldBias oldBiasMomentum oldActivations oldMomentum) (BasicRecurrent' biasGradient activationGradient) =
5757+ let newBiasMomentum = konst learningMomentum * oldBiasMomentum - konst learningRate * biasGradient
5858+ newBias = oldBias + newBiasMomentum
5959+ newMomentum = konst learningMomentum * oldMomentum - konst learningRate * activationGradient
6060+ regulariser = konst (learningRegulariser * learningRate) * oldActivations
6161+ newActivations = oldActivations + newMomentum - regulariser
6262+ in BasicRecurrent newBias newBiasMomentum newActivations newMomentum
6363+6464+ createRandom = randomBasicRecurrent
6565+6666+instance (KnownNat i, KnownNat o, KnownNat (i + o), i <= (i + o), o ~ ((i + o) - i)) => RecurrentUpdateLayer (BasicRecurrent i o) where
6767+ type RecurrentShape (BasicRecurrent i o) = 'D1 o
6868+6969+instance (KnownNat i, KnownNat o, KnownNat (i + o), i <= (i + o), o ~ ((i + o) - i)) => RecurrentLayer (BasicRecurrent i o) ('D1 i) ('D1 o) where
7070+ -- Do a matrix vector multiplication and return the result.
7171+ runRecurrentForwards (BasicRecurrent wB _ wN _) (S1D lastOutput) (S1D thisInput) =
7272+ let thisOutput = S1D $ wB + wN #> (thisInput # lastOutput)
7373+ in (thisOutput, thisOutput)
7474+7575+ -- Run a backpropogation step for a full connected layer.
7676+ runRecurrentBackwards (BasicRecurrent _ _ wN _) (S1D lastOutput) (S1D thisInput) (S1D dRec) (S1D dEdy) =
7777+ let biasGradient = (dRec + dEdy)
7878+ layerGrad = (dRec + dEdy) `outer` (thisInput # lastOutput)
7979+ -- calcluate derivatives for next step
8080+ (backGrad, recGrad) = split $ tr wN #> (dRec + dEdy)
8181+ in (BasicRecurrent' biasGradient layerGrad, S1D recGrad, S1D backGrad)
8282+8383+randomBasicRecurrent :: (MonadRandom m, KnownNat i, KnownNat o, KnownNat x, x ~ (i + o))
8484+ => m (BasicRecurrent i o)
8585+randomBasicRecurrent = do
8686+ seed1 <- getRandom
8787+ seed2 <- getRandom
8888+ let wB = randomVector seed1 Uniform * 2 - 1
8989+ wN = uniformSample seed2 (-1) 1
9090+ bm = konst 0
9191+ mm = konst 0
9292+ return $ BasicRecurrent wB bm wN mm
+244
src/Grenade/Recurrent/Layers/LSTM.hs
···11+{-# LANGUAGE BangPatterns #-}
22+{-# LANGUAGE DataKinds #-}
33+{-# LANGUAGE GADTs #-}
44+{-# LANGUAGE RankNTypes #-}
55+{-# LANGUAGE RecordWildCards #-}
66+{-# LANGUAGE TypeOperators #-}
77+{-# LANGUAGE TypeFamilies #-}
88+{-# LANGUAGE MultiParamTypeClasses #-}
99+{-# LANGUAGE FlexibleContexts #-}
1010+{-# LANGUAGE ViewPatterns #-}
1111+module Grenade.Recurrent.Layers.LSTM (
1212+ LSTM (..)
1313+ , LSTMWeights (..)
1414+ , randomLSTM
1515+ ) where
1616+1717+import Control.Monad.Random ( MonadRandom, getRandom )
1818+1919+-- import Data.List ( foldl1' )
2020+import Data.Singletons.TypeLits
2121+2222+import Numeric.LinearAlgebra.Static
2323+2424+import Grenade.Core.Network
2525+import Grenade.Core.Shape
2626+2727+import Grenade.Layers.Internal.Update
2828+2929+import Grenade.Recurrent.Core.Network
3030+3131+-- | Long Short Term Memory Recurrent unit
3232+--
3333+-- This is a Peephole formulation, so the recurrent shape is
3434+-- just the cell state, the previous output is not held or used
3535+-- at all.
3636+data LSTM :: Nat -> Nat -> * where
3737+ LSTM :: ( KnownNat input
3838+ , KnownNat output
3939+ ) => !(LSTMWeights input output) -- Weights
4040+ -> !(LSTMWeights input output) -- Momentums
4141+ -> LSTM input output
4242+4343+data LSTMWeights :: Nat -> Nat -> * where
4444+ LSTMWeights :: ( KnownNat input
4545+ , KnownNat output
4646+ ) => {
4747+ lstmWf :: !(L output input) -- Weight Forget (W_f)
4848+ , lstmUf :: !(L output output) -- Cell State Forget (U_f)
4949+ , lstmBf :: !(R output) -- Bias Forget (b_f)
5050+ , lstmWi :: !(L output input) -- Weight Input (W_i)
5151+ , lstmUi :: !(L output output) -- Cell State Input (U_i)
5252+ , lstmBi :: !(R output) -- Bias Input (b_i)
5353+ , lstmWo :: !(L output input) -- Weight Output (W_o)
5454+ , lstmUo :: !(L output output) -- Cell State Output (U_o)
5555+ , lstmBo :: !(R output) -- Bias Output (b_o)
5656+ , lstmWc :: !(L output input) -- Weight Cell (W_c)
5757+ , lstmBc :: !(R output) -- Bias Cell (b_c)
5858+ } -> LSTMWeights input output
5959+6060+instance Show (LSTM i o) where
6161+ show LSTM {} = "LSTM"
6262+6363+instance (KnownNat i, KnownNat o) => UpdateLayer (LSTM i o) where
6464+ -- The gradients are the same shape as the weights and momentum
6565+ -- This seems to be a general pattern, maybe it should be enforced.
6666+ type Gradient (LSTM i o) = (LSTMWeights i o)
6767+6868+ -- Run the update function for each group matrix/vector of weights, momentums and gradients.
6969+ -- Hmm, maybe the function should be used instead of passing in the learning parameters.
7070+ runUpdate LearningParameters {..} (LSTM w m) g =
7171+ let (wf, wf') = u lstmWf w m g
7272+ (uf, uf') = u lstmUf w m g
7373+ (bf, bf') = v lstmBf w m g
7474+ (wi, wi') = u lstmWi w m g
7575+ (ui, ui') = u lstmUi w m g
7676+ (bi, bi') = v lstmBi w m g
7777+ (wo, wo') = u lstmWo w m g
7878+ (uo, uo') = u lstmUo w m g
7979+ (bo, bo') = v lstmBo w m g
8080+ (wc, wc') = u lstmWc w m g
8181+ (bc, bc') = v lstmBc w m g
8282+ in LSTM (LSTMWeights wf uf bf wi ui bi wo uo bo wc bc) (LSTMWeights wf' uf' bf' wi' ui' bi' wo' uo' bo' wc' bc')
8383+ where
8484+ -- Utility function for updating with the momentum, gradients, and weights.
8585+ u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> x -> ((L out ix), (L out ix))
8686+ u e (e -> weights) (e -> momentum) (e -> gradient) =
8787+ decendMatrix learningRate learningMomentum learningRegulariser weights gradient momentum
8888+8989+ v :: forall x ix. (KnownNat ix) => (x -> (R ix)) -> x -> x -> x -> ((R ix), (R ix))
9090+ v e (e -> weights) (e -> momentum) (e -> gradient) =
9191+ decendVector learningRate learningMomentum learningRegulariser weights gradient momentum
9292+9393+ -- There's a lot of updates here, so to try and minimise the number of data copies
9494+ -- we'll create a mutable bucket for each.
9595+ -- runUpdates rate lstm gs =
9696+ -- let combinedGradient = foldl1' uu gs
9797+ -- in runUpdate rate lstm combinedGradient
9898+ -- where
9999+ -- uu :: (KnownNat i, KnownNat o) => LSTMWeights i o -> LSTMWeights i o -> LSTMWeights i o
100100+ -- uu a b =
101101+ -- let wf = u lstmWf a b
102102+ -- uf = u lstmUf a b
103103+ -- bf = v lstmBf a b
104104+ -- wi = u lstmWi a b
105105+ -- ui = u lstmUi a b
106106+ -- bi = v lstmBi a b
107107+ -- wo = u lstmWo a b
108108+ -- uo = u lstmUo a b
109109+ -- bo = v lstmBo a b
110110+ -- wc = u lstmWc a b
111111+ -- bc = v lstmBc a b
112112+ -- in LSTMWeights wf uf bf wi ui bi wo uo bo wc bc
113113+ -- u :: forall x ix out. (KnownNat ix, KnownNat out) => (x -> (L out ix)) -> x -> x -> L out ix
114114+ -- u e (e -> a) (e -> b) = tr $ tr a + tr b
115115+116116+ -- v :: forall x ix. (x -> (R ix)) -> x -> x -> R ix
117117+ -- v e (e -> a) (e -> b) = a + b
118118+119119+ createRandom = randomLSTM
120120+121121+instance (KnownNat i, KnownNat o) => RecurrentUpdateLayer (LSTM i o) where
122122+ -- The recurrent shape is the same size as the output.
123123+ -- It's actually the cell state however, as this is a peephole variety LSTM.
124124+ type RecurrentShape (LSTM i o) = 'D1 o
125125+126126+instance (KnownNat i, KnownNat o) => RecurrentLayer (LSTM i o) ('D1 i) ('D1 o) where
127127+ -- Forward propagation for the LSTM layer.
128128+ -- The size of the cell state is also the size of the output.
129129+ runRecurrentForwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) =
130130+ let -- Forget state vector
131131+ f_t = sigmoid $ lstmBf + lstmWf #> input + lstmUf #> cell
132132+ -- Input state vector
133133+ i_t = sigmoid $ lstmBi + lstmWi #> input + lstmUi #> cell
134134+ -- Output state vector
135135+ o_t = sigmoid $ lstmBo + lstmWo #> input + lstmUo #> cell
136136+ -- Cell input state vector
137137+ c_x = tanh $ lstmBc + lstmWc #> input
138138+ -- Cell state
139139+ c_t = f_t * cell + i_t * c_x
140140+ -- Output (it's sometimes recommended to use tanh c_t)
141141+ h_t = o_t * c_t
142142+ in (S1D c_t, S1D h_t)
143143+144144+ -- Run a backpropogation step for an LSTM layer.
145145+ -- We're doing all the derivatives by hand here, so one should
146146+ -- be extra careful when changing this.
147147+ --
148148+ -- There's a test version using the AD library without hmatrix in the test
149149+ -- suite. These should match always.
150150+ runRecurrentBackwards (LSTM (LSTMWeights {..}) _) (S1D cell) (S1D input) (S1D cellGrad) (S1D h_t') =
151151+ -- We're not keeping the Wengert tape during the forward pass,
152152+ -- so we're duplicating some work here.
153153+ --
154154+ -- If I was being generous, I'd call it checkpointing.
155155+ --
156156+ -- Maybe think about better ways to store some intermediate states.
157157+ let -- Forget state vector
158158+ f_s = lstmBf + lstmWf #> input + lstmUf #> cell
159159+ f_t = sigmoid f_s
160160+ -- Input state vector
161161+ i_s = lstmBi + lstmWi #> input + lstmUi #> cell
162162+ i_t = sigmoid i_s
163163+ -- Output state vector
164164+ o_s = lstmBo + lstmWo #> input + lstmUo #> cell
165165+ o_t = sigmoid o_s
166166+ -- Cell input state vector
167167+ c_s = lstmBc + lstmWc #> input
168168+ c_x = tanh c_s
169169+ -- Cell state
170170+ c_t = f_t * cell + i_t * c_x
171171+172172+ -- Reverse Mode AD Derivitives
173173+ c_t' = h_t' * o_t + cellGrad
174174+175175+ f_t' = c_t' * cell
176176+ f_s' = sigmoid' f_s * f_t'
177177+178178+ o_t' = h_t' * c_t
179179+ o_s' = sigmoid' o_s * o_t'
180180+181181+ i_t' = c_t' * c_x
182182+ i_s' = sigmoid' i_s * i_t'
183183+184184+ c_x' = c_t' * i_t
185185+ c_s' = tanh' c_s * c_x'
186186+187187+ -- The derivatives to pass sideways (recurrent) and downwards
188188+ cell' = tr lstmUf #> f_s' + tr lstmUo #> o_s' + tr lstmUi #> i_s' + c_t' * f_t
189189+ input' = tr lstmWf #> f_s' + tr lstmWo #> o_s' + tr lstmWi #> i_s' + tr lstmWc #> c_s'
190190+191191+ -- Calculate the gradient Matricies for the input
192192+ lstmWf' = f_s' `outer` input
193193+ lstmWi' = i_s' `outer` input
194194+ lstmWo' = o_s' `outer` input
195195+ lstmWc' = c_s' `outer` input
196196+197197+ -- Calculate the gradient Matricies for the cell
198198+ lstmUf' = f_s' `outer` cell
199199+ lstmUi' = i_s' `outer` cell
200200+ lstmUo' = o_s' `outer` cell
201201+202202+ -- The biases just get the values, but we'll write it so it's obvious
203203+ lstmBf' = f_s'
204204+ lstmBi' = i_s'
205205+ lstmBo' = o_s'
206206+ lstmBc' = c_s'
207207+208208+ gradients = LSTMWeights lstmWf' lstmUf' lstmBf' lstmWi' lstmUi' lstmBi' lstmWo' lstmUo' lstmBo' lstmWc' lstmBc'
209209+ in (gradients, S1D cell', S1D input')
210210+211211+-- | Generate an LSTM layer with random Weights
212212+-- one can also just call createRandom from UpdateLayer
213213+--
214214+-- Has forget gate biases set to 1 to encourage early learning.
215215+--
216216+-- https://github.com/karpathy/char-rnn/commit/0dfeaa454e687dd0278f036552ea1e48a0a408c9
217217+--
218218+randomLSTM :: forall m i o. (MonadRandom m, KnownNat i, KnownNat o)
219219+ => m (LSTM i o)
220220+randomLSTM = do
221221+ let w = (\s -> uniformSample s (-1) 1 ) <$> getRandom
222222+ u = (\s -> uniformSample s (-1) 1 ) <$> getRandom
223223+ v = (\s -> randomVector s Uniform * 2 - 1) <$> getRandom
224224+225225+ w0 = konst 0
226226+ u0 = konst 0
227227+ v0 = konst 0
228228+229229+ LSTM <$> (LSTMWeights <$> w <*> u <*> pure (konst 1) <*> w <*> u <*> v <*> w <*> u <*> v <*> w <*> v)
230230+ <*> pure (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
231231+232232+-- | Maths
233233+--
234234+-- TODO: move to not here
235235+sigmoid :: Floating a => a -> a
236236+sigmoid x = 1 / (1 + exp (-x))
237237+238238+sigmoid' :: Floating a => a -> a
239239+sigmoid' x = logix * (1 - logix)
240240+ where
241241+ logix = sigmoid x
242242+243243+tanh' :: (Floating a) => a -> a
244244+tanh' t = 1 - s ^ (2 :: Int) where s = tanh t
+23
src/Grenade/Recurrent/Layers/Trivial.hs
···11+{-# LANGUAGE DataKinds #-}
22+{-# LANGUAGE TypeOperators #-}
33+{-# LANGUAGE TypeFamilies #-}
44+{-# LANGUAGE MultiParamTypeClasses #-}
55+{-# LANGUAGE FlexibleInstances #-}
66+module Grenade.Recurrent.Layers.Trivial (
77+ Trivial (..)
88+ ) where
99+1010+import Grenade.Core.Network
1111+1212+-- | A trivial layer.
1313+data Trivial = Trivial
1414+ deriving Show
1515+1616+instance UpdateLayer Trivial where
1717+ type Gradient Trivial = ()
1818+ runUpdate _ _ _ = Trivial
1919+ createRandom = return Trivial
2020+2121+instance (a ~ b) => Layer Trivial a b where
2222+ runForwards _ = id
2323+ runBackwards _ _ y = ((), y)
+84
src/Grenade/Utils/OneHot.hs
···11+{-# LANGUAGE DataKinds #-}
22+{-# LANGUAGE GADTs #-}
33+{-# LANGUAGE TypeFamilies #-}
44+{-# LANGUAGE TypeOperators #-}
55+{-# LANGUAGE FlexibleContexts #-}
66+{-# LANGUAGE ScopedTypeVariables #-}
77+{-# LANGUAGE RankNTypes #-}
88+99+module Grenade.Utils.OneHot (
1010+ oneHot
1111+ , hotMap
1212+ , makeHot
1313+ , unHot
1414+ ) where
1515+1616+import Data.List ( group, sort )
1717+1818+import Data.Map ( Map )
1919+import qualified Data.Map as M
2020+2121+import Data.Proxy
2222+import Data.Singletons.TypeLits
2323+2424+import Data.Vector ( Vector )
2525+import qualified Data.Vector as V
2626+2727+import Numeric.LinearAlgebra ( maxIndex )
2828+import Numeric.LinearAlgebra.Devel
2929+import Numeric.LinearAlgebra.Static
3030+3131+import Grenade.Core.Shape
3232+3333+-- | From an int which is hot, create a 1D Shape
3434+-- with one index hot (1) with the rest 0.
3535+-- Rerurns Nothing if the hot number is larger
3636+-- than the length of the vector.
3737+oneHot :: forall n. (KnownNat n)
3838+ => Int -> Maybe (S ('D1 n))
3939+oneHot hot =
4040+ let len = fromIntegral $ natVal (Proxy :: Proxy n)
4141+ in if hot < len
4242+ then
4343+ fmap S1D . create $ runSTVector $ do
4444+ vec <- newVector 0 len
4545+ writeVector vec hot 1
4646+ return vec
4747+ else Nothing
4848+4949+-- | Create a one hot map from any enumerable.
5050+-- Returns a map, and the ordered list for the reverse transformation
5151+hotMap :: (Ord a, KnownNat n) => Proxy n -> [a] -> Maybe (Map a Int, Vector a)
5252+hotMap n as =
5353+ let len = fromIntegral $ natVal n
5454+ uniq = [ c | (c:_) <- group $ sort as]
5555+ hotl = length uniq
5656+ in if hotl <= len
5757+ then
5858+ Just (M.fromList $ zip uniq [0..], V.fromList uniq)
5959+ else Nothing
6060+6161+-- | From a map and value, create a 1D Shape
6262+-- with one index hot (1) with the rest 0.
6363+-- Rerurns Nothing if the hot number is larger
6464+-- than the length of the vector or the map
6565+-- doesn't contain the value.
6666+makeHot :: forall a n. (Ord a, KnownNat n)
6767+ => Map a Int -> a -> Maybe (S ('D1 n))
6868+makeHot m x = do
6969+ hot <- M.lookup x m
7070+ let len = fromIntegral $ natVal (Proxy :: Proxy n)
7171+ if hot < len
7272+ then
7373+ fmap S1D . create $ runSTVector $ do
7474+ vec <- newVector 0 len
7575+ writeVector vec hot 1
7676+ return vec
7777+ else Nothing
7878+7979+unHot :: forall a n. (KnownNat n)
8080+ => Vector a -> (S ('D1 n)) -> Maybe a
8181+unHot v (S1D xs)
8282+ = (V.!?) v
8383+ $ maxIndex (extract xs)
8484+
+20-10
test/Test/Grenade/Layers/Convolution.hs
···11-{-# LANGUAGE TemplateHaskell #-}
11+{-# LANGUAGE TemplateHaskell #-}
22{-# LANGUAGE DataKinds #-}
33-{-# LANGUAGE GADTs #-}
44-{-# LANGUAGE ScopedTypeVariables #-}
55-{-# LANGUAGE KindSignatures #-}
66-{-# LANGUAGE ConstraintKinds #-}
33+{-# LANGUAGE GADTs #-}
44+{-# LANGUAGE ScopedTypeVariables #-}
55+{-# LANGUAGE KindSignatures #-}
66+{-# LANGUAGE ConstraintKinds #-}
77{-# LANGUAGE TypeOperators #-}
88-98{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
109module Test.Grenade.Layers.Convolution where
1110···3029instance Show OpaqueConvolution where
3130 show (OpaqueConvolution n) = show n
32313232+genConvolution :: ( KnownNat channels
3333+ , KnownNat filters
3434+ , KnownNat kernelRows
3535+ , KnownNat kernelColumns
3636+ , KnownNat strideRows
3737+ , KnownNat strideColumns
3838+ , KnownNat kernelFlattened
3939+ , kernelFlattened ~ (kernelRows * kernelColumns * channels)
4040+ ) => Jack (Convolution channels filters kernelRows kernelColumns strideRows strideColumns)
4141+genConvolution = Convolution <$> uniformSample <*> uniformSample
4242+3343genOpaqueOpaqueConvolution :: Jack OpaqueConvolution
3444genOpaqueOpaqueConvolution = do
3545 Just channels <- someNatVal <$> choose (1, 10)
···4656 p2 = natDict pkc
4757 p3 = natDict pch
4858 in case p1 %* p2 %* p3 of
4949- Dict -> OpaqueConvolution <$> (Convolution <$> uniformSample <*> uniformSample :: Jack (Convolution ch fl kr kc sr sc))
5959+ Dict -> OpaqueConvolution <$> (genConvolution :: Jack (Convolution ch fl kr kc sr sc))
50605161prop_conv_net_witness =
5262 gamble genOpaqueOpaqueConvolution $ \onet ->
···8090 , (unsafeCoerce (Dict :: Dict ()) :: Dict (((outRows - 1) * strideRows) ~ (inRows - kernelRows)))
8191 , (unsafeCoerce (Dict :: Dict ()) :: Dict (((outCols - 1) * strideCols) ~ (inCols - kernelCols)))) of
8292 (Dict, Dict, Dict, Dict) ->
8383- gamble (S3D' <$> uniformSample) $ \(input :: S' ('D3 inRows inCols channels)) ->
8484- let output :: S' ('D3 outRows outCols filters) = runForwards convLayer input
8585- backed :: (Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols), S' ('D3 inRows inCols channels))
9393+ gamble (S3D <$> uniformSample) $ \(input :: S ('D3 inRows inCols channels)) ->
9494+ let output :: S ('D3 outRows outCols filters) = runForwards convLayer input
9595+ backed :: (Gradient (Convolution channels filters kernelRows kernelCols strideRows strideCols), S ('D3 inRows inCols channels))
8696 = runBackwards convLayer input output
8797 in backed `seq` True
8898 ) :: Property
+4-4
test/Test/Grenade/Layers/FullyConnected.hs
···4444prop_fully_connected_forwards :: Property
4545prop_fully_connected_forwards =
4646 gamble genOpaqueFullyConnected $ \(OpaqueFullyConnected (fclayer :: FullyConnected i o)) ->
4747- gamble (S1D' <$> randomVector) $ \(input :: S' ('D1 i)) ->
4848- let output :: S' ('D1 o) = runForwards fclayer input
4949- backed :: (Gradient (FullyConnected i o), S' ('D1 i))
5050- = runBackwards fclayer input output
4747+ gamble (S1D <$> randomVector) $ \(input :: S ('D1 i)) ->
4848+ let output :: S ('D1 o) = runForwards fclayer input
4949+ backed :: (Gradient (FullyConnected i o), S ('D1 i))
5050+ = runBackwards fclayer input output
5151 in backed `seq` True
52525353return []
+5-5
test/Test/Grenade/Layers/Pooling.hs
···11-{-# LANGUAGE TemplateHaskell #-}
22-{-# LANGUAGE DataKinds #-}
33-{-# LANGUAGE KindSignatures #-}
44-{-# LANGUAGE GADTs #-}
55-{-# LANGUAGE ScopedTypeVariables #-}
11+{-# LANGUAGE TemplateHaskell #-}
22+{-# LANGUAGE DataKinds #-}
33+{-# LANGUAGE KindSignatures #-}
44+{-# LANGUAGE GADTs #-}
55+{-# LANGUAGE ScopedTypeVariables #-}
66{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
77module Test.Grenade.Layers.Pooling where
88
+101
test/Test/Grenade/Recurrent/Layers/LSTM.hs
···11+{-# LANGUAGE TemplateHaskell #-}
22+{-# LANGUAGE DataKinds #-}
33+{-# LANGUAGE GADTs #-}
44+{-# LANGUAGE ScopedTypeVariables #-}
55+{-# LANGUAGE ConstraintKinds #-}
66+{-# LANGUAGE TypeOperators #-}
77+{-# LANGUAGE FlexibleContexts #-}
88+{-# LANGUAGE RankNTypes #-}
99+1010+{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
1111+module Test.Grenade.Recurrent.Layers.LSTM where
1212+1313+import Disorder.Jack
1414+1515+import Data.Foldable ( toList )
1616+import Data.Singletons.TypeLits
1717+1818+import Grenade
1919+import Grenade.Recurrent
2020+2121+import qualified Numeric.LinearAlgebra as H
2222+import qualified Numeric.LinearAlgebra.Static as S
2323+2424+2525+import qualified Test.Grenade.Recurrent.Layers.LSTM.Reference as Reference
2626+import Test.Jack.Hmatrix
2727+2828+genLSTM :: forall i o. (KnownNat i, KnownNat o) => Jack (LSTM i o)
2929+genLSTM = do
3030+ let w = uniformSample
3131+ u = uniformSample
3232+ v = randomVector
3333+3434+ w0 = S.konst 0
3535+ u0 = S.konst 0
3636+ v0 = S.konst 0
3737+3838+ LSTM <$> (LSTMWeights <$> w <*> u <*> v <*> w <*> u <*> v <*> w <*> u <*> v <*> w <*> v)
3939+ <*> pure (LSTMWeights w0 u0 v0 w0 u0 v0 w0 u0 v0 w0 v0)
4040+4141+prop_lstm_reference_forwards =
4242+ gamble randomVector $ \(input :: S.R 3) ->
4343+ gamble randomVector $ \(cell :: S.R 2) ->
4444+ gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) ->
4545+ let actual = runRecurrentForwards net (S1D cell) (S1D input)
4646+ in case actual of
4747+ ((S1D cellOut) :: S ('D1 2), (S1D output) :: S ('D1 2)) ->
4848+ let cellOut' = Reference.Vector . H.toList . S.extract $ cellOut
4949+ output' = Reference.Vector . H.toList . S.extract $ output
5050+ refNet = Reference.lstmToReference lstmWeights
5151+ refCell = Reference.Vector . H.toList . S.extract $ cell
5252+ refInput = Reference.Vector . H.toList . S.extract $ input
5353+ (refCO, refO) = Reference.runLSTM refNet refCell refInput
5454+ in toList refCO ~~~ toList cellOut' .&&. toList refO ~~~ toList output'
5555+5656+prop_lstm_reference_backwards =
5757+ gamble randomVector $ \(input :: S.R 3) ->
5858+ gamble randomVector $ \(cell :: S.R 2) ->
5959+ gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) ->
6060+ let actualBacks = runRecurrentBackwards net (S1D cell) (S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
6161+ in case actualBacks of
6262+ (actualGradients, _, _) ->
6363+ let refNet = Reference.lstmToReference lstmWeights
6464+ refCell = Reference.Vector . H.toList . S.extract $ cell
6565+ refInput = Reference.Vector . H.toList . S.extract $ input
6666+ refGradients = Reference.runLSTMback refCell refInput refNet
6767+ in toList refGradients ~~~ toList (Reference.lstmToReference actualGradients)
6868+6969+prop_lstm_reference_backwards_input =
7070+ gamble randomVector $ \(input :: S.R 3) ->
7171+ gamble randomVector $ \(cell :: S.R 2) ->
7272+ gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) ->
7373+ let actualBacks = runRecurrentBackwards net (S1D cell) (S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
7474+ in case actualBacks of
7575+ (_, _, S1D actualGradients) ->
7676+ let refNet = Reference.lstmToReference lstmWeights
7777+ refCell = Reference.Vector . H.toList . S.extract $ cell
7878+ refInput = Reference.Vector . H.toList . S.extract $ input
7979+ refGradients = Reference.runLSTMbackOnInput refCell refNet refInput
8080+ in toList refGradients ~~~ H.toList (S.extract actualGradients)
8181+8282+prop_lstm_reference_backwards_cell =
8383+ gamble randomVector $ \(input :: S.R 3) ->
8484+ gamble randomVector $ \(cell :: S.R 2) ->
8585+ gamble genLSTM $ \(net@(LSTM lstmWeights _) :: LSTM 3 2) ->
8686+ let actualBacks = runRecurrentBackwards net (S1D cell) (S1D input) (S1D (S.konst 1) :: S ('D1 2)) (S1D (S.konst 1) :: S ('D1 2))
8787+ in case actualBacks of
8888+ (_, S1D actualGradients, _) ->
8989+ let refNet = Reference.lstmToReference lstmWeights
9090+ refCell = Reference.Vector . H.toList . S.extract $ cell
9191+ refInput = Reference.Vector . H.toList . S.extract $ input
9292+ refGradients = Reference.runLSTMbackOnCell refInput refNet refCell
9393+ in toList refGradients ~~~ (H.toList . S.extract $ actualGradients)
9494+9595+9696+(~~~) as bs = all (< 1e-8) (zipWith (-) as bs)
9797+infix 4 ~~~
9898+9999+return []
100100+tests :: IO Bool
101101+tests = $quickCheckAll
···11+{-# LANGUAGE DataKinds #-}
22+{-# LANGUAGE GADTs #-}
33+{-# LANGUAGE ScopedTypeVariables #-}
44+{-# LANGUAGE ConstraintKinds #-}
55+{-# LANGUAGE TypeOperators #-}
66+{-# LANGUAGE DeriveFunctor #-}
77+{-# LANGUAGE DeriveFoldable #-}
88+{-# LANGUAGE DeriveTraversable #-}
99+{-# LANGUAGE RecordWildCards #-}
1010+{-# LANGUAGE FlexibleContexts #-}
1111+{-# LANGUAGE RankNTypes #-}
1212+1313+{-# OPTIONS_GHC -fno-warn-missing-signatures #-}
1414+module Test.Grenade.Recurrent.Layers.LSTM.Reference where
1515+1616+import Data.Reflection
1717+import Numeric.AD.Mode.Reverse
1818+import Numeric.AD.Internal.Reverse ( Tape )
1919+2020+import qualified Grenade.Recurrent.Layers.LSTM as LSTM
2121+import qualified Numeric.LinearAlgebra.Static as S
2222+import qualified Numeric.LinearAlgebra as H
2323+2424+--
2525+-- This module contains a set of list only versions of
2626+-- an LSTM layer which can be used with the AD library.
2727+--
2828+-- Using this, we can check to make sure that our fast
2929+-- back propagation implementation is correct.
3030+--
3131+3232+-- | List only matrix deriving functor
3333+data Matrix a = Matrix {
3434+ matrixWeights :: [[a]]
3535+ } deriving (Functor, Foldable, Traversable, Eq, Show)
3636+3737+-- | List only vector deriving functor
3838+data Vector a = Vector {
3939+ vectorWeights :: [a]
4040+ } deriving (Functor, Foldable, Traversable, Eq, Show)
4141+4242+-- | List only LSTM weights
4343+data RefLSTM a = RefLSTM
4444+ { refLstmWf :: Matrix a -- Weight Forget (W_f)
4545+ , refLstmUf :: Matrix a -- Cell State Forget (U_f)
4646+ , refLstmBf :: Vector a -- Bias Forget (b_f)
4747+ , refLstmWi :: Matrix a -- Weight Input (W_i)
4848+ , refLstmUi :: Matrix a -- Cell State Input (U_i)
4949+ , refLstmBi :: Vector a -- Bias Input (b_i)
5050+ , refLstmWo :: Matrix a -- Weight Output (W_o)
5151+ , refLstmUo :: Matrix a -- Cell State Output (U_o)
5252+ , refLstmBo :: Vector a -- Bias Output (b_o)
5353+ , refLstmWc :: Matrix a -- Weight Cell (W_c)
5454+ , refLstmBc :: Vector a -- Bias Cell (b_c)
5555+ } deriving (Functor, Foldable, Traversable, Eq, Show)
5656+5757+lstmToReference :: LSTM.LSTMWeights a b -> RefLSTM Double
5858+lstmToReference LSTM.LSTMWeights {..} =
5959+ let refLstmWf = Matrix . H.toLists . S.extract $ lstmWf -- Weight Forget (W_f)
6060+ refLstmUf = Matrix . H.toLists . S.extract $ lstmUf -- Cell State Forget (U_f)
6161+ refLstmBf = Vector . H.toList . S.extract $ lstmBf -- Bias Forget (b_f)
6262+ refLstmWi = Matrix . H.toLists . S.extract $ lstmWi -- Weight Input (W_i)
6363+ refLstmUi = Matrix . H.toLists . S.extract $ lstmUi -- Cell State Input (U_i)
6464+ refLstmBi = Vector . H.toList . S.extract $ lstmBi -- Bias Input (b_i)
6565+ refLstmWo = Matrix . H.toLists . S.extract $ lstmWo -- Weight Output (W_o)
6666+ refLstmUo = Matrix . H.toLists . S.extract $ lstmUo -- Cell State Output (U_o)
6767+ refLstmBo = Vector . H.toList . S.extract $ lstmBo -- Bias Output (b_o)
6868+ refLstmWc = Matrix . H.toLists . S.extract $ lstmWc -- Weight Cell (W_c)
6969+ refLstmBc = Vector . H.toList . S.extract $ lstmBc -- Bias Cell (b_c)
7070+ in RefLSTM {..}
7171+7272+runLSTM :: Floating a => RefLSTM a -> Vector a -> Vector a -> (Vector a, Vector a)
7373+runLSTM RefLSTM {..} cell input =
7474+ let -- Forget state vector
7575+ f_t = sigmoid $ refLstmBf #+ refLstmWf #> input #+ refLstmUf #> cell
7676+ -- Input state vector
7777+ i_t = sigmoid $ refLstmBi #+ refLstmWi #> input #+ refLstmUi #> cell
7878+ -- Output state vector
7979+ o_t = sigmoid $ refLstmBo #+ refLstmWo #> input #+ refLstmUo #> cell
8080+ -- Cell input state vector
8181+ c_x = fmap tanh $ refLstmBc #+ refLstmWc #> input
8282+ -- Cell state
8383+ c_t = f_t #* cell #+ i_t #* c_x
8484+ -- Output (it's sometimes recommended to use tanh c_t)
8585+ h_t = o_t #* c_t
8686+ in (c_t, h_t)
8787+8888+runLSTMback :: forall a. Floating a => Vector a -> Vector a -> RefLSTM a -> RefLSTM a
8989+runLSTMback cell input =
9090+ grad f
9191+ where
9292+ f :: forall s. Reifies s Tape => RefLSTM (Reverse s a) -> Reverse s a
9393+ f net =
9494+ let cell' = fmap auto cell
9595+ input' = fmap auto input
9696+ (cells, forwarded) = runLSTM net cell' input'
9797+ in sum forwarded + sum cells
9898+9999+runLSTMbackOnInput :: forall a. Floating a => Vector a -> RefLSTM a -> Vector a -> Vector a
100100+runLSTMbackOnInput cell net =
101101+ grad f
102102+ where
103103+ f :: forall s. Reifies s Tape => Vector (Reverse s a) -> Reverse s a
104104+ f input =
105105+ let cell' = fmap auto cell
106106+ net' = fmap auto net
107107+ (cells, forwarded) = runLSTM net' cell' input
108108+ in sum forwarded + sum cells
109109+110110+runLSTMbackOnCell :: forall a. Floating a => Vector a -> RefLSTM a -> Vector a -> Vector a
111111+runLSTMbackOnCell input net =
112112+ grad f
113113+ where
114114+ f :: forall s. Reifies s Tape => Vector (Reverse s a) -> Reverse s a
115115+ f cell =
116116+ let input' = fmap auto input
117117+ net' = fmap auto net
118118+ (cells, forwarded) = runLSTM net' cell input'
119119+ in sum forwarded + sum cells
120120+121121+-- | Helper to multiply a matrix by a vector
122122+matMult :: Num a => Matrix a -> Vector a -> Vector a
123123+matMult (Matrix m) (Vector v) = Vector result
124124+ where
125125+ lrs = map length m
126126+ l = length v
127127+ result = if all (== l) lrs
128128+ then map (\r -> sum $ zipWith (*) r v) m
129129+ else error $ "Matrix has rows of length " ++ show lrs ++
130130+ " but vector is of length " ++ show l
131131+132132+(#>) :: Num a => Matrix a -> Vector a -> Vector a
133133+(#>) = matMult
134134+infixr 8 #>
135135+136136+(#+) :: Num a => Vector a -> Vector a -> Vector a
137137+(#+) (Vector as) (Vector bs) = Vector $ zipWith (+) as bs
138138+infixl 6 #+
139139+140140+(#-) :: Num a => Vector a -> Vector a -> Vector a
141141+(#-) (Vector as) (Vector bs) = Vector $ zipWith (-) as bs
142142+infixl 6 #-
143143+144144+(#*) :: Num a => Vector a -> Vector a -> Vector a
145145+(#*) (Vector as) (Vector bs) = Vector $ zipWith (*) as bs
146146+infixl 7 #*
147147+148148+sigmoid :: (Functor f, Floating a) => f a -> f a
149149+sigmoid xs = (\x -> 1 / (1 + exp (-x))) <$> xs
+2-5
test/Test/Jack/Hmatrix.hs
···4455module Test.Jack.Hmatrix where
6677-import Data.Proxy
87import Disorder.Jack
98109import GHC.TypeLits
···1211import qualified Numeric.LinearAlgebra.Static as HStatic
13121413randomVector :: forall n. KnownNat n => Jack (HStatic.R n)
1515-randomVector = HStatic.fromList <$> vectorOf (fromInteger (natVal (Proxy :: Proxy n))) sizedRealFrac
1414+randomVector = (\s -> HStatic.randomVector s HStatic.Uniform * 2 - 1) <$> sizedNat
16151716uniformSample :: forall m n. (KnownNat m, KnownNat n) => Jack (HStatic.L m n)
1818-uniformSample = HStatic.fromList
1919- <$> vectorOf (fromInteger (natVal (Proxy :: Proxy m)) * fromInteger (natVal (Proxy :: Proxy n)))
2020- sizedRealFrac
1717+uniformSample = (\s -> HStatic.uniformSample s (-1) 1 ) <$> sizedNat