diff options
Diffstat (limited to 'lib/Server')
-rw-r--r-- | lib/Server/Base.hs | 9 | ||||
-rw-r--r-- | lib/Server/Ingest.hs | 211 | ||||
-rw-r--r-- | lib/Server/Subscribe.hs | 63 | ||||
-rw-r--r-- | lib/Server/Util.hs | 64 |
4 files changed, 321 insertions, 26 deletions
diff --git a/lib/Server/Base.hs b/lib/Server/Base.hs new file mode 100644 index 0000000..14b77ca --- /dev/null +++ b/lib/Server/Base.hs @@ -0,0 +1,9 @@ + +module Server.Base (ServerState) where + +import Control.Concurrent.STM (TQueue, TVar) +import qualified Data.Map as M +import Data.UUID (UUID) +import Persist (TrainPing) + +type ServerState = TVar (M.Map UUID [TQueue (Maybe TrainPing)]) diff --git a/lib/Server/Ingest.hs b/lib/Server/Ingest.hs new file mode 100644 index 0000000..d774017 --- /dev/null +++ b/lib/Server/Ingest.hs @@ -0,0 +1,211 @@ +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE RecordWildCards #-} + +module Server.Ingest (handleTrackerRegister, handleTrainPing, handleWS) where +import API (Metrics (..), + RegisterJson (..)) +import Control.Concurrent.STM (atomically, readTVar, + writeTQueue) +import Control.Monad (forever, void, when) +import Control.Monad.Catch (handle) +import Control.Monad.Extra (ifM, mapMaybeM, whenJust) +import Control.Monad.IO.Class (MonadIO (liftIO)) +import Control.Monad.Logger (LoggingT, logInfoN, + logWarnN) +import Control.Monad.Reader (ReaderT) +import qualified Data.Aeson as A +import qualified Data.ByteString.Char8 as C8 +import Data.Coerce (coerce) +import Data.Functor ((<&>)) +import qualified Data.Map as M +import Data.Pool (Pool) +import Data.Text (Text) +import Data.Text.Encoding (decodeASCII, decodeUtf8) +import Data.Time (NominalDiffTime, + UTCTime (..), addUTCTime, + getCurrentTime, + nominalDay) +import qualified Data.Vector as V +import Database.Persist +import Database.Persist.Postgresql (SqlBackend) +import Fmt ((+|), (|+)) +import qualified GTFS +import qualified Network.WebSockets as WS +import Persist +import Servant (err400, throwError) +import Servant.Server (Handler) +import Server.Util (ServiceM, getTzseries, + utcToSeconds) + +import Config (LoggingConfig, + ServerConfig) +import Data.ByteString (ByteString) +import Data.ByteString.Lazy (toStrict) +import Data.Foldable (minimumBy) +import Data.Function (on, (&)) +import qualified Data.Text as T +import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries) +import qualified Data.UUID as UUID +import Extrapolation (Extrapolator (..), + LinearExtrapolator (..), + euclid) +import GTFS (seconds2Double) +import Prometheus (decGauge, incGauge) +import Server.Base (ServerState) + + + +handleTrackerRegister + :: Pool SqlBackend + -> RegisterJson + -> ServiceM Token +handleTrackerRegister dbpool RegisterJson{..} = do + today <- liftIO getCurrentTime <&> utctDay + expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod + runSql dbpool $ do + TrackerKey tracker <- insert (Tracker expires False registerAgent Nothing) + pure tracker + where + validityPeriod :: NominalDiffTime + validityPeriod = nominalDay + +handleTrainPing + :: Pool SqlBackend + -> ServerState + -> ServerConfig + -> LoggingT (ReaderT LoggingConfig Handler) a + -> TrainPing + -> LoggingT (ReaderT LoggingConfig Handler) (Maybe TrainAnchor) +handleTrainPing dbpool subscribers cfg onError ping@TrainPing{..} = + isTokenValid dbpool trainPingToken >>= \case + Nothing -> onError >> pure Nothing + Just tracker@Tracker{..} -> do + + -- if the tracker is not associated with a ticket, it is probably + -- just starting out on a new trip, or has finished an old one. + maybeTicketId <- case trackerCurrentTicket of + Just ticketId -> pure (Just ticketId) + Nothing -> runSql dbpool $ do + now <- liftIO getCurrentTime + tickets <- selectList [ TicketDay ==. utctDay now, TicketCompleted ==. False ] [] + ticketsWithFirstStation <- flip mapMaybeM tickets + (\ticket@(Entity ticketId _) -> do + selectFirst [StopTicket ==. ticketId] [Asc StopSequence] >>= \case + Nothing -> pure Nothing + Just (Entity _ stop) -> do + station <- getJust (stopStation stop) + tzseries <- liftIO $ getTzseries cfg (GTFS.tzname (stopDeparture stop)) + pure (Just (ticket, station, stop, tzseries))) + + if null ticketsWithFirstStation then pure Nothing else do + let (closestTicket, _, _, _) = minimumBy + -- (compare `on` euclid trainPingGeopos . stationGeopos . snd) + (compare `on` + (\(Entity _ ticket, station, stop, tzseries) -> + let + runningDay = ticketDay ticket + spaceDistance = euclid trainPingGeopos (stationGeopos station) + timeDiff = + GTFS.toSeconds (stopDeparture stop) tzseries runningDay + - utcToSeconds now runningDay + in + euclid trainPingGeopos (stationGeopos station) + + abs (seconds2Double timeDiff / 3600))) + ticketsWithFirstStation + logInfoN + $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ + " is now handling ticket "+|UUID.toString (coerce (entityKey closestTicket))|+ + " (trip "+|ticketTripName (entityVal closestTicket)|+")." + + update trainPingToken + [TrackerCurrentTicket =. Just (entityKey closestTicket)] + + pure (Just (entityKey closestTicket)) + + ticketId <- case maybeTicketId of + Just ticketId -> pure ticketId + Nothing -> do + logWarnN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ + " sent a ping, but no trips are running today." + throwError err400 + + runSql dbpool $ do + ticket@Ticket{..} <- getJust ticketId + + stations <- selectList [ StopTicket ==. ticketId ] [Asc StopArrival] + >>= mapM (\stop -> do + station <- getJust (stopStation (entityVal stop)) + tzseries <- liftIO $ getTzseries cfg (GTFS.tzname (stopArrival (entityVal stop))) + pure (entityVal stop, station, tzseries)) + <&> V.fromList + + shapePoints <- selectList [ShapePointShape ==. ticketShape] [Asc ShapePointIndex] + <&> (V.fromList . fmap entityVal) + + let anchor = extrapolateAnchorFromPing LinearExtrapolator + ticketId ticket stations shapePoints ping + + insert ping + + last <- selectFirst [TrainAnchorTicket ==. ticketId] [Desc TrainAnchorWhen] + -- only insert new estimates if they've actually changed anything + when (fmap (trainAnchorDelay . entityVal) last /= Just (trainAnchorDelay anchor)) + $ void $ insert anchor + + -- are we at the final stop? if so, mark this ticket as done + -- & the tracker as free + let maxSequence = V.last stations + & (\(stop, _, _) -> stopSequence stop) + & fromIntegral + when (trainAnchorSequence anchor + 0.1 >= maxSequence) $ do + update trainPingToken + [TrackerCurrentTicket =. Nothing] + update ticketId + [TicketCompleted =. True] + logInfoN $ "Tracker "+|UUID.toString (coerce trainPingToken)|+ + " has completed ticket "+|UUID.toString (coerce ticketId)|+ + " (trip "+|ticketTripName|+")" + + queues <- liftIO $ atomically $ do + queues <- readTVar subscribers <&> M.lookup (coerce ticketId) + whenJust queues $ + mapM_ (\q -> writeTQueue q (Just ping)) + pure queues + pure (Just anchor) + +handleWS + :: Pool SqlBackend + -> ServerState + -> ServerConfig + -> Metrics + -> WS.Connection -> ServiceM () +handleWS dbpool subscribers cfg Metrics{..} conn = do + liftIO $ WS.forkPingThread conn 30 + incGauge metricsWSGauge + handle (\(e :: WS.ConnectionException) -> decGauge metricsWSGauge) $ forever $ do + msg <- liftIO $ WS.receiveData conn + case A.eitherDecode msg of + Left err -> do + logWarnN ("stray websocket message: "+|decodeASCII (toStrict msg)|+" (could not decode: "+|err|+")") + liftIO $ WS.sendClose conn (C8.pack err) + -- TODO: send a close msg (Nothing) to the subscribed queues? decGauge metricsWSGauge + Right ping -> do + -- if invalid token, send a "polite" close request. Note that the client may + -- ignore this and continue sending messages, which will continue to be handled. + handleTrainPing dbpool subscribers cfg (liftIO $ WS.sendClose conn ("" :: ByteString)) ping >>= \case + Just anchor -> liftIO $ WS.sendTextData conn (A.encode anchor) + Nothing -> pure () + +-- TODO: proper debug logging for expired tokens +isTokenValid :: Pool SqlBackend -> TrackerId -> ServiceM (Maybe Tracker) +isTokenValid dbpool token = runSql dbpool $ get token >>= \case + Just tracker | not (trackerBlocked tracker) -> do + ifM (hasExpired (trackerExpires tracker)) + (pure Nothing) + (pure (Just tracker)) + _ -> pure Nothing + +hasExpired :: MonadIO m => UTCTime -> m Bool +hasExpired limit = do + now <- liftIO getCurrentTime + pure (now > limit) diff --git a/lib/Server/Subscribe.hs b/lib/Server/Subscribe.hs new file mode 100644 index 0000000..fdc092b --- /dev/null +++ b/lib/Server/Subscribe.hs @@ -0,0 +1,63 @@ + +module Server.Subscribe where +import Conduit (MonadIO (..)) +import Control.Concurrent.STM (atomically, newTQueue, readTQueue, + readTVar, writeTVar) +import Control.Exception (handle) +import Control.Monad.Extra (forever, whenJust) +import qualified Data.Aeson as A +import qualified Data.ByteString.Char8 as C8 +import Data.Functor ((<&>)) +import Data.Map (Map) +import qualified Data.Map as M +import Data.Pool +import Data.UUID (UUID) +import Database.Persist (Entity (entityKey), SelectOpt (Desc), + entityVal, selectFirst, selectList, + (<-.), (==.), (||.)) +import Database.Persist.Sql (SqlBackend) +import qualified Network.WebSockets as WS +import Persist +import Server.Base (ServerState) +import Server.Util (ServiceM) + + +handleSubscribe + :: Pool SqlBackend + -> ServerState + -> UUID + -> WS.Connection + -> ServiceM () +handleSubscribe dbpool subscribers (ticketId :: UUID) conn = liftIO $ WS.withPingThread conn 30 (pure ()) $ do + queue <- atomically $ do + queue <- newTQueue + qs <- readTVar subscribers + writeTVar subscribers + $ M.insertWith (<>) ticketId [queue] qs + pure queue + -- send most recent ping, if any (so we won't have to wait for movement) + lastPing <- runSqlWithoutLog dbpool $ do + trackers <- getTicketTrackers ticketId + <&> fmap entityKey + selectFirst [TrainPingToken <-. trackers] [Desc TrainPingTimestamp] + <&> fmap entityVal + whenJust lastPing $ \ping -> + WS.sendTextData conn (A.encode lastPing) + handle (\(e :: WS.ConnectionException) -> removeSubscriber queue) $ forever $ do + res <- atomically $ readTQueue queue + case res of + Just ping -> WS.sendTextData conn (A.encode ping) + Nothing -> do + removeSubscriber queue + WS.sendClose conn (C8.pack "train ended") + where removeSubscriber queue = atomically $ do + qs <- readTVar subscribers + writeTVar subscribers + $ M.adjust (filter (/= queue)) ticketId qs + +-- getTicketTrackers :: (MonadLogger (t (ResourceT IO)), MonadIO (t (ResourceT IO))) +-- => UUID -> ReaderT SqlBackend (t (ResourceT IO)) [Entity Tracker] +getTicketTrackers ticketId = do + joins <- selectList [TrackerTicketTicket ==. TicketKey ticketId] [] + <&> fmap (trackerTicketTracker . entityVal) + selectList ([TrackerId <-. joins] ||. [TrackerCurrentTicket ==. Just (TicketKey ticketId)]) [] diff --git a/lib/Server/Util.hs b/lib/Server/Util.hs index 0106428..290b9c5 100644 --- a/lib/Server/Util.hs +++ b/lib/Server/Util.hs @@ -1,33 +1,41 @@ {-# LANGUAGE BlockArguments #-} {-# LANGUAGE RecordWildCards #-} -- | mostly the monad the service runs in -module Server.Util (Service, ServiceM, runService, sendErrorMsg, secondsNow, utcToSeconds, runLogging) where +module Server.Util (Service, ServiceM, runService, sendErrorMsg, secondsNow, utcToSeconds, runLogging, getTzseries) where -import Config (LoggingConfig (..)) -import Control.Exception (handle, try) -import Control.Monad.Extra (void, whenJust) -import Control.Monad.IO.Class (MonadIO (liftIO)) -import Control.Monad.Logger (Loc, LogLevel (..), LogSource, LogStr, - LoggingT (..), defaultOutput, - fromLogStr, runStderrLoggingT) -import Control.Monad.Reader (ReaderT (..)) -import qualified Data.Aeson as A -import Data.ByteString (ByteString) -import qualified Data.ByteString as C8 -import Data.Text (Text) -import qualified Data.Text as T -import Data.Text.Encoding (decodeUtf8Lenient) -import Data.Time (Day, UTCTime (..), diffUTCTime, - getCurrentTime, - nominalDiffTimeToSeconds) -import Fmt ((+|), (|+)) -import GHC.IO.Exception (IOException (IOError)) -import GTFS (Seconds (..)) -import Prometheus (MonadMonitor (doIO)) -import Servant (Handler, ServerError, ServerT, err404, - errBody, errHeaders, throwError) -import System.IO (stderr) -import System.Process.Extra (callProcess) +import Config (LoggingConfig (..), + ServerConfig (..)) +import Control.Exception (handle, try) +import Control.Monad.Extra (void, whenJust) +import Control.Monad.IO.Class (MonadIO (liftIO)) +import Control.Monad.Logger (Loc, LogLevel (..), + LogSource, LogStr, + LoggingT (..), + defaultOutput, fromLogStr, + runStderrLoggingT) +import Control.Monad.Reader (ReaderT (..)) +import qualified Data.Aeson as A +import Data.ByteString (ByteString) +import qualified Data.ByteString as C8 +import Data.Text (Text) +import qualified Data.Text as T +import Data.Text.Encoding (decodeUtf8Lenient) +import Data.Time (Day, UTCTime (..), + diffUTCTime, + getCurrentTime, + nominalDiffTimeToSeconds) +import Data.Time.LocalTime.TimeZone.Olson (getTimeZoneSeriesFromOlsonFile) +import Data.Time.LocalTime.TimeZone.Series (TimeZoneSeries) +import Fmt ((+|), (|+)) +import GHC.IO.Exception (IOException (IOError)) +import GTFS (Seconds (..)) +import Prometheus (MonadMonitor (doIO)) +import Servant (Handler, ServerError, + ServerT, err404, errBody, + errHeaders, throwError) +import System.FilePath ((</>)) +import System.IO (stderr) +import System.Process.Extra (callProcess) type ServiceM = LoggingT (ReaderT LoggingConfig Handler) type Service api = ServerT api ServiceM @@ -77,3 +85,7 @@ secondsNow runningDay = do utcToSeconds :: UTCTime -> Day -> Seconds utcToSeconds time day = Seconds $ round $ nominalDiffTimeToSeconds $ diffUTCTime time (UTCTime day 0) + +getTzseries :: ServerConfig -> Text -> IO TimeZoneSeries +getTzseries settings tzname = getTimeZoneSeriesFromOlsonFile + (serverConfigZoneinfoPath settings </> T.unpack tzname) |