diff options
Diffstat (limited to 'lib/Server.hs')
-rw-r--r-- | lib/Server.hs | 91 |
1 files changed, 50 insertions, 41 deletions
diff --git a/lib/Server.hs b/lib/Server.hs index 016707b..c6d2d94 100644 --- a/lib/Server.hs +++ b/lib/Server.hs @@ -1,8 +1,9 @@ -{-# LANGUAGE DataKinds #-} -{-# LANGUAGE ExplicitNamespaces #-} -{-# LANGUAGE LambdaCase #-} -{-# LANGUAGE OverloadedLists #-} -{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE ExplicitNamespaces #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE OverloadedLists #-} +{-# LANGUAGE PartialTypeSignatures #-} +{-# LANGUAGE RecordWildCards #-} -- Implementation of the API. This module is the main point of the program. @@ -16,8 +17,8 @@ import Control.Monad.Catch (handle) import Control.Monad.Extra (ifM, maybeM, unlessM, whenJust, whenM) import Control.Monad.IO.Class (MonadIO (liftIO)) -import Control.Monad.Logger (LoggingT, logWarnN) -import Control.Monad.Reader (forM) +import Control.Monad.Logger (LoggingT, NoLoggingT, logWarnN) +import Control.Monad.Reader (ReaderT, forM) import Control.Monad.Trans (lift) import Data.Aeson ((.=)) import qualified Data.Aeson as A @@ -61,9 +62,11 @@ import Extrapolation (Extrapolator (..), LinearExtrapolator (..)) import System.IO.Unsafe +import Conduit (ResourceT) import Config (ServerConfig (serverConfigAssets)) import Data.ByteString (ByteString) import Data.ByteString.Lazy (toStrict) +import Data.UUID (UUID) import Prometheus import Prometheus.Metric.GHC @@ -83,7 +86,7 @@ doMigration pool = runSql pool $ -- returns an empty list runMigration migrateAll -server :: GTFS -> Metrics -> TVar (M.Map TripID [TQueue (Maybe TrainPing)]) -> Pool SqlBackend -> ServerConfig -> Service CompleteAPI +server :: GTFS -> Metrics -> TVar (M.Map UUID [TQueue (Maybe TrainPing)]) -> Pool SqlBackend -> ServerConfig -> Service CompleteAPI server gtfs@GTFS{..} Metrics{..} subscribers dbpool settings = handleDebugAPI :<|> (handleStations :<|> handleTimetable :<|> handleTimetableStops :<|> handleTrip :<|> handleRegister :<|> handleTrainPing (throwError err401) :<|> handleWS @@ -101,7 +104,7 @@ server gtfs@GTFS{..} Metrics{..} subscribers dbpool settings = handleDebugAPI pure . A.toJSON . fmap mkJson . M.elems $ tripsOnDay gtfs day where mkJson :: Trip Deep Deep -> A.Value mkJson Trip {..} = A.object - [ "trip" .= tripTripID + [ "trip" .= tripTripId , "sequencelength" .= (stopSequence . V.last) tripStops , "stops" .= fmap (\Stop{..} -> A.object [ "departure" .= toUTC stopDeparture tzseries day @@ -114,34 +117,35 @@ server gtfs@GTFS{..} Metrics{..} subscribers dbpool settings = handleDebugAPI handleTrip trip = case M.lookup trip trips of Just res -> pure res Nothing -> throwError err404 - handleRegister tripID RegisterJson{..} = do + handleRegister (ticketId :: UUID) RegisterJson{..} = do today <- liftIO getCurrentTime <&> utctDay - unless (runsOnDay gtfs tripID today) - $ sendErrorMsg "this trip does not run today." expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod - RunningKey token <- runSql dbpool $ insert (Running expires False tripID today Nothing registerAgent) - pure token - handleDebugRegister tripID day = do + runSql dbpool $ do + TrackerKey tracker <- insert (Tracker expires False registerAgent) + insert (TrackerTicket (TicketKey ticketId) (TrackerKey tracker)) + pure tracker + handleDebugRegister (ticketId :: UUID) = do expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod - RunningKey token <- runSql dbpool $ insert (Running expires False tripID day Nothing "debug key") - pure token - handleTrainPing onError ping = isTokenValid dbpool (coerce $ trainPingToken ping) >>= \case + runSql dbpool $ do + TrackerKey tracker <- insert (Tracker expires False "debug key") + insert (TrackerTicket (TicketKey ticketId) (TrackerKey tracker)) + pure tracker + handleTrainPing onError ping@TrainPing{..} = isTokenValid dbpool trainPingToken trainPingTicket + >>= \case Nothing -> do onError pure Nothing - Just running@Running{..} -> do - let anchor = extrapolateAnchorFromPing LinearExtrapolator gtfs running ping + Just (tracker@Tracker{..}, ticket@Ticket{..}) -> do + let anchor = extrapolateAnchorFromPing LinearExtrapolator gtfs ticket ping -- TODO: are these always inserted in order? runSql dbpool $ do insert ping - last <- selectFirst - [TrainAnchorTrip ==. runningTrip, TrainAnchorDay ==. runningDay] - [Desc TrainAnchorWhen] + last <- selectFirst [TrainAnchorTicket ==. trainPingTicket] [Desc TrainAnchorWhen] -- only insert new estimates if they've actually changed anything when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor)) $ void $ insert anchor queues <- liftIO $ atomically $ do - queues <- readTVar subscribers <&> M.lookup runningTrip + queues <- readTVar subscribers <&> M.lookup (coerce trainPingTicket) whenJust queues $ mapM_ (\q -> writeTQueue q (Just ping)) pure queues @@ -162,18 +166,18 @@ server gtfs@GTFS{..} Metrics{..} subscribers dbpool settings = handleDebugAPI liftIO $ handleTrainPing (WS.sendClose conn ("" :: ByteString)) ping >>= \case Just anchor -> WS.sendTextData conn (A.encode anchor) Nothing -> pure () - handleSubscribe tripId day conn = liftIO $ WS.withPingThread conn 30 (pure ()) $ do + handleSubscribe (ticketId :: UUID) conn = liftIO $ WS.withPingThread conn 30 (pure ()) $ do queue <- atomically $ do queue <- newTQueue qs <- readTVar subscribers writeTVar subscribers - $ M.insertWith (<>) tripId [queue] qs + $ M.insertWith (<>) ticketId [queue] qs pure queue -- send most recent ping, if any (so we won't have to wait for movement) lastPing <- runSql dbpool $ do - tokens <- selectList [RunningDay ==. day, RunningTrip ==. tripId] [] + trackers <- getTicketTrackers ticketId <&> fmap entityKey - selectFirst [TrainPingToken <-. tokens] [Desc TrainPingTimestamp] + selectFirst [TrainPingToken <-. trackers] [Desc TrainPingTimestamp] <&> fmap entityVal whenJust lastPing $ \ping -> WS.sendTextData conn (A.encode lastPing) @@ -187,34 +191,39 @@ server gtfs@GTFS{..} Metrics{..} subscribers dbpool settings = handleDebugAPI where removeSubscriber queue = atomically $ do qs <- readTVar subscribers writeTVar subscribers - $ M.adjust (filter (/= queue)) tripId qs + $ M.adjust (filter (/= queue)) ticketId qs handleDebugState = do now <- liftIO getCurrentTime runSql dbpool $ do - running <- selectList [RunningBlocked ==. False, RunningExpires >=. now] [] - pairs <- forM running $ \(Entity token@(RunningKey uuid) _) -> do + tracker <- selectList [TrackerBlocked ==. False, TrackerExpires >=. now] [] + pairs <- forM tracker $ \(Entity token@(TrackerKey uuid) _) -> do entities <- selectList [TrainPingToken ==. token] [] pure (uuid, fmap entityVal entities) pure (M.fromList pairs) - handleDebugTrain tripId day = do - unless (runsOnDay gtfs tripId day) - $ sendErrorMsg ("this trip does not run on "+|day|+".") + handleDebugTrain ticketId = do runSql dbpool $ do - tokens <- selectList [RunningTrip ==. tripId, RunningDay ==. day] [] - pings <- forM tokens $ \(Entity token _) -> do + trackers <- getTicketTrackers ticketId + pings <- forM trackers $ \(Entity token _) -> do selectList [TrainPingToken ==. token] [] <&> fmap entityVal pure (concat pings) handleDebugAPI = pure $ toSwagger (Proxy @API) metrics = exportMetricsAsText <&> (decodeUtf8 . toStrict) +getTicketTrackers :: UUID -> ReaderT SqlBackend (NoLoggingT (ResourceT IO)) [Entity Tracker] +getTicketTrackers ticketId = do + joins <- selectList [TrackerTicketTicket ==. TicketKey ticketId] [] + <&> fmap (trackerTicketTracker . entityVal) + selectList [TrackerId <-. joins] [] + -- TODO: proper debug logging for expired tokens -isTokenValid :: MonadIO m => Pool SqlBackend -> Token -> m (Maybe Running) -isTokenValid dbpool token = runSql dbpool $ get (coerce token) >>= \case - Just trip | not (runningBlocked trip) -> do - ifM (hasExpired (runningExpires trip)) +isTokenValid :: MonadIO m => Pool SqlBackend -> TrackerId -> TicketId -> m (Maybe (Tracker, Ticket)) +isTokenValid dbpool token ticketId = runSql dbpool $ get token >>= \case + Just tracker | not (trackerBlocked tracker) -> do + ifM (hasExpired (trackerExpires tracker)) (pure Nothing) - (pure (Just trip)) + $ runSql dbpool $ get ticketId + <&> (\case { Nothing -> Nothing; Just ticket -> Just (tracker, ticket) }) _ -> pure Nothing hasExpired :: MonadIO m => UTCTime -> m Bool |