馃挘 Machine learning which might blow up in your face 馃挘
at master 218 lines 6.2 kB view raw
1{-# LANGUAGE CPP #-} 2{-# LANGUAGE DataKinds #-} 3{-# LANGUAGE GADTs #-} 4{-# LANGUAGE KindSignatures #-} 5{-# LANGUAGE TypeFamilies #-} 6{-# LANGUAGE TypeOperators #-} 7{-# LANGUAGE StandaloneDeriving #-} 8{-# LANGUAGE FlexibleContexts #-} 9{-# LANGUAGE ScopedTypeVariables #-} 10{-# LANGUAGE RankNTypes #-} 11{-# LANGUAGE UndecidableInstances #-} 12{-| 13Module : Grenade.Core.Shape 14Description : Dependently typed shapes of data which are passed between layers of a network 15Copyright : (c) Huw Campbell, 2016-2017 16License : BSD2 17Stability : experimental 18 19 20-} 21module Grenade.Core.Shape ( 22 S (..) 23 , Shape (..) 24 , Sing (..) 25 , SShape (..) 26 , randomOfShape 27 , fromStorable 28 ) where 29 30import Control.DeepSeq (NFData (..)) 31import Control.Monad.Random ( MonadRandom, getRandom ) 32import Data.Kind (Type) 33import Data.Proxy 34import Data.Serialize 35import Data.Singletons 36import Data.Vector.Storable ( Vector ) 37import qualified Data.Vector.Storable as V 38import GHC.TypeLits 39import qualified Numeric.LinearAlgebra.Static as H 40import Numeric.LinearAlgebra.Static 41import qualified Numeric.LinearAlgebra as NLA 42 43-- | The current shapes we accept. 44-- at the moment this is just one, two, and three dimensional 45-- Vectors/Matricies. 46-- 47-- These are only used with DataKinds, as Kind `Shape`, with Types 'D1, 'D2, 'D3. 48data Shape 49 = D1 Nat 50 -- ^ One dimensional vector 51 | D2 Nat Nat 52 -- ^ Two dimensional matrix. Row, Column. 53 | D3 Nat Nat Nat 54 -- ^ Three dimensional matrix. Row, Column, Channels. 55 56-- | Concrete data structures for a Shape. 57-- 58-- All shapes are held in contiguous memory. 59-- 3D is held in a matrix (usually row oriented) which has height depth * rows. 60data S (n :: Shape) where 61 S1D :: ( KnownNat len ) 62 => R len 63 -> S ('D1 len) 64 65 S2D :: ( KnownNat rows, KnownNat columns ) 66 => L rows columns 67 -> S ('D2 rows columns) 68 69 S3D :: ( KnownNat rows 70 , KnownNat columns 71 , KnownNat depth 72 , KnownNat (rows * depth)) 73 => L (rows * depth) columns 74 -> S ('D3 rows columns depth) 75 76deriving instance Show (S n) 77 78-- Singleton instances. 79-- 80-- These could probably be derived with template haskell, but this seems 81-- clear and makes adding the KnownNat constraints simple. 82-- We can also keep our code TH free, which is great. 83#if MIN_VERSION_singletons(2,6,0) 84-- In singletons 2.6 Sing switched from a data family to a type family. 85type instance Sing = SShape 86 87data SShape :: Shape -> Type where 88 D1Sing :: KnownNat a => SShape ('D1 a) 89 D2Sing :: (KnownNat a, KnownNat b) => SShape ('D2 a b) 90 D3Sing :: (KnownNat (a * c), KnownNat a, KnownNat b, KnownNat c) => SShape ('D3 a b c) 91#else 92data instance Sing (n :: Shape) where 93 D1Sing :: Sing a -> Sing ('D1 a) 94 D2Sing :: Sing a -> Sing b -> Sing ('D2 a b) 95 D3Sing :: KnownNat (a * c) => Sing a -> Sing b -> Sing c -> Sing ('D3 a b c) 96#endif 97 98instance KnownNat a => SingI ('D1 a) where 99 sing = D1Sing 100instance (KnownNat a, KnownNat b) => SingI ('D2 a b) where 101 sing = D2Sing 102instance (KnownNat a, KnownNat b, KnownNat c, KnownNat (a * c)) => SingI ('D3 a b c) where 103 sing = D3Sing 104 105instance SingI x => Num (S x) where 106 (+) = n2 (+) 107 (-) = n2 (-) 108 (*) = n2 (*) 109 abs = n1 abs 110 signum = n1 signum 111 fromInteger x = nk (fromInteger x) 112 113instance SingI x => Fractional (S x) where 114 (/) = n2 (/) 115 recip = n1 recip 116 fromRational x = nk (fromRational x) 117 118instance SingI x => Floating (S x) where 119 pi = nk pi 120 exp = n1 exp 121 log = n1 log 122 sqrt = n1 sqrt 123 (**) = n2 (**) 124 logBase = n2 logBase 125 sin = n1 sin 126 cos = n1 cos 127 tan = n1 tan 128 asin = n1 asin 129 acos = n1 acos 130 atan = n1 atan 131 sinh = n1 sinh 132 cosh = n1 cosh 133 tanh = n1 tanh 134 asinh = n1 asinh 135 acosh = n1 acosh 136 atanh = n1 atanh 137 138-- 139-- I haven't made shapes strict, as sometimes they're not needed 140-- (the last input gradient back for instance) 141-- 142instance NFData (S x) where 143 rnf (S1D x) = rnf x 144 rnf (S2D x) = rnf x 145 rnf (S3D x) = rnf x 146 147-- | Generate random data of the desired shape 148randomOfShape :: forall x m. ( MonadRandom m, SingI x ) => m (S x) 149randomOfShape = do 150 seed :: Int <- getRandom 151 return $ case (sing :: Sing x) of 152 D1Sing -> 153 S1D (randomVector seed Uniform * 2 - 1) 154 155 D2Sing -> 156 S2D (uniformSample seed (-1) 1) 157 158 D3Sing -> 159 S3D (uniformSample seed (-1) 1) 160 161-- | Generate a shape from a Storable Vector. 162-- 163-- Returns Nothing if the vector is of the wrong size. 164fromStorable :: forall x. SingI x => Vector Double -> Maybe (S x) 165fromStorable xs = case sing :: Sing x of 166 D1Sing -> 167 S1D <$> H.create xs 168 169 D2Sing -> 170 S2D <$> mkL xs 171 172 D3Sing -> 173 S3D <$> mkL xs 174 where 175 mkL :: forall rows columns. (KnownNat rows, KnownNat columns) 176 => Vector Double -> Maybe (L rows columns) 177 mkL v = 178 let rows = fromIntegral $ natVal (Proxy :: Proxy rows) 179 columns = fromIntegral $ natVal (Proxy :: Proxy columns) 180 in if rows * columns == V.length v 181 then H.create $ NLA.reshape columns v 182 else Nothing 183 184 185instance SingI x => Serialize (S x) where 186 put i = (case i of 187 (S1D x) -> putListOf put . NLA.toList . H.extract $ x 188 (S2D x) -> putListOf put . NLA.toList . NLA.flatten . H.extract $ x 189 (S3D x) -> putListOf put . NLA.toList . NLA.flatten . H.extract $ x 190 ) :: PutM () 191 192 get = do 193 Just i <- fromStorable . V.fromList <$> getListOf get 194 return i 195 196-- Helper function for creating the number instances 197n1 :: ( forall a. Floating a => a -> a ) -> S x -> S x 198n1 f (S1D x) = S1D (f x) 199n1 f (S2D x) = S2D (f x) 200n1 f (S3D x) = S3D (f x) 201 202-- Helper function for creating the number instances 203n2 :: ( forall a. Floating a => a -> a -> a ) -> S x -> S x -> S x 204n2 f (S1D x) (S1D y) = S1D (f x y) 205n2 f (S2D x) (S2D y) = S2D (f x y) 206n2 f (S3D x) (S3D y) = S3D (f x y) 207 208-- Helper function for creating the number instances 209nk :: forall x. (SingI x) => Double -> S x 210nk x = case (sing :: Sing x) of 211 D1Sing -> 212 S1D (konst x) 213 214 D2Sing -> 215 S2D (konst x) 216 217 D3Sing -> 218 S3D (konst x)