use crate::error::Error; use std::time::Duration; use tokio::sync::mpsc::error::TryRecvError; use tokio::sync::{mpsc, oneshot}; use tokio::time::timeout; const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(3); pub type ResponseSender = oneshot::Sender; #[derive(Debug)] pub enum Message { Notification(N), Request(Req, ResponseSender), } pub struct Channel { _phantom: std::marker::PhantomData<(N, Req, Res)>, } impl Channel { pub fn unbounded() -> (Sender, Receiver) { let (sender, receiver) = mpsc::unbounded_channel(); (Sender { chan: sender }, Receiver { chan: receiver }) } } pub struct Sender { chan: mpsc::UnboundedSender>, } impl Clone for Sender { fn clone(&self) -> Self { Sender { chan: self.chan.clone(), } } } impl Sender { pub fn notify(&self, msg: N) -> Result<(), Error> { self.chan .send(Message::Notification(msg)) .map_err(|e| Error::SendError(e)) } pub async fn request(&self, request: Req) -> Result> { let (sender, receiver) = oneshot::channel(); self.chan .send(Message::Request(request, sender)) .map_err(|e| Error::SendError(e))?; timeout(DEFAULT_REQUEST_TIMEOUT, receiver) .await .map_err(|_| Error::TimeoutError)? .map_err(|e| Error::RecvError(e)) } } pub struct Receiver { chan: mpsc::UnboundedReceiver>, } impl Receiver { pub async fn recv(&mut self) -> Option> { self.chan.recv().await } pub fn try_recv(&mut self) -> Result, TryRecvError> { self.chan.try_recv() } } #[cfg(test)] mod tests { use super::*; #[derive(Debug)] enum Notification {} #[derive(Debug)] enum Request { GetNumber, } #[derive(Debug, PartialEq, Eq)] enum Response { GetNumber(u32), } #[tokio::test] async fn request_response() { let (tx, mut rx) = Channel::::unbounded(); let task1 = async move { match rx.recv().await.expect("not dropped") { Message::Notification(_) => {} Message::Request(Request::GetNumber, sender) => { sender.send(Response::GetNumber(42)).expect("not dropped"); } } }; let task2 = async move { let result = tx.request(Request::GetNumber).await.expect("not dropped"); assert_eq!(result, Response::GetNumber(42)); }; tokio::join!(task1, task2); } }