use super::tss::{Tss, TssAction, TssPeerId, VerifiableSecretSharingCommitment};
use crate::admin::AdminMsg;
use crate::network::{Message, Network, PeerId, TssMessage};
use crate::runtime::Runtime;
use crate::tasks::{TaskExecutor, TaskParams};
use anyhow::Result;
use futures::future::join_all;
use futures::SinkExt;
use futures::{
channel::{mpsc, oneshot},
future::poll_fn,
stream::FuturesUnordered,
Future, FutureExt, Stream, StreamExt,
};
use polkadot_sdk::sp_runtime::BoundedVec;
use std::sync::Arc;
use std::{
collections::{BTreeMap, BTreeSet, HashMap},
path::PathBuf,
pin::Pin,
task::Poll,
};
use time_primitives::{
BlockHash, BlockNumber, Commitment, ShardId, ShardStatus, TaskId, TssSignature,
TssSigningRequest,
};
use tracing::{event, span, Level, Span};
const DEFAULT_HEARTBEAT_PERIOD: BlockNumber = 100;
pub struct TimeWorkerParams<Tx, Rx> {
pub substrate: Arc<dyn Runtime>,
pub task_params: TaskParams,
pub network: Tx,
pub tss_request: mpsc::Receiver<TssSigningRequest>,
pub net_request: Rx,
pub tss_keyshare_cache: PathBuf,
pub admin_request: mpsc::Sender<AdminMsg>,
}
pub struct TimeWorker<Tx, Rx> {
substrate: Arc<dyn Runtime>,
network: Tx,
tss_request: mpsc::Receiver<TssSigningRequest>,
net_request: Rx,
task_params: TaskParams,
tss_states: HashMap<ShardId, Tss>,
executor_states: HashMap<ShardId, TaskExecutor>,
messages: BTreeMap<BlockNumber, Vec<(ShardId, PeerId, TssMessage)>>,
requests: BTreeMap<BlockNumber, Vec<(ShardId, TaskId, Vec<u8>)>>,
channels: HashMap<TaskId, oneshot::Sender<([u8; 32], TssSignature)>>,
#[allow(clippy::type_complexity)]
outgoing_requests:
FuturesUnordered<Pin<Box<dyn Future<Output = (Result<()>, Span)> + Send + 'static>>>,
tss_keyshare_cache: PathBuf,
admin_request: mpsc::Sender<AdminMsg>,
}
fn display_peer_id(peer_id: PeerId) -> String {
let Ok(peer_id) = TssPeerId::new(peer_id) else {
return hex::encode(peer_id);
};
peer_id.to_string()
}
impl<Tx, Rx> TimeWorker<Tx, Rx>
where
Tx: Network + Clone,
Rx: Stream<Item = (PeerId, Message)> + Send + Unpin,
{
pub fn new(worker_params: TimeWorkerParams<Tx, Rx>) -> Self {
let TimeWorkerParams {
substrate,
task_params,
network,
tss_request,
net_request,
tss_keyshare_cache,
admin_request,
} = worker_params;
Self {
substrate,
task_params,
network,
tss_request,
net_request,
tss_states: Default::default(),
executor_states: Default::default(),
messages: Default::default(),
requests: Default::default(),
channels: Default::default(),
outgoing_requests: Default::default(),
tss_keyshare_cache,
admin_request,
}
}
async fn on_finality(
&mut self,
span: &Span,
block_hash: BlockHash,
block: BlockNumber,
) -> Result<()> {
let span = span!(
parent: span,
Level::DEBUG,
"on_finality",
tc_block = block,
tc_block_hash = format!("{block_hash:?}"),
);
event!(parent: &span, Level::DEBUG, "on_finality");
let account_id = self.substrate.account_id();
let shards = self.substrate.shards(account_id, block_hash).await?;
self.tss_states.retain(|shard_id, _| shards.contains(shard_id));
self.executor_states.retain(|shard_id, _| shards.contains(shard_id));
for shard_id in shards.iter().copied() {
if self.tss_states.contains_key(&shard_id) {
continue;
}
let span = span!(parent: &span, Level::DEBUG, "joining", gmp_shard_id = shard_id);
let members = self.substrate.shard_members(shard_id, block_hash).await?;
let threshold = self.substrate.shard_threshold(shard_id, block_hash).await?;
let futures: Vec<_> = members
.into_iter()
.map(|(account, _)| {
let substrate = self.substrate.clone();
async move {
match substrate.member_peer_id(&account, block_hash).await {
Ok(Some(peer_id)) => Some(peer_id),
Ok(None) | Err(_) => None,
}
}
})
.collect();
let members =
join_all(futures).await.into_iter().flatten().collect::<BTreeSet<PeerId>>();
let commitment = if let Some(commitment) =
self.substrate.shard_commitment(shard_id, block_hash).await?
{
let commitment =
VerifiableSecretSharingCommitment::deserialize(commitment.0.to_vec())?;
Some(commitment)
} else {
None
};
self.tss_states.insert(
shard_id,
Tss::new(
self.network.peer_id(),
members,
threshold,
commitment,
&self.tss_keyshare_cache,
&span,
)?,
);
self.poll_actions(&span, shard_id, block).await;
}
for shard_id in shards.iter().copied() {
let Some(tss) = self.tss_states.get_mut(&shard_id) else {
continue;
};
if tss.committed() {
continue;
}
if self.substrate.shard_status(shard_id, block_hash).await? != ShardStatus::Committed {
continue;
}
let span = span!(parent: &span, Level::DEBUG, "committing", gmp_shard_id = shard_id);
let commitment = self.substrate.shard_commitment(shard_id, block_hash).await?.unwrap();
let commitment = VerifiableSecretSharingCommitment::deserialize(commitment.0.to_vec())?;
tss.on_commit(commitment, &span);
self.poll_actions(&span, shard_id, block).await;
}
while let Some(n) = self.requests.keys().copied().next() {
if n > block {
break;
}
for (shard_id, task_id, data) in self.requests.remove(&n).unwrap() {
let span = span!(
parent: &span,
Level::DEBUG,
"received signing request from task executor",
gmp_shard_id = shard_id,
gmp_task_id = task_id,
);
let Some(tss) = self.tss_states.get_mut(&shard_id) else {
event!(
parent: &span,
Level::ERROR,
"trying to run task on unknown shard, dropping channel",
);
self.channels.remove(&task_id);
continue;
};
event!(parent: &span, Level::DEBUG, "signing");
tss.on_sign(task_id, data.to_vec(), &span);
self.poll_actions(&span, shard_id, block).await;
}
}
for shard_id in shards.iter().copied() {
if self.substrate.shard_status(shard_id, block_hash).await? != ShardStatus::Online {
continue;
}
let executor = self
.executor_states
.entry(shard_id)
.or_insert(TaskExecutor::new(self.task_params.clone()));
let span = span!(
parent: &span,
Level::DEBUG,
"shard",
gmp_shard_id = shard_id,
);
let (start_sessions, complete_sessions, _failed_tasks) =
match executor.process_tasks(block_hash, block, shard_id, &span).await {
Ok((start_sessions, complete_sessions, failed_tasks)) => {
(start_sessions, complete_sessions, failed_tasks)
},
Err(error) => {
event!(
parent: &span,
Level::INFO,
"failed to start tasks: {:?}",
error,
);
continue;
},
};
let Some(tss) = self.tss_states.get_mut(&shard_id) else {
continue;
};
for session in complete_sessions {
let span = span!(parent: &span, Level::DEBUG, "completing", gmp_task_id = session);
tss.on_complete(session, &span);
}
for session in start_sessions {
let span = span!(parent: &span, Level::DEBUG, "starting", gmp_task_id = session);
tss.on_start(session, &span);
}
}
while let Some(n) = self.messages.keys().copied().next() {
if n > block {
break;
}
for (shard_id, peer_id, msg) in self.messages.remove(&n).unwrap() {
let span = span!(parent: &span, Level::DEBUG, "messages", gmp_shard_id = shard_id,
net_from = display_peer_id(peer_id), net_message = msg.to_string());
let Some(tss) = self.tss_states.get_mut(&shard_id) else {
event!(
parent: &span,
Level::INFO,
"dropping message",
);
continue;
};
tss.on_message(peer_id, msg, &span)?;
self.poll_actions(&span, shard_id, n).await;
}
}
if let Err(e) = self.admin_request.send(AdminMsg::SetShards(shards)).await {
event!(parent: &span, Level::ERROR, "admin request failed: {:?}", e);
};
Ok(())
}
async fn poll_actions(&mut self, span: &Span, shard_id: ShardId, block: BlockNumber) {
while let Some(action) = self
.tss_states
.get_mut(&shard_id)
.unwrap()
.next_action(&self.tss_keyshare_cache, span)
{
match action {
TssAction::Send(msgs) => {
for (peer_id, payload) in msgs {
let span = span!(
parent: span,
Level::DEBUG,
"tx",
net_to = display_peer_id(peer_id),
net_message = payload.to_string(),
);
let msg = Message {
shard_id,
block: if payload.is_response() { 0 } else { block },
payload,
};
let endpoint = self.network.clone();
self.outgoing_requests.push(Box::pin(async move {
event!(parent: &span, Level::DEBUG, "send");
let result = endpoint.send(peer_id, msg).await;
(result, span)
}));
}
},
TssAction::Commit(commitment, proof_of_knowledge) => {
event!(
parent: span,
Level::DEBUG,
"commit",
);
self.substrate
.submit_commitment(
shard_id,
Commitment(BoundedVec::truncate_from(commitment.serialize())),
proof_of_knowledge.serialize(),
)
.await
.unwrap();
},
TssAction::PublicKey(tss_public_key) => {
let public_key = tss_public_key.to_bytes().unwrap();
event!(
parent: span,
Level::DEBUG,
"public key 0x{}",
hex::encode(public_key),
);
self.substrate.submit_online(shard_id).await.unwrap();
},
TssAction::Signature(task_id, hash, tss_signature) => {
let tss_signature = tss_signature.to_bytes();
event!(
parent: span,
Level::DEBUG,
gmp_task_id = task_id,
"signature 0x{}",
hex::encode(tss_signature),
);
if let Some(tx) = self.channels.remove(&task_id) {
tx.send((hash, tss_signature)).ok();
}
},
}
}
}
pub async fn run(mut self, span: &Span) {
event!(
parent: span,
Level::DEBUG,
"starting tss",
);
self.outgoing_requests.push(Box::pin(poll_fn(|_| Poll::Pending)));
let mut block_notifications = self.substrate.block_notification_stream();
let mut finality_notifications = self.substrate.finality_notification_stream();
let block = finality_notifications.next().await.expect("Finality stream is not active");
let heartbeat_period = self
.substrate
.heartbeat_timeout(block.0)
.await
.ok()
.and_then(|t| t.gt(&1).then_some(t / 2))
.unwrap_or(DEFAULT_HEARTBEAT_PERIOD);
event!(parent: span, Level::INFO, "Started chronicle loop");
let mut send_heartbeat = true;
loop {
futures::select! {
notification = block_notifications.next().fuse() => {
let _enter = span.enter();
let Some((block_hash, block)) = notification else {
event!(
parent: span,
Level::DEBUG,
"no new block notifications"
);
continue;
};
if let Err(e) = self.admin_request.send(AdminMsg::NewBlock(block as _)).await {
event!(
parent: span,
Level::ERROR,
"Admin request error: {e:?}",
);
};
if block % heartbeat_period == 0 {
if send_heartbeat {
event!(
parent: span,
Level::ERROR,
"missed heartbeat period",
);
}
send_heartbeat = true;
}
if send_heartbeat {
let account_id = self.substrate.account_id();
match self.substrate.is_heartbeat_submitted(account_id, block_hash).await {
Ok(true) => {
send_heartbeat = false;
event!(parent: span, Level::INFO, "heartbeat already submitted");
},
Ok(false) => {
event!(
parent: span,
Level::INFO,
"submitting heartbeat",
);
match self.substrate.submit_heartbeat().await {
Ok(()) => {
send_heartbeat = false;
event!(parent: span, Level::INFO, "submitted heartbeat");
}
Err(e) => {
event!(
parent: span,
Level::INFO,
"Error submitting heartbeat: {:?}",
e
);
}
}
},
Err(e) => {
event!(
parent: span,
Level::ERROR,
"Error checking heartbeat status: {:?}",
e
);
}
}
}
},
notification = finality_notifications.next().fuse() => {
let _enter = span.enter();
let Some((block_hash, block)) = notification else {
event!(
parent: span,
Level::DEBUG,
"no new finality notifications"
);
continue;
};
if let Err(e) = self.on_finality(span, block_hash, block).await {
event!(
parent: span,
Level::ERROR,
"Error running on_finality {:?}",
e
);
}
},
tss_request = self.tss_request.next().fuse() => {
let _enter = span.enter();
let Some(TssSigningRequest { task_id, shard_id, data, tx, block }) = tss_request else {
continue;
};
event!(
parent: span,
Level::DEBUG,
gmp_shard_id = shard_id,
gmp_task_id = task_id,
tc_block = block,
"received signing request",
);
self.requests.entry(block).or_default().push((shard_id, task_id, data));
self.channels.insert(task_id, tx);
},
msg = self.net_request.next().fuse() => {
let _enter = span.enter();
let Some((peer, Message { shard_id, block, payload })) = msg else {
continue;
};
event!(
parent: span,
Level::DEBUG,
gmp_shard_id = shard_id,
tc_block = block,
net_from = display_peer_id(peer),
"rx {}",
payload,
);
self.messages.entry(block).or_default().push((shard_id, peer, payload));
},
outgoing_request = self.outgoing_requests.next().fuse() => {
let Some((result, span)) = outgoing_request else {
continue;
};
let _enter = span.enter();
if let Err(error) = result {
event!(
parent: &span,
Level::ERROR,
"error {:?}",
error,
);
} else {
event!(
parent: &span,
Level::DEBUG,
"sent",
);
}
}
}
}
}
}