💣 Machine learning which might blow up in your face 💣

Cleanup imports and move examples to new project

+1 -1
LICENSE
··· 1 - Copyright (c) 2016, Huw Campbell 1 + Copyright (c) 2016-2017, Huw Campbell 2 2 All rights reserved. 3 3 4 4 Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
+6 -2
README.md
··· 20 20 ```haskell 21 21 type MNIST 22 22 = Network 23 - '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, FlattenLayer, Relu, FullyConnected 256 80, Logit, FullyConnected 80 10, Logit] 24 - '[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10] 23 + '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu 24 + , Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, FlattenLayer, Relu 25 + , FullyConnected 256 80, Logit, FullyConnected 80 10, Logit] 26 + '[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10 27 + , 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256 28 + , 'D1 80, 'D1 80, 'D1 10, 'D1 10] 25 29 26 30 randomMnist :: MonadRandom m => m MNIST 27 31 randomMnist = randomNetwork
+10
examples/LICENSE
··· 1 + Copyright (c) 2016-2017, Huw Campbell 2 + All rights reserved. 3 + 4 + Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met: 5 + 6 + 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer. 7 + 8 + 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution. 9 + 10 + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+108
examples/grenade-examples.cabal
··· 1 + name: grenade-examples 2 + version: 0.0.1 3 + license: BSD2 4 + license-file: LICENSE 5 + author: Huw Campbell <huw.campbell@gmail.com> 6 + maintainer: Huw Campbell <huw.campbell@gmail.com> 7 + copyright: (c) 2016-2017 Huw Campbell. 8 + synopsis: grenade-examples 9 + category: System 10 + cabal-version: >= 1.8 11 + build-type: Simple 12 + description: grenade-examples 13 + 14 + source-repository head 15 + type: git 16 + location: https://github.com/HuwCampbell/grenade.git 17 + 18 + library 19 + 20 + executable feedforward 21 + ghc-options: -Wall -threaded -O2 22 + main-is: main/feedforward.hs 23 + build-depends: base 24 + , grenade 25 + , attoparsec 26 + , bytestring 27 + , cereal 28 + , either 29 + , optparse-applicative == 0.13.* 30 + , text == 1.2.* 31 + , mtl >= 2.2.1 && < 2.3 32 + , hmatrix 33 + , transformers 34 + , singletons 35 + , semigroups 36 + , MonadRandom 37 + 38 + executable mnist 39 + ghc-options: -Wall -threaded -O2 40 + main-is: main/mnist.hs 41 + build-depends: base 42 + , grenade 43 + , attoparsec 44 + , either 45 + , optparse-applicative == 0.13.* 46 + , text == 1.2.* 47 + , mtl >= 2.2.1 && < 2.3 48 + , hmatrix >= 0.18 && < 0.19 49 + , transformers 50 + , semigroups 51 + , singletons 52 + , MonadRandom 53 + , vector 54 + 55 + executable gan-mnist 56 + ghc-options: -Wall -threaded -O2 57 + main-is: main/gan-mnist.hs 58 + build-depends: base 59 + , grenade 60 + , attoparsec 61 + , bytestring 62 + , cereal 63 + , either 64 + , optparse-applicative == 0.13.* 65 + , text == 1.2.* 66 + , mtl >= 2.2.1 && < 2.3 67 + , hmatrix >= 0.18 && < 0.19 68 + , transformers 69 + , semigroups 70 + , singletons 71 + , MonadRandom 72 + , vector 73 + 74 + executable recurrent 75 + ghc-options: -Wall -threaded -O2 76 + main-is: main/recurrent.hs 77 + build-depends: base 78 + , grenade 79 + , attoparsec 80 + , either 81 + , optparse-applicative == 0.13.* 82 + , text == 1.2.* 83 + , mtl >= 2.2.1 && < 2.3 84 + , hmatrix >= 0.18 && < 0.19 85 + , transformers 86 + , semigroups 87 + , singletons 88 + , MonadRandom 89 + 90 + executable shakespeare 91 + ghc-options: -Wall -threaded -O2 92 + main-is: main/shakespeare.hs 93 + build-depends: base 94 + , grenade 95 + , attoparsec 96 + , bytestring 97 + , cereal 98 + , either 99 + , optparse-applicative == 0.13.* 100 + , text == 1.2.* 101 + , mtl >= 2.2.1 && < 2.3 102 + , hmatrix >= 0.18 && < 0.19 103 + , transformers 104 + , semigroups 105 + , singletons 106 + , vector 107 + , MonadRandom 108 + , containers
+36 -121
grenade.cabal
··· 4 4 license-file: LICENSE 5 5 author: Huw Campbell <huw.campbell@gmail.com> 6 6 maintainer: Huw Campbell <huw.campbell@gmail.com> 7 - copyright: (c) 2015 Huw Campbell. 7 + copyright: (c) 2016-2017 Huw Campbell. 8 8 synopsis: grenade 9 9 category: System 10 10 cabal-version: >= 1.8 ··· 12 12 description: grenade. 13 13 14 14 extra-source-files: 15 - cbits/im2col.h 16 - cbits/im2col.c 17 - cbits/gradient_decent.h 18 - cbits/gradient_decent.c 19 - cbits/pad.h 20 - cbits/pad.c 15 + cbits/im2col.h 16 + cbits/im2col.c 17 + cbits/gradient_decent.h 18 + cbits/gradient_decent.c 19 + cbits/pad.h 20 + cbits/pad.c 21 21 22 22 source-repository head 23 23 type: git ··· 27 27 build-depends: 28 28 base >= 4.8 && < 5 29 29 , bytestring == 0.10.* 30 - , containers 31 - , deepseq 32 - , either == 4.4.* 33 - , cereal 30 + , containers >= 0.5 && < 0.6 31 + , cereal >= 0.5 && < 0.6 32 + , deepseq >= 1.4 && < 1.5 34 33 , exceptions == 0.8.* 35 34 , hmatrix == 0.18.* 36 - , MonadRandom 37 - , mtl >= 2.2.1 && < 2.3 38 - , primitive 35 + , MonadRandom >= 0.4 && < 0.6 36 + , mtl >= 2.2.1 && < 2.3 37 + , primitive >= 0.6 && < 0.7 39 38 , text == 1.2.* 40 - , transformers 41 - , singletons >= 2.1 && < 2.3 42 - , vector == 0.11.* 39 + , singletons >= 2.1 && < 2.3 40 + , vector >= 0.11 && < 0.13 43 41 44 42 ghc-options: 45 43 -Wall 46 44 hs-source-dirs: 47 45 src 46 + 47 + if impl(ghc < 8.0) 48 + ghc-options: -fno-warn-incomplete-patterns 48 49 49 50 50 51 exposed-modules: ··· 55 56 Grenade.Core.Network 56 57 Grenade.Core.Runner 57 58 Grenade.Core.Shape 58 - Grenade.Layers.Crop 59 + 60 + Grenade.Layers 59 61 Grenade.Layers.Concat 60 62 Grenade.Layers.Convolution 63 + Grenade.Layers.Crop 61 64 Grenade.Layers.Dropout 65 + Grenade.Layers.Elu 62 66 Grenade.Layers.FullyConnected 63 - Grenade.Layers.Reshape 67 + Grenade.Layers.Inception 64 68 Grenade.Layers.Logit 65 69 Grenade.Layers.Merge 66 - Grenade.Layers.Relu 67 - Grenade.Layers.Elu 68 - Grenade.Layers.Tanh 69 70 Grenade.Layers.Pad 70 71 Grenade.Layers.Pooling 72 + Grenade.Layers.Relu 73 + Grenade.Layers.Reshape 74 + Grenade.Layers.Softmax 75 + Grenade.Layers.Tanh 76 + Grenade.Layers.Trivial 71 77 72 78 Grenade.Layers.Internal.Convolution 73 79 Grenade.Layers.Internal.Pad ··· 81 87 Grenade.Recurrent.Core.Network 82 88 Grenade.Recurrent.Core.Runner 83 89 90 + Grenade.Recurrent.Layers 84 91 Grenade.Recurrent.Layers.BasicRecurrent 85 92 Grenade.Recurrent.Layers.LSTM 86 93 87 94 Grenade.Utils.OneHot 88 95 89 - includes: cbits/im2col.h 90 - cbits/gradient_decent.h 91 - cbits/pad.h 92 - c-sources: cbits/im2col.c 93 - cbits/gradient_decent.c 94 - cbits/pad.c 95 - 96 - cc-options: -std=c99 -O3 -msse4.2 -Wall -Werror -DCABAL=1 96 + includes: cbits/im2col.h 97 + cbits/gradient_decent.h 98 + cbits/pad.h 99 + c-sources: cbits/im2col.c 100 + cbits/gradient_decent.c 101 + cbits/pad.c 97 102 98 - executable feedforward 99 - ghc-options: -Wall -threaded -O2 100 - main-is: main/feedforward.hs 101 - build-depends: base 102 - , grenade 103 - , attoparsec 104 - , bytestring 105 - , cereal 106 - , either 107 - , optparse-applicative == 0.13.* 108 - , text == 1.2.* 109 - , mtl >= 2.2.1 && < 2.3 110 - , hmatrix 111 - , transformers 112 - , singletons 113 - , semigroups 114 - , MonadRandom 115 - 116 - executable mnist 117 - ghc-options: -Wall -threaded -O2 118 - main-is: main/mnist.hs 119 - build-depends: base 120 - , grenade 121 - , attoparsec 122 - , either 123 - , optparse-applicative == 0.13.* 124 - , text == 1.2.* 125 - , mtl >= 2.2.1 && < 2.3 126 - , hmatrix >= 0.18 && < 0.19 127 - , transformers 128 - , semigroups 129 - , singletons 130 - , MonadRandom 131 - , vector 132 - 133 - executable gan-mnist 134 - ghc-options: -Wall -threaded -O2 135 - main-is: main/gan-mnist.hs 136 - build-depends: base 137 - , grenade 138 - , attoparsec 139 - , bytestring 140 - , cereal 141 - , either 142 - , optparse-applicative == 0.13.* 143 - , text == 1.2.* 144 - , mtl >= 2.2.1 && < 2.3 145 - , hmatrix >= 0.18 && < 0.19 146 - , transformers 147 - , semigroups 148 - , singletons 149 - , MonadRandom 150 - , vector 151 - 152 - executable recurrent 153 - ghc-options: -Wall -threaded -O2 154 - main-is: main/recurrent.hs 155 - build-depends: base 156 - , grenade 157 - , attoparsec 158 - , either 159 - , optparse-applicative == 0.13.* 160 - , text == 1.2.* 161 - , mtl >= 2.2.1 && < 2.3 162 - , hmatrix >= 0.18 && < 0.19 163 - , transformers 164 - , semigroups 165 - , singletons 166 - , MonadRandom 167 - 168 - 169 - executable shakespeare 170 - ghc-options: -Wall -threaded -O2 171 - main-is: main/shakespeare.hs 172 - build-depends: base 173 - , grenade 174 - , attoparsec 175 - , bytestring 176 - , cereal 177 - , either 178 - , optparse-applicative == 0.13.* 179 - , text == 1.2.* 180 - , mtl >= 2.2.1 && < 2.3 181 - , hmatrix >= 0.18 && < 0.19 182 - , transformers 183 - , semigroups 184 - , singletons 185 - , vector 186 - , MonadRandom 187 - , containers 188 - 103 + cc-options: -std=c99 -O3 -msse4.2 -Wall -Werror -DCABAL=1 189 104 190 105 test-suite test 191 106 type: exitcode-stdio-1.0
main/feedforward.hs examples/main/feedforward.hs
+4 -3
main/gan-mnist.hs examples/main/gan-mnist.hs
··· 58 58 import Grenade.Utils.OneHot 59 59 60 60 type Discriminator = Network '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, Reshape, Relu, FullyConnected 256 80, Logit, FullyConnected 80 1, Logit] 61 - '[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 1, 'D1 1] 61 + '[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 1, 'D1 1] 62 62 63 63 type Generator = Network '[ FullyConnected 100 10240, Relu, Reshape, Convolution 10 10 5 5 1 1, Relu, Convolution 10 1 1 1 1 1, Logit, Reshape] 64 64 '[ 'D1 100, 'D1 10240, 'D1 10240, 'D3 32 32 10, 'D3 28 28 10, 'D3 28 28 10, 'D3 28 28 1, 'D3 28 28 1, 'D2 28 28 ] ··· 77 77 (discriminatorTapeFake, guessFake) = runNetwork discriminator fakeExample 78 78 79 79 (discriminator'real, _) = runGradient discriminator discriminatorTapeReal ( guessReal - 1 ) 80 - (discriminator'fake, push) = runGradient discriminator discriminatorTapeFake guessFake 80 + (discriminator'fake, _) = runGradient discriminator discriminatorTapeFake guessFake 81 + (_, push) = runGradient discriminator discriminatorTapeFake ( guessFake - 1) 81 82 82 - (generator', _) = runGradient generator generatorTape (-push) 83 + (generator', _) = runGradient generator generatorTape push 83 84 84 85 newDiscriminator = foldl' (applyUpdate rate { learningRegulariser = learningRegulariser rate * 10}) discriminator [ discriminator'real, discriminator'fake ] 85 86 newGenerator = applyUpdate rate generator generator'
+18 -2
main/mnist.hs examples/main/mnist.hs
··· 34 34 35 35 -- With the mnist data from Kaggle normalised to doubles between 0 and 1, learning rate of 0.01 and 15 iterations, 36 36 -- this network should get down to about a 1.3% error rate. 37 - type MNIST = Network '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, Reshape, Relu, FullyConnected 256 80, Logit, FullyConnected 80 10, Logit] 38 - '[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10] 37 + -- 38 + -- /NOTE:/ This model is actually too complex for MNIST, and one should use the type given in the readme instead. 39 + -- This one is just here to demonstrate Inception layers in use. 40 + -- 41 + type MNIST = 42 + Network 43 + '[ Reshape 44 + , Inception 28 28 1 5 5 5, Pooling 2 2 2 2, Relu 45 + , Inception 14 14 15 5 5 5, Pooling 2 2 2 2, Relu 46 + , Reshape 47 + , FullyConnected 735 80, Logit 48 + , FullyConnected 80 10, Logit] 49 + '[ 'D2 28 28, 'D3 28 28 1 50 + , 'D3 28 28 15, 'D3 14 14 15, 'D3 14 14 15 51 + , 'D3 14 14 15, 'D3 7 7 15, 'D3 7 7 15 52 + , 'D1 735 53 + , 'D1 80, 'D1 80 54 + , 'D1 10, 'D1 10] 39 55 40 56 randomMnist :: MonadRandom m => m MNIST 41 57 randomMnist = randomNetwork
main/recurrent.hs examples/main/recurrent.hs
main/shakespeare.hs examples/main/shakespeare.hs
+49 -20
src/Grenade.hs
··· 1 1 module Grenade ( 2 - module X 2 + -- | This is an empty module which simply re-exports public definitions 3 + -- for machine learning with Grenade. 4 + 5 + -- * Exported modules 6 + -- 7 + -- | The core types and runners for Grenade. 8 + module Grenade.Core 9 + 10 + -- | The neural network layer zoo 11 + , module Grenade.Layers 12 + 13 + 14 + -- * Overview of the library 15 + -- $library 16 + 17 + -- * Example usage 18 + -- $example 19 + 3 20 ) where 4 21 5 - import Grenade.Core.LearningParameters as X 6 - import Grenade.Core.Layer as X 7 - import Grenade.Core.Network as X 8 - import Grenade.Core.Runner as X 9 - import Grenade.Core.Shape as X 10 - import Grenade.Layers.Concat as X 11 - import Grenade.Layers.Crop as X 12 - import Grenade.Layers.Dropout as X 13 - import Grenade.Layers.Pad as X 14 - import Grenade.Layers.Pooling as X 15 - import Grenade.Layers.Reshape as X 16 - import Grenade.Layers.FullyConnected as X 17 - import Grenade.Layers.Logit as X 18 - import Grenade.Layers.Merge as X 19 - import Grenade.Layers.Convolution as X 20 - import Grenade.Layers.Relu as X 21 - import Grenade.Layers.Elu as X 22 - import Grenade.Layers.Tanh as X 23 - import Grenade.Layers.Softmax as X 22 + import Grenade.Core 23 + import Grenade.Layers 24 + 25 + {- $library 26 + Grenade is a purely functional deep learning library. 27 + 28 + It provides an expressive type level API for the construction 29 + of complex neural network architectures. Backing this API is and 30 + implementation written using BLAS and LAPACK, mostly provided by 31 + the hmatrix library. 32 + -} 33 + 34 + {- $example 35 + A few examples are provided at https://github.com/HuwCampbell/grenade 36 + under the examples folder. 37 + 38 + The starting place is to write your neural network type and a 39 + function to create a random layer of that type. The following 40 + is a simple example which runs a logistic regression. 41 + 42 + > type MyNet = Network '[ FullyConnected 10 1, Logit ] '[ 'D1 10, 'D1 1, 'D1 1 ] 43 + > 44 + > randomMyNet :: MonadRandom MyNet 45 + > randomMyNet = randomNetwork 46 + 47 + The function `randomMyNet` witnesses the `CreatableNetwork` 48 + constraint of the neural network, that is it ensures the network 49 + can be built, and hence, that the architecture is sound. 50 + -} 51 + 52 +
+10 -5
src/Grenade/Core.hs
··· 1 1 module Grenade.Core ( 2 - module X 2 + module Grenade.Core.Layer 3 + , module Grenade.Core.LearningParameters 4 + , module Grenade.Core.Network 5 + , module Grenade.Core.Runner 6 + , module Grenade.Core.Shape 3 7 ) where 4 8 5 - import Grenade.Core.Layer as X 6 - import Grenade.Core.LearningParameters as X 7 - import Grenade.Core.Shape as X 8 - import Grenade.Core.Network as X 9 + import Grenade.Core.Layer 10 + import Grenade.Core.LearningParameters 11 + import Grenade.Core.Network 12 + import Grenade.Core.Runner 13 + import Grenade.Core.Shape
+4
src/Grenade/Core/LearningParameters.hs
··· 1 1 module Grenade.Core.LearningParameters ( 2 + -- | This module contains learning algorithm specific 3 + -- code. Currently, this module should be consifered 4 + -- unstable, due to issue #26. 5 + 2 6 LearningParameters (..) 3 7 ) where 4 8
+2 -7
src/Grenade/Core/Network.hs
··· 1 - {-# LANGUAGE CPP #-} 2 1 {-# LANGUAGE DataKinds #-} 3 2 {-# LANGUAGE BangPatterns #-} 4 3 {-# LANGUAGE GADTs #-} ··· 18 17 This module defines the core data types and functions 19 18 for non-recurrent neural networks. 20 19 -} 21 - 22 - #if __GLASGOW_HASKELL__ < 800 23 - {-# OPTIONS_GHC -fno-warn-incomplete-patterns #-} 24 - #endif 25 20 26 21 module Grenade.Core.Network ( 27 22 Network (..) ··· 47 42 48 43 -- | Type of a network. 49 44 -- 50 - -- The [*] type specifies the types of the layers. 45 + -- The @[*]@ type specifies the types of the layers. 51 46 -- 52 - -- The [Shape] type specifies the shapes of data passed between the layers. 47 + -- The @[Shape]@ type specifies the shapes of data passed between the layers. 53 48 -- 54 49 -- Can be considered to be a heterogeneous list of layers which are able to 55 50 -- transform the data shapes of the network.
-9
src/Grenade/Core/Shape.hs
··· 1 - {-# LANGUAGE CPP #-} 2 1 {-# LANGUAGE DataKinds #-} 3 2 {-# LANGUAGE GADTs #-} 4 3 {-# LANGUAGE KindSignatures #-} ··· 8 7 {-# LANGUAGE FlexibleContexts #-} 9 8 {-# LANGUAGE ScopedTypeVariables #-} 10 9 {-# LANGUAGE RankNTypes #-} 11 - 12 - -- Ghc 7.10 fails to recognise n2 is complete. 13 - #if __GLASGOW_HASKELL__ < 800 14 - {-# OPTIONS_GHC -fno-warn-incomplete-patterns #-} 15 - #endif 16 10 {-| 17 11 Module : Grenade.Core.Shape 18 12 Description : Core definition of the Shapes of data we understand ··· 65 59 -- All shapes are held in contiguous memory. 66 60 -- 3D is held in a matrix (usually row oriented) which has height depth * rows. 67 61 data S (n :: Shape) where 68 - -- | One dimensional data 69 62 S1D :: ( KnownNat len ) 70 63 => R len 71 64 -> S ('D1 len) 72 65 73 - -- | Two dimensional data 74 66 S2D :: ( KnownNat rows, KnownNat columns ) 75 67 => L rows columns 76 68 -> S ('D2 rows columns) 77 69 78 - -- | Three dimensional data 79 70 S3D :: ( KnownNat rows 80 71 , KnownNat columns 81 72 , KnownNat depth
+33
src/Grenade/Layers.hs
··· 1 + module Grenade.Layers ( 2 + module Grenade.Layers.Concat 3 + , module Grenade.Layers.Convolution 4 + , module Grenade.Layers.Crop 5 + , module Grenade.Layers.Elu 6 + , module Grenade.Layers.FullyConnected 7 + , module Grenade.Layers.Inception 8 + , module Grenade.Layers.Logit 9 + , module Grenade.Layers.Merge 10 + , module Grenade.Layers.Pad 11 + , module Grenade.Layers.Pooling 12 + , module Grenade.Layers.Reshape 13 + , module Grenade.Layers.Relu 14 + , module Grenade.Layers.Softmax 15 + , module Grenade.Layers.Tanh 16 + , module Grenade.Layers.Trivial 17 + ) where 18 + 19 + import Grenade.Layers.Concat 20 + import Grenade.Layers.Convolution 21 + import Grenade.Layers.Crop 22 + import Grenade.Layers.Elu 23 + import Grenade.Layers.Pad 24 + import Grenade.Layers.FullyConnected 25 + import Grenade.Layers.Inception 26 + import Grenade.Layers.Logit 27 + import Grenade.Layers.Merge 28 + import Grenade.Layers.Pooling 29 + import Grenade.Layers.Reshape 30 + import Grenade.Layers.Relu 31 + import Grenade.Layers.Softmax 32 + import Grenade.Layers.Tanh 33 + import Grenade.Layers.Trivial
+4 -2
src/Grenade/Layers/Concat.hs
··· 9 9 {-# LANGUAGE ScopedTypeVariables #-} 10 10 {-# LANGUAGE StandaloneDeriving #-} 11 11 {-| 12 - Module : Grenade.Core.Network 13 - Description : Core definition a simple neural etwork 12 + Module : Grenade.Layers.Concat 13 + Description : Concatenation layer 14 14 Copyright : (c) Huw Campbell, 2016-2017 15 15 License : BSD2 16 16 Stability : experimental 17 + 18 + This module provides the concatenation layer, whic used to run two separate layers in parallel and combine their outputs. 17 19 -} 18 20 module Grenade.Layers.Concat ( 19 21 Concat (..)
+26 -12
src/Grenade/Layers/Inception.hs
··· 9 9 {-# LANGUAGE ScopedTypeVariables #-} 10 10 {-| 11 11 Module : Grenade.Core.Network 12 - Description : Core definition a simple neural etwork 12 + Description : Inception style parallel convolutional network composition. 13 13 Copyright : (c) Huw Campbell, 2016-2017 14 14 License : BSD2 15 15 Stability : experimental 16 + 17 + Export an Inception style type, which can be used to build up 18 + complex multiconvolution size networks. 16 19 -} 17 20 module Grenade.Layers.Inception ( 18 21 Inception ··· 25 28 import Grenade.Layers.Pad 26 29 import Grenade.Layers.Concat 27 30 28 - 31 + -- | Type of an inception layer. 32 + -- 33 + -- It looks like a bit of a handful, but is actually pretty easy to use. 34 + -- 35 + -- The first three type parameters are the size of the (3D) data the 36 + -- inception layer will take. It will emit 3D data with the number of 37 + -- channels being the sum of @chx@, @chy@, @chz@, which are the number 38 + -- of convolution filters in the 3x3, 5x5, and 7x7 convolutions Layers 39 + -- respectively. 40 + -- 41 + -- The network get padded effectively before each convolution filters 42 + -- such that the output dimension is the same x and y as the input. 29 43 type Inception rows cols channels chx chy chz 30 - = Network '[ Concat ('D3 (rows - 2) (cols - 2) (chx + chy)) (InceptionS rows cols channels chx chy) ('D3 (rows - 2) (cols - 2) chz) (Inception7x7 rows cols channels chz) ] 31 - '[ 'D3 rows cols channels, 'D3 (rows -2) (cols -2) (chx + chy + chz) ] 44 + = Network '[ Concat ('D3 rows cols (chx + chy)) (InceptionS rows cols channels chx chy) ('D3 rows cols chz) (Inception7x7 rows cols channels chz) ] 45 + '[ 'D3 rows cols channels, 'D3 rows cols (chx + chy + chz) ] 32 46 33 47 type InceptionS rows cols channels chx chy 34 - = Network '[ Concat ('D3 (rows - 2) (cols - 2) chx) (Inception3x3 rows cols channels chx) ('D3 (rows - 2) (cols - 2) chy) (Inception5x5 rows cols channels chy) ] 35 - '[ 'D3 rows cols channels, 'D3 (rows -2) (cols -2) (chx + chy) ] 48 + = Network '[ Concat ('D3 rows cols chx) (Inception3x3 rows cols channels chx) ('D3 rows cols chy) (Inception5x5 rows cols channels chy) ] 49 + '[ 'D3 rows cols channels, 'D3 rows cols (chx + chy) ] 36 50 37 51 type Inception3x3 rows cols channels chx 38 - = Network '[ Convolution channels chx 3 3 1 1 ] 39 - '[ 'D3 rows cols channels, 'D3 (rows -2) (cols -2) chx ] 52 + = Network '[ Pad 1 1 1 1, Convolution channels chx 3 3 1 1 ] 53 + '[ 'D3 rows cols channels, 'D3 (rows + 2) (cols + 2) channels, 'D3 rows cols chx ] 40 54 41 55 type Inception5x5 rows cols channels chx 42 - = Network '[ Pad 1 1 1 1, Convolution channels chx 5 5 1 1 ] 43 - '[ 'D3 rows cols channels, 'D3 (rows + 2) (cols + 2) channels, 'D3 (rows - 2) (cols - 2) chx ] 56 + = Network '[ Pad 2 2 2 2, Convolution channels chx 5 5 1 1 ] 57 + '[ 'D3 rows cols channels, 'D3 (rows + 4) (cols + 4) channels, 'D3 rows cols chx ] 44 58 45 59 type Inception7x7 rows cols channels chx 46 - = Network '[ Pad 2 2 2 2, Convolution channels chx 7 7 1 1 ] 47 - '[ 'D3 rows cols channels, 'D3 (rows + 4) (cols + 4) channels, 'D3 (rows - 2) (cols - 2) chx ] 60 + = Network '[ Pad 3 3 3 3, Convolution channels chx 7 7 1 1 ] 61 + '[ 'D3 rows cols channels, 'D3 (rows + 6) (cols + 6) channels, 'D3 rows cols chx ] 48 62
+3 -2
src/Grenade/Layers/Trivial.hs
··· 9 9 10 10 import Data.Serialize 11 11 12 - import Grenade.Core.Network 12 + import Grenade.Core 13 13 14 14 -- | A trivial layer. 15 15 data Trivial = Trivial ··· 25 25 createRandom = return Trivial 26 26 27 27 instance (a ~ b) => Layer Trivial a b where 28 - runForwards _ = id 28 + type Tape Trivial a b = () 29 + runForwards _ a = ((), a) 29 30 runBackwards _ _ y = ((), y)
+4 -6
src/Grenade/Recurrent.hs
··· 1 1 module Grenade.Recurrent ( 2 - module X 2 + module Grenade.Recurrent.Core 3 + , module Grenade.Recurrent.Layers 3 4 ) where 4 5 5 - import Grenade.Recurrent.Core.Layer as X 6 - import Grenade.Recurrent.Core.Network as X 7 - import Grenade.Recurrent.Core.Runner as X 8 - import Grenade.Recurrent.Layers.BasicRecurrent as X 9 - import Grenade.Recurrent.Layers.LSTM as X 6 + import Grenade.Recurrent.Core 7 + import Grenade.Recurrent.Layers
+6 -3
src/Grenade/Recurrent/Core.hs
··· 1 1 module Grenade.Recurrent.Core ( 2 - module X 2 + module Grenade.Recurrent.Core.Layer 3 + , module Grenade.Recurrent.Core.Network 4 + , module Grenade.Recurrent.Core.Runner 3 5 ) where 4 6 5 - import Grenade.Recurrent.Core.Layer as X 6 - import Grenade.Recurrent.Core.Network as X 7 + import Grenade.Recurrent.Core.Layer 8 + import Grenade.Recurrent.Core.Network 9 + import Grenade.Recurrent.Core.Runner
+1 -5
src/Grenade/Recurrent/Core/Network.hs
··· 1 - {-# LANGUAGE CPP #-} 2 1 {-# LANGUAGE DataKinds #-} 3 2 {-# LANGUAGE GADTs #-} 4 3 {-# LANGUAGE TypeOperators #-} ··· 10 9 {-# LANGUAGE RankNTypes #-} 11 10 {-# LANGUAGE BangPatterns #-} 12 11 {-# LANGUAGE ScopedTypeVariables #-} 13 - 14 - #if __GLASGOW_HASKELL__ < 800 15 - {-# OPTIONS_GHC -fno-warn-incomplete-patterns #-} 16 - #endif 12 + {-# LANGUAGE UndecidableInstances #-} 17 13 18 14 module Grenade.Recurrent.Core.Network ( 19 15 Recurrent
-5
src/Grenade/Recurrent/Core/Runner.hs
··· 3 3 {-# LANGUAGE DataKinds #-} 4 4 {-# LANGUAGE ScopedTypeVariables #-} 5 5 {-# LANGUAGE TypeOperators #-} 6 - {-# LANGUAGE CPP #-} 7 6 {-# LANGUAGE TypeFamilies #-} 8 7 {-# LANGUAGE FlexibleContexts #-} 9 8 {-# LANGUAGE RankNTypes #-} 10 9 {-# LANGUAGE RecordWildCards #-} 11 - 12 - #if __GLASGOW_HASKELL__ < 800 13 - {-# OPTIONS_GHC -fno-warn-incomplete-patterns #-} 14 - #endif 15 10 16 11 module Grenade.Recurrent.Core.Runner ( 17 12 trainRecurrent
+7
src/Grenade/Recurrent/Layers.hs
··· 1 + module Grenade.Recurrent.Layers ( 2 + module Grenade.Recurrent.Layers.BasicRecurrent 3 + , module Grenade.Recurrent.Layers.LSTM 4 + ) where 5 + 6 + import Grenade.Recurrent.Layers.BasicRecurrent 7 + import Grenade.Recurrent.Layers.LSTM
-5
src/Grenade/Recurrent/Layers/BasicRecurrent.hs
··· 1 - {-# LANGUAGE CPP #-} 2 1 {-# LANGUAGE DataKinds #-} 3 2 {-# LANGUAGE GADTs #-} 4 3 {-# LANGUAGE RecordWildCards #-} ··· 8 7 {-# LANGUAGE FlexibleContexts #-} 9 8 {-# LANGUAGE UndecidableInstances #-} 10 9 11 - -- GHC 7.10 doesn't see recurrent run functions as total. 12 - #if __GLASGOW_HASKELL__ < 800 13 - {-# OPTIONS_GHC -fno-warn-incomplete-patterns #-} 14 - #endif 15 10 module Grenade.Recurrent.Layers.BasicRecurrent ( 16 11 BasicRecurrent (..) 17 12 , randomBasicRecurrent
-6
src/Grenade/Recurrent/Layers/LSTM.hs
··· 1 1 {-# LANGUAGE BangPatterns #-} 2 - {-# LANGUAGE CPP #-} 3 2 {-# LANGUAGE DataKinds #-} 4 3 {-# LANGUAGE GADTs #-} 5 4 {-# LANGUAGE RankNTypes #-} ··· 10 9 {-# LANGUAGE FlexibleContexts #-} 11 10 {-# LANGUAGE ViewPatterns #-} 12 11 {-# LANGUAGE ScopedTypeVariables #-} 13 - 14 - -- GHC 7.10 doesn't see recurrent run functions as total. 15 - #if __GLASGOW_HASKELL__ < 800 16 - {-# OPTIONS_GHC -fno-warn-incomplete-patterns #-} 17 - #endif 18 12 19 13 module Grenade.Recurrent.Layers.LSTM ( 20 14 LSTM (..)
+9
test/Test/Grenade/Layers/PadCrop.hs
··· 30 30 (_ , grad) = runBackwards net tapes d 31 31 in d ~~~ res .&&. grad ~~~ d 32 32 33 + prop_pad_crop_2d :: Property 34 + prop_pad_crop_2d = 35 + let net :: Network '[Pad 2 3 4 6, Crop 2 3 4 6] '[ 'D2 7 9, 'D2 16 15, 'D2 7 9 ] 36 + net = Pad :~> Crop :~> NNil 37 + in gamble genOfShape $ \(d :: S ('D2 7 9)) -> 38 + let (tapes, res) = runForwards net d 39 + (_ , grad) = runBackwards net tapes d 40 + in d ~~~ res .&&. grad ~~~ d 41 + 33 42 (~~~) :: S x -> S x -> Bool 34 43 (S1D x) ~~~ (S1D y) = norm_Inf (x - y) < 0.00001 35 44 (S2D x) ~~~ (S2D y) = norm_Inf (x - y) < 0.00001