aboutsummaryrefslogtreecommitdiff
path: root/lib/Server
diff options
context:
space:
mode:
Diffstat (limited to '')
-rw-r--r--lib/Server.hs63
1 files changed, 31 insertions, 32 deletions
diff --git a/lib/Server.hs b/lib/Server.hs
index 759080c..ef5663a 100644
--- a/lib/Server.hs
+++ b/lib/Server.hs
@@ -12,9 +12,9 @@
-- Implementation of the API. This module is the main point of the program.
module Server (application) where
import Control.Monad (forever, unless, void, when)
-import Control.Monad.Extra (maybeM, unlessM, whenM)
+import Control.Monad.Extra (maybeM, unlessM, whenM, ifM)
import Control.Monad.IO.Class (MonadIO (liftIO))
-import Control.Monad.Logger (logWarnN)
+import Control.Monad.Logger (logWarnN, LoggingT)
import Control.Monad.Reader (forM)
import Control.Monad.Trans (lift)
import qualified Data.Aeson as A
@@ -57,6 +57,7 @@ import Extrapolation (Extrapolator (..),
import System.IO.Unsafe
import Config (ServerConfig)
+import Data.ByteString (ByteString)
application :: GTFS -> Pool SqlBackend -> IO Application
application gtfs dbpool = do
@@ -73,7 +74,7 @@ doMigration pool = runSql pool $
server :: GTFS -> Pool SqlBackend -> Service CompleteAPI
server gtfs@GTFS{..} dbpool = handleDebugAPI
:<|> (handleStations :<|> handleTimetable :<|> handleTrip
- :<|> handleRegister :<|> handleTrainPing :<|> handleWS
+ :<|> handleRegister :<|> handleTrainPing (throwError err401) :<|> handleWS
:<|> handleDebugState :<|> handleDebugTrain :<|> handleDebugRegister
:<|> gtfsRealtimeServer gtfs dbpool)
:<|> pure (unsafePerformIO (toWaiAppPlain (ControlRoom gtfs dbpool)))
@@ -99,22 +100,23 @@ server gtfs@GTFS{..} dbpool = handleDebugAPI
expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod
RunningKey token <- runSql dbpool $ insert (Running expires False tripID day Nothing "debug key")
pure token
- handleTrainPing ping = do
- running@Running{..} <- lift $ checkTokenValid dbpool (coerce $ trainPingToken ping)
- let anchor = extrapolateAnchorFromPing @LinearExtrapolator gtfs running ping
- -- TODO: are these always inserted in order?
- runSql dbpool $ do
- insert ping
- last <- selectFirst
- [TrainAnchorTrip ==. runningTrip, TrainAnchorDay ==. runningDay]
- [Desc TrainAnchorWhen]
- -- only insert new estimates if they've actually changed anything
- when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor))
- $ void $ insert anchor
-
- pure NoContent
+ handleTrainPing onError ping = isTokenValid dbpool (coerce $ trainPingToken ping) >>= \case
+ Nothing -> do
+ onError
+ pure NoContent
+ Just running@Running{..} -> do
+ let anchor = extrapolateAnchorFromPing @LinearExtrapolator gtfs running ping
+ -- TODO: are these always inserted in order?
+ runSql dbpool $ do
+ insert ping
+ last <- selectFirst
+ [TrainAnchorTrip ==. runningTrip, TrainAnchorDay ==. runningDay]
+ [Desc TrainAnchorWhen]
+ -- only insert new estimates if they've actually changed anything
+ when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor))
+ $ void $ insert anchor
+ pure NoContent
handleWS conn = do
- -- TODO test this!!
liftIO $ WS.forkPingThread conn 30
forever $ do
msg <- liftIO $ WS.receiveData conn
@@ -122,9 +124,10 @@ server gtfs@GTFS{..} dbpool = handleDebugAPI
Left err -> do
logWarnN ("stray websocket message: "+|show msg|+" (could not decode: "+|err|+")")
liftIO $ WS.sendClose conn (C8.pack err)
- Right ping -> do
- lift $ checkTokenValid dbpool (coerce $ trainPingToken ping)
- void $ runSql dbpool $ insert ping
+ Right ping ->
+ -- if invalid token, send a "polite" close request. Note that the client may
+ -- ignore this and continue sending messages, which will continue to be handled.
+ liftIO $ void $ handleTrainPing (WS.sendClose conn ("" :: ByteString)) ping
handleDebugState = do
now <- liftIO getCurrentTime
runSql dbpool $ do
@@ -145,17 +148,13 @@ server gtfs@GTFS{..} dbpool = handleDebugAPI
-- TODO: proper debug logging for expired tokens
-checkTokenValid :: Pool SqlBackend -> Token -> Handler Running
-checkTokenValid dbpool token = do
- trip <- try $ runSql dbpool $ get (coerce token)
- when (runningBlocked trip)
- $ throwError err401
- whenM (hasExpired (runningExpires trip))
- $ throwError err401
- pure trip
- where try m = m >>= \case
- Just a -> pure a
- Nothing -> throwError err404
+isTokenValid :: MonadIO m => Pool SqlBackend -> Token -> m (Maybe Running)
+isTokenValid dbpool token = runSql dbpool $ get (coerce token) >>= \case
+ Just trip | not (runningBlocked trip) -> do
+ ifM (hasExpired (runningExpires trip))
+ (pure Nothing)
+ (pure (Just trip))
+ _ -> pure Nothing
hasExpired :: MonadIO m => UTCTime -> m Bool
hasExpired limit = do