
341 lines
16 KiB

use tokio::sync::oneshot;
use crate::{AppData, AppDataResponse, AppError, NodeId, RaftNetwork, RaftStorage};
use crate::config::SnapshotPolicy;
use crate::error::RaftResult;
use crate::core::{ConsensusState, LeaderState, ReplicationState, SnapshotState, State, UpdateCurrentLeader};
use crate::replication::{RaftEvent, ReplicaEvent, ReplicationStream};
use crate::storage::CurrentSnapshotData;
impl<'a, D: AppData, R: AppDataResponse, E: AppError, N: RaftNetwork<D, E>, S: RaftStorage<D, R, E>> LeaderState<'a, D, R, E, N, S> {
/// Spawn a new replication stream returning its replication state handle.
#[tracing::instrument(level="debug", skip(self))]
pub(super) fn spawn_replication_stream(&self, target: NodeId) -> ReplicationState<D> {
let replstream = ReplicationStream::new(, target, self.core.current_term, self.core.config.clone(),
self.core.last_log_index, self.core.last_log_term, self.core.commit_index,,, self.replicationtx.clone(),
ReplicationState{match_index: self.core.last_log_index, is_at_line_rate: false, replstream, remove_after_commit: None}
/// Handle a replication event coming from one of the replication streams.
#[tracing::instrument(level="trace", skip(self, event))]
pub(super) async fn handle_replica_event(&mut self, event: ReplicaEvent<S::Snapshot>) {
let res = match event {
ReplicaEvent::RateUpdate{target, is_line_rate} => self.handle_rate_update(target, is_line_rate).await,
ReplicaEvent::RevertToFollower{target, term} => self.handle_revert_to_follower(target, term).await,
ReplicaEvent::UpdateMatchIndex{target, match_index} => self.handle_update_match_index(target, match_index).await,
ReplicaEvent::NeedsSnapshot{target, tx} => self.handle_needs_snapshot(target, tx).await,
ReplicaEvent::Shutdown => {
if let Err(err) = res {
tracing::error!({error=%err}, "error while processing event from replication stream");
/// Handle events from replication streams updating their replication rate tracker.
#[tracing::instrument(level="trace", skip(self, target, is_line_rate))]
async fn handle_rate_update(&mut self, target: NodeId, is_line_rate: bool) -> RaftResult<(), E> {
// Get a handle the target's replication stat & update it as needed.
if let Some(state) = self.nodes.get_mut(&target) {
state.is_at_line_rate = is_line_rate;
return Ok(());
// Else, if this is a non-voter, then update as needed.
if let Some(state) = self.non_voters.get_mut(&target) {
state.state.is_at_line_rate = is_line_rate;
state.is_ready_to_join = true;
// Issue a response on the non-voters response channel if needed.
if let Some(tx) = state.tx.take() {
let _ = tx.send(Ok(()));
// If we are in NonVoterSync state, and this is one of the nodes being awaiting, then update.
match std::mem::replace(&mut self.consensus_state, ConsensusState::Uniform) {
ConsensusState::NonVoterSync{mut awaiting, members, tx} => {
if awaiting.is_empty() {
// We are ready to move forward with entering joint consensus.
self.consensus_state = ConsensusState::Uniform;
self.change_membership(members, tx).await;
} else {
// We are still awaiting additional nodes, so replace our original state.
self.consensus_state = ConsensusState::NonVoterSync{awaiting, members, tx};
other => self.consensus_state = other, // Set the original value back to what it was.
/// Handle events from replication streams for when this node needs to revert to follower state.
#[tracing::instrument(level="trace", skip(self, term))]
async fn handle_revert_to_follower(&mut self, _: NodeId, term: u64) -> RaftResult<(), E> {
if &term > &self.core.current_term {
self.core.update_current_term(term, None);
/// Handle events from a replication stream which updates the target node's match index.
#[tracing::instrument(level="trace", skip(self, target, match_index))]
async fn handle_update_match_index(&mut self, target: NodeId, match_index: u64) -> RaftResult<(), E> {
// If this is a non-voter, then update and return.
if let Some(state) = self.non_voters.get_mut(&target) {
state.state.match_index = match_index;
return Ok(());
// Update target's match index & check if it is awaiting removal.
let mut needs_removal = false;
match self.nodes.get_mut(&target) {
Some(state) => {
state.match_index = match_index;
if let Some(threshold) = &state.remove_after_commit {
if &match_index >= threshold {
needs_removal = true;
_ => return Ok(()), // Node not found.
// Drop replication stream if needed.
if needs_removal {
if let Some(node) = self.nodes.remove(&target) {
let _ = node.replstream.repltx.send(RaftEvent::Terminate);
// Determine the new commit index of the current membership config nodes.
let mut indices_c0 = self.nodes.iter()
.filter(|(id, _)| self.core.membership.members.contains(id))
.map(|(_, node)| node.match_index)
if !self.is_stepping_down {
let commit_index_c0 = calculate_new_commit_index(indices_c0, self.core.commit_index);
// If we are in joint consensus, then calculate the new commit index of the new membership config nodes.
let mut commit_index_c1 = commit_index_c0; // Defaults to just matching C0.
if let Some(members) = &self.core.membership.members_after_consensus {
let indices_c1 = self.nodes.iter()
.filter(|(id, _)| members.contains(id))
.map(|(_, node)| node.match_index)
commit_index_c1 = calculate_new_commit_index(indices_c1, self.core.commit_index);
// Determine if we have a new commit index, accounting for joint consensus.
// If a new commit index has been established, then update a few needed elements.
let has_new_commit_index = &commit_index_c0 > &self.core.commit_index && &commit_index_c1 > &self.core.commit_index;
if has_new_commit_index {
self.core.commit_index = std::cmp::min(commit_index_c0, commit_index_c1);
// Update all replication streams based on new commit index.
for node in self.nodes.values() {
let _ = node.replstream.repltx.send(RaftEvent::UpdateCommitIndex{commit_index: self.core.commit_index});
for node in self.non_voters.values() {
let _ = node.state.replstream.repltx.send(RaftEvent::UpdateCommitIndex{commit_index: self.core.commit_index});
// Check if there are any pending requests which need to be processed.
let filter = self.awaiting_committed.iter().enumerate()
.take_while(|(_idx, elem)| &elem.entry.index <= &self.core.commit_index)
.map(|(idx, _)| idx);
if let Some(offset) = filter {
// Build a new ApplyLogsTask from each of the given client requests.
for request in self.awaiting_committed.drain(..=offset).collect::<Vec<_>>() {
/// Handle events from replication streams requesting for snapshot info.
#[tracing::instrument(level="trace", skip(self, tx))]
async fn handle_needs_snapshot(&mut self, _: NodeId, tx: oneshot::Sender<CurrentSnapshotData<S::Snapshot>>) -> RaftResult<(), E> {
// Ensure snapshotting is configured, else do nothing.
let threshold = match &self.core.config.snapshot_policy {
SnapshotPolicy::LogsSinceLast(threshold) => *threshold,
// Check for existence of current snapshot.
let current_snapshot_opt =
.map_err(|err| self.core.map_fatal_storage_result(err))?;
if let Some(snapshot) = current_snapshot_opt {
// If snapshot exists, ensure its distance from the leader's last log index is <= half
// of the configured snapshot threshold, else create a new snapshot.
if snapshot_is_within_half_of_threshold(&snapshot.index, &self.core.last_log_index, &threshold) {
let _ = tx.send(snapshot);
return Ok(());
// Check if snapshot creation is already in progress.
match self.core.snapshot_state.take() {
// If so, we spawn a task to await its completion (or cancellation), and respond to
// the replication stream. The repl stream will wait for the completion and will then
// send anothe request to fetch the finished snapshot.
Some(SnapshotState::Snapshotting{through, handle, sender}) => {
let mut chan = sender.subscribe();
tokio::spawn(async move {
let _ = chan.recv().await;
self.core.snapshot_state = Some(SnapshotState::Snapshotting{through, handle, sender});
return Ok(());
_ => (), // Else we just drop any other state and continue. Leaders never enter `Streaming` state.
// At this point, we just attempt to request a snapshot. Under normal circumstances, the
// leader will always be keeping up-to-date with its snapshotting, and the latest snapshot
// will always be found and this block will never even be executed.
// If this block is executed, and a snapshot is needed, the repl stream will submit another
// request here shortly, and will hit the above logic where it will await the snapshot complection.
/// Determine the value for `current_commit` based on all known indicies of the cluster members.
/// - `entries`: is a vector of all of the highest known indices to be replicated on a target node,
/// one per node of the cluster, including the leader as long as the leader is not stepping down.
/// - `current_commit`: is the Raft node's `current_commit` value before invoking this function.
/// The output of this function will never be less than this value.
/// NOTE: there are a few edge cases accounted for in this routine which will never practically
/// be hit, but they are accounted for in the name of good measure.
fn calculate_new_commit_index(mut entries: Vec<u64>, current_commit: u64) -> u64 {
// Handle cases where len < 2.
let len = entries.len();
if len == 0 {
return current_commit;
} else if len == 1 {
let only_elem = entries[0];
return if only_elem < current_commit { current_commit } else { only_elem };
// Calculate offset which will give the majority slice of high-end.
let offset = if (len % 2) == 0 { (len/2)-1 } else { len/2 };
let new_val = entries.get(offset).unwrap_or(&current_commit);
if new_val < &current_commit {
} else {
/// Check if the given snapshot data is within half of the configured threshold.
fn snapshot_is_within_half_of_threshold(snapshot_last_index: &u64, last_log_index: &u64, threshold: &u64) -> bool {
// Calculate distance from actor's last log index.
let distance_from_line = if snapshot_last_index > last_log_index { 0u64 } else { last_log_index - snapshot_last_index }; // Guard against underflow.
let half_of_threshold = threshold / &2;
distance_from_line <= half_of_threshold
mod tests {
use super::*;
// snapshot_is_within_half_of_threshold //////////////////////////////////
mod snapshot_is_within_half_of_threshold {
use super::*;
macro_rules! test_snapshot_is_within_half_of_threshold {
({test=>$name:ident, snapshot_last_index=>$snapshot_last_index:expr, last_log_index=>$last_log:expr, threshold=>$thresh:expr, expected=>$exp:literal}) => {
fn $name() {
let res = snapshot_is_within_half_of_threshold($snapshot_last_index, $last_log, $thresh);
assert_eq!(res, $exp)
snapshot_last_index=>&50, last_log_index=>&100, threshold=>&500, expected=>true
snapshot_last_index=>&1, last_log_index=>&500, threshold=>&100, expected=>false
snapshot_last_index=>&200, last_log_index=>&100, threshold=>&500, expected=>true
// calculate_new_commit_index ////////////////////////////////////////////
mod calculate_new_commit_index {
use super::*;
macro_rules! test_calculate_new_commit_index {
($name:ident, $expected:literal, $current:literal, $entries:expr) => {
fn $name() {
let mut entries = $entries;
let output = calculate_new_commit_index(entries.clone(), $current);
assert_eq!(output, $expected, "Sorted values: {:?}", entries);
10, 5, vec![20, 5, 0, 15, 10]
20, 20, vec![]
100, 0, vec![100]
100, 100, vec![50]
0, 0, vec![0, 100, 0, 100, 0, 100]
100, 0, vec![0, 100, 0, 100, 0, 100, 100]