馃挘 Machine learning which might blow up in your face 馃挘
at master 103 lines 3.7 kB view raw
1{-# LANGUAGE BangPatterns #-} 2{-# LANGUAGE CPP #-} 3{-# LANGUAGE DataKinds #-} 4{-# LANGUAGE ScopedTypeVariables #-} 5{-# LANGUAGE TypeOperators #-} 6{-# LANGUAGE TupleSections #-} 7{-# LANGUAGE TypeFamilies #-} 8import Control.Monad 9import Control.Monad.Random 10import Data.List ( foldl' ) 11 12import qualified Data.ByteString as B 13import Data.Serialize 14#if ! MIN_VERSION_base(4,13,0) 15import Data.Semigroup ( (<>) ) 16#endif 17import GHC.TypeLits 18 19import qualified Numeric.LinearAlgebra.Static as SA 20 21import Options.Applicative 22 23import Grenade 24 25 26-- The defininition for our simple feed forward network. 27-- The type level lists represents the layers and the shapes passed through the layers. 28-- One can see that for this demonstration we are using relu, tanh and logit non-linear 29-- units, which can be easily subsituted for each other in and out. 30-- 31-- With around 100000 examples, this should show two clear circles which have been learned by the network. 32type FFNet = Network '[ FullyConnected 2 40, Tanh, FullyConnected 40 10, Relu, FullyConnected 10 1, Logit ] 33 '[ 'D1 2, 'D1 40, 'D1 40, 'D1 10, 'D1 10, 'D1 1, 'D1 1] 34 35randomNet :: MonadRandom m => m FFNet 36randomNet = randomNetwork 37 38netTrain :: FFNet -> LearningParameters -> Int -> IO FFNet 39netTrain net0 rate n = do 40 inps <- replicateM n $ do 41 s <- getRandom 42 return $ S1D $ SA.randomVector s SA.Uniform * 2 - 1 43 let outs = flip map inps $ \(S1D v) -> 44 if v `inCircle` (fromRational 0.33, 0.33) || v `inCircle` (fromRational (-0.33), 0.33) 45 then S1D $ fromRational 1 46 else S1D $ fromRational 0 47 48 let trained = foldl' trainEach net0 (zip inps outs) 49 return trained 50 51 where 52 inCircle :: KnownNat n => SA.R n -> (SA.R n, Double) -> Bool 53 v `inCircle` (o, r) = SA.norm_2 (v - o) <= r 54 trainEach !network (i,o) = train rate network i o 55 56netLoad :: FilePath -> IO FFNet 57netLoad modelPath = do 58 modelData <- B.readFile modelPath 59 either fail return $ runGet (get :: Get FFNet) modelData 60 61netScore :: FFNet -> IO () 62netScore network = do 63 let testIns = [ [ (x,y) | x <- [0..50] ] 64 | y <- [0..20] ] 65 outMat = fmap (fmap (\(x,y) -> (render . normx) $ runNet network (S1D $ SA.vector [x / 25 - 1,y / 10 - 1]))) testIns 66 putStrLn $ unlines outMat 67 68 where 69 render n' | n' <= 0.2 = ' ' 70 | n' <= 0.4 = '.' 71 | n' <= 0.6 = '-' 72 | n' <= 0.8 = '=' 73 | otherwise = '#' 74 75 normx :: S ('D1 1) -> Double 76 normx (S1D r) = SA.mean r 77 78data FeedForwardOpts = FeedForwardOpts Int LearningParameters (Maybe FilePath) (Maybe FilePath) 79 80feedForward' :: Parser FeedForwardOpts 81feedForward' = 82 FeedForwardOpts <$> option auto (long "examples" <> short 'e' <> value 100000) 83 <*> (LearningParameters 84 <$> option auto (long "train_rate" <> short 'r' <> value 0.01) 85 <*> option auto (long "momentum" <> value 0.9) 86 <*> option auto (long "l2" <> value 0.0005) 87 ) 88 <*> optional (strOption (long "load")) 89 <*> optional (strOption (long "save")) 90 91main :: IO () 92main = do 93 FeedForwardOpts examples rate load save <- execParser (info (feedForward' <**> helper) idm) 94 net0 <- case load of 95 Just loadFile -> netLoad loadFile 96 Nothing -> randomNet 97 98 net <- netTrain net0 rate examples 99 netScore net 100 101 case save of 102 Just saveFile -> B.writeFile saveFile $ runPut (put net) 103 Nothing -> return ()