···1414import Data.List ( foldl' )
1515import Data.Maybe ( fromMaybe )
16161717-#if ! MIN_VERSION_base(4,13,0)
1818-import Data.Semigroup ( (<>) )
1919-#endif
20172118import qualified Data.Vector as V
2219import Data.Vector ( Vector )
23202421import qualified Data.Map as M
2525-#if ! MIN_VERSION_base(4,13,0)
2626-import Data.Proxy ( Proxy (..) )
2727-#endif
28222923import qualified Data.ByteString as B
3024import Data.Serialize
31253232-import Data.Singletons.Prelude
3326import GHC.TypeLits
34273528import Numeric.LinearAlgebra.Static ( konst )
···4134import Grenade.Utils.OneHot
42354336import System.IO.Unsafe ( unsafeInterleaveIO )
3737+import Data.Proxy
3838+import Prelude.Singletons
44394540-- The defininition for our natural language recurrent network.
4641-- This network is able to learn and generate simple words in
+2-1
grenade.cabal
···3838library
3939 build-depends:
4040 base >= 4.8 && < 5
4141- , bytestring == 0.10.*
4141+ , bytestring >= 0.10.0
4242 , containers >= 0.5 && < 0.7
4343 , cereal >= 0.5 && < 0.6
4444 , deepseq >= 1.4 && < 1.5
···4848 -- Versions of singletons are *tightly* coupled with the
4949 -- GHC version so its fine to drop version bounds.
5050 , singletons
5151+ , singletons-base
5152 , vector >= 0.11 && < 0.13
52535354 ghc-options:
+1-1
src/Grenade/Core/Network.hs
···3434import Control.Monad.Random ( MonadRandom )
35353636import Data.Singletons
3737-import Data.Singletons.Prelude
3837import Data.Serialize
39384039#if MIN_VERSION_base(4,9,0)
···4443import Grenade.Core.Layer
4544import Grenade.Core.LearningParameters
4645import Grenade.Core.Shape
4646+import Prelude.Singletons
47474848-- | Type of a network.
4949--
+2-1
src/Grenade/Core/Runner.hs
···1313 , runNet
1414 ) where
15151616-import Data.Singletons.Prelude
17161817import Grenade.Core.LearningParameters
1918import Grenade.Core.Network
2019import Grenade.Core.Shape
2020+import Data.Singletons
2121+import Prelude.Singletons
21222223-- | Perform reverse automatic differentiation on the network
2324-- for the current input and expected output.
+16-31
src/Grenade/Core/Shape.hs
···2121module Grenade.Core.Shape (
2222 S (..)
2323 , Shape (..)
2424-#if MIN_VERSION_singletons(2,6,0)
2524 , SShape (..)
2626-#else
2727- , Sing (..)
2828-#endif
2929-3025 , randomOfShape
3126 , fromStorable
3227 ) where
33283429import Control.DeepSeq (NFData (..))
3530import Control.Monad.Random ( MonadRandom, getRandom )
3636-3737-#if MIN_VERSION_base(4,13,0)
3831import Data.Kind (Type)
3939-#endif
4032import Data.Proxy
4133import Data.Serialize
4234import Data.Singletons
4343-import Data.Singletons.TypeLits
4435import Data.Vector.Storable ( Vector )
4536import qualified Data.Vector.Storable as V
4646-4747-#if MIN_VERSION_base(4,11,0)
4848-import GHC.TypeLits hiding (natVal)
4949-#else
5037import GHC.TypeLits
5151-#endif
5252-5338import qualified Numeric.LinearAlgebra.Static as H
5439import Numeric.LinearAlgebra.Static
5540import qualified Numeric.LinearAlgebra as NLA
···9984type instance Sing = SShape
1008510186data SShape :: Shape -> Type where
102102- D1Sing :: Sing a -> SShape ('D1 a)
103103- D2Sing :: Sing a -> Sing b -> SShape ('D2 a b)
104104- D3Sing :: KnownNat (a * c) => Sing a -> Sing b -> Sing c -> SShape ('D3 a b c)
8787+ D1Sing :: KnownNat a => SShape ('D1 a)
8888+ D2Sing :: (KnownNat a, KnownNat b) => SShape ('D2 a b)
8989+ D3Sing :: (KnownNat (a * c), KnownNat a, KnownNat b, KnownNat c) => SShape ('D3 a b c)
10590#else
10691data instance Sing (n :: Shape) where
10792 D1Sing :: Sing a -> Sing ('D1 a)
···11095#endif
1119611297instance KnownNat a => SingI ('D1 a) where
113113- sing = D1Sing sing
9898+ sing = D1Sing
11499instance (KnownNat a, KnownNat b) => SingI ('D2 a b) where
115115- sing = D2Sing sing sing
100100+ sing = D2Sing
116101instance (KnownNat a, KnownNat b, KnownNat c, KnownNat (a * c)) => SingI ('D3 a b c) where
117117- sing = D3Sing sing sing sing
102102+ sing = D3Sing
118103119104instance SingI x => Num (S x) where
120105 (+) = n2 (+)
···163148randomOfShape = do
164149 seed :: Int <- getRandom
165150 return $ case (sing :: Sing x) of
166166- D1Sing SNat ->
151151+ D1Sing ->
167152 S1D (randomVector seed Uniform * 2 - 1)
168153169169- D2Sing SNat SNat ->
154154+ D2Sing ->
170155 S2D (uniformSample seed (-1) 1)
171156172172- D3Sing SNat SNat SNat ->
157157+ D3Sing ->
173158 S3D (uniformSample seed (-1) 1)
174159175160-- | Generate a shape from a Storable Vector.
···177162-- Returns Nothing if the vector is of the wrong size.
178163fromStorable :: forall x. SingI x => Vector Double -> Maybe (S x)
179164fromStorable xs = case sing :: Sing x of
180180- D1Sing SNat ->
165165+ D1Sing ->
181166 S1D <$> H.create xs
182167183183- D2Sing SNat SNat ->
168168+ D2Sing ->
184169 S2D <$> mkL xs
185170186186- D3Sing SNat SNat SNat ->
171171+ D3Sing ->
187172 S3D <$> mkL xs
188173 where
189174 mkL :: forall rows columns. (KnownNat rows, KnownNat columns)
···220205n2 f (S3D x) (S3D y) = S3D (f x y)
221206222207-- Helper function for creating the number instances
223223-nk :: forall x. SingI x => Double -> S x
208208+nk :: forall x. (SingI x) => Double -> S x
224209nk x = case (sing :: Sing x) of
225225- D1Sing SNat ->
210210+ D1Sing ->
226211 S1D (konst x)
227212228228- D2Sing SNat SNat ->
213213+ D2Sing ->
229214 S2D (konst x)
230215231231- D3Sing SNat SNat SNat ->
216216+ D3Sing ->
232217 S3D (konst x)
···2121-- import Data.List ( foldl1' )
2222import Data.Proxy
2323import Data.Serialize
2424-import Data.Singletons.TypeLits
2525-2626-#if MIN_VERSION_base(4,9,0)
2724import Data.Kind (Type)
2828-#endif
29253026import qualified Numeric.LinearAlgebra as LA
3127import Numeric.LinearAlgebra.Static
···3329import Grenade.Core
3430import Grenade.Recurrent.Core
3531import Grenade.Layers.Internal.Update
3232+import GHC.TypeLits
363337343835-- | Long Short Term Memory Recurrent unit
+1-1
src/Grenade/Utils/OneHot.hs
···2222import qualified Data.Map as M
23232424import Data.Proxy
2525-import Data.Singletons.TypeLits
26252726import Data.Vector ( Vector )
2827import qualified Data.Vector as V
···3130import Numeric.LinearAlgebra ( maxIndex )
3231import Numeric.LinearAlgebra.Devel
3332import Numeric.LinearAlgebra.Static
3333+import GHC.TypeLits
34343535import Grenade.Core.Shape
3636
+7-1
stack.yaml
···2020#
2121# resolver: ./custom-snapshot.yaml
2222# resolver: https://example.com/snapshots/2018-01-01.yaml
2323-resolver: lts-18.28
2323+resolver: lts-20.18
24242525# User packages to be built.
2626# Various formats can be used as shown in the example below.
···7070#
7171# Allow a newer minor version of GHC than the snapshot specifies
7272# compiler-check: newer-minor
7373+7474+nix:
7575+ enable: true
7676+ packages:
7777+ - blas
7878+ - lapack