+1
-1
LICENSE
+1
-1
LICENSE
+6
-2
README.md
+6
-2
README.md
···
20
20
```haskell
21
21
type MNIST
22
22
= Network
23
-
'[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, FlattenLayer, Relu, FullyConnected 256 80, Logit, FullyConnected 80 10, Logit]
24
-
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10]
23
+
'[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu
24
+
, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, FlattenLayer, Relu
25
+
, FullyConnected 256 80, Logit, FullyConnected 80 10, Logit]
26
+
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10
27
+
, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256
28
+
, 'D1 80, 'D1 80, 'D1 10, 'D1 10]
25
29
26
30
randomMnist :: MonadRandom m => m MNIST
27
31
randomMnist = randomNetwork
+10
examples/LICENSE
+10
examples/LICENSE
···
1
+
Copyright (c) 2016-2017, Huw Campbell
2
+
All rights reserved.
3
+
4
+
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
5
+
6
+
1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
7
+
8
+
2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
9
+
10
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+108
examples/grenade-examples.cabal
+108
examples/grenade-examples.cabal
···
1
+
name: grenade-examples
2
+
version: 0.0.1
3
+
license: BSD2
4
+
license-file: LICENSE
5
+
author: Huw Campbell <huw.campbell@gmail.com>
6
+
maintainer: Huw Campbell <huw.campbell@gmail.com>
7
+
copyright: (c) 2016-2017 Huw Campbell.
8
+
synopsis: grenade-examples
9
+
category: System
10
+
cabal-version: >= 1.8
11
+
build-type: Simple
12
+
description: grenade-examples
13
+
14
+
source-repository head
15
+
type: git
16
+
location: https://github.com/HuwCampbell/grenade.git
17
+
18
+
library
19
+
20
+
executable feedforward
21
+
ghc-options: -Wall -threaded -O2
22
+
main-is: main/feedforward.hs
23
+
build-depends: base
24
+
, grenade
25
+
, attoparsec
26
+
, bytestring
27
+
, cereal
28
+
, either
29
+
, optparse-applicative == 0.13.*
30
+
, text == 1.2.*
31
+
, mtl >= 2.2.1 && < 2.3
32
+
, hmatrix
33
+
, transformers
34
+
, singletons
35
+
, semigroups
36
+
, MonadRandom
37
+
38
+
executable mnist
39
+
ghc-options: -Wall -threaded -O2
40
+
main-is: main/mnist.hs
41
+
build-depends: base
42
+
, grenade
43
+
, attoparsec
44
+
, either
45
+
, optparse-applicative == 0.13.*
46
+
, text == 1.2.*
47
+
, mtl >= 2.2.1 && < 2.3
48
+
, hmatrix >= 0.18 && < 0.19
49
+
, transformers
50
+
, semigroups
51
+
, singletons
52
+
, MonadRandom
53
+
, vector
54
+
55
+
executable gan-mnist
56
+
ghc-options: -Wall -threaded -O2
57
+
main-is: main/gan-mnist.hs
58
+
build-depends: base
59
+
, grenade
60
+
, attoparsec
61
+
, bytestring
62
+
, cereal
63
+
, either
64
+
, optparse-applicative == 0.13.*
65
+
, text == 1.2.*
66
+
, mtl >= 2.2.1 && < 2.3
67
+
, hmatrix >= 0.18 && < 0.19
68
+
, transformers
69
+
, semigroups
70
+
, singletons
71
+
, MonadRandom
72
+
, vector
73
+
74
+
executable recurrent
75
+
ghc-options: -Wall -threaded -O2
76
+
main-is: main/recurrent.hs
77
+
build-depends: base
78
+
, grenade
79
+
, attoparsec
80
+
, either
81
+
, optparse-applicative == 0.13.*
82
+
, text == 1.2.*
83
+
, mtl >= 2.2.1 && < 2.3
84
+
, hmatrix >= 0.18 && < 0.19
85
+
, transformers
86
+
, semigroups
87
+
, singletons
88
+
, MonadRandom
89
+
90
+
executable shakespeare
91
+
ghc-options: -Wall -threaded -O2
92
+
main-is: main/shakespeare.hs
93
+
build-depends: base
94
+
, grenade
95
+
, attoparsec
96
+
, bytestring
97
+
, cereal
98
+
, either
99
+
, optparse-applicative == 0.13.*
100
+
, text == 1.2.*
101
+
, mtl >= 2.2.1 && < 2.3
102
+
, hmatrix >= 0.18 && < 0.19
103
+
, transformers
104
+
, semigroups
105
+
, singletons
106
+
, vector
107
+
, MonadRandom
108
+
, containers
+36
-121
grenade.cabal
+36
-121
grenade.cabal
···
4
4
license-file: LICENSE
5
5
author: Huw Campbell <huw.campbell@gmail.com>
6
6
maintainer: Huw Campbell <huw.campbell@gmail.com>
7
-
copyright: (c) 2015 Huw Campbell.
7
+
copyright: (c) 2016-2017 Huw Campbell.
8
8
synopsis: grenade
9
9
category: System
10
10
cabal-version: >= 1.8
···
12
12
description: grenade.
13
13
14
14
extra-source-files:
15
-
cbits/im2col.h
16
-
cbits/im2col.c
17
-
cbits/gradient_decent.h
18
-
cbits/gradient_decent.c
19
-
cbits/pad.h
20
-
cbits/pad.c
15
+
cbits/im2col.h
16
+
cbits/im2col.c
17
+
cbits/gradient_decent.h
18
+
cbits/gradient_decent.c
19
+
cbits/pad.h
20
+
cbits/pad.c
21
21
22
22
source-repository head
23
23
type: git
···
27
27
build-depends:
28
28
base >= 4.8 && < 5
29
29
, bytestring == 0.10.*
30
-
, containers
31
-
, deepseq
32
-
, either == 4.4.*
33
-
, cereal
30
+
, containers >= 0.5 && < 0.6
31
+
, cereal >= 0.5 && < 0.6
32
+
, deepseq >= 1.4 && < 1.5
34
33
, exceptions == 0.8.*
35
34
, hmatrix == 0.18.*
36
-
, MonadRandom
37
-
, mtl >= 2.2.1 && < 2.3
38
-
, primitive
35
+
, MonadRandom >= 0.4 && < 0.6
36
+
, mtl >= 2.2.1 && < 2.3
37
+
, primitive >= 0.6 && < 0.7
39
38
, text == 1.2.*
40
-
, transformers
41
-
, singletons >= 2.1 && < 2.3
42
-
, vector == 0.11.*
39
+
, singletons >= 2.1 && < 2.3
40
+
, vector >= 0.11 && < 0.13
43
41
44
42
ghc-options:
45
43
-Wall
46
44
hs-source-dirs:
47
45
src
46
+
47
+
if impl(ghc < 8.0)
48
+
ghc-options: -fno-warn-incomplete-patterns
48
49
49
50
50
51
exposed-modules:
···
55
56
Grenade.Core.Network
56
57
Grenade.Core.Runner
57
58
Grenade.Core.Shape
58
-
Grenade.Layers.Crop
59
+
60
+
Grenade.Layers
59
61
Grenade.Layers.Concat
60
62
Grenade.Layers.Convolution
63
+
Grenade.Layers.Crop
61
64
Grenade.Layers.Dropout
65
+
Grenade.Layers.Elu
62
66
Grenade.Layers.FullyConnected
63
-
Grenade.Layers.Reshape
67
+
Grenade.Layers.Inception
64
68
Grenade.Layers.Logit
65
69
Grenade.Layers.Merge
66
-
Grenade.Layers.Relu
67
-
Grenade.Layers.Elu
68
-
Grenade.Layers.Tanh
69
70
Grenade.Layers.Pad
70
71
Grenade.Layers.Pooling
72
+
Grenade.Layers.Relu
73
+
Grenade.Layers.Reshape
74
+
Grenade.Layers.Softmax
75
+
Grenade.Layers.Tanh
76
+
Grenade.Layers.Trivial
71
77
72
78
Grenade.Layers.Internal.Convolution
73
79
Grenade.Layers.Internal.Pad
···
81
87
Grenade.Recurrent.Core.Network
82
88
Grenade.Recurrent.Core.Runner
83
89
90
+
Grenade.Recurrent.Layers
84
91
Grenade.Recurrent.Layers.BasicRecurrent
85
92
Grenade.Recurrent.Layers.LSTM
86
93
87
94
Grenade.Utils.OneHot
88
95
89
-
includes: cbits/im2col.h
90
-
cbits/gradient_decent.h
91
-
cbits/pad.h
92
-
c-sources: cbits/im2col.c
93
-
cbits/gradient_decent.c
94
-
cbits/pad.c
95
-
96
-
cc-options: -std=c99 -O3 -msse4.2 -Wall -Werror -DCABAL=1
96
+
includes: cbits/im2col.h
97
+
cbits/gradient_decent.h
98
+
cbits/pad.h
99
+
c-sources: cbits/im2col.c
100
+
cbits/gradient_decent.c
101
+
cbits/pad.c
97
102
98
-
executable feedforward
99
-
ghc-options: -Wall -threaded -O2
100
-
main-is: main/feedforward.hs
101
-
build-depends: base
102
-
, grenade
103
-
, attoparsec
104
-
, bytestring
105
-
, cereal
106
-
, either
107
-
, optparse-applicative == 0.13.*
108
-
, text == 1.2.*
109
-
, mtl >= 2.2.1 && < 2.3
110
-
, hmatrix
111
-
, transformers
112
-
, singletons
113
-
, semigroups
114
-
, MonadRandom
115
-
116
-
executable mnist
117
-
ghc-options: -Wall -threaded -O2
118
-
main-is: main/mnist.hs
119
-
build-depends: base
120
-
, grenade
121
-
, attoparsec
122
-
, either
123
-
, optparse-applicative == 0.13.*
124
-
, text == 1.2.*
125
-
, mtl >= 2.2.1 && < 2.3
126
-
, hmatrix >= 0.18 && < 0.19
127
-
, transformers
128
-
, semigroups
129
-
, singletons
130
-
, MonadRandom
131
-
, vector
132
-
133
-
executable gan-mnist
134
-
ghc-options: -Wall -threaded -O2
135
-
main-is: main/gan-mnist.hs
136
-
build-depends: base
137
-
, grenade
138
-
, attoparsec
139
-
, bytestring
140
-
, cereal
141
-
, either
142
-
, optparse-applicative == 0.13.*
143
-
, text == 1.2.*
144
-
, mtl >= 2.2.1 && < 2.3
145
-
, hmatrix >= 0.18 && < 0.19
146
-
, transformers
147
-
, semigroups
148
-
, singletons
149
-
, MonadRandom
150
-
, vector
151
-
152
-
executable recurrent
153
-
ghc-options: -Wall -threaded -O2
154
-
main-is: main/recurrent.hs
155
-
build-depends: base
156
-
, grenade
157
-
, attoparsec
158
-
, either
159
-
, optparse-applicative == 0.13.*
160
-
, text == 1.2.*
161
-
, mtl >= 2.2.1 && < 2.3
162
-
, hmatrix >= 0.18 && < 0.19
163
-
, transformers
164
-
, semigroups
165
-
, singletons
166
-
, MonadRandom
167
-
168
-
169
-
executable shakespeare
170
-
ghc-options: -Wall -threaded -O2
171
-
main-is: main/shakespeare.hs
172
-
build-depends: base
173
-
, grenade
174
-
, attoparsec
175
-
, bytestring
176
-
, cereal
177
-
, either
178
-
, optparse-applicative == 0.13.*
179
-
, text == 1.2.*
180
-
, mtl >= 2.2.1 && < 2.3
181
-
, hmatrix >= 0.18 && < 0.19
182
-
, transformers
183
-
, semigroups
184
-
, singletons
185
-
, vector
186
-
, MonadRandom
187
-
, containers
188
-
103
+
cc-options: -std=c99 -O3 -msse4.2 -Wall -Werror -DCABAL=1
189
104
190
105
test-suite test
191
106
type: exitcode-stdio-1.0
main/feedforward.hs
examples/main/feedforward.hs
main/feedforward.hs
examples/main/feedforward.hs
+4
-3
main/gan-mnist.hs
examples/main/gan-mnist.hs
+4
-3
main/gan-mnist.hs
examples/main/gan-mnist.hs
···
58
58
import Grenade.Utils.OneHot
59
59
60
60
type Discriminator = Network '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, Reshape, Relu, FullyConnected 256 80, Logit, FullyConnected 80 1, Logit]
61
-
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 1, 'D1 1]
61
+
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 1, 'D1 1]
62
62
63
63
type Generator = Network '[ FullyConnected 100 10240, Relu, Reshape, Convolution 10 10 5 5 1 1, Relu, Convolution 10 1 1 1 1 1, Logit, Reshape]
64
64
'[ 'D1 100, 'D1 10240, 'D1 10240, 'D3 32 32 10, 'D3 28 28 10, 'D3 28 28 10, 'D3 28 28 1, 'D3 28 28 1, 'D2 28 28 ]
···
77
77
(discriminatorTapeFake, guessFake) = runNetwork discriminator fakeExample
78
78
79
79
(discriminator'real, _) = runGradient discriminator discriminatorTapeReal ( guessReal - 1 )
80
-
(discriminator'fake, push) = runGradient discriminator discriminatorTapeFake guessFake
80
+
(discriminator'fake, _) = runGradient discriminator discriminatorTapeFake guessFake
81
+
(_, push) = runGradient discriminator discriminatorTapeFake ( guessFake - 1)
81
82
82
-
(generator', _) = runGradient generator generatorTape (-push)
83
+
(generator', _) = runGradient generator generatorTape push
83
84
84
85
newDiscriminator = foldl' (applyUpdate rate { learningRegulariser = learningRegulariser rate * 10}) discriminator [ discriminator'real, discriminator'fake ]
85
86
newGenerator = applyUpdate rate generator generator'
+18
-2
main/mnist.hs
examples/main/mnist.hs
+18
-2
main/mnist.hs
examples/main/mnist.hs
···
34
34
35
35
-- With the mnist data from Kaggle normalised to doubles between 0 and 1, learning rate of 0.01 and 15 iterations,
36
36
-- this network should get down to about a 1.3% error rate.
37
-
type MNIST = Network '[ Convolution 1 10 5 5 1 1, Pooling 2 2 2 2, Relu, Convolution 10 16 5 5 1 1, Pooling 2 2 2 2, Reshape, Relu, FullyConnected 256 80, Logit, FullyConnected 80 10, Logit]
38
-
'[ 'D2 28 28, 'D3 24 24 10, 'D3 12 12 10, 'D3 12 12 10, 'D3 8 8 16, 'D3 4 4 16, 'D1 256, 'D1 256, 'D1 80, 'D1 80, 'D1 10, 'D1 10]
37
+
--
38
+
-- /NOTE:/ This model is actually too complex for MNIST, and one should use the type given in the readme instead.
39
+
-- This one is just here to demonstrate Inception layers in use.
40
+
--
41
+
type MNIST =
42
+
Network
43
+
'[ Reshape
44
+
, Inception 28 28 1 5 5 5, Pooling 2 2 2 2, Relu
45
+
, Inception 14 14 15 5 5 5, Pooling 2 2 2 2, Relu
46
+
, Reshape
47
+
, FullyConnected 735 80, Logit
48
+
, FullyConnected 80 10, Logit]
49
+
'[ 'D2 28 28, 'D3 28 28 1
50
+
, 'D3 28 28 15, 'D3 14 14 15, 'D3 14 14 15
51
+
, 'D3 14 14 15, 'D3 7 7 15, 'D3 7 7 15
52
+
, 'D1 735
53
+
, 'D1 80, 'D1 80
54
+
, 'D1 10, 'D1 10]
39
55
40
56
randomMnist :: MonadRandom m => m MNIST
41
57
randomMnist = randomNetwork
main/recurrent.hs
examples/main/recurrent.hs
main/recurrent.hs
examples/main/recurrent.hs
main/shakespeare.hs
examples/main/shakespeare.hs
main/shakespeare.hs
examples/main/shakespeare.hs
+49
-20
src/Grenade.hs
+49
-20
src/Grenade.hs
···
1
1
module Grenade (
2
-
module X
2
+
-- | This is an empty module which simply re-exports public definitions
3
+
-- for machine learning with Grenade.
4
+
5
+
-- * Exported modules
6
+
--
7
+
-- | The core types and runners for Grenade.
8
+
module Grenade.Core
9
+
10
+
-- | The neural network layer zoo
11
+
, module Grenade.Layers
12
+
13
+
14
+
-- * Overview of the library
15
+
-- $library
16
+
17
+
-- * Example usage
18
+
-- $example
19
+
3
20
) where
4
21
5
-
import Grenade.Core.LearningParameters as X
6
-
import Grenade.Core.Layer as X
7
-
import Grenade.Core.Network as X
8
-
import Grenade.Core.Runner as X
9
-
import Grenade.Core.Shape as X
10
-
import Grenade.Layers.Concat as X
11
-
import Grenade.Layers.Crop as X
12
-
import Grenade.Layers.Dropout as X
13
-
import Grenade.Layers.Pad as X
14
-
import Grenade.Layers.Pooling as X
15
-
import Grenade.Layers.Reshape as X
16
-
import Grenade.Layers.FullyConnected as X
17
-
import Grenade.Layers.Logit as X
18
-
import Grenade.Layers.Merge as X
19
-
import Grenade.Layers.Convolution as X
20
-
import Grenade.Layers.Relu as X
21
-
import Grenade.Layers.Elu as X
22
-
import Grenade.Layers.Tanh as X
23
-
import Grenade.Layers.Softmax as X
22
+
import Grenade.Core
23
+
import Grenade.Layers
24
+
25
+
{- $library
26
+
Grenade is a purely functional deep learning library.
27
+
28
+
It provides an expressive type level API for the construction
29
+
of complex neural network architectures. Backing this API is and
30
+
implementation written using BLAS and LAPACK, mostly provided by
31
+
the hmatrix library.
32
+
-}
33
+
34
+
{- $example
35
+
A few examples are provided at https://github.com/HuwCampbell/grenade
36
+
under the examples folder.
37
+
38
+
The starting place is to write your neural network type and a
39
+
function to create a random layer of that type. The following
40
+
is a simple example which runs a logistic regression.
41
+
42
+
> type MyNet = Network '[ FullyConnected 10 1, Logit ] '[ 'D1 10, 'D1 1, 'D1 1 ]
43
+
>
44
+
> randomMyNet :: MonadRandom MyNet
45
+
> randomMyNet = randomNetwork
46
+
47
+
The function `randomMyNet` witnesses the `CreatableNetwork`
48
+
constraint of the neural network, that is it ensures the network
49
+
can be built, and hence, that the architecture is sound.
50
+
-}
51
+
52
+
+10
-5
src/Grenade/Core.hs
+10
-5
src/Grenade/Core.hs
···
1
1
module Grenade.Core (
2
-
module X
2
+
module Grenade.Core.Layer
3
+
, module Grenade.Core.LearningParameters
4
+
, module Grenade.Core.Network
5
+
, module Grenade.Core.Runner
6
+
, module Grenade.Core.Shape
3
7
) where
4
8
5
-
import Grenade.Core.Layer as X
6
-
import Grenade.Core.LearningParameters as X
7
-
import Grenade.Core.Shape as X
8
-
import Grenade.Core.Network as X
9
+
import Grenade.Core.Layer
10
+
import Grenade.Core.LearningParameters
11
+
import Grenade.Core.Network
12
+
import Grenade.Core.Runner
13
+
import Grenade.Core.Shape
+4
src/Grenade/Core/LearningParameters.hs
+4
src/Grenade/Core/LearningParameters.hs
+2
-7
src/Grenade/Core/Network.hs
+2
-7
src/Grenade/Core/Network.hs
···
1
-
{-# LANGUAGE CPP #-}
2
1
{-# LANGUAGE DataKinds #-}
3
2
{-# LANGUAGE BangPatterns #-}
4
3
{-# LANGUAGE GADTs #-}
···
18
17
This module defines the core data types and functions
19
18
for non-recurrent neural networks.
20
19
-}
21
-
22
-
#if __GLASGOW_HASKELL__ < 800
23
-
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
24
-
#endif
25
20
26
21
module Grenade.Core.Network (
27
22
Network (..)
···
47
42
48
43
-- | Type of a network.
49
44
--
50
-
-- The [*] type specifies the types of the layers.
45
+
-- The @[*]@ type specifies the types of the layers.
51
46
--
52
-
-- The [Shape] type specifies the shapes of data passed between the layers.
47
+
-- The @[Shape]@ type specifies the shapes of data passed between the layers.
53
48
--
54
49
-- Can be considered to be a heterogeneous list of layers which are able to
55
50
-- transform the data shapes of the network.
-9
src/Grenade/Core/Shape.hs
-9
src/Grenade/Core/Shape.hs
···
1
-
{-# LANGUAGE CPP #-}
2
1
{-# LANGUAGE DataKinds #-}
3
2
{-# LANGUAGE GADTs #-}
4
3
{-# LANGUAGE KindSignatures #-}
···
8
7
{-# LANGUAGE FlexibleContexts #-}
9
8
{-# LANGUAGE ScopedTypeVariables #-}
10
9
{-# LANGUAGE RankNTypes #-}
11
-
12
-
-- Ghc 7.10 fails to recognise n2 is complete.
13
-
#if __GLASGOW_HASKELL__ < 800
14
-
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
15
-
#endif
16
10
{-|
17
11
Module : Grenade.Core.Shape
18
12
Description : Core definition of the Shapes of data we understand
···
65
59
-- All shapes are held in contiguous memory.
66
60
-- 3D is held in a matrix (usually row oriented) which has height depth * rows.
67
61
data S (n :: Shape) where
68
-
-- | One dimensional data
69
62
S1D :: ( KnownNat len )
70
63
=> R len
71
64
-> S ('D1 len)
72
65
73
-
-- | Two dimensional data
74
66
S2D :: ( KnownNat rows, KnownNat columns )
75
67
=> L rows columns
76
68
-> S ('D2 rows columns)
77
69
78
-
-- | Three dimensional data
79
70
S3D :: ( KnownNat rows
80
71
, KnownNat columns
81
72
, KnownNat depth
+33
src/Grenade/Layers.hs
+33
src/Grenade/Layers.hs
···
1
+
module Grenade.Layers (
2
+
module Grenade.Layers.Concat
3
+
, module Grenade.Layers.Convolution
4
+
, module Grenade.Layers.Crop
5
+
, module Grenade.Layers.Elu
6
+
, module Grenade.Layers.FullyConnected
7
+
, module Grenade.Layers.Inception
8
+
, module Grenade.Layers.Logit
9
+
, module Grenade.Layers.Merge
10
+
, module Grenade.Layers.Pad
11
+
, module Grenade.Layers.Pooling
12
+
, module Grenade.Layers.Reshape
13
+
, module Grenade.Layers.Relu
14
+
, module Grenade.Layers.Softmax
15
+
, module Grenade.Layers.Tanh
16
+
, module Grenade.Layers.Trivial
17
+
) where
18
+
19
+
import Grenade.Layers.Concat
20
+
import Grenade.Layers.Convolution
21
+
import Grenade.Layers.Crop
22
+
import Grenade.Layers.Elu
23
+
import Grenade.Layers.Pad
24
+
import Grenade.Layers.FullyConnected
25
+
import Grenade.Layers.Inception
26
+
import Grenade.Layers.Logit
27
+
import Grenade.Layers.Merge
28
+
import Grenade.Layers.Pooling
29
+
import Grenade.Layers.Reshape
30
+
import Grenade.Layers.Relu
31
+
import Grenade.Layers.Softmax
32
+
import Grenade.Layers.Tanh
33
+
import Grenade.Layers.Trivial
+4
-2
src/Grenade/Layers/Concat.hs
+4
-2
src/Grenade/Layers/Concat.hs
···
9
9
{-# LANGUAGE ScopedTypeVariables #-}
10
10
{-# LANGUAGE StandaloneDeriving #-}
11
11
{-|
12
-
Module : Grenade.Core.Network
13
-
Description : Core definition a simple neural etwork
12
+
Module : Grenade.Layers.Concat
13
+
Description : Concatenation layer
14
14
Copyright : (c) Huw Campbell, 2016-2017
15
15
License : BSD2
16
16
Stability : experimental
17
+
18
+
This module provides the concatenation layer, whic used to run two separate layers in parallel and combine their outputs.
17
19
-}
18
20
module Grenade.Layers.Concat (
19
21
Concat (..)
+26
-12
src/Grenade/Layers/Inception.hs
+26
-12
src/Grenade/Layers/Inception.hs
···
9
9
{-# LANGUAGE ScopedTypeVariables #-}
10
10
{-|
11
11
Module : Grenade.Core.Network
12
-
Description : Core definition a simple neural etwork
12
+
Description : Inception style parallel convolutional network composition.
13
13
Copyright : (c) Huw Campbell, 2016-2017
14
14
License : BSD2
15
15
Stability : experimental
16
+
17
+
Export an Inception style type, which can be used to build up
18
+
complex multiconvolution size networks.
16
19
-}
17
20
module Grenade.Layers.Inception (
18
21
Inception
···
25
28
import Grenade.Layers.Pad
26
29
import Grenade.Layers.Concat
27
30
28
-
31
+
-- | Type of an inception layer.
32
+
--
33
+
-- It looks like a bit of a handful, but is actually pretty easy to use.
34
+
--
35
+
-- The first three type parameters are the size of the (3D) data the
36
+
-- inception layer will take. It will emit 3D data with the number of
37
+
-- channels being the sum of @chx@, @chy@, @chz@, which are the number
38
+
-- of convolution filters in the 3x3, 5x5, and 7x7 convolutions Layers
39
+
-- respectively.
40
+
--
41
+
-- The network get padded effectively before each convolution filters
42
+
-- such that the output dimension is the same x and y as the input.
29
43
type Inception rows cols channels chx chy chz
30
-
= Network '[ Concat ('D3 (rows - 2) (cols - 2) (chx + chy)) (InceptionS rows cols channels chx chy) ('D3 (rows - 2) (cols - 2) chz) (Inception7x7 rows cols channels chz) ]
31
-
'[ 'D3 rows cols channels, 'D3 (rows -2) (cols -2) (chx + chy + chz) ]
44
+
= Network '[ Concat ('D3 rows cols (chx + chy)) (InceptionS rows cols channels chx chy) ('D3 rows cols chz) (Inception7x7 rows cols channels chz) ]
45
+
'[ 'D3 rows cols channels, 'D3 rows cols (chx + chy + chz) ]
32
46
33
47
type InceptionS rows cols channels chx chy
34
-
= Network '[ Concat ('D3 (rows - 2) (cols - 2) chx) (Inception3x3 rows cols channels chx) ('D3 (rows - 2) (cols - 2) chy) (Inception5x5 rows cols channels chy) ]
35
-
'[ 'D3 rows cols channels, 'D3 (rows -2) (cols -2) (chx + chy) ]
48
+
= Network '[ Concat ('D3 rows cols chx) (Inception3x3 rows cols channels chx) ('D3 rows cols chy) (Inception5x5 rows cols channels chy) ]
49
+
'[ 'D3 rows cols channels, 'D3 rows cols (chx + chy) ]
36
50
37
51
type Inception3x3 rows cols channels chx
38
-
= Network '[ Convolution channels chx 3 3 1 1 ]
39
-
'[ 'D3 rows cols channels, 'D3 (rows -2) (cols -2) chx ]
52
+
= Network '[ Pad 1 1 1 1, Convolution channels chx 3 3 1 1 ]
53
+
'[ 'D3 rows cols channels, 'D3 (rows + 2) (cols + 2) channels, 'D3 rows cols chx ]
40
54
41
55
type Inception5x5 rows cols channels chx
42
-
= Network '[ Pad 1 1 1 1, Convolution channels chx 5 5 1 1 ]
43
-
'[ 'D3 rows cols channels, 'D3 (rows + 2) (cols + 2) channels, 'D3 (rows - 2) (cols - 2) chx ]
56
+
= Network '[ Pad 2 2 2 2, Convolution channels chx 5 5 1 1 ]
57
+
'[ 'D3 rows cols channels, 'D3 (rows + 4) (cols + 4) channels, 'D3 rows cols chx ]
44
58
45
59
type Inception7x7 rows cols channels chx
46
-
= Network '[ Pad 2 2 2 2, Convolution channels chx 7 7 1 1 ]
47
-
'[ 'D3 rows cols channels, 'D3 (rows + 4) (cols + 4) channels, 'D3 (rows - 2) (cols - 2) chx ]
60
+
= Network '[ Pad 3 3 3 3, Convolution channels chx 7 7 1 1 ]
61
+
'[ 'D3 rows cols channels, 'D3 (rows + 6) (cols + 6) channels, 'D3 rows cols chx ]
48
62
+3
-2
src/Grenade/Layers/Trivial.hs
+3
-2
src/Grenade/Layers/Trivial.hs
···
9
9
10
10
import Data.Serialize
11
11
12
-
import Grenade.Core.Network
12
+
import Grenade.Core
13
13
14
14
-- | A trivial layer.
15
15
data Trivial = Trivial
···
25
25
createRandom = return Trivial
26
26
27
27
instance (a ~ b) => Layer Trivial a b where
28
-
runForwards _ = id
28
+
type Tape Trivial a b = ()
29
+
runForwards _ a = ((), a)
29
30
runBackwards _ _ y = ((), y)
+4
-6
src/Grenade/Recurrent.hs
+4
-6
src/Grenade/Recurrent.hs
···
1
1
module Grenade.Recurrent (
2
-
module X
2
+
module Grenade.Recurrent.Core
3
+
, module Grenade.Recurrent.Layers
3
4
) where
4
5
5
-
import Grenade.Recurrent.Core.Layer as X
6
-
import Grenade.Recurrent.Core.Network as X
7
-
import Grenade.Recurrent.Core.Runner as X
8
-
import Grenade.Recurrent.Layers.BasicRecurrent as X
9
-
import Grenade.Recurrent.Layers.LSTM as X
6
+
import Grenade.Recurrent.Core
7
+
import Grenade.Recurrent.Layers
+6
-3
src/Grenade/Recurrent/Core.hs
+6
-3
src/Grenade/Recurrent/Core.hs
···
1
1
module Grenade.Recurrent.Core (
2
-
module X
2
+
module Grenade.Recurrent.Core.Layer
3
+
, module Grenade.Recurrent.Core.Network
4
+
, module Grenade.Recurrent.Core.Runner
3
5
) where
4
6
5
-
import Grenade.Recurrent.Core.Layer as X
6
-
import Grenade.Recurrent.Core.Network as X
7
+
import Grenade.Recurrent.Core.Layer
8
+
import Grenade.Recurrent.Core.Network
9
+
import Grenade.Recurrent.Core.Runner
+1
-5
src/Grenade/Recurrent/Core/Network.hs
+1
-5
src/Grenade/Recurrent/Core/Network.hs
···
1
-
{-# LANGUAGE CPP #-}
2
1
{-# LANGUAGE DataKinds #-}
3
2
{-# LANGUAGE GADTs #-}
4
3
{-# LANGUAGE TypeOperators #-}
···
10
9
{-# LANGUAGE RankNTypes #-}
11
10
{-# LANGUAGE BangPatterns #-}
12
11
{-# LANGUAGE ScopedTypeVariables #-}
13
-
14
-
#if __GLASGOW_HASKELL__ < 800
15
-
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
16
-
#endif
12
+
{-# LANGUAGE UndecidableInstances #-}
17
13
18
14
module Grenade.Recurrent.Core.Network (
19
15
Recurrent
-5
src/Grenade/Recurrent/Core/Runner.hs
-5
src/Grenade/Recurrent/Core/Runner.hs
···
3
3
{-# LANGUAGE DataKinds #-}
4
4
{-# LANGUAGE ScopedTypeVariables #-}
5
5
{-# LANGUAGE TypeOperators #-}
6
-
{-# LANGUAGE CPP #-}
7
6
{-# LANGUAGE TypeFamilies #-}
8
7
{-# LANGUAGE FlexibleContexts #-}
9
8
{-# LANGUAGE RankNTypes #-}
10
9
{-# LANGUAGE RecordWildCards #-}
11
-
12
-
#if __GLASGOW_HASKELL__ < 800
13
-
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
14
-
#endif
15
10
16
11
module Grenade.Recurrent.Core.Runner (
17
12
trainRecurrent
+7
src/Grenade/Recurrent/Layers.hs
+7
src/Grenade/Recurrent/Layers.hs
-5
src/Grenade/Recurrent/Layers/BasicRecurrent.hs
-5
src/Grenade/Recurrent/Layers/BasicRecurrent.hs
···
1
-
{-# LANGUAGE CPP #-}
2
1
{-# LANGUAGE DataKinds #-}
3
2
{-# LANGUAGE GADTs #-}
4
3
{-# LANGUAGE RecordWildCards #-}
···
8
7
{-# LANGUAGE FlexibleContexts #-}
9
8
{-# LANGUAGE UndecidableInstances #-}
10
9
11
-
-- GHC 7.10 doesn't see recurrent run functions as total.
12
-
#if __GLASGOW_HASKELL__ < 800
13
-
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
14
-
#endif
15
10
module Grenade.Recurrent.Layers.BasicRecurrent (
16
11
BasicRecurrent (..)
17
12
, randomBasicRecurrent
-6
src/Grenade/Recurrent/Layers/LSTM.hs
-6
src/Grenade/Recurrent/Layers/LSTM.hs
···
1
1
{-# LANGUAGE BangPatterns #-}
2
-
{-# LANGUAGE CPP #-}
3
2
{-# LANGUAGE DataKinds #-}
4
3
{-# LANGUAGE GADTs #-}
5
4
{-# LANGUAGE RankNTypes #-}
···
10
9
{-# LANGUAGE FlexibleContexts #-}
11
10
{-# LANGUAGE ViewPatterns #-}
12
11
{-# LANGUAGE ScopedTypeVariables #-}
13
-
14
-
-- GHC 7.10 doesn't see recurrent run functions as total.
15
-
#if __GLASGOW_HASKELL__ < 800
16
-
{-# OPTIONS_GHC -fno-warn-incomplete-patterns #-}
17
-
#endif
18
12
19
13
module Grenade.Recurrent.Layers.LSTM (
20
14
LSTM (..)
+9
test/Test/Grenade/Layers/PadCrop.hs
+9
test/Test/Grenade/Layers/PadCrop.hs
···
30
30
(_ , grad) = runBackwards net tapes d
31
31
in d ~~~ res .&&. grad ~~~ d
32
32
33
+
prop_pad_crop_2d :: Property
34
+
prop_pad_crop_2d =
35
+
let net :: Network '[Pad 2 3 4 6, Crop 2 3 4 6] '[ 'D2 7 9, 'D2 16 15, 'D2 7 9 ]
36
+
net = Pad :~> Crop :~> NNil
37
+
in gamble genOfShape $ \(d :: S ('D2 7 9)) ->
38
+
let (tapes, res) = runForwards net d
39
+
(_ , grad) = runBackwards net tapes d
40
+
in d ~~~ res .&&. grad ~~~ d
41
+
33
42
(~~~) :: S x -> S x -> Bool
34
43
(S1D x) ~~~ (S1D y) = norm_Inf (x - y) < 0.00001
35
44
(S2D x) ~~~ (S2D y) = norm_Inf (x - y) < 0.00001