馃挘 Machine learning which might blow up in your face 馃挘
1{-# LANGUAGE CPP #-}
2{-# LANGUAGE DataKinds #-}
3{-# LANGUAGE GADTs #-}
4{-# LANGUAGE KindSignatures #-}
5{-# LANGUAGE TypeFamilies #-}
6{-# LANGUAGE TypeOperators #-}
7{-# LANGUAGE StandaloneDeriving #-}
8{-# LANGUAGE FlexibleContexts #-}
9{-# LANGUAGE ScopedTypeVariables #-}
10{-# LANGUAGE RankNTypes #-}
11{-# LANGUAGE UndecidableInstances #-}
12{-|
13Module : Grenade.Core.Shape
14Description : Dependently typed shapes of data which are passed between layers of a network
15Copyright : (c) Huw Campbell, 2016-2017
16License : BSD2
17Stability : experimental
18
19
20-}
21module Grenade.Core.Shape (
22 S (..)
23 , Shape (..)
24 , Sing (..)
25 , SShape (..)
26 , randomOfShape
27 , fromStorable
28 ) where
29
30import Control.DeepSeq (NFData (..))
31import Control.Monad.Random ( MonadRandom, getRandom )
32import Data.Kind (Type)
33import Data.Proxy
34import Data.Serialize
35import Data.Singletons
36import Data.Vector.Storable ( Vector )
37import qualified Data.Vector.Storable as V
38import GHC.TypeLits
39import qualified Numeric.LinearAlgebra.Static as H
40import Numeric.LinearAlgebra.Static
41import qualified Numeric.LinearAlgebra as NLA
42
43-- | The current shapes we accept.
44-- at the moment this is just one, two, and three dimensional
45-- Vectors/Matricies.
46--
47-- These are only used with DataKinds, as Kind `Shape`, with Types 'D1, 'D2, 'D3.
48data Shape
49 = D1 Nat
50 -- ^ One dimensional vector
51 | D2 Nat Nat
52 -- ^ Two dimensional matrix. Row, Column.
53 | D3 Nat Nat Nat
54 -- ^ Three dimensional matrix. Row, Column, Channels.
55
56-- | Concrete data structures for a Shape.
57--
58-- All shapes are held in contiguous memory.
59-- 3D is held in a matrix (usually row oriented) which has height depth * rows.
60data S (n :: Shape) where
61 S1D :: ( KnownNat len )
62 => R len
63 -> S ('D1 len)
64
65 S2D :: ( KnownNat rows, KnownNat columns )
66 => L rows columns
67 -> S ('D2 rows columns)
68
69 S3D :: ( KnownNat rows
70 , KnownNat columns
71 , KnownNat depth
72 , KnownNat (rows * depth))
73 => L (rows * depth) columns
74 -> S ('D3 rows columns depth)
75
76deriving instance Show (S n)
77
78-- Singleton instances.
79--
80-- These could probably be derived with template haskell, but this seems
81-- clear and makes adding the KnownNat constraints simple.
82-- We can also keep our code TH free, which is great.
83#if MIN_VERSION_singletons(2,6,0)
84-- In singletons 2.6 Sing switched from a data family to a type family.
85type instance Sing = SShape
86
87data SShape :: Shape -> Type where
88 D1Sing :: KnownNat a => SShape ('D1 a)
89 D2Sing :: (KnownNat a, KnownNat b) => SShape ('D2 a b)
90 D3Sing :: (KnownNat (a * c), KnownNat a, KnownNat b, KnownNat c) => SShape ('D3 a b c)
91#else
92data instance Sing (n :: Shape) where
93 D1Sing :: Sing a -> Sing ('D1 a)
94 D2Sing :: Sing a -> Sing b -> Sing ('D2 a b)
95 D3Sing :: KnownNat (a * c) => Sing a -> Sing b -> Sing c -> Sing ('D3 a b c)
96#endif
97
98instance KnownNat a => SingI ('D1 a) where
99 sing = D1Sing
100instance (KnownNat a, KnownNat b) => SingI ('D2 a b) where
101 sing = D2Sing
102instance (KnownNat a, KnownNat b, KnownNat c, KnownNat (a * c)) => SingI ('D3 a b c) where
103 sing = D3Sing
104
105instance SingI x => Num (S x) where
106 (+) = n2 (+)
107 (-) = n2 (-)
108 (*) = n2 (*)
109 abs = n1 abs
110 signum = n1 signum
111 fromInteger x = nk (fromInteger x)
112
113instance SingI x => Fractional (S x) where
114 (/) = n2 (/)
115 recip = n1 recip
116 fromRational x = nk (fromRational x)
117
118instance SingI x => Floating (S x) where
119 pi = nk pi
120 exp = n1 exp
121 log = n1 log
122 sqrt = n1 sqrt
123 (**) = n2 (**)
124 logBase = n2 logBase
125 sin = n1 sin
126 cos = n1 cos
127 tan = n1 tan
128 asin = n1 asin
129 acos = n1 acos
130 atan = n1 atan
131 sinh = n1 sinh
132 cosh = n1 cosh
133 tanh = n1 tanh
134 asinh = n1 asinh
135 acosh = n1 acosh
136 atanh = n1 atanh
137
138--
139-- I haven't made shapes strict, as sometimes they're not needed
140-- (the last input gradient back for instance)
141--
142instance NFData (S x) where
143 rnf (S1D x) = rnf x
144 rnf (S2D x) = rnf x
145 rnf (S3D x) = rnf x
146
147-- | Generate random data of the desired shape
148randomOfShape :: forall x m. ( MonadRandom m, SingI x ) => m (S x)
149randomOfShape = do
150 seed :: Int <- getRandom
151 return $ case (sing :: Sing x) of
152 D1Sing ->
153 S1D (randomVector seed Uniform * 2 - 1)
154
155 D2Sing ->
156 S2D (uniformSample seed (-1) 1)
157
158 D3Sing ->
159 S3D (uniformSample seed (-1) 1)
160
161-- | Generate a shape from a Storable Vector.
162--
163-- Returns Nothing if the vector is of the wrong size.
164fromStorable :: forall x. SingI x => Vector Double -> Maybe (S x)
165fromStorable xs = case sing :: Sing x of
166 D1Sing ->
167 S1D <$> H.create xs
168
169 D2Sing ->
170 S2D <$> mkL xs
171
172 D3Sing ->
173 S3D <$> mkL xs
174 where
175 mkL :: forall rows columns. (KnownNat rows, KnownNat columns)
176 => Vector Double -> Maybe (L rows columns)
177 mkL v =
178 let rows = fromIntegral $ natVal (Proxy :: Proxy rows)
179 columns = fromIntegral $ natVal (Proxy :: Proxy columns)
180 in if rows * columns == V.length v
181 then H.create $ NLA.reshape columns v
182 else Nothing
183
184
185instance SingI x => Serialize (S x) where
186 put i = (case i of
187 (S1D x) -> putListOf put . NLA.toList . H.extract $ x
188 (S2D x) -> putListOf put . NLA.toList . NLA.flatten . H.extract $ x
189 (S3D x) -> putListOf put . NLA.toList . NLA.flatten . H.extract $ x
190 ) :: PutM ()
191
192 get = do
193 Just i <- fromStorable . V.fromList <$> getListOf get
194 return i
195
196-- Helper function for creating the number instances
197n1 :: ( forall a. Floating a => a -> a ) -> S x -> S x
198n1 f (S1D x) = S1D (f x)
199n1 f (S2D x) = S2D (f x)
200n1 f (S3D x) = S3D (f x)
201
202-- Helper function for creating the number instances
203n2 :: ( forall a. Floating a => a -> a -> a ) -> S x -> S x -> S x
204n2 f (S1D x) (S1D y) = S1D (f x y)
205n2 f (S2D x) (S2D y) = S2D (f x y)
206n2 f (S3D x) (S3D y) = S3D (f x y)
207
208-- Helper function for creating the number instances
209nk :: forall x. (SingI x) => Double -> S x
210nk x = case (sing :: Sing x) of
211 D1Sing ->
212 S1D (konst x)
213
214 D2Sing ->
215 S2D (konst x)
216
217 D3Sing ->
218 S3D (konst x)