(.require [library [lux (.except) [abstract [monad (.only do)]] [control ["[0]" pipe] ["[0]" io (.only IO)] ["[0]" try (.only Try)] ["[0]" exception (.only exception)]] [data [text ["%" \\format (.only format)]] [collection ["[0]" queue (.only Queue)]]] [math [number ["n" nat] ["i" int]]] [meta [type ["[0]" primitive (.except)] ["[0]" refinement]]]]] [// ["[0]" atom (.only Atom)] ["[0]" async (.only Async Resolver)]]) (type State (Record [#max_positions Nat #open_positions Int #waiting_list (Queue (Resolver Any))])) (primitive .public Semaphore (Atom State) (def most_positions_possible (.nat (at i.interval top))) (def .public (semaphore initial_open_positions) (-> Nat Semaphore) (let [max_positions (n.min initial_open_positions ..most_positions_possible)] (abstraction (atom.atom [#max_positions max_positions #open_positions (.int max_positions) #waiting_list queue.empty])))) (def .public (wait! semaphore) (Ex (_ k) (-> Semaphore (Async Any))) (let [semaphore (representation semaphore) [signal sink] (is [(Async Any) (Resolver Any)] (async.async []))] (exec (io.run! (with_expansions [ (these (the #open_positions) (i.> -1))] (do io.monad [[_ state'] (atom.update! (|>> (revised #open_positions --) (pipe.if [] [] [(revised #waiting_list (queue.end sink))])) semaphore)] (with_expansions [ (sink []) (in false)] (if (|> state' ) ))))) signal))) (exception .public (semaphore_is_maxed_out [max_positions Nat]) (exception.report "Max Positions" (%.nat max_positions))) (def .public (signal! semaphore) (Ex (_ k) (-> Semaphore (Async (Try Int)))) (let [semaphore (representation semaphore)] (async.future (do [! io.monad] [[pre post] (atom.update! (function (_ state) (if (i.= (.int (the #max_positions state)) (the #open_positions state)) state (|> state (revised #open_positions ++) (revised #waiting_list queue.next)))) semaphore)] (if (same? pre post) (in (exception.except ..semaphore_is_maxed_out [(the #max_positions pre)])) (do ! [_ (case (queue.front (the #waiting_list pre)) {.#None} (in true) {.#Some sink} (sink []))] (in {try.#Success (the #open_positions post)}))))))) ) (primitive .public Mutex Semaphore (def .public (mutex _) (-> Any Mutex) (abstraction (semaphore 1))) (def acquire! (-> Mutex (Async Any)) (|>> representation ..wait!)) (def release! (-> Mutex (Async (Try Int))) (|>> representation ..signal!)) (def .public (synchronize! mutex procedure) (All (_ a) (-> Mutex (IO (Async a)) (Async a))) (do async.monad [_ (..acquire! mutex) output (io.run! procedure) _ (..release! mutex)] (in output))) ) (def .public limit (refinement.refiner (n.> 0))) (type .public Limit (~ (refinement.type limit))) (primitive .public Barrier (Record [#limit Limit #count (Atom Nat) #start_turnstile Semaphore #end_turnstile Semaphore]) (def .public (barrier limit) (-> Limit Barrier) (abstraction [#limit limit #count (atom.atom 0) #start_turnstile (..semaphore 0) #end_turnstile (..semaphore 0)])) (def (un_block! times turnstile) (-> Nat Semaphore (Async Any)) (loop (again [step 0]) (if (n.< times step) (do async.monad [outcome (..signal! turnstile)] (again (++ step))) (at async.monad in [])))) (with_template [ ] [(def ( barrier) (-> Barrier (Async Any)) (do async.monad [.let [barrier (representation barrier) limit (refinement.value (the #limit barrier)) goal [_ count] (io.run! (atom.update! (the #count barrier))) reached? (n.= goal count)]] (if reached? (..un_block! (-- limit) (the barrier)) (..wait! (the barrier)))))] [start! ++ limit #start_turnstile] [end! -- 0 #end_turnstile] ) (def .public (block! barrier) (-> Barrier (Async Any)) (do async.monad [_ (..start! barrier)] (..end! barrier))) )