update error handling for connecting to websockets

This commit is contained in:
Diego Prats 2024-11-12 19:12:14 -08:00
parent 0e1288b2c6
commit 16a57fdae1

View File

@ -135,16 +135,60 @@ async fn main() {
json!({"prover_id": prover_id}), json!({"prover_id": prover_id}),
); );
let (mut client, _) = tokio_tungstenite::connect_async(&ws_addr_string) // This function connects to the Orchestrator via websockets
.await // and returns the connected client
.unwrap(); async fn connect_to_orchestrator(ws_addr: &str) -> Result<WebSocketStream<MaybeTlsStream<TcpStream>>, Box<dyn std::error::Error>> {
track( // Connect to the Orchestrator via websockets
"connected".into(), let (client, _) = tokio_tungstenite::connect_async(ws_addr)
"Connected.".into(), .await
&ws_addr_string, // If the connection fails, print an error and return the error
json!({"prover_id": prover_id}), .map_err(|e| {
); eprintln!("Failed to connect to orchestrator at {}: {}", ws_addr, e);
e
})?;
// Return the connected client
Ok(client)
}
/// This function wraps connect_to_orchestrator and retries
/// with exponential backoff if the connection fails
async fn connect_to_orchestrator_with_retry(ws_addr: &str) -> WebSocketStream<MaybeTlsStream<TcpStream>> {
let mut attempt = 1;
loop {
match connect_to_orchestrator(ws_addr).await {
Ok(client) => {
track(
"connected".into(),
"Connected.".into(),
&ws_addr_string,
json!({"prover_id": prover_id}),
);
return client;
},
Err(e) => {
eprintln!(
"Could not connect to orchestrator (attempt {}). Retrying in {} seconds...",
attempt,
2u64.pow(attempt.min(6)), // Cap exponential backoff at 64 seconds
);
// Exponential backoff
tokio::time::sleep(
tokio::time::Duration::from_secs(2u64.pow(attempt.min(6)))
).await;
attempt += 1;
}
}
}
}
// Connect to the Orchestrator with exponential backoff
let mut client = connect_to_orchestrator_with_retry(&ws_addr_string).await;
let registration = ProverRequest { let registration = ProverRequest {
contents: Some(prover_request::Contents::Registration( contents: Some(prover_request::Contents::Registration(