From a4045a5b0a898042cd78eba9b22550c965a1bbd9 Mon Sep 17 00:00:00 2001 From: stuebinm Date: Sat, 27 Aug 2022 01:45:12 +0200 Subject: controlroom: lots of pretty little knobs (also some database schema changes, for good measure) --- lib/Server.hs | 78 ++++++++++++++++++++++++++++------------------------------- 1 file changed, 37 insertions(+), 41 deletions(-) (limited to 'lib/Server.hs') diff --git a/lib/Server.hs b/lib/Server.hs index f7ee81b..75617bd 100644 --- a/lib/Server.hs +++ b/lib/Server.hs @@ -11,8 +11,8 @@ -- Implementation of the API. This module is the main point of the program. module Server (application) where -import Control.Monad (forever, void, when) -import Control.Monad.Extra (maybeM, whenM) +import Control.Monad (forever, unless, void, when) +import Control.Monad.Extra (maybeM, unlessM, whenM) import Control.Monad.IO.Class (MonadIO (liftIO)) import Control.Monad.Logger (logWarnN) import Control.Monad.Reader (forM) @@ -35,8 +35,9 @@ import Database.Persist import Database.Persist.Postgresql (SqlBackend, runMigration) import Fmt ((+|), (|+)) import qualified Network.WebSockets as WS -import Servant (Application, err401, err404, - serve, throwError) +import Servant (Application, + ServerError (errBody), err401, + err404, serve, throwError) import Servant.API (NoContent (..), (:<|>) (..)) import Servant.Server (Handler, hoistServer) import Servant.Swagger (toSwagger) @@ -46,7 +47,8 @@ import GTFS import Persist import Server.ControlRoom import Server.GTFS_RT (gtfsRealtimeServer) -import Server.Util (Service, ServiceM, runService) +import Server.Util (Service, ServiceM, runService, + sendErrorMsg) import Yesod (toWaiAppPlain) import System.IO.Unsafe @@ -64,11 +66,12 @@ doMigration pool = runSql pool $ runMigration migrateAll server :: GTFS -> Pool SqlBackend -> Service CompleteAPI -server gtfs@GTFS{..} dbpool = handleDebugAPI :<|> (handleStations :<|> handleTimetable :<|> handleTrip - :<|> handleRegister :<|> handleTripPing :<|> handleWS :<|> handleDebugState :<|> - gtfsRealtimeServer gtfs dbpool - :<|> adminServer gtfs dbpool) - :<|> pure (unsafePerformIO (toWaiAppPlain (ControlRoom "http://localhost:4000/cr" gtfs dbpool))) +server gtfs@GTFS{..} dbpool = handleDebugAPI + :<|> (handleStations :<|> handleTimetable :<|> handleTrip + :<|> handleRegister :<|> handleTripPing :<|> handleWS + :<|> handleDebugState :<|> handleDebugTrain :<|> handleDebugRegister + :<|> gtfsRealtimeServer gtfs dbpool) + :<|> pure (unsafePerformIO (toWaiAppPlain (ControlRoom gtfs dbpool))) where handleStations = pure stations handleTimetable station maybeDay = do -- TODO: resolve "overlay" trips (perhaps just additional CalendarDates?) @@ -80,13 +83,19 @@ server gtfs@GTFS{..} dbpool = handleDebugAPI :<|> (handleStations :<|> handleTim handleTrip trip = case M.lookup trip trips of Just res -> pure res Nothing -> throwError err404 - handleRegister tripID = do - -- TODO registration may carry extra information! + handleRegister tripID RegisterJson{..} = do + today <- liftIO getCurrentTime <&> utctDay + when (not $ runsOnDay gtfs tripID today) + $ sendErrorMsg "this trip does not run today." expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod - RunningTripKey token <- runSql dbpool $ insert (RunningTrip expires False tripID Nothing) + RunningKey token <- runSql dbpool $ insert (Running expires False tripID today Nothing registerAgent) + pure token + handleDebugRegister tripID day = do + expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod + RunningKey token <- runSql dbpool $ insert (Running expires False tripID day Nothing "debug key") pure token handleTripPing ping = do - lift $ checkTokenValid dbpool (coerce $ tripPingToken ping) + lift $ checkTokenValid dbpool (coerce $ trainPingToken ping) -- TODO: are these always inserted in order? runSql dbpool $ insert ping pure NoContent @@ -100,47 +109,34 @@ server gtfs@GTFS{..} dbpool = handleDebugAPI :<|> (handleStations :<|> handleTim logWarnN ("stray websocket message: "+|show msg|+" (could not decode: "+|err|+")") liftIO $ WS.sendClose conn (C8.pack err) Right ping -> do - lift $ checkTokenValid dbpool (coerce $ tripPingToken ping) + lift $ checkTokenValid dbpool (coerce $ trainPingToken ping) void $ runSql dbpool $ insert ping handleDebugState = do now <- liftIO getCurrentTime runSql dbpool $ do - running <- selectList [RunningTripBlocked ==. False, RunningTripExpires >=. now] [] - pairs <- forM running $ \(Entity token@(RunningTripKey uuid) _) -> do - entities <- selectList [TripPingToken ==. token] [] + running <- selectList [RunningBlocked ==. False, RunningExpires >=. now] [] + pairs <- forM running $ \(Entity token@(RunningKey uuid) _) -> do + entities <- selectList [TrainPingToken ==. token] [] pure (uuid, fmap entityVal entities) pure (M.fromList pairs) + handleDebugTrain tripId day = do + unless (runsOnDay gtfs tripId day) + $ sendErrorMsg ("this trip does not run on "+|day|+".") + runSql dbpool $ do + tokens <- selectList [RunningTrip ==. tripId, RunningDay ==. day] [] + pings <- forM tokens $ \(Entity token _) -> do + selectList [TrainPingToken ==. token] [] <&> fmap entityVal + pure (concat pings) handleDebugAPI = pure $ toSwagger (Proxy @API) -adminServer :: GTFS -> Pool SqlBackend -> Service AdminAPI -adminServer gtfs dbpool = - addAnnounce :<|> delAnnounce :<|> modTripDate Added Cancelled - :<|> modTripDate Cancelled Added :<|> extraTrip - where addAnnounce ann@Announcement{..} = runSql dbpool $ do - AnnouncementKey uuid <- insert ann - pure uuid - delAnnounce uuid = runSql dbpool $ do - delete (AnnouncementKey uuid) - pure NoContent - modTripDate one two tripId day = runSql dbpool $ do - getBy (TripAndDay tripId day) >>= \case - Just (Entity key (ScheduleAmendment _ _ status)) -> do - when (status == two) $ delete key - pure NoContent - Nothing -> do - insert (ScheduleAmendment tripId day one) - pure NoContent - extraTrip = error "unimplemented!" - - -- TODO: proper debug logging for expired tokens checkTokenValid :: Pool SqlBackend -> Token -> Handler () checkTokenValid dbpool token = do trip <- try $ runSql dbpool $ get (coerce token) - when (runningTripBlocked trip) + when (runningBlocked trip) $ throwError err401 - whenM (hasExpired (runningTripExpires trip)) + whenM (hasExpired (runningExpires trip)) $ throwError err401 where try m = m >>= \case Just a -> pure a -- cgit v1.2.3