diff options
author | stuebinm | 2022-09-10 22:40:38 +0200 |
---|---|---|
committer | stuebinm | 2022-09-10 22:40:38 +0200 |
commit | 1d8c2f078b4920c8813c48618bf443a7c8c767f3 (patch) | |
tree | eca4c792bec740c9f1e9a0285ed6aabcbe1cd429 /lib | |
parent | 676dfae3263799806da1a3cf5d4162b434b84259 (diff) |
use websockets for the on-board-unit
Diffstat (limited to '')
-rw-r--r-- | lib/Server.hs | 63 |
1 files changed, 31 insertions, 32 deletions
diff --git a/lib/Server.hs b/lib/Server.hs index 759080c..ef5663a 100644 --- a/lib/Server.hs +++ b/lib/Server.hs @@ -12,9 +12,9 @@ -- Implementation of the API. This module is the main point of the program. module Server (application) where import Control.Monad (forever, unless, void, when) -import Control.Monad.Extra (maybeM, unlessM, whenM) +import Control.Monad.Extra (maybeM, unlessM, whenM, ifM) import Control.Monad.IO.Class (MonadIO (liftIO)) -import Control.Monad.Logger (logWarnN) +import Control.Monad.Logger (logWarnN, LoggingT) import Control.Monad.Reader (forM) import Control.Monad.Trans (lift) import qualified Data.Aeson as A @@ -57,6 +57,7 @@ import Extrapolation (Extrapolator (..), import System.IO.Unsafe import Config (ServerConfig) +import Data.ByteString (ByteString) application :: GTFS -> Pool SqlBackend -> IO Application application gtfs dbpool = do @@ -73,7 +74,7 @@ doMigration pool = runSql pool $ server :: GTFS -> Pool SqlBackend -> Service CompleteAPI server gtfs@GTFS{..} dbpool = handleDebugAPI :<|> (handleStations :<|> handleTimetable :<|> handleTrip - :<|> handleRegister :<|> handleTrainPing :<|> handleWS + :<|> handleRegister :<|> handleTrainPing (throwError err401) :<|> handleWS :<|> handleDebugState :<|> handleDebugTrain :<|> handleDebugRegister :<|> gtfsRealtimeServer gtfs dbpool) :<|> pure (unsafePerformIO (toWaiAppPlain (ControlRoom gtfs dbpool))) @@ -99,22 +100,23 @@ server gtfs@GTFS{..} dbpool = handleDebugAPI expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod RunningKey token <- runSql dbpool $ insert (Running expires False tripID day Nothing "debug key") pure token - handleTrainPing ping = do - running@Running{..} <- lift $ checkTokenValid dbpool (coerce $ trainPingToken ping) - let anchor = extrapolateAnchorFromPing @LinearExtrapolator gtfs running ping - -- TODO: are these always inserted in order? - runSql dbpool $ do - insert ping - last <- selectFirst - [TrainAnchorTrip ==. runningTrip, TrainAnchorDay ==. runningDay] - [Desc TrainAnchorWhen] - -- only insert new estimates if they've actually changed anything - when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor)) - $ void $ insert anchor - - pure NoContent + handleTrainPing onError ping = isTokenValid dbpool (coerce $ trainPingToken ping) >>= \case + Nothing -> do + onError + pure NoContent + Just running@Running{..} -> do + let anchor = extrapolateAnchorFromPing @LinearExtrapolator gtfs running ping + -- TODO: are these always inserted in order? + runSql dbpool $ do + insert ping + last <- selectFirst + [TrainAnchorTrip ==. runningTrip, TrainAnchorDay ==. runningDay] + [Desc TrainAnchorWhen] + -- only insert new estimates if they've actually changed anything + when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor)) + $ void $ insert anchor + pure NoContent handleWS conn = do - -- TODO test this!! liftIO $ WS.forkPingThread conn 30 forever $ do msg <- liftIO $ WS.receiveData conn @@ -122,9 +124,10 @@ server gtfs@GTFS{..} dbpool = handleDebugAPI Left err -> do logWarnN ("stray websocket message: "+|show msg|+" (could not decode: "+|err|+")") liftIO $ WS.sendClose conn (C8.pack err) - Right ping -> do - lift $ checkTokenValid dbpool (coerce $ trainPingToken ping) - void $ runSql dbpool $ insert ping + Right ping -> + -- if invalid token, send a "polite" close request. Note that the client may + -- ignore this and continue sending messages, which will continue to be handled. + liftIO $ void $ handleTrainPing (WS.sendClose conn ("" :: ByteString)) ping handleDebugState = do now <- liftIO getCurrentTime runSql dbpool $ do @@ -145,17 +148,13 @@ server gtfs@GTFS{..} dbpool = handleDebugAPI -- TODO: proper debug logging for expired tokens -checkTokenValid :: Pool SqlBackend -> Token -> Handler Running -checkTokenValid dbpool token = do - trip <- try $ runSql dbpool $ get (coerce token) - when (runningBlocked trip) - $ throwError err401 - whenM (hasExpired (runningExpires trip)) - $ throwError err401 - pure trip - where try m = m >>= \case - Just a -> pure a - Nothing -> throwError err404 +isTokenValid :: MonadIO m => Pool SqlBackend -> Token -> m (Maybe Running) +isTokenValid dbpool token = runSql dbpool $ get (coerce token) >>= \case + Just trip | not (runningBlocked trip) -> do + ifM (hasExpired (runningExpires trip)) + (pure Nothing) + (pure (Just trip)) + _ -> pure Nothing hasExpired :: MonadIO m => UTCTime -> m Bool hasExpired limit = do |