diff options
Diffstat (limited to 'lib/Server.hs')
-rw-r--r-- | lib/Server.hs | 179 |
1 files changed, 112 insertions, 67 deletions
diff --git a/lib/Server.hs b/lib/Server.hs index 3922a7b..73c55cb 100644 --- a/lib/Server.hs +++ b/lib/Server.hs @@ -16,12 +16,14 @@ import Control.Concurrent.STM (TQueue, TVar, atomically, import Control.Monad (forever, unless, void, when) import Control.Monad.Catch (handle) -import Control.Monad.Extra (ifM, maybeM, unlessM, - whenJust, whenM) +import Control.Monad.Extra (ifM, mapMaybeM, maybeM, + unlessM, whenJust, whenM) import Control.Monad.IO.Class (MonadIO (liftIO)) -import Control.Monad.Logger (LoggingT, NoLoggingT, +import Control.Monad.Logger (LoggingT, MonadLogger, + NoLoggingT, logInfoN, logWarnN) -import Control.Monad.Reader (ReaderT, forM) +import Control.Monad.Reader (MonadReader, ReaderT, + forM) import Control.Monad.Trans (lift) import Data.Aeson ((.=)) import qualified Data.Aeson as A @@ -33,7 +35,7 @@ import Data.Pool (Pool) import Data.Proxy (Proxy (Proxy)) import Data.Swagger (toSchema) import Data.Text (Text) -import Data.Text.Encoding (decodeUtf8) +import Data.Text.Encoding (decodeASCII, decodeUtf8) import Data.Time (NominalDiffTime, UTCTime (utctDay), addUTCTime, diffUTCTime, @@ -48,7 +50,8 @@ import Fmt ((+|), (|+)) import qualified Network.WebSockets as WS import Servant (Application, ServerError (errBody), - err401, err404, serve, + err400, err401, err404, + serve, serveDirectoryFileServer, throwError) import Servant.API (NoContent (..), @@ -62,24 +65,33 @@ import Persist import Server.ControlRoom import Server.GTFS_RT (gtfsRealtimeServer) import Server.Util (Service, ServiceM, - runService, sendErrorMsg) + runService, sendErrorMsg, + utcToSeconds) import Yesod (toWaiAppPlain) -import Extrapolation (Extrapolator (..), - LinearExtrapolator (..)) -import System.IO.Unsafe - import Conduit (ResourceT) -import Config (ServerConfig (serverConfigAssets, serverConfigZoneinfoPath)) +import Config (LoggingConfig, + ServerConfig (..)) +import Control.Exception (throw) import Data.ByteString (ByteString) import Data.ByteString.Lazy (toStrict) +import Data.Foldable (minimumBy) +import Data.Function (on, (&)) +import Data.Maybe (fromMaybe) import qualified Data.Text as T import Data.Time.LocalTime.TimeZone.Olson (getTimeZoneSeriesFromOlsonFile) import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries) import Data.UUID (UUID) +import qualified Data.UUID as UUID +import Extrapolation (Extrapolator (..), + LinearExtrapolator (..), + euclid) +import GTFS (Seconds (unSeconds), + seconds2Double) import Prometheus import Prometheus.Metric.GHC import System.FilePath ((</>)) +import System.IO.Unsafe application :: GTFS.GTFS -> Pool SqlBackend -> ServerConfig -> IO Application application gtfs dbpool settings = do @@ -94,65 +106,84 @@ application gtfs dbpool settings = do (serverConfigZoneinfoPath settings </> T.unpack tzname) subscribers <- newTVarIO mempty - pure $ serve (Proxy @CompleteAPI) $ hoistServer (Proxy @CompleteAPI) runService + pure $ serve (Proxy @CompleteAPI) + $ hoistServer (Proxy @CompleteAPI) (runService (serverConfigLogging settings)) $ server gtfs getTzseries metrics subscribers dbpool settings -- databaseMigration :: ConnectionString -> IO () -doMigration pool = runSql pool $ runMigration $ do +doMigration pool = runSqlWithoutLog pool $ runMigration $ do migrateEnableExtension "uuid-ossp" migrateAll server :: GTFS.GTFS -> (Text -> IO TimeZoneSeries) -> Metrics -> TVar (M.Map UUID [TQueue (Maybe TrainPing)]) -> Pool SqlBackend -> ServerConfig -> Service CompleteAPI server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI - :<|> (handleTimetable :<|> handleTimetableStops :<|> handleTrip - :<|> handleRegister :<|> handleTrainPing (throwError err401) :<|> handleWS + :<|> (handleTrackerRegister :<|> handleTrainPing (throwError err401) :<|> handleWS :<|> handleSubscribe :<|> handleDebugState :<|> handleDebugTrain - :<|> handleDebugRegister :<|> pure (GTFS.gtfsFile gtfs) :<|> gtfsRealtimeServer gtfs dbpool) + :<|> pure (GTFS.gtfsFile gtfs) :<|> gtfsRealtimeServer gtfs dbpool) :<|> metrics :<|> serveDirectoryFileServer (serverConfigAssets settings) :<|> pure (unsafePerformIO (toWaiAppPlain (ControlRoom gtfs dbpool settings))) - where handleTimetable station maybeDay = - M.filter isLastStop . GTFS.tripsOnDay gtfs <$> liftIO day - where isLastStop = (==) station . GTFS.stationId . GTFS.stopStation . V.last . GTFS.tripStops - day = maybeM (getCurrentTime <&> utctDay) pure (pure maybeDay) - handleTimetableStops day = - pure . A.toJSON . fmap mkJson . M.elems $ GTFS.tripsOnDay gtfs day - where mkJson :: GTFS.Trip GTFS.Deep GTFS.Deep -> A.Value - mkJson GTFS.Trip {..} = A.object - [ "trip" .= tripTripId - , "sequencelength" .= (GTFS.stopSequence . V.last) tripStops - , "stops" .= fmap (\GTFS.Stop{..} -> A.object - [ "departure" .= GTFS.toUTC stopDeparture (GTFS.tzseries gtfs) day - , "arrival" .= GTFS.toUTC stopArrival (GTFS.tzseries gtfs) day - , "station" .= GTFS.stationId stopStation - , "lat" .= GTFS.stationLat stopStation - , "lon" .= GTFS.stationLon stopStation - ]) tripStops - ] - handleTrip trip = case M.lookup trip (GTFS.trips gtfs) of - Just res -> pure res - Nothing -> throwError err404 - handleRegister (ticketId :: UUID) RegisterJson{..} = do + where handleTrackerRegister RegisterJson{..} = do today <- liftIO getCurrentTime <&> utctDay expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod runSql dbpool $ do - TrackerKey tracker <- insert (Tracker expires False registerAgent) - insert (TrackerTicket (TicketKey ticketId) (TrackerKey tracker)) - pure tracker - handleDebugRegister (ticketId :: UUID) = do - expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod - runSql dbpool $ do - TrackerKey tracker <- insert (Tracker expires False "debug key") - insert (TrackerTicket (TicketKey ticketId) (TrackerKey tracker)) + TrackerKey tracker <- insert (Tracker expires False registerAgent Nothing) pure tracker handleTrainPing onError ping@TrainPing{..} = - let ticketId = trainPingTicket in - isTokenValid dbpool trainPingToken ticketId >>= \case - Nothing -> do - onError - pure Nothing - Just (tracker@Tracker{..}, ticket@Ticket{..}) -> do + isTokenValid dbpool trainPingToken >>= \case + Nothing -> onError >> pure Nothing + Just tracker@Tracker{..} -> do + + -- if the tracker is not associated with a ticket, it is probably + -- just starting out on a new trip, or has finished an old one. + maybeTicketId <- case trackerCurrentTicket of + Just ticketId -> pure (Just ticketId) + Nothing -> runSql dbpool $ do + now <- liftIO getCurrentTime + tickets <- selectList [ TicketDay ==. utctDay now, TicketCompleted ==. False ] [] + ticketsWithFirstStation <- flip mapMaybeM tickets + (\ticket@(Entity ticketId _) -> do + selectFirst [StopTicket ==. ticketId] [Asc StopSequence] >>= \case + Nothing -> pure Nothing + Just (Entity _ stop) -> do + station <- getJust (stopStation stop) + tzseries <- liftIO $ getTzseries (GTFS.tzname (stopDeparture stop)) + pure (Just (ticket, station, stop, tzseries))) + + if null ticketsWithFirstStation then pure Nothing else do + let (closestTicket, _, _, _) = minimumBy + -- (compare `on` euclid trainPingGeopos . stationGeopos . snd) + (compare `on` + (\(Entity _ ticket, station, stop, tzseries) -> + let + runningDay = ticketDay ticket + spaceDistance = euclid trainPingGeopos (stationGeopos station) + timeDiff = + GTFS.toSeconds (stopDeparture stop) tzseries runningDay + - utcToSeconds now runningDay + in + euclid trainPingGeopos (stationGeopos station) + + abs (seconds2Double timeDiff / 3600))) + ticketsWithFirstStation + logInfoN + $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ + " is now handling ticket "+|UUID.toString (coerce (entityKey closestTicket))|+ + " (trip "+|ticketTripName (entityVal closestTicket)|+")." + + update (coerce trainPingToken) + [TrackerCurrentTicket =. Just (entityKey closestTicket)] + + pure (Just (entityKey closestTicket)) + + ticketId <- case maybeTicketId of + Just ticketId -> pure ticketId + Nothing -> do + logWarnN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ + " sent a ping, but no trips are running today." + throwError err400 + runSql dbpool $ do + ticket@Ticket{..} <- getJust ticketId stations <- selectList [ StopTicket ==. ticketId ] [Asc StopArrival] >>= mapM (\stop -> do @@ -165,17 +196,31 @@ server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI <&> (V.fromList . fmap entityVal) let anchor = extrapolateAnchorFromPing LinearExtrapolator - ticket stations shapePoints ping + ticketId ticket stations shapePoints ping insert ping - last <- selectFirst [TrainAnchorTicket ==. trainPingTicket] [Desc TrainAnchorWhen] + last <- selectFirst [TrainAnchorTicket ==. ticketId] [Desc TrainAnchorWhen] -- only insert new estimates if they've actually changed anything when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor)) $ void $ insert anchor + -- are we at the final stop? if so, mark this ticket as done + -- & the tracker as free + let maxSequence = V.last stations + & (\(stop, _, _) -> stopSequence stop) + & fromIntegral + when (trainAnchorSequence anchor + 0.1 >= maxSequence) $ do + update (coerce trainPingToken) + [TrackerCurrentTicket =. Nothing] + update ticketId + [TicketCompleted =. True] + logInfoN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ + " has completed ticket "+|UUID.toString (coerce ticketId)|+ + " (trip "+|ticketTripName|+")" + queues <- liftIO $ atomically $ do - queues <- readTVar subscribers <&> M.lookup (coerce trainPingTicket) + queues <- readTVar subscribers <&> M.lookup (coerce ticketId) whenJust queues $ mapM_ (\q -> writeTQueue q (Just ping)) pure queues @@ -187,14 +232,14 @@ server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI msg <- liftIO $ WS.receiveData conn case A.eitherDecode msg of Left err -> do - logWarnN ("stray websocket message: "+|show msg|+" (could not decode: "+|err|+")") + logWarnN ("stray websocket message: "+|decodeASCII (toStrict msg)|+" (could not decode: "+|err|+")") liftIO $ WS.sendClose conn (C8.pack err) -- TODO: send a close msg (Nothing) to the subscribed queues? decGauge metricsWSGauge - Right ping -> + Right ping -> do -- if invalid token, send a "polite" close request. Note that the client may -- ignore this and continue sending messages, which will continue to be handled. - liftIO $ handleTrainPing (WS.sendClose conn ("" :: ByteString)) ping >>= \case - Just anchor -> WS.sendTextData conn (A.encode anchor) + handleTrainPing (liftIO $ WS.sendClose conn ("" :: ByteString)) ping >>= \case + Just anchor -> liftIO $ WS.sendTextData conn (A.encode anchor) Nothing -> pure () handleSubscribe (ticketId :: UUID) conn = liftIO $ WS.withPingThread conn 30 (pure ()) $ do queue <- atomically $ do @@ -204,7 +249,7 @@ server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI $ M.insertWith (<>) ticketId [queue] qs pure queue -- send most recent ping, if any (so we won't have to wait for movement) - lastPing <- runSql dbpool $ do + lastPing <- runSqlWithoutLog dbpool $ do trackers <- getTicketTrackers ticketId <&> fmap entityKey selectFirst [TrainPingToken <-. trackers] [Desc TrainPingTimestamp] @@ -239,21 +284,21 @@ server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI handleDebugAPI = pure $ toSwagger (Proxy @API) metrics = exportMetricsAsText <&> (decodeUtf8 . toStrict) -getTicketTrackers :: UUID -> ReaderT SqlBackend (NoLoggingT (ResourceT IO)) [Entity Tracker] +getTicketTrackers :: (MonadLogger (t (ResourceT IO)), MonadIO (t (ResourceT IO))) + => UUID -> ReaderT SqlBackend (t (ResourceT IO)) [Entity Tracker] getTicketTrackers ticketId = do joins <- selectList [TrackerTicketTicket ==. TicketKey ticketId] [] <&> fmap (trackerTicketTracker . entityVal) - selectList [TrackerId <-. joins] [] + selectList ([TrackerId <-. joins] ||. [TrackerCurrentTicket ==. Just (TicketKey ticketId)]) [] -- TODO: proper debug logging for expired tokens -isTokenValid :: MonadIO m => Pool SqlBackend -> TrackerId -> TicketId -> m (Maybe (Tracker, Ticket)) -isTokenValid dbpool token ticketId = runSql dbpool $ get token >>= \case +isTokenValid :: Pool SqlBackend -> TrackerId -> ServiceM (Maybe Tracker) +isTokenValid dbpool token = runSql dbpool $ get token >>= \case Just tracker | not (trackerBlocked tracker) -> do ifM (hasExpired (trackerExpires tracker)) (pure Nothing) - $ runSql dbpool $ get ticketId - <&> (\case { Nothing -> Nothing; Just ticket -> Just (tracker, ticket) }) + (pure (Just tracker)) _ -> pure Nothing hasExpired :: MonadIO m => UTCTime -> m Bool |