馃挘 Machine learning which might blow up in your face 馃挘
1{-# LANGUAGE CPP #-}
2{-# LANGUAGE DataKinds #-}
3{-# LANGUAGE BangPatterns #-}
4{-# LANGUAGE GADTs #-}
5{-# LANGUAGE ScopedTypeVariables #-}
6{-# LANGUAGE TypeOperators #-}
7{-# LANGUAGE TypeFamilies #-}
8{-# LANGUAGE MultiParamTypeClasses #-}
9{-# LANGUAGE FlexibleContexts #-}
10{-# LANGUAGE FlexibleInstances #-}
11{-|
12Module : Grenade.Core.Network
13Description : Core definition of a Neural Network
14Copyright : (c) Huw Campbell, 2016-2017
15License : BSD2
16Stability : experimental
17
18This module defines the core data types and functions
19for non-recurrent neural networks.
20-}
21
22module Grenade.Core.Network (
23 Network (..)
24 , Gradients (..)
25 , Tapes (..)
26
27 , runNetwork
28 , runGradient
29 , applyUpdate
30
31 , randomNetwork
32 ) where
33
34import Control.Monad.Random ( MonadRandom )
35
36import Data.Singletons
37import Data.Serialize
38
39#if MIN_VERSION_base(4,9,0)
40import Data.Kind (Type)
41#endif
42
43import Grenade.Core.Layer
44import Grenade.Core.LearningParameters
45import Grenade.Core.Shape
46import Prelude.Singletons
47
48-- | Type of a network.
49--
50-- The @[*]@ type specifies the types of the layers.
51--
52-- The @[Shape]@ type specifies the shapes of data passed between the layers.
53--
54-- Can be considered to be a heterogeneous list of layers which are able to
55-- transform the data shapes of the network.
56data Network :: [Type] -> [Shape] -> Type where
57 NNil :: SingI i
58 => Network '[] '[i]
59
60 (:~>) :: (SingI i, SingI h, Layer x i h)
61 => !x
62 -> !(Network xs (h ': hs))
63 -> Network (x ': xs) (i ': h ': hs)
64infixr 5 :~>
65
66instance Show (Network '[] '[i]) where
67 show NNil = "NNil"
68instance (Show x, Show (Network xs rs)) => Show (Network (x ': xs) (i ': rs)) where
69 show (x :~> xs) = show x ++ "\n~>\n" ++ show xs
70
71-- | Gradient of a network.
72--
73-- Parameterised on the layers of the network.
74data Gradients :: [Type] -> Type where
75 GNil :: Gradients '[]
76
77 (:/>) :: UpdateLayer x
78 => Gradient x
79 -> Gradients xs
80 -> Gradients (x ': xs)
81
82-- | Wegnert Tape of a network.
83--
84-- Parameterised on the layers and shapes of the network.
85data Tapes :: [Type] -> [Shape] -> Type where
86 TNil :: SingI i
87 => Tapes '[] '[i]
88
89 (:\>) :: (SingI i, SingI h, Layer x i h)
90 => !(Tape x i h)
91 -> !(Tapes xs (h ': hs))
92 -> Tapes (x ': xs) (i ': h ': hs)
93
94
95-- | Running a network forwards with some input data.
96--
97-- This gives the output, and the Wengert tape required for back
98-- propagation.
99runNetwork :: forall layers shapes.
100 Network layers shapes
101 -> S (Head shapes)
102 -> (Tapes layers shapes, S (Last shapes))
103runNetwork =
104 go
105 where
106 go :: forall js ss. (Last js ~ Last shapes)
107 => Network ss js
108 -> S (Head js)
109 -> (Tapes ss js, S (Last js))
110 go (layer :~> n) !x =
111 let (tape, forward) = runForwards layer x
112 (tapes, answer) = go n forward
113 in (tape :\> tapes, answer)
114
115 go NNil !x
116 = (TNil, x)
117
118
119-- | Running a loss gradient back through the network.
120--
121-- This requires a Wengert tape, generated with the appropriate input
122-- for the loss.
123--
124-- Gives the gradients for the layer, and the gradient across the
125-- input (which may not be required).
126runGradient :: forall layers shapes.
127 Network layers shapes
128 -> Tapes layers shapes
129 -> S (Last shapes)
130 -> (Gradients layers, S (Head shapes))
131runGradient net tapes o =
132 go net tapes
133 where
134 go :: forall js ss. (Last js ~ Last shapes)
135 => Network ss js
136 -> Tapes ss js
137 -> (Gradients ss, S (Head js))
138 go (layer :~> n) (tape :\> nt) =
139 let (gradients, feed) = go n nt
140 (layer', backGrad) = runBackwards layer tape feed
141 in (layer' :/> gradients, backGrad)
142
143 go NNil TNil
144 = (GNil, o)
145
146
147-- | Apply one step of stochastic gradient descent across the network.
148applyUpdate :: LearningParameters
149 -> Network layers shapes
150 -> Gradients layers
151 -> Network layers shapes
152applyUpdate rate (layer :~> rest) (gradient :/> grest)
153 = runUpdate rate layer gradient :~> applyUpdate rate rest grest
154
155applyUpdate _ NNil GNil
156 = NNil
157
158-- | A network can easily be created by hand with (:~>), but an easy way to
159-- initialise a random network is with the randomNetwork.
160class CreatableNetwork (xs :: [Type]) (ss :: [Shape]) where
161 -- | Create a network with randomly initialised weights.
162 --
163 -- Calls to this function will not compile if the type of the neural
164 -- network is not sound.
165 randomNetwork :: MonadRandom m => m (Network xs ss)
166
167instance SingI i => CreatableNetwork '[] '[i] where
168 randomNetwork = return NNil
169
170instance (SingI i, SingI o, Layer x i o, CreatableNetwork xs (o ': rs)) => CreatableNetwork (x ': xs) (i ': o ': rs) where
171 randomNetwork = (:~>) <$> createRandom <*> randomNetwork
172
173-- | Add very simple serialisation to the network
174instance SingI i => Serialize (Network '[] '[i]) where
175 put NNil = pure ()
176 get = return NNil
177
178instance (SingI i, SingI o, Layer x i o, Serialize x, Serialize (Network xs (o ': rs))) => Serialize (Network (x ': xs) (i ': o ': rs)) where
179 put (x :~> r) = put x >> put r
180 get = (:~>) <$> get <*> get
181
182
183-- | Ultimate composition.
184--
185-- This allows a complete network to be treated as a layer in a larger network.
186instance CreatableNetwork sublayers subshapes => UpdateLayer (Network sublayers subshapes) where
187 type Gradient (Network sublayers subshapes) = Gradients sublayers
188 runUpdate = applyUpdate
189 createRandom = randomNetwork
190
191-- | Ultimate composition.
192--
193-- This allows a complete network to be treated as a layer in a larger network.
194instance (CreatableNetwork sublayers subshapes, i ~ (Head subshapes), o ~ (Last subshapes)) => Layer (Network sublayers subshapes) i o where
195 type Tape (Network sublayers subshapes) i o = Tapes sublayers subshapes
196 runForwards = runNetwork
197 runBackwards = runGradient