aboutsummaryrefslogtreecommitdiff
path: root/lib/Server.hs
diff options
context:
space:
mode:
authorstuebinm2024-05-01 03:07:06 +0200
committerstuebinm2024-05-02 00:18:16 +0200
commit80984549172d7de83564757de80996487ca2fb15 (patch)
tree1e4bfe43fa9fc96fa5642fe34f502005775f257f /lib/Server.hs
parentb26a3d617e90c9693a4ee8b7cc8bbba506cd4746 (diff)
restructure: get the tracker to work again
This should hopefully be the final (major) part of the restructuring: a tracker no longer has to know which trip it is on (and indeed it has no idea for now), instead the server keeps state about which trips are currently running and will insert incoming pings in a hopefully reasonable manner, based on their geoposition & time. There's lots of associated TODO items here (especially there should be manual overrides for all this logic in the web ui), but that's work for a future me. (incidentally, this also adds support for sending all log messages out via ntfy-sh)
Diffstat (limited to 'lib/Server.hs')
-rw-r--r--lib/Server.hs179
1 files changed, 112 insertions, 67 deletions
diff --git a/lib/Server.hs b/lib/Server.hs
index 3922a7b..73c55cb 100644
--- a/lib/Server.hs
+++ b/lib/Server.hs
@@ -16,12 +16,14 @@ import Control.Concurrent.STM (TQueue, TVar, atomically,
import Control.Monad (forever, unless, void,
when)
import Control.Monad.Catch (handle)
-import Control.Monad.Extra (ifM, maybeM, unlessM,
- whenJust, whenM)
+import Control.Monad.Extra (ifM, mapMaybeM, maybeM,
+ unlessM, whenJust, whenM)
import Control.Monad.IO.Class (MonadIO (liftIO))
-import Control.Monad.Logger (LoggingT, NoLoggingT,
+import Control.Monad.Logger (LoggingT, MonadLogger,
+ NoLoggingT, logInfoN,
logWarnN)
-import Control.Monad.Reader (ReaderT, forM)
+import Control.Monad.Reader (MonadReader, ReaderT,
+ forM)
import Control.Monad.Trans (lift)
import Data.Aeson ((.=))
import qualified Data.Aeson as A
@@ -33,7 +35,7 @@ import Data.Pool (Pool)
import Data.Proxy (Proxy (Proxy))
import Data.Swagger (toSchema)
import Data.Text (Text)
-import Data.Text.Encoding (decodeUtf8)
+import Data.Text.Encoding (decodeASCII, decodeUtf8)
import Data.Time (NominalDiffTime,
UTCTime (utctDay),
addUTCTime, diffUTCTime,
@@ -48,7 +50,8 @@ import Fmt ((+|), (|+))
import qualified Network.WebSockets as WS
import Servant (Application,
ServerError (errBody),
- err401, err404, serve,
+ err400, err401, err404,
+ serve,
serveDirectoryFileServer,
throwError)
import Servant.API (NoContent (..),
@@ -62,24 +65,33 @@ import Persist
import Server.ControlRoom
import Server.GTFS_RT (gtfsRealtimeServer)
import Server.Util (Service, ServiceM,
- runService, sendErrorMsg)
+ runService, sendErrorMsg,
+ utcToSeconds)
import Yesod (toWaiAppPlain)
-import Extrapolation (Extrapolator (..),
- LinearExtrapolator (..))
-import System.IO.Unsafe
-
import Conduit (ResourceT)
-import Config (ServerConfig (serverConfigAssets, serverConfigZoneinfoPath))
+import Config (LoggingConfig,
+ ServerConfig (..))
+import Control.Exception (throw)
import Data.ByteString (ByteString)
import Data.ByteString.Lazy (toStrict)
+import Data.Foldable (minimumBy)
+import Data.Function (on, (&))
+import Data.Maybe (fromMaybe)
import qualified Data.Text as T
import Data.Time.LocalTime.TimeZone.Olson (getTimeZoneSeriesFromOlsonFile)
import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries)
import Data.UUID (UUID)
+import qualified Data.UUID as UUID
+import Extrapolation (Extrapolator (..),
+ LinearExtrapolator (..),
+ euclid)
+import GTFS (Seconds (unSeconds),
+ seconds2Double)
import Prometheus
import Prometheus.Metric.GHC
import System.FilePath ((</>))
+import System.IO.Unsafe
application :: GTFS.GTFS -> Pool SqlBackend -> ServerConfig -> IO Application
application gtfs dbpool settings = do
@@ -94,65 +106,84 @@ application gtfs dbpool settings = do
(serverConfigZoneinfoPath settings </> T.unpack tzname)
subscribers <- newTVarIO mempty
- pure $ serve (Proxy @CompleteAPI) $ hoistServer (Proxy @CompleteAPI) runService
+ pure $ serve (Proxy @CompleteAPI)
+ $ hoistServer (Proxy @CompleteAPI) (runService (serverConfigLogging settings))
$ server gtfs getTzseries metrics subscribers dbpool settings
-- databaseMigration :: ConnectionString -> IO ()
-doMigration pool = runSql pool $ runMigration $ do
+doMigration pool = runSqlWithoutLog pool $ runMigration $ do
migrateEnableExtension "uuid-ossp"
migrateAll
server :: GTFS.GTFS -> (Text -> IO TimeZoneSeries) -> Metrics -> TVar (M.Map UUID [TQueue (Maybe TrainPing)]) -> Pool SqlBackend -> ServerConfig -> Service CompleteAPI
server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI
- :<|> (handleTimetable :<|> handleTimetableStops :<|> handleTrip
- :<|> handleRegister :<|> handleTrainPing (throwError err401) :<|> handleWS
+ :<|> (handleTrackerRegister :<|> handleTrainPing (throwError err401) :<|> handleWS
:<|> handleSubscribe :<|> handleDebugState :<|> handleDebugTrain
- :<|> handleDebugRegister :<|> pure (GTFS.gtfsFile gtfs) :<|> gtfsRealtimeServer gtfs dbpool)
+ :<|> pure (GTFS.gtfsFile gtfs) :<|> gtfsRealtimeServer gtfs dbpool)
:<|> metrics
:<|> serveDirectoryFileServer (serverConfigAssets settings)
:<|> pure (unsafePerformIO (toWaiAppPlain (ControlRoom gtfs dbpool settings)))
- where handleTimetable station maybeDay =
- M.filter isLastStop . GTFS.tripsOnDay gtfs <$> liftIO day
- where isLastStop = (==) station . GTFS.stationId . GTFS.stopStation . V.last . GTFS.tripStops
- day = maybeM (getCurrentTime <&> utctDay) pure (pure maybeDay)
- handleTimetableStops day =
- pure . A.toJSON . fmap mkJson . M.elems $ GTFS.tripsOnDay gtfs day
- where mkJson :: GTFS.Trip GTFS.Deep GTFS.Deep -> A.Value
- mkJson GTFS.Trip {..} = A.object
- [ "trip" .= tripTripId
- , "sequencelength" .= (GTFS.stopSequence . V.last) tripStops
- , "stops" .= fmap (\GTFS.Stop{..} -> A.object
- [ "departure" .= GTFS.toUTC stopDeparture (GTFS.tzseries gtfs) day
- , "arrival" .= GTFS.toUTC stopArrival (GTFS.tzseries gtfs) day
- , "station" .= GTFS.stationId stopStation
- , "lat" .= GTFS.stationLat stopStation
- , "lon" .= GTFS.stationLon stopStation
- ]) tripStops
- ]
- handleTrip trip = case M.lookup trip (GTFS.trips gtfs) of
- Just res -> pure res
- Nothing -> throwError err404
- handleRegister (ticketId :: UUID) RegisterJson{..} = do
+ where handleTrackerRegister RegisterJson{..} = do
today <- liftIO getCurrentTime <&> utctDay
expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod
runSql dbpool $ do
- TrackerKey tracker <- insert (Tracker expires False registerAgent)
- insert (TrackerTicket (TicketKey ticketId) (TrackerKey tracker))
- pure tracker
- handleDebugRegister (ticketId :: UUID) = do
- expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod
- runSql dbpool $ do
- TrackerKey tracker <- insert (Tracker expires False "debug key")
- insert (TrackerTicket (TicketKey ticketId) (TrackerKey tracker))
+ TrackerKey tracker <- insert (Tracker expires False registerAgent Nothing)
pure tracker
handleTrainPing onError ping@TrainPing{..} =
- let ticketId = trainPingTicket in
- isTokenValid dbpool trainPingToken ticketId >>= \case
- Nothing -> do
- onError
- pure Nothing
- Just (tracker@Tracker{..}, ticket@Ticket{..}) -> do
+ isTokenValid dbpool trainPingToken >>= \case
+ Nothing -> onError >> pure Nothing
+ Just tracker@Tracker{..} -> do
+
+ -- if the tracker is not associated with a ticket, it is probably
+ -- just starting out on a new trip, or has finished an old one.
+ maybeTicketId <- case trackerCurrentTicket of
+ Just ticketId -> pure (Just ticketId)
+ Nothing -> runSql dbpool $ do
+ now <- liftIO getCurrentTime
+ tickets <- selectList [ TicketDay ==. utctDay now, TicketCompleted ==. False ] []
+ ticketsWithFirstStation <- flip mapMaybeM tickets
+ (\ticket@(Entity ticketId _) -> do
+ selectFirst [StopTicket ==. ticketId] [Asc StopSequence] >>= \case
+ Nothing -> pure Nothing
+ Just (Entity _ stop) -> do
+ station <- getJust (stopStation stop)
+ tzseries <- liftIO $ getTzseries (GTFS.tzname (stopDeparture stop))
+ pure (Just (ticket, station, stop, tzseries)))
+
+ if null ticketsWithFirstStation then pure Nothing else do
+ let (closestTicket, _, _, _) = minimumBy
+ -- (compare `on` euclid trainPingGeopos . stationGeopos . snd)
+ (compare `on`
+ (\(Entity _ ticket, station, stop, tzseries) ->
+ let
+ runningDay = ticketDay ticket
+ spaceDistance = euclid trainPingGeopos (stationGeopos station)
+ timeDiff =
+ GTFS.toSeconds (stopDeparture stop) tzseries runningDay
+ - utcToSeconds now runningDay
+ in
+ euclid trainPingGeopos (stationGeopos station)
+ + abs (seconds2Double timeDiff / 3600)))
+ ticketsWithFirstStation
+ logInfoN
+ $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
+ " is now handling ticket "+|UUID.toString (coerce (entityKey closestTicket))|+
+ " (trip "+|ticketTripName (entityVal closestTicket)|+")."
+
+ update (coerce trainPingToken)
+ [TrackerCurrentTicket =. Just (entityKey closestTicket)]
+
+ pure (Just (entityKey closestTicket))
+
+ ticketId <- case maybeTicketId of
+ Just ticketId -> pure ticketId
+ Nothing -> do
+ logWarnN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
+ " sent a ping, but no trips are running today."
+ throwError err400
+
runSql dbpool $ do
+ ticket@Ticket{..} <- getJust ticketId
stations <- selectList [ StopTicket ==. ticketId ] [Asc StopArrival]
>>= mapM (\stop -> do
@@ -165,17 +196,31 @@ server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI
<&> (V.fromList . fmap entityVal)
let anchor = extrapolateAnchorFromPing LinearExtrapolator
- ticket stations shapePoints ping
+ ticketId ticket stations shapePoints ping
insert ping
- last <- selectFirst [TrainAnchorTicket ==. trainPingTicket] [Desc TrainAnchorWhen]
+ last <- selectFirst [TrainAnchorTicket ==. ticketId] [Desc TrainAnchorWhen]
-- only insert new estimates if they've actually changed anything
when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor))
$ void $ insert anchor
+ -- are we at the final stop? if so, mark this ticket as done
+ -- & the tracker as free
+ let maxSequence = V.last stations
+ & (\(stop, _, _) -> stopSequence stop)
+ & fromIntegral
+ when (trainAnchorSequence anchor + 0.1 >= maxSequence) $ do
+ update (coerce trainPingToken)
+ [TrackerCurrentTicket =. Nothing]
+ update ticketId
+ [TicketCompleted =. True]
+ logInfoN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+
+ " has completed ticket "+|UUID.toString (coerce ticketId)|+
+ " (trip "+|ticketTripName|+")"
+
queues <- liftIO $ atomically $ do
- queues <- readTVar subscribers <&> M.lookup (coerce trainPingTicket)
+ queues <- readTVar subscribers <&> M.lookup (coerce ticketId)
whenJust queues $
mapM_ (\q -> writeTQueue q (Just ping))
pure queues
@@ -187,14 +232,14 @@ server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI
msg <- liftIO $ WS.receiveData conn
case A.eitherDecode msg of
Left err -> do
- logWarnN ("stray websocket message: "+|show msg|+" (could not decode: "+|err|+")")
+ logWarnN ("stray websocket message: "+|decodeASCII (toStrict msg)|+" (could not decode: "+|err|+")")
liftIO $ WS.sendClose conn (C8.pack err)
-- TODO: send a close msg (Nothing) to the subscribed queues? decGauge metricsWSGauge
- Right ping ->
+ Right ping -> do
-- if invalid token, send a "polite" close request. Note that the client may
-- ignore this and continue sending messages, which will continue to be handled.
- liftIO $ handleTrainPing (WS.sendClose conn ("" :: ByteString)) ping >>= \case
- Just anchor -> WS.sendTextData conn (A.encode anchor)
+ handleTrainPing (liftIO $ WS.sendClose conn ("" :: ByteString)) ping >>= \case
+ Just anchor -> liftIO $ WS.sendTextData conn (A.encode anchor)
Nothing -> pure ()
handleSubscribe (ticketId :: UUID) conn = liftIO $ WS.withPingThread conn 30 (pure ()) $ do
queue <- atomically $ do
@@ -204,7 +249,7 @@ server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI
$ M.insertWith (<>) ticketId [queue] qs
pure queue
-- send most recent ping, if any (so we won't have to wait for movement)
- lastPing <- runSql dbpool $ do
+ lastPing <- runSqlWithoutLog dbpool $ do
trackers <- getTicketTrackers ticketId
<&> fmap entityKey
selectFirst [TrainPingToken <-. trackers] [Desc TrainPingTimestamp]
@@ -239,21 +284,21 @@ server gtfs getTzseries Metrics{..} subscribers dbpool settings = handleDebugAPI
handleDebugAPI = pure $ toSwagger (Proxy @API)
metrics = exportMetricsAsText <&> (decodeUtf8 . toStrict)
-getTicketTrackers :: UUID -> ReaderT SqlBackend (NoLoggingT (ResourceT IO)) [Entity Tracker]
+getTicketTrackers :: (MonadLogger (t (ResourceT IO)), MonadIO (t (ResourceT IO)))
+ => UUID -> ReaderT SqlBackend (t (ResourceT IO)) [Entity Tracker]
getTicketTrackers ticketId = do
joins <- selectList [TrackerTicketTicket ==. TicketKey ticketId] []
<&> fmap (trackerTicketTracker . entityVal)
- selectList [TrackerId <-. joins] []
+ selectList ([TrackerId <-. joins] ||. [TrackerCurrentTicket ==. Just (TicketKey ticketId)]) []
-- TODO: proper debug logging for expired tokens
-isTokenValid :: MonadIO m => Pool SqlBackend -> TrackerId -> TicketId -> m (Maybe (Tracker, Ticket))
-isTokenValid dbpool token ticketId = runSql dbpool $ get token >>= \case
+isTokenValid :: Pool SqlBackend -> TrackerId -> ServiceM (Maybe Tracker)
+isTokenValid dbpool token = runSql dbpool $ get token >>= \case
Just tracker | not (trackerBlocked tracker) -> do
ifM (hasExpired (trackerExpires tracker))
(pure Nothing)
- $ runSql dbpool $ get ticketId
- <&> (\case { Nothing -> Nothing; Just ticket -> Just (tracker, ticket) })
+ (pure (Just tracker))
_ -> pure Nothing
hasExpired :: MonadIO m => UTCTime -> m Bool