馃挘 Machine learning which might blow up in your face 馃挘
at master 197 lines 6.1 kB view raw
1{-# LANGUAGE CPP #-} 2{-# LANGUAGE DataKinds #-} 3{-# LANGUAGE BangPatterns #-} 4{-# LANGUAGE GADTs #-} 5{-# LANGUAGE ScopedTypeVariables #-} 6{-# LANGUAGE TypeOperators #-} 7{-# LANGUAGE TypeFamilies #-} 8{-# LANGUAGE MultiParamTypeClasses #-} 9{-# LANGUAGE FlexibleContexts #-} 10{-# LANGUAGE FlexibleInstances #-} 11{-| 12Module : Grenade.Core.Network 13Description : Core definition of a Neural Network 14Copyright : (c) Huw Campbell, 2016-2017 15License : BSD2 16Stability : experimental 17 18This module defines the core data types and functions 19for non-recurrent neural networks. 20-} 21 22module Grenade.Core.Network ( 23 Network (..) 24 , Gradients (..) 25 , Tapes (..) 26 27 , runNetwork 28 , runGradient 29 , applyUpdate 30 31 , randomNetwork 32 ) where 33 34import Control.Monad.Random ( MonadRandom ) 35 36import Data.Singletons 37import Data.Serialize 38 39#if MIN_VERSION_base(4,9,0) 40import Data.Kind (Type) 41#endif 42 43import Grenade.Core.Layer 44import Grenade.Core.LearningParameters 45import Grenade.Core.Shape 46import Prelude.Singletons 47 48-- | Type of a network. 49-- 50-- The @[*]@ type specifies the types of the layers. 51-- 52-- The @[Shape]@ type specifies the shapes of data passed between the layers. 53-- 54-- Can be considered to be a heterogeneous list of layers which are able to 55-- transform the data shapes of the network. 56data Network :: [Type] -> [Shape] -> Type where 57 NNil :: SingI i 58 => Network '[] '[i] 59 60 (:~>) :: (SingI i, SingI h, Layer x i h) 61 => !x 62 -> !(Network xs (h ': hs)) 63 -> Network (x ': xs) (i ': h ': hs) 64infixr 5 :~> 65 66instance Show (Network '[] '[i]) where 67 show NNil = "NNil" 68instance (Show x, Show (Network xs rs)) => Show (Network (x ': xs) (i ': rs)) where 69 show (x :~> xs) = show x ++ "\n~>\n" ++ show xs 70 71-- | Gradient of a network. 72-- 73-- Parameterised on the layers of the network. 74data Gradients :: [Type] -> Type where 75 GNil :: Gradients '[] 76 77 (:/>) :: UpdateLayer x 78 => Gradient x 79 -> Gradients xs 80 -> Gradients (x ': xs) 81 82-- | Wegnert Tape of a network. 83-- 84-- Parameterised on the layers and shapes of the network. 85data Tapes :: [Type] -> [Shape] -> Type where 86 TNil :: SingI i 87 => Tapes '[] '[i] 88 89 (:\>) :: (SingI i, SingI h, Layer x i h) 90 => !(Tape x i h) 91 -> !(Tapes xs (h ': hs)) 92 -> Tapes (x ': xs) (i ': h ': hs) 93 94 95-- | Running a network forwards with some input data. 96-- 97-- This gives the output, and the Wengert tape required for back 98-- propagation. 99runNetwork :: forall layers shapes. 100 Network layers shapes 101 -> S (Head shapes) 102 -> (Tapes layers shapes, S (Last shapes)) 103runNetwork = 104 go 105 where 106 go :: forall js ss. (Last js ~ Last shapes) 107 => Network ss js 108 -> S (Head js) 109 -> (Tapes ss js, S (Last js)) 110 go (layer :~> n) !x = 111 let (tape, forward) = runForwards layer x 112 (tapes, answer) = go n forward 113 in (tape :\> tapes, answer) 114 115 go NNil !x 116 = (TNil, x) 117 118 119-- | Running a loss gradient back through the network. 120-- 121-- This requires a Wengert tape, generated with the appropriate input 122-- for the loss. 123-- 124-- Gives the gradients for the layer, and the gradient across the 125-- input (which may not be required). 126runGradient :: forall layers shapes. 127 Network layers shapes 128 -> Tapes layers shapes 129 -> S (Last shapes) 130 -> (Gradients layers, S (Head shapes)) 131runGradient net tapes o = 132 go net tapes 133 where 134 go :: forall js ss. (Last js ~ Last shapes) 135 => Network ss js 136 -> Tapes ss js 137 -> (Gradients ss, S (Head js)) 138 go (layer :~> n) (tape :\> nt) = 139 let (gradients, feed) = go n nt 140 (layer', backGrad) = runBackwards layer tape feed 141 in (layer' :/> gradients, backGrad) 142 143 go NNil TNil 144 = (GNil, o) 145 146 147-- | Apply one step of stochastic gradient descent across the network. 148applyUpdate :: LearningParameters 149 -> Network layers shapes 150 -> Gradients layers 151 -> Network layers shapes 152applyUpdate rate (layer :~> rest) (gradient :/> grest) 153 = runUpdate rate layer gradient :~> applyUpdate rate rest grest 154 155applyUpdate _ NNil GNil 156 = NNil 157 158-- | A network can easily be created by hand with (:~>), but an easy way to 159-- initialise a random network is with the randomNetwork. 160class CreatableNetwork (xs :: [Type]) (ss :: [Shape]) where 161 -- | Create a network with randomly initialised weights. 162 -- 163 -- Calls to this function will not compile if the type of the neural 164 -- network is not sound. 165 randomNetwork :: MonadRandom m => m (Network xs ss) 166 167instance SingI i => CreatableNetwork '[] '[i] where 168 randomNetwork = return NNil 169 170instance (SingI i, SingI o, Layer x i o, CreatableNetwork xs (o ': rs)) => CreatableNetwork (x ': xs) (i ': o ': rs) where 171 randomNetwork = (:~>) <$> createRandom <*> randomNetwork 172 173-- | Add very simple serialisation to the network 174instance SingI i => Serialize (Network '[] '[i]) where 175 put NNil = pure () 176 get = return NNil 177 178instance (SingI i, SingI o, Layer x i o, Serialize x, Serialize (Network xs (o ': rs))) => Serialize (Network (x ': xs) (i ': o ': rs)) where 179 put (x :~> r) = put x >> put r 180 get = (:~>) <$> get <*> get 181 182 183-- | Ultimate composition. 184-- 185-- This allows a complete network to be treated as a layer in a larger network. 186instance CreatableNetwork sublayers subshapes => UpdateLayer (Network sublayers subshapes) where 187 type Gradient (Network sublayers subshapes) = Gradients sublayers 188 runUpdate = applyUpdate 189 createRandom = randomNetwork 190 191-- | Ultimate composition. 192-- 193-- This allows a complete network to be treated as a layer in a larger network. 194instance (CreatableNetwork sublayers subshapes, i ~ (Head subshapes), o ~ (Last subshapes)) => Layer (Network sublayers subshapes) i o where 195 type Tape (Network sublayers subshapes) i o = Tapes sublayers subshapes 196 runForwards = runNetwork 197 runBackwards = runGradient