diff options
Diffstat (limited to '')
| -rw-r--r-- | lib/Server.hs | 199 | 
1 files changed, 103 insertions, 96 deletions
| diff --git a/lib/Server.hs b/lib/Server.hs index 1b79300..4a78735 100644 --- a/lib/Server.hs +++ b/lib/Server.hs @@ -6,11 +6,13 @@  {-# LANGUAGE FlexibleInstances          #-}  {-# LANGUAGE GADTs                      #-}  {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase                 #-}  {-# LANGUAGE MultiParamTypeClasses      #-}  {-# LANGUAGE QuasiQuotes                #-}  {-# LANGUAGE RecordWildCards            #-}  {-# LANGUAGE StandaloneDeriving         #-}  {-# LANGUAGE TemplateHaskell            #-} +{-# LANGUAGE TupleSections              #-}  {-# LANGUAGE TypeApplications           #-}  {-# LANGUAGE TypeFamilies               #-}  {-# LANGUAGE TypeOperators              #-} @@ -18,115 +20,120 @@  {-# LANGUAGE UndecidableInstances       #-}  module Server where -import           Servant                (Application, -                                         FromHttpApiData (parseUrlPiece), -                                         Server, err404, serve, throwError, -                                         type (:>)) -import           Servant.API            (Capture, FromHttpApiData, Get, JSON, -                                         Post, ReqBody, type (:<|>) ((:<|>))) -import           Servant.Docs           (DocCapture (..), DocQueryParam (..), -                                         ParamKind (..), ToCapture (..), -                                         ToParam (..)) - +import           Conduit                        (MonadTrans (lift), ResourceT)  import           Control.Concurrent.STM -import           Control.Monad.IO.Class (MonadIO (liftIO)) -import           Data.Aeson             (FromJSON (parseJSON), ToJSON (toJSON), -                                         ToJSONKey, genericParseJSON, -                                         genericToJSON) -import qualified Data.Aeson             as A -import           Data.Functor           ((<&>)) -import           Data.Map               (Map) -import qualified Data.Map               as M -import           Data.Proxy             (Proxy (Proxy)) -import           Data.Swagger -import           Data.Text              (Text) -import           Data.Time              (UTCTime (utctDay), dayOfWeek, -                                         getCurrentTime) -import           Data.UUID              (UUID) -import qualified Data.UUID              as UUID -import qualified Data.UUID.V4           as UUID -import           Data.Vector            (Vector) +import           Control.Monad                  (when) +import           Control.Monad.Extra            (whenM) +import           Control.Monad.IO.Class         (MonadIO (liftIO)) +import           Control.Monad.Logger.CallStack (NoLoggingT) +import           Control.Monad.Reader           (forM) +import           Control.Monad.Trans.Maybe      (MaybeT (..)) +import           Data.Aeson                     (FromJSON (parseJSON), +                                                 ToJSON (toJSON), ToJSONKey, +                                                 genericParseJSON, +                                                 genericToJSON) +import qualified Data.Aeson                     as A +import           Data.Coerce                    (coerce) +import           Data.Functor                   ((<&>)) +import           Data.Map                       (Map) +import qualified Data.Map                       as M +import           Data.Pool                      (Pool) +import           Data.Proxy                     (Proxy (Proxy)) +import           Data.Swagger                   hiding (get) +import           Data.Text                      (Text) +import           Data.Time                      (NominalDiffTime, +                                                 UTCTime (utctDay), addUTCTime, +                                                 dayOfWeek, diffUTCTime, +                                                 getCurrentTime, nominalDay) +import           Data.UUID                      (UUID) +import qualified Data.UUID                      as UUID +import qualified Data.UUID.V4                   as UUID +import           Data.Vector                    (Vector)  import           Database.Persist -import           Database.Persist.TH -import           GHC.Foreign            (withCStringsLen) -import           GHC.Generics           (Generic) +import           Database.Persist.Postgresql +import           GHC.Generics                   (Generic)  import           GTFS -import           PersistOrphans -import           Servant.Server         (Handler) -import           Servant.Swagger        (toSwagger) - -newtype Token = Token UUID -  deriving newtype (Show, ToJSON, Eq, Ord, FromHttpApiData, ToJSONKey) -instance ToSchema Token where -  declareNamedSchema _ = declareNamedSchema (Proxy @String) -instance ToParamSchema Token where -  toParamSchema _ = toParamSchema (Proxy @String) - -share [mkPersist sqlSettings, mkMigrate "migrateAll"] [persistLowerCase| -TripToken sql=tt_tracker_token -  Id UUID default=uuid_generate_v4() -  issued UTCTime -  blocked Bool -  tripNumber Text -  deriving Eq Show Generic - -TripPing json sql=tt_trip_ping -  token UUID -  latitude Double -  longitude Double -  delay Double -  timestamp UTCTime -  deriving Show Generic Eq - -|] - - -instance ToSchema TripPing where -  declareNamedSchema = genericDeclareNamedSchema (swaggerOptions "ping") - -type KnownTrips = TVar (Map Token [TripPing]) - -type API = "stations" :> Get '[JSON] (Map StationID Station) -  :<|> "timetable" :> Capture "Station ID" StationID :> Get '[JSON] (Map TripID (Trip Deep)) -  :<|> "trip" :> Capture "Trip ID" TripID :> Get '[JSON] (Trip Deep) -  -- ingress API (put this behind BasicAuth?) -  -- TODO: perhaps require a first ping for registration? -  :<|> "trainregister" :> Capture "Trip ID" TripID :> Post '[JSON] Token -  -- TODO: perhaps a websocket instead? -  :<|> "trainping" :> Capture "Train Token" Token :> ReqBody '[JSON] TripPing :> Post '[JSON] () -  -- debug things -  :<|> "debug" :> "state" :> Get '[JSON] (Map Token [TripPing]) -type CompleteAPI = "debug" :> "openapi" :> Get '[JSON] Swagger -  :<|> API - - - - - -server :: GTFS -> KnownTrips -> Server CompleteAPI -server gtfs@GTFS{..} knownTrains = handleDebugAPI :<|> handleStations :<|> handleTimetable :<|> handleTrip +import           Servant                        (Application, +                                                 FromHttpApiData (parseUrlPiece), +                                                 Server, err401, err404, serve, +                                                 throwError, type (:>)) +import           Servant.API                    (Capture, FromHttpApiData, Get, +                                                 JSON, Post, ReqBody, +                                                 type (:<|>) ((:<|>))) +import           Servant.Docs                   (DocCapture (..), +                                                 DocQueryParam (..), +                                                 ParamKind (..), ToCapture (..), +                                                 ToParam (..)) +import           Servant.Server                 (Handler) +import           Servant.Swagger                (toSwagger) +import           Web.PathPieces                 (PathPiece) + +import           API +import           Persist + +application :: GTFS -> Pool SqlBackend -> IO Application +application gtfs dbpool = do +  doMigration dbpool +  pure $ serve (Proxy @CompleteAPI) $ server gtfs dbpool + + + +-- databaseMigration :: ConnectionString -> IO () +doMigration pool = runSql pool $ +  -- TODO: before that, check if the uuid module is enabled +  -- in sql: check if SELECT * FROM pg_extension WHERE extname = 'uuid-ossp'; +  -- returns an empty list +  runMigration migrateAll + +server :: GTFS -> Pool SqlBackend -> Server CompleteAPI +server gtfs@GTFS{..} dbpool = handleDebugAPI :<|> handleStations :<|> handleTimetable :<|> handleTrip    :<|> handleRegister :<|> handleTripPing :<|> handleDebugState    where handleStations = pure stations          handleTimetable station = do +          -- TODO: resolve "overlay" trips (perhaps just additional CalendarDates?)            today <- liftIO getCurrentTime <&> utctDay            pure $ tripsOnDay gtfs today          handleTrip trip = case M.lookup trip trips of            Just res -> pure res            Nothing  -> throwError err404 -        handleRegister tripID = liftIO $ do -          token <- UUID.nextRandom <&> Token -          atomically $ modifyTVar knownTrains (M.insert token []) -          pure token -        handleTripPing token ping = liftIO $ atomically $ do -            modifyTVar knownTrains (M.update (\history -> Just (ping : history)) token) -            pure () -        handleDebugState = liftIO $ readTVarIO knownTrains +        handleRegister tripID = do +          expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod +          RunningTripKey uuid <- runSql dbpool $ insert (RunningTrip expires False tripID) +          pure (Token uuid) +        handleTripPing ping = do +          checkTokenValid dbpool (tripPingToken ping) +          -- TODO: are these always inserted in order? +          runSql dbpool $ insert ping +          pure () +        handleDebugState = do +          now <- liftIO $ getCurrentTime +          runSql dbpool $ do +           running <- selectList [RunningTripBlocked ==. False, RunningTripExpires >=. now] [] +           pairs <- forM running $ \(Entity (RunningTripKey uuid) _) -> do +             entities <- selectList [TripPingToken ==. Token uuid] [] +             pure (Token uuid, fmap entityVal entities) +           pure (M.fromList pairs)          handleDebugAPI = pure $ toSwagger (Proxy @API) -application :: GTFS -> IO Application -application gtfs = do -  knownTrips <- newTVarIO mempty -  pure $ serve (Proxy @CompleteAPI) $ server gtfs knownTrips +checkTokenValid :: Pool SqlBackend -> Token -> Handler () +checkTokenValid dbpool token = do +  trip <- try $ runSql dbpool $ get (coerce token) +  when (runningTripBlocked trip) +    $ throwError err401 +  whenM (hasExpired (runningTripExpires trip)) +    $ throwError err401 +  where try m = m >>= \case +          Just a  -> pure a +          Nothing -> throwError err404 + +hasExpired :: MonadIO m => UTCTime -> m Bool +hasExpired limit = do +  now <- liftIO getCurrentTime +  pure (now > limit) + +validityPeriod :: NominalDiffTime +validityPeriod = nominalDay + | 
