From 9e89c93b3b84b5c82c186cff62c33218a0a4d298 Mon Sep 17 00:00:00 2001 From: stuebinm Date: Thu, 16 Jun 2022 00:25:15 +0200 Subject: actually use the database (at least for a few simple things) Also, more modules! --- app/Main.hs | 11 ++- haskell-gtfs.cabal | 5 ++ lib/API.hs | 35 ++++++++++ lib/Persist.hs | 85 +++++++++++++++++++++++ lib/Server.hs | 199 +++++++++++++++++++++++++++-------------------------- 5 files changed, 236 insertions(+), 99 deletions(-) create mode 100644 lib/API.hs create mode 100644 lib/Persist.hs diff --git a/app/Main.hs b/app/Main.hs index 0b3165e..7d3b5dc 100644 --- a/app/Main.hs +++ b/app/Main.hs @@ -26,16 +26,21 @@ import Network.Wai.Handler.Warp (run) import Network.Wai.Middleware.RequestLogger (OutputFormat (..), RequestLoggerSettings (..), mkRequestLogger) +import Database.Persist.Postgresql +import Control.Monad.Logger (runStderrLoggingT) +import Control.Monad.IO.Class (MonadIO (liftIO)) import GTFS import Server +connStr = "user=travelynx" main :: IO () main = do gtfs <- loadGtfs "./gtfs.zip" - app <- application gtfs loggerMiddleware <- mkRequestLogger $ def { outputFormat = Detailed True } - putStrLn "starting server …" - run 4000 (loggerMiddleware app) + runStderrLoggingT $ withPostgresqlPool connStr 10 $ \pool -> liftIO $ do + app <- application gtfs pool + putStrLn "starting server …" + run 4000 (loggerMiddleware app) diff --git a/haskell-gtfs.cabal b/haskell-gtfs.cabal index a54a9ff..33700b0 100644 --- a/haskell-gtfs.cabal +++ b/haskell-gtfs.cabal @@ -32,6 +32,8 @@ executable haskell-gtfs , wai-extra , warp , data-default-class >= 0.1.2 + , persistent-postgresql + , monad-logger hs-source-dirs: app default-language: Haskell2010 default-extensions: OverloadedStrings @@ -67,6 +69,9 @@ library , conduit , path-pieces , either + , resource-pool + , transformers + , extra hs-source-dirs: lib exposed-modules: GTFS, Server, PersistOrphans default-language: Haskell2010 diff --git a/lib/API.hs b/lib/API.hs new file mode 100644 index 0000000..3fb4c3c --- /dev/null +++ b/lib/API.hs @@ -0,0 +1,35 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + + +module API where + +import Data.Map (Map) +import Data.Swagger (Swagger) +import GTFS +import Persist +import Servant (Application, FromHttpApiData (parseUrlPiece), + Server, err401, err404, serve, throwError, + type (:>)) +import Servant.API (Capture, FromHttpApiData, Get, JSON, Post, + ReqBody, type (:<|>) ((:<|>))) + +type API = "stations" :> Get '[JSON] (Map StationID Station) + :<|> "timetable" :> Capture "Station ID" StationID :> Get '[JSON] (Map TripID (Trip Deep)) + :<|> "trip" :> Capture "Trip ID" TripID :> Get '[JSON] (Trip Deep) + -- ingress API (put this behind BasicAuth?) + -- TODO: perhaps require a first ping for registration? + :<|> "trip" :> "register" :> Capture "Trip ID" TripID :> Post '[JSON] Token + -- TODO: perhaps a websocket instead? + :<|> "trip" :> "ping" :> ReqBody '[JSON] TripPing :> Post '[JSON] () + -- debug things + :<|> "debug" :> "state" :> Get '[JSON] (Map Token [TripPing]) +type CompleteAPI = "debug" :> "openapi" :> Get '[JSON] Swagger + :<|> API diff --git a/lib/Persist.hs b/lib/Persist.hs new file mode 100644 index 0000000..b4df1fb --- /dev/null +++ b/lib/Persist.hs @@ -0,0 +1,85 @@ +{-# LANGUAGE DataKinds #-} +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} +{-# LANGUAGE DerivingStrategies #-} +{-# LANGUAGE FlexibleContexts #-} +{-# LANGUAGE FlexibleInstances #-} +{-# LANGUAGE GADTs #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} +{-# LANGUAGE MultiParamTypeClasses #-} +{-# LANGUAGE QuasiQuotes #-} +{-# LANGUAGE RecordWildCards #-} +{-# LANGUAGE StandaloneDeriving #-} +{-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TupleSections #-} +{-# LANGUAGE TypeApplications #-} +{-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE TypeOperators #-} +{-# LANGUAGE UndecidableInstances #-} + + +module Persist where + +import Data.Aeson (FromJSON, ToJSON, ToJSONKey) +import Data.Swagger (ToParamSchema (..), ToSchema (..), + genericDeclareNamedSchema) +import Data.Text (Text) +import Data.UUID (UUID) +import Database.Persist +import Database.Persist.Sql (PersistFieldSql, + runSqlPersistMPool) +import Database.Persist.TH +import GTFS +import PersistOrphans +import Servant (FromHttpApiData) + +import Conduit (ResourceT) +import Control.Monad.IO.Class (MonadIO (liftIO)) +import Control.Monad.Logger (NoLoggingT) +import Control.Monad.Reader (ReaderT) +import Data.Data (Proxy (..)) +import Data.Pool (Pool) +import Data.Time (NominalDiffTime, + UTCTime (utctDay), addUTCTime, + dayOfWeek, diffUTCTime, + getCurrentTime, nominalDay) +import Database.Persist.Postgresql (SqlBackend) +import GHC.Generics (Generic) +import Web.PathPieces (PathPiece) + + + +newtype Token = Token UUID + deriving newtype + ( Show, ToJSON, FromJSON, Eq, Ord, FromHttpApiData + , ToJSONKey, PersistField, PersistFieldSql, PathPiece) +instance ToSchema Token where + declareNamedSchema _ = declareNamedSchema (Proxy @String) +instance ToParamSchema Token where + toParamSchema _ = toParamSchema (Proxy @String) + + +share [mkPersist sqlSettings, mkMigrate "migrateAll"] [persistLowerCase| +RunningTrip sql=tt_tracker_token + Id UUID default=uuid_generate_v4() + expires UTCTime + blocked Bool + tripNumber Text + deriving Eq Show Generic + +TripPing json sql=tt_trip_ping + token Token + latitude Double + longitude Double + delay Double + timestamp UTCTime + deriving Show Generic Eq + +|] + +instance ToSchema TripPing where + declareNamedSchema = genericDeclareNamedSchema (swaggerOptions "ping") + +runSql :: MonadIO m => Pool SqlBackend -> (ReaderT SqlBackend (NoLoggingT (ResourceT IO)) a) -> m a +runSql pool = liftIO . flip runSqlPersistMPool pool diff --git a/lib/Server.hs b/lib/Server.hs index 1b79300..4a78735 100644 --- a/lib/Server.hs +++ b/lib/Server.hs @@ -6,11 +6,13 @@ {-# LANGUAGE FlexibleInstances #-} {-# LANGUAGE GADTs #-} {-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE LambdaCase #-} {-# LANGUAGE MultiParamTypeClasses #-} {-# LANGUAGE QuasiQuotes #-} {-# LANGUAGE RecordWildCards #-} {-# LANGUAGE StandaloneDeriving #-} {-# LANGUAGE TemplateHaskell #-} +{-# LANGUAGE TupleSections #-} {-# LANGUAGE TypeApplications #-} {-# LANGUAGE TypeFamilies #-} {-# LANGUAGE TypeOperators #-} @@ -18,115 +20,120 @@ {-# LANGUAGE UndecidableInstances #-} module Server where -import Servant (Application, - FromHttpApiData (parseUrlPiece), - Server, err404, serve, throwError, - type (:>)) -import Servant.API (Capture, FromHttpApiData, Get, JSON, - Post, ReqBody, type (:<|>) ((:<|>))) -import Servant.Docs (DocCapture (..), DocQueryParam (..), - ParamKind (..), ToCapture (..), - ToParam (..)) - +import Conduit (MonadTrans (lift), ResourceT) import Control.Concurrent.STM -import Control.Monad.IO.Class (MonadIO (liftIO)) -import Data.Aeson (FromJSON (parseJSON), ToJSON (toJSON), - ToJSONKey, genericParseJSON, - genericToJSON) -import qualified Data.Aeson as A -import Data.Functor ((<&>)) -import Data.Map (Map) -import qualified Data.Map as M -import Data.Proxy (Proxy (Proxy)) -import Data.Swagger -import Data.Text (Text) -import Data.Time (UTCTime (utctDay), dayOfWeek, - getCurrentTime) -import Data.UUID (UUID) -import qualified Data.UUID as UUID -import qualified Data.UUID.V4 as UUID -import Data.Vector (Vector) +import Control.Monad (when) +import Control.Monad.Extra (whenM) +import Control.Monad.IO.Class (MonadIO (liftIO)) +import Control.Monad.Logger.CallStack (NoLoggingT) +import Control.Monad.Reader (forM) +import Control.Monad.Trans.Maybe (MaybeT (..)) +import Data.Aeson (FromJSON (parseJSON), + ToJSON (toJSON), ToJSONKey, + genericParseJSON, + genericToJSON) +import qualified Data.Aeson as A +import Data.Coerce (coerce) +import Data.Functor ((<&>)) +import Data.Map (Map) +import qualified Data.Map as M +import Data.Pool (Pool) +import Data.Proxy (Proxy (Proxy)) +import Data.Swagger hiding (get) +import Data.Text (Text) +import Data.Time (NominalDiffTime, + UTCTime (utctDay), addUTCTime, + dayOfWeek, diffUTCTime, + getCurrentTime, nominalDay) +import Data.UUID (UUID) +import qualified Data.UUID as UUID +import qualified Data.UUID.V4 as UUID +import Data.Vector (Vector) import Database.Persist -import Database.Persist.TH -import GHC.Foreign (withCStringsLen) -import GHC.Generics (Generic) +import Database.Persist.Postgresql +import GHC.Generics (Generic) import GTFS -import PersistOrphans -import Servant.Server (Handler) -import Servant.Swagger (toSwagger) - -newtype Token = Token UUID - deriving newtype (Show, ToJSON, Eq, Ord, FromHttpApiData, ToJSONKey) -instance ToSchema Token where - declareNamedSchema _ = declareNamedSchema (Proxy @String) -instance ToParamSchema Token where - toParamSchema _ = toParamSchema (Proxy @String) - -share [mkPersist sqlSettings, mkMigrate "migrateAll"] [persistLowerCase| -TripToken sql=tt_tracker_token - Id UUID default=uuid_generate_v4() - issued UTCTime - blocked Bool - tripNumber Text - deriving Eq Show Generic - -TripPing json sql=tt_trip_ping - token UUID - latitude Double - longitude Double - delay Double - timestamp UTCTime - deriving Show Generic Eq - -|] - - -instance ToSchema TripPing where - declareNamedSchema = genericDeclareNamedSchema (swaggerOptions "ping") - -type KnownTrips = TVar (Map Token [TripPing]) - -type API = "stations" :> Get '[JSON] (Map StationID Station) - :<|> "timetable" :> Capture "Station ID" StationID :> Get '[JSON] (Map TripID (Trip Deep)) - :<|> "trip" :> Capture "Trip ID" TripID :> Get '[JSON] (Trip Deep) - -- ingress API (put this behind BasicAuth?) - -- TODO: perhaps require a first ping for registration? - :<|> "trainregister" :> Capture "Trip ID" TripID :> Post '[JSON] Token - -- TODO: perhaps a websocket instead? - :<|> "trainping" :> Capture "Train Token" Token :> ReqBody '[JSON] TripPing :> Post '[JSON] () - -- debug things - :<|> "debug" :> "state" :> Get '[JSON] (Map Token [TripPing]) -type CompleteAPI = "debug" :> "openapi" :> Get '[JSON] Swagger - :<|> API - - - - - -server :: GTFS -> KnownTrips -> Server CompleteAPI -server gtfs@GTFS{..} knownTrains = handleDebugAPI :<|> handleStations :<|> handleTimetable :<|> handleTrip +import Servant (Application, + FromHttpApiData (parseUrlPiece), + Server, err401, err404, serve, + throwError, type (:>)) +import Servant.API (Capture, FromHttpApiData, Get, + JSON, Post, ReqBody, + type (:<|>) ((:<|>))) +import Servant.Docs (DocCapture (..), + DocQueryParam (..), + ParamKind (..), ToCapture (..), + ToParam (..)) +import Servant.Server (Handler) +import Servant.Swagger (toSwagger) +import Web.PathPieces (PathPiece) + +import API +import Persist + +application :: GTFS -> Pool SqlBackend -> IO Application +application gtfs dbpool = do + doMigration dbpool + pure $ serve (Proxy @CompleteAPI) $ server gtfs dbpool + + + +-- databaseMigration :: ConnectionString -> IO () +doMigration pool = runSql pool $ + -- TODO: before that, check if the uuid module is enabled + -- in sql: check if SELECT * FROM pg_extension WHERE extname = 'uuid-ossp'; + -- returns an empty list + runMigration migrateAll + +server :: GTFS -> Pool SqlBackend -> Server CompleteAPI +server gtfs@GTFS{..} dbpool = handleDebugAPI :<|> handleStations :<|> handleTimetable :<|> handleTrip :<|> handleRegister :<|> handleTripPing :<|> handleDebugState where handleStations = pure stations handleTimetable station = do + -- TODO: resolve "overlay" trips (perhaps just additional CalendarDates?) today <- liftIO getCurrentTime <&> utctDay pure $ tripsOnDay gtfs today handleTrip trip = case M.lookup trip trips of Just res -> pure res Nothing -> throwError err404 - handleRegister tripID = liftIO $ do - token <- UUID.nextRandom <&> Token - atomically $ modifyTVar knownTrains (M.insert token []) - pure token - handleTripPing token ping = liftIO $ atomically $ do - modifyTVar knownTrains (M.update (\history -> Just (ping : history)) token) - pure () - handleDebugState = liftIO $ readTVarIO knownTrains + handleRegister tripID = do + expires <- liftIO $ getCurrentTime <&> addUTCTime validityPeriod + RunningTripKey uuid <- runSql dbpool $ insert (RunningTrip expires False tripID) + pure (Token uuid) + handleTripPing ping = do + checkTokenValid dbpool (tripPingToken ping) + -- TODO: are these always inserted in order? + runSql dbpool $ insert ping + pure () + handleDebugState = do + now <- liftIO $ getCurrentTime + runSql dbpool $ do + running <- selectList [RunningTripBlocked ==. False, RunningTripExpires >=. now] [] + pairs <- forM running $ \(Entity (RunningTripKey uuid) _) -> do + entities <- selectList [TripPingToken ==. Token uuid] [] + pure (Token uuid, fmap entityVal entities) + pure (M.fromList pairs) handleDebugAPI = pure $ toSwagger (Proxy @API) -application :: GTFS -> IO Application -application gtfs = do - knownTrips <- newTVarIO mempty - pure $ serve (Proxy @CompleteAPI) $ server gtfs knownTrips +checkTokenValid :: Pool SqlBackend -> Token -> Handler () +checkTokenValid dbpool token = do + trip <- try $ runSql dbpool $ get (coerce token) + when (runningTripBlocked trip) + $ throwError err401 + whenM (hasExpired (runningTripExpires trip)) + $ throwError err401 + where try m = m >>= \case + Just a -> pure a + Nothing -> throwError err404 + +hasExpired :: MonadIO m => UTCTime -> m Bool +hasExpired limit = do + now <- liftIO getCurrentTime + pure (now > limit) + +validityPeriod :: NominalDiffTime +validityPeriod = nominalDay + -- cgit v1.2.3