馃挘 Machine learning which might blow up in your face 馃挘
1{-# LANGUAGE BangPatterns #-}
2{-# LANGUAGE CPP #-}
3{-# LANGUAGE DataKinds #-}
4{-# LANGUAGE ScopedTypeVariables #-}
5{-# LANGUAGE TypeOperators #-}
6{-# LANGUAGE TupleSections #-}
7{-# LANGUAGE TypeFamilies #-}
8import Control.Monad
9import Control.Monad.Random
10import Data.List ( foldl' )
11
12import qualified Data.ByteString as B
13import Data.Serialize
14#if ! MIN_VERSION_base(4,13,0)
15import Data.Semigroup ( (<>) )
16#endif
17import GHC.TypeLits
18
19import qualified Numeric.LinearAlgebra.Static as SA
20
21import Options.Applicative
22
23import Grenade
24
25
26-- The defininition for our simple feed forward network.
27-- The type level lists represents the layers and the shapes passed through the layers.
28-- One can see that for this demonstration we are using relu, tanh and logit non-linear
29-- units, which can be easily subsituted for each other in and out.
30--
31-- With around 100000 examples, this should show two clear circles which have been learned by the network.
32type FFNet = Network '[ FullyConnected 2 40, Tanh, FullyConnected 40 10, Relu, FullyConnected 10 1, Logit ]
33 '[ 'D1 2, 'D1 40, 'D1 40, 'D1 10, 'D1 10, 'D1 1, 'D1 1]
34
35randomNet :: MonadRandom m => m FFNet
36randomNet = randomNetwork
37
38netTrain :: FFNet -> LearningParameters -> Int -> IO FFNet
39netTrain net0 rate n = do
40 inps <- replicateM n $ do
41 s <- getRandom
42 return $ S1D $ SA.randomVector s SA.Uniform * 2 - 1
43 let outs = flip map inps $ \(S1D v) ->
44 if v `inCircle` (fromRational 0.33, 0.33) || v `inCircle` (fromRational (-0.33), 0.33)
45 then S1D $ fromRational 1
46 else S1D $ fromRational 0
47
48 let trained = foldl' trainEach net0 (zip inps outs)
49 return trained
50
51 where
52 inCircle :: KnownNat n => SA.R n -> (SA.R n, Double) -> Bool
53 v `inCircle` (o, r) = SA.norm_2 (v - o) <= r
54 trainEach !network (i,o) = train rate network i o
55
56netLoad :: FilePath -> IO FFNet
57netLoad modelPath = do
58 modelData <- B.readFile modelPath
59 either fail return $ runGet (get :: Get FFNet) modelData
60
61netScore :: FFNet -> IO ()
62netScore network = do
63 let testIns = [ [ (x,y) | x <- [0..50] ]
64 | y <- [0..20] ]
65 outMat = fmap (fmap (\(x,y) -> (render . normx) $ runNet network (S1D $ SA.vector [x / 25 - 1,y / 10 - 1]))) testIns
66 putStrLn $ unlines outMat
67
68 where
69 render n' | n' <= 0.2 = ' '
70 | n' <= 0.4 = '.'
71 | n' <= 0.6 = '-'
72 | n' <= 0.8 = '='
73 | otherwise = '#'
74
75 normx :: S ('D1 1) -> Double
76 normx (S1D r) = SA.mean r
77
78data FeedForwardOpts = FeedForwardOpts Int LearningParameters (Maybe FilePath) (Maybe FilePath)
79
80feedForward' :: Parser FeedForwardOpts
81feedForward' =
82 FeedForwardOpts <$> option auto (long "examples" <> short 'e' <> value 100000)
83 <*> (LearningParameters
84 <$> option auto (long "train_rate" <> short 'r' <> value 0.01)
85 <*> option auto (long "momentum" <> value 0.9)
86 <*> option auto (long "l2" <> value 0.0005)
87 )
88 <*> optional (strOption (long "load"))
89 <*> optional (strOption (long "save"))
90
91main :: IO ()
92main = do
93 FeedForwardOpts examples rate load save <- execParser (info (feedForward' <**> helper) idm)
94 net0 <- case load of
95 Just loadFile -> netLoad loadFile
96 Nothing -> randomNet
97
98 net <- netTrain net0 rate examples
99 netScore net
100
101 case save of
102 Just saveFile -> B.writeFile saveFile $ runPut (put net)
103 Nothing -> return ()