use crate::error::Error; use crate::metrics::unbounded_channel; use metrics::{Counter, CounterUsize}; use std::sync::Arc; use std::time::Duration; use tokio::sync::oneshot; use tokio::time::timeout; const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(3); pub type ResponseSender = oneshot::Sender; #[derive(Debug)] pub enum Message { Notification(N), Request(Req, ResponseSender), } pub struct Channel { _phantom: std::marker::PhantomData<(N, Req, Res)>, } impl Channel { pub fn unbounded(name: &str) -> (Sender, Receiver) { let metrics_group = format!("common_channel_{}", name); let (sender, receiver) = unbounded_channel(metrics_group.as_str()); let metrics_timeout = CounterUsize::register_with_group(metrics_group.as_str(), "timeout"); ( Sender { chan: sender, metrics_timeout, }, receiver, ) } } pub struct Sender { chan: crate::metrics::Sender>, metrics_timeout: Arc>, } impl Clone for Sender { fn clone(&self) -> Self { Sender { chan: self.chan.clone(), metrics_timeout: self.metrics_timeout.clone(), } } } impl Sender { pub fn notify(&self, msg: N) -> Result<(), Error> { self.chan .send(Message::Notification(msg)) .map_err(|e| Error::SendError(e)) } pub async fn request(&self, request: Req) -> Result> { let (sender, receiver) = oneshot::channel(); self.chan .send(Message::Request(request, sender)) .map_err(|e| Error::SendError(e))?; timeout(DEFAULT_REQUEST_TIMEOUT, receiver) .await .map_err(|_| { self.metrics_timeout.inc(1); Error::TimeoutError })? .map_err(|e| Error::RecvError(e)) } } pub type Receiver = crate::metrics::Receiver>; #[cfg(test)] mod tests { use super::*; #[derive(Debug)] enum Notification {} #[derive(Debug)] enum Request { GetNumber, } #[derive(Debug, PartialEq, Eq)] enum Response { GetNumber(u32), } #[tokio::test] async fn request_response() { let (tx, mut rx) = Channel::::unbounded("test"); let task1 = async move { match rx.recv().await.expect("not dropped") { Message::Notification(_) => {} Message::Request(Request::GetNumber, sender) => { sender.send(Response::GetNumber(42)).expect("not dropped"); } } }; let task2 = async move { let result = tx.request(Request::GetNumber).await.expect("not dropped"); assert_eq!(result, Response::GetNumber(42)); }; tokio::join!(task1, task2); } }