From 6685757a03b559b589a8f8909f9c4d635c0d6cef Mon Sep 17 00:00:00 2001 From: Bo QIU <35757521+boqiu@users.noreply.github.com> Date: Thu, 29 Aug 2024 18:35:35 +0800 Subject: [PATCH] notify file announcement to sync layer only if shard config matches (#174) * notify file announcement to sync layer only if shard config matches * use FromStr trait --- node/router/src/libp2p_event_handler.rs | 23 +++- node/src/config/convert.rs | 2 +- node/storage/src/config.rs | 171 +++++++++++++++++------- 3 files changed, 141 insertions(+), 55 deletions(-) diff --git a/node/router/src/libp2p_event_handler.rs b/node/router/src/libp2p_event_handler.rs index 05619b4..ae80c11 100644 --- a/node/router/src/libp2p_event_handler.rs +++ b/node/router/src/libp2p_event_handler.rs @@ -681,6 +681,12 @@ impl Libp2pEventHandler { return MessageAcceptance::Reject; } + // verify announced shard config + let announced_shard_config = match ShardConfig::new(msg.shard_id, msg.num_shard) { + Ok(v) => v, + Err(_) => return MessageAcceptance::Reject, + }; + // propagate gossip to peers let d = duration_since( msg.resend_timestamp, @@ -692,13 +698,16 @@ impl Libp2pEventHandler { return MessageAcceptance::Ignore; } - // notify sync layer - for tx_id in msg.tx_ids.iter() { - self.send_to_sync(SyncMessage::AnnounceFileGossip { - tx_id: *tx_id, - peer_id: msg.peer_id.clone().into(), - addr: addr.clone(), - }); + // notify sync layer if shard config matches + let my_shard_config = self.store.get_store().flow().get_shard_config(); + if my_shard_config.intersect(&announced_shard_config) { + for tx_id in msg.tx_ids.iter() { + self.send_to_sync(SyncMessage::AnnounceFileGossip { + tx_id: *tx_id, + peer_id: msg.peer_id.clone().into(), + addr: addr.clone(), + }); + } } // insert message to cache diff --git a/node/src/config/convert.rs b/node/src/config/convert.rs index dc0113b..450fbd0 100644 --- a/node/src/config/convert.rs +++ b/node/src/config/convert.rs @@ -253,6 +253,6 @@ impl ZgsConfig { } fn shard_config(&self) -> Result { - ShardConfig::new(&self.shard_position) + self.shard_position.clone().try_into() } } diff --git a/node/storage/src/config.rs b/node/storage/src/config.rs index 7b73075..deccf57 100644 --- a/node/storage/src/config.rs +++ b/node/storage/src/config.rs @@ -1,6 +1,6 @@ use serde::{Deserialize, Serialize}; use ssz_derive::{Decode, Encode}; -use std::{cell::RefCell, path::PathBuf, rc::Rc}; +use std::{cell::RefCell, path::PathBuf, rc::Rc, str::FromStr}; pub const SHARD_CONFIG_KEY: &str = "shard_config"; @@ -25,51 +25,11 @@ impl Default for ShardConfig { } } -impl ShardConfig { - pub fn new(shard_position: &Option) -> Result { - let (id, num) = if let Some(position) = shard_position { - Self::parse_position(position)? - } else { - (0, 1) - }; +impl FromStr for ShardConfig { + type Err = String; - if id >= num { - return Err(format!( - "Incorrect shard_id: expected [0, {}), actual {}", - num, id - )); - } - - if !num.is_power_of_two() { - return Err(format!( - "Incorrect shard group bytes: {}, should be power of two", - num - )); - } - Ok(ShardConfig { - shard_id: id, - num_shard: num, - }) - } - - pub fn miner_shard_mask(&self) -> u64 { - !(self.num_shard - 1) as u64 - } - - pub fn miner_shard_id(&self) -> u64 { - self.shard_id as u64 - } - - pub fn is_valid(&self) -> bool { - self.num_shard > 0 && self.num_shard.is_power_of_two() && self.shard_id < self.num_shard - } - - pub fn in_range(&self, segment_index: u64) -> bool { - segment_index as usize % self.num_shard == self.shard_id - } - - pub fn parse_position(input: &str) -> Result<(usize, usize), String> { - let parts: Vec<&str> = input.trim().split('/').map(|s| s.trim()).collect(); + fn from_str(s: &str) -> Result { + let parts: Vec<&str> = s.trim().split('/').map(|s| s.trim()).collect(); if parts.len() != 2 { return Err("Incorrect format, expected like: '0 / 8'".into()); @@ -82,7 +42,66 @@ impl ShardConfig { .parse::() .map_err(|e| format!("Cannot parse shard position {:?}", e))?; - Ok((numerator, denominator)) + Self::new(numerator, denominator) + } +} + +impl TryFrom> for ShardConfig { + type Error = String; + + fn try_from(value: Option) -> Result { + if let Some(position) = value { + Self::from_str(&position) + } else { + Ok(Self::default()) + } + } +} + +impl ShardConfig { + pub fn new(id: usize, num: usize) -> Result { + let config = ShardConfig { + shard_id: id, + num_shard: num, + }; + + config.validate()?; + + Ok(config) + } + + pub fn miner_shard_mask(&self) -> u64 { + !(self.num_shard - 1) as u64 + } + + pub fn miner_shard_id(&self) -> u64 { + self.shard_id as u64 + } + + pub fn validate(&self) -> Result<(), String> { + if self.shard_id >= self.num_shard { + return Err(format!( + "Incorrect shard_id: expected [0, {}), actual {}", + self.num_shard, self.shard_id + )); + } + + if self.num_shard == 0 { + return Err("Shard num is 0".into()); + } + + if !self.num_shard.is_power_of_two() { + return Err(format!( + "Incorrect shard group bytes: {}, should be power of two", + self.num_shard + )); + } + + Ok(()) + } + + pub fn in_range(&self, segment_index: u64) -> bool { + segment_index as usize % self.num_shard == self.shard_id } pub fn next_segment_index(&self, current: usize, start_index: usize) -> usize { @@ -90,6 +109,30 @@ impl ShardConfig { let shift = (start_index + current + self.num_shard - self.shard_id) % self.num_shard; current + self.num_shard - shift } + + /// Whether `self` intersect with the `other` shard config. + pub fn intersect(&self, other: &ShardConfig) -> bool { + let ShardConfig { + num_shard: mut left_num_shard, + shard_id: mut left_shard_id, + } = self; + let ShardConfig { + num_shard: mut right_num_shard, + shard_id: mut right_shard_id, + } = other; + + while left_num_shard != right_num_shard { + if left_num_shard < right_num_shard { + right_num_shard /= 2; + right_shard_id /= 2; + } else { + left_num_shard /= 2; + left_shard_id /= 2; + } + } + + left_shard_id == right_shard_id + } } struct ShardSegmentTreeNode { @@ -146,7 +189,7 @@ impl ShardSegmentTreeNode { pub fn all_shards_available(shard_configs: Vec) -> bool { let mut root = ShardSegmentTreeNode::new(1); for shard_config in shard_configs.iter() { - if !shard_config.is_valid() { + if shard_config.validate().is_err() { continue; } root.insert(shard_config.num_shard, shard_config.shard_id); @@ -163,6 +206,10 @@ mod tests { use super::ShardConfig; + fn new_config(id: usize, num: usize) -> ShardConfig { + ShardConfig::new(id, num).unwrap() + } + #[test] fn test_all_shards_available() { assert!(all_shards_available(vec![ @@ -210,4 +257,34 @@ mod tests { }, ])); } + + #[test] + fn test_shard_intersect() { + // 1 shard + assert_eq!(new_config(0, 1).intersect(&new_config(0, 1)), true); + + // either is 1 shard + assert_eq!(new_config(0, 1).intersect(&new_config(0, 2)), true); + assert_eq!(new_config(0, 1).intersect(&new_config(1, 2)), true); + assert_eq!(new_config(0, 2).intersect(&new_config(0, 1)), true); + assert_eq!(new_config(1, 2).intersect(&new_config(0, 1)), true); + + // same shards + assert_eq!(new_config(1, 4).intersect(&new_config(0, 4)), false); + assert_eq!(new_config(1, 4).intersect(&new_config(1, 4)), true); + assert_eq!(new_config(1, 4).intersect(&new_config(2, 4)), false); + assert_eq!(new_config(1, 4).intersect(&new_config(3, 4)), false); + + // left shards is less + assert_eq!(new_config(1, 2).intersect(&new_config(0, 4)), false); + assert_eq!(new_config(1, 2).intersect(&new_config(1, 4)), false); + assert_eq!(new_config(1, 2).intersect(&new_config(2, 4)), true); + assert_eq!(new_config(1, 2).intersect(&new_config(3, 4)), true); + + // right shards is less + assert_eq!(new_config(1, 4).intersect(&new_config(0, 2)), true); + assert_eq!(new_config(1, 4).intersect(&new_config(1, 2)), false); + assert_eq!(new_config(2, 4).intersect(&new_config(0, 2)), false); + assert_eq!(new_config(2, 4).intersect(&new_config(1, 2)), true); + } }