···11-Copyright (c) 2016, Huw Campbell
11+Copyright (c) 2016-2017, Huw Campbell
22All rights reserved.
3344Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
···11+Copyright (c) 2016-2017, Huw Campbell
22+All rights reserved.
33+44+Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
55+66+1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
77+88+2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
99+1010+THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
···34343535-- With the mnist data from Kaggle normalised to doubles between 0 and 1, learning rate of 0.01 and 15 iterations,
3636-- this network should get down to about a 1.3% error rate.
3737-type MNIST = Network '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, Reshape, Relu, FullyConnected 256 80, Logit, FullyConnected 80 10, Logit]
3838- '[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10]
3737+--
3838+-- /NOTE:/ This model is actually too complex for MNIST, and one should use the type given in the readme instead.
3939+-- This one is just here to demonstrate Inception layers in use.
4040+--
4141+type MNIST =
4242+ Network
4343+ '[ Reshape
4444+ , Inception 28 28 1 5 5 5, Pooling 2 2 2 2, Relu
4545+ , Inception 14 14 15 5 5 5, Pooling 2 2 2 2, Relu
4646+ , Reshape
4747+ , FullyConnected 735 80, Logit
4848+ , FullyConnected 80 10, Logit]
4949+ '[ 'D2 28 28, 'D3 28 28 1
5050+ , 'D3 28 28 15, 'D3 14 14 15, 'D3 14 14 15
5151+ , 'D3 14 14 15, 'D3 7 7 15, 'D3 7 7 15
5252+ , 'D1 735
5353+ , 'D1 80, 'D1 80
5454+ , 'D1 10, 'D1 10]
39554056randomMnist :: MonadRandom m => m MNIST
4157randomMnist = randomNetwork
main/recurrent.hs
examples/main/recurrent.hs
main/shakespeare.hs
examples/main/shakespeare.hs
+49-20
src/Grenade.hs
···11module Grenade (
22- module X
22+ -- | This is an empty module which simply re-exports public definitions
33+ -- for machine learning with Grenade.
44+55+ -- * Exported modules
66+ --
77+ -- | The core types and runners for Grenade.
88+ module Grenade.Core
99+1010+ -- | The neural network layer zoo
1111+ , module Grenade.Layers
1212+1313+1414+ -- * Overview of the library
1515+ -- $library
1616+1717+ -- * Example usage
1818+ -- $example
1919+320 ) where
42155-import Grenade.Core.LearningParameters as X
66-import Grenade.Core.Layer as X
77-import Grenade.Core.Network as X
88-import Grenade.Core.Runner as X
99-import Grenade.Core.Shape as X
1010-import Grenade.Layers.Concat as X
1111-import Grenade.Layers.Crop as X
1212-import Grenade.Layers.Dropout as X
1313-import Grenade.Layers.Pad as X
1414-import Grenade.Layers.Pooling as X
1515-import Grenade.Layers.Reshape as X
1616-import Grenade.Layers.FullyConnected as X
1717-import Grenade.Layers.Logit as X
1818-import Grenade.Layers.Merge as X
1919-import Grenade.Layers.Convolution as X
2020-import Grenade.Layers.Relu as X
2121-import Grenade.Layers.Elu as X
2222-import Grenade.Layers.Tanh as X
2323-import Grenade.Layers.Softmax as X
2222+import Grenade.Core
2323+import Grenade.Layers
2424+2525+{- $library
2626+Grenade is a purely functional deep learning library.
2727+2828+It provides an expressive type level API for the construction
2929+of complex neural network architectures. Backing this API is and
3030+implementation written using BLAS and LAPACK, mostly provided by
3131+the hmatrix library.
3232+-}
3333+3434+{- $example
3535+A few examples are provided at https://github.com/HuwCampbell/grenade
3636+under the examples folder.
3737+3838+The starting place is to write your neural network type and a
3939+function to create a random layer of that type. The following
4040+is a simple example which runs a logistic regression.
4141+4242+> type MyNet = Network '[ FullyConnected 10 1, Logit ] '[ 'D1 10, 'D1 1, 'D1 1 ]
4343+>
4444+> randomMyNet :: MonadRandom MyNet
4545+> randomMyNet = randomNetwork
4646+4747+The function `randomMyNet` witnesses the `CreatableNetwork`
4848+constraint of the neural network, that is it ensures the network
4949+can be built, and hence, that the architecture is sound.
5050+-}
5151+5252+
+10-5
src/Grenade/Core.hs
···11module Grenade.Core (
22- module X
22+ module Grenade.Core.Layer
33+ , module Grenade.Core.LearningParameters
44+ , module Grenade.Core.Network
55+ , module Grenade.Core.Runner
66+ , module Grenade.Core.Shape
37 ) where
4855-import Grenade.Core.Layer as X
66-import Grenade.Core.LearningParameters as X
77-import Grenade.Core.Shape as X
88-import Grenade.Core.Network as X
99+import Grenade.Core.Layer
1010+import Grenade.Core.LearningParameters
1111+import Grenade.Core.Network
1212+import Grenade.Core.Runner
1313+import Grenade.Core.Shape
+4
src/Grenade/Core/LearningParameters.hs
···11module Grenade.Core.LearningParameters (
22+ -- | This module contains learning algorithm specific
33+ -- code. Currently, this module should be consifered
44+ -- unstable, due to issue #26.
55+26 LearningParameters (..)
37 ) where
48
+2-7
src/Grenade/Core/Network.hs
···11-{-# LANGUAGE CPP #-}
21{-# LANGUAGE DataKinds #-}
32{-# LANGUAGE BangPatterns #-}
43{-# LANGUAGE GADTs #-}
···1817This module defines the core data types and functions
1918for non-recurrent neural networks.
2019-}
2121-2222-#if __GLASGOW_HASKELL__ < 800
2323-{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
2424-#endif
25202621module Grenade.Core.Network (
2722 Network (..)
···47424843-- | Type of a network.
4944--
5050--- The [*] type specifies the types of the layers.
4545+-- The @[*]@ type specifies the types of the layers.
5146--
5252--- The [Shape] type specifies the shapes of data passed between the layers.
4747+-- The @[Shape]@ type specifies the shapes of data passed between the layers.
5348--
5449-- Can be considered to be a heterogeneous list of layers which are able to
5550-- transform the data shapes of the network.
-9
src/Grenade/Core/Shape.hs
···11-{-# LANGUAGE CPP #-}
21{-# LANGUAGE DataKinds #-}
32{-# LANGUAGE GADTs #-}
43{-# LANGUAGE KindSignatures #-}
···87{-# LANGUAGE FlexibleContexts #-}
98{-# LANGUAGE ScopedTypeVariables #-}
109{-# LANGUAGE RankNTypes #-}
1111-1212--- Ghc 7.10 fails to recognise n2 is complete.
1313-#if __GLASGOW_HASKELL__ < 800
1414-{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
1515-#endif
1610{-|
1711Module : Grenade.Core.Shape
1812Description : Core definition of the Shapes of data we understand
···6559-- All shapes are held in contiguous memory.
6660-- 3D is held in a matrix (usually row oriented) which has height depth * rows.
6761data S (n :: Shape) where
6868- -- | One dimensional data
6962 S1D :: ( KnownNat len )
7063 => R len
7164 -> S ('D1 len)
72657373- -- | Two dimensional data
7466 S2D :: ( KnownNat rows, KnownNat columns )
7567 => L rows columns
7668 -> S ('D2 rows columns)
77697878- -- | Three dimensional data
7970 S3D :: ( KnownNat rows
8071 , KnownNat columns
8172 , KnownNat depth
···99{-# LANGUAGE ScopedTypeVariables #-}
1010{-# LANGUAGE StandaloneDeriving #-}
1111{-|
1212-Module : Grenade.Core.Network
1313-Description : Core definition a simple neural etwork
1212+Module : Grenade.Layers.Concat
1313+Description : Concatenation layer
1414Copyright : (c) Huw Campbell, 2016-2017
1515License : BSD2
1616Stability : experimental
1717+1818+This module provides the concatenation layer, whic used to run two separate layers in parallel and combine their outputs.
1719-}
1820module Grenade.Layers.Concat (
1921 Concat (..)
+26-12
src/Grenade/Layers/Inception.hs
···99{-# LANGUAGE ScopedTypeVariables #-}
1010{-|
1111Module : Grenade.Core.Network
1212-Description : Core definition a simple neural etwork
1212+Description : Inception style parallel convolutional network composition.
1313Copyright : (c) Huw Campbell, 2016-2017
1414License : BSD2
1515Stability : experimental
1616+1717+Export an Inception style type, which can be used to build up
1818+complex multiconvolution size networks.
1619-}
1720module Grenade.Layers.Inception (
1821 Inception
···2528import Grenade.Layers.Pad
2629import Grenade.Layers.Concat
27302828-3131+-- | Type of an inception layer.
3232+--
3333+-- It looks like a bit of a handful, but is actually pretty easy to use.
3434+--
3535+-- The first three type parameters are the size of the (3D) data the
3636+-- inception layer will take. It will emit 3D data with the number of
3737+-- channels being the sum of @chx@, @chy@, @chz@, which are the number
3838+-- of convolution filters in the 3x3, 5x5, and 7x7 convolutions Layers
3939+-- respectively.
4040+--
4141+-- The network get padded effectively before each convolution filters
4242+-- such that the output dimension is the same x and y as the input.
2943type Inception rows cols channels chx chy chz
3030- = Network '[ Concat ('D3 (rows - 2) (cols - 2) (chx + chy)) (InceptionS rows cols channels chx chy) ('D3 (rows - 2) (cols - 2) chz) (Inception7x7 rows cols channels chz) ]
3131- '[ 'D3 rows cols channels, 'D3 (rows -2) (cols -2) (chx + chy + chz) ]
4444+ = Network '[ Concat ('D3 rows cols (chx + chy)) (InceptionS rows cols channels chx chy) ('D3 rows cols chz) (Inception7x7 rows cols channels chz) ]
4545+ '[ 'D3 rows cols channels, 'D3 rows cols (chx + chy + chz) ]
32463347type InceptionS rows cols channels chx chy
3434- = Network '[ Concat ('D3 (rows - 2) (cols - 2) chx) (Inception3x3 rows cols channels chx) ('D3 (rows - 2) (cols - 2) chy) (Inception5x5 rows cols channels chy) ]
3535- '[ 'D3 rows cols channels, 'D3 (rows -2) (cols -2) (chx + chy) ]
4848+ = Network '[ Concat ('D3 rows cols chx) (Inception3x3 rows cols channels chx) ('D3 rows cols chy) (Inception5x5 rows cols channels chy) ]
4949+ '[ 'D3 rows cols channels, 'D3 rows cols (chx + chy) ]
36503751type Inception3x3 rows cols channels chx
3838- = Network '[ Convolution channels chx 3 3 1 1 ]
3939- '[ 'D3 rows cols channels, 'D3 (rows -2) (cols -2) chx ]
5252+ = Network '[ Pad 1 1 1 1, Convolution channels chx 3 3 1 1 ]
5353+ '[ 'D3 rows cols channels, 'D3 (rows + 2) (cols + 2) channels, 'D3 rows cols chx ]
40544155type Inception5x5 rows cols channels chx
4242- = Network '[ Pad 1 1 1 1, Convolution channels chx 5 5 1 1 ]
4343- '[ 'D3 rows cols channels, 'D3 (rows + 2) (cols + 2) channels, 'D3 (rows - 2) (cols - 2) chx ]
5656+ = Network '[ Pad 2 2 2 2, Convolution channels chx 5 5 1 1 ]
5757+ '[ 'D3 rows cols channels, 'D3 (rows + 4) (cols + 4) channels, 'D3 rows cols chx ]
44584559type Inception7x7 rows cols channels chx
4646- = Network '[ Pad 2 2 2 2, Convolution channels chx 7 7 1 1 ]
4747- '[ 'D3 rows cols channels, 'D3 (rows + 4) (cols + 4) channels, 'D3 (rows - 2) (cols - 2) chx ]
6060+ = Network '[ Pad 3 3 3 3, Convolution channels chx 7 7 1 1 ]
6161+ '[ 'D3 rows cols channels, 'D3 (rows + 6) (cols + 6) channels, 'D3 rows cols chx ]
4862
+3-2
src/Grenade/Layers/Trivial.hs
···991010import Data.Serialize
11111212-import Grenade.Core.Network
1212+import Grenade.Core
13131414-- | A trivial layer.
1515data Trivial = Trivial
···2525 createRandom = return Trivial
26262727instance (a ~ b) => Layer Trivial a b where
2828- runForwards _ = id
2828+ type Tape Trivial a b = ()
2929+ runForwards _ a = ((), a)
2930 runBackwards _ _ y = ((), y)
+4-6
src/Grenade/Recurrent.hs
···11module Grenade.Recurrent (
22- module X
22+ module Grenade.Recurrent.Core
33+ , module Grenade.Recurrent.Layers
34 ) where
4555-import Grenade.Recurrent.Core.Layer as X
66-import Grenade.Recurrent.Core.Network as X
77-import Grenade.Recurrent.Core.Runner as X
88-import Grenade.Recurrent.Layers.BasicRecurrent as X
99-import Grenade.Recurrent.Layers.LSTM as X
66+import Grenade.Recurrent.Core
77+import Grenade.Recurrent.Layers
+6-3
src/Grenade/Recurrent/Core.hs
···11module Grenade.Recurrent.Core (
22- module X
22+ module Grenade.Recurrent.Core.Layer
33+ , module Grenade.Recurrent.Core.Network
44+ , module Grenade.Recurrent.Core.Runner
35 ) where
4655-import Grenade.Recurrent.Core.Layer as X
66-import Grenade.Recurrent.Core.Network as X
77+import Grenade.Recurrent.Core.Layer
88+import Grenade.Recurrent.Core.Network
99+import Grenade.Recurrent.Core.Runner
+1-5
src/Grenade/Recurrent/Core/Network.hs
···11-{-# LANGUAGE CPP #-}
21{-# LANGUAGE DataKinds #-}
32{-# LANGUAGE GADTs #-}
43{-# LANGUAGE TypeOperators #-}
···109{-# LANGUAGE RankNTypes #-}
1110{-# LANGUAGE BangPatterns #-}
1211{-# LANGUAGE ScopedTypeVariables #-}
1313-1414-#if __GLASGOW_HASKELL__ < 800
1515-{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
1616-#endif
1212+{-# LANGUAGE UndecidableInstances #-}
17131814module Grenade.Recurrent.Core.Network (
1915 Recurrent
-5
src/Grenade/Recurrent/Core/Runner.hs
···33{-# LANGUAGE DataKinds #-}
44{-# LANGUAGE ScopedTypeVariables #-}
55{-# LANGUAGE TypeOperators #-}
66-{-# LANGUAGE CPP #-}
76{-# LANGUAGE TypeFamilies #-}
87{-# LANGUAGE FlexibleContexts #-}
98{-# LANGUAGE RankNTypes #-}
109{-# LANGUAGE RecordWildCards #-}
1111-1212-#if __GLASGOW_HASKELL__ < 800
1313-{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
1414-#endif
15101611module Grenade.Recurrent.Core.Runner (
1712 trainRecurrent
···11-{-# LANGUAGE CPP #-}
21{-# LANGUAGE DataKinds #-}
32{-# LANGUAGE GADTs #-}
43{-# LANGUAGE RecordWildCards #-}
···87{-# LANGUAGE FlexibleContexts #-}
98{-# LANGUAGE UndecidableInstances #-}
1091111--- GHC 7.10 doesn't see recurrent run functions as total.
1212-#if __GLASGOW_HASKELL__ < 800
1313-{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
1414-#endif
1510module Grenade.Recurrent.Layers.BasicRecurrent (
1611 BasicRecurrent (..)
1712 , randomBasicRecurrent
-6
src/Grenade/Recurrent/Layers/LSTM.hs
···11{-# LANGUAGE BangPatterns #-}
22-{-# LANGUAGE CPP #-}
32{-# LANGUAGE DataKinds #-}
43{-# LANGUAGE GADTs #-}
54{-# LANGUAGE RankNTypes #-}
···109{-# LANGUAGE FlexibleContexts #-}
1110{-# LANGUAGE ViewPatterns #-}
1211{-# LANGUAGE ScopedTypeVariables #-}
1313-1414--- GHC 7.10 doesn't see recurrent run functions as total.
1515-#if __GLASGOW_HASKELL__ < 800
1616-{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
1717-#endif
18121913module Grenade.Recurrent.Layers.LSTM (
2014 LSTM (..)
+9
test/Test/Grenade/Layers/PadCrop.hs
···3030 (_ , grad) = runBackwards net tapes d
3131 in d ~~~ res .&&. grad ~~~ d
32323333+prop_pad_crop_2d :: Property
3434+prop_pad_crop_2d =
3535+ let net :: Network '[Pad 2 3 4 6, Crop 2 3 4 6] '[ 'D2 7 9, 'D2 16 15, 'D2 7 9 ]
3636+ net = Pad :~> Crop :~> NNil
3737+ in gamble genOfShape $ \(d :: S ('D2 7 9)) ->
3838+ let (tapes, res) = runForwards net d
3939+ (_ , grad) = runBackwards net tapes d
4040+ in d ~~~ res .&&. grad ~~~ d
4141+3342(~~~) :: S x -> S x -> Bool
3443(S1D x) ~~~ (S1D y) = norm_Inf (x - y) < 0.00001
3544(S2D x) ~~~ (S2D y) = norm_Inf (x - y) < 0.00001