chronicle/network/
protocol.rs

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
use super::{Message, Network, NetworkConfig, PeerId, PROTOCOL_NAME};
use anyhow::Result;
use futures::channel::mpsc;
use futures::{Future, FutureExt, SinkExt};
use peernet::{Endpoint, NotificationHandler, Protocol, ProtocolHandler};
use std::pin::Pin;
use std::time::Duration;
use tracing::{Level, Span};

pub struct TssEndpoint {
	endpoint: Endpoint,
}

struct TssProtocol;

impl Protocol for TssProtocol {
	const ID: u16 = 0;
	const REQ_BUF: usize = 4096;
	const RES_BUF: usize = 4096;
	type Request = Message;
	type Response = Message;
}

#[derive(Clone)]
struct TssProtocolHandler {
	tx: mpsc::Sender<(PeerId, Message)>,
}

impl TssProtocolHandler {
	pub fn new(tx: mpsc::Sender<(PeerId, Message)>) -> Self {
		Self { tx }
	}
}

impl NotificationHandler<TssProtocol> for TssProtocolHandler {
	fn notify(&self, peer: peernet::PeerId, req: Message) -> Result<()> {
		let mut tx = self.tx.clone();
		tokio::spawn(async move {
			tx.send((*peer.as_bytes(), req)).await.ok();
		});
		Ok(())
	}
}

impl TssEndpoint {
	pub async fn new(
		config: NetworkConfig,
		tx: mpsc::Sender<(PeerId, Message)>,
		span: &Span,
	) -> Result<Self> {
		let mut builder = ProtocolHandler::builder();
		builder.register_notification_handler(TssProtocolHandler::new(tx));
		let handler = builder.build();

		let mut builder = Endpoint::builder(PROTOCOL_NAME.as_bytes().to_vec());
		builder.secret(config.secret);
		builder.handler(handler);
		builder.republish_interval(Duration::from_secs(60 * 5));
		builder.publish_ttl(Duration::from_secs(60 * 5 * 4));
		builder.relay_map(None);
		let endpoint = builder.build().await?;
		let peer_id = endpoint.peer_id();
		let span =
			tracing::span!(parent: span, Level::INFO, "network", net_peer_id = peer_id.to_string());
		loop {
			tracing::info!(
				parent: &span,
				"waiting for peer id to be registered",
			);
			let addr = match endpoint.resolve(peer_id).await {
				Ok(addr) => addr,
				Err(e) => {
					tracing::warn!("FAILED to resolve peer_id: {e}");
					tokio::time::sleep(Duration::from_secs(1)).await;
					continue;
				},
			};
			let dbg = endpoint.addr().await?;
			if addr != dbg {
				tracing::warn!("addr: {addr:?} != endpoint.addr(): {dbg:?}");
				tokio::time::sleep(Duration::from_secs(1)).await;
				continue;
			}
			tracing::info!(parent: &span, "peer id registered");
			break;
		}
		Ok(Self { endpoint })
	}
}

impl Network for TssEndpoint {
	fn peer_id(&self) -> PeerId {
		*self.endpoint.peer_id().as_bytes()
	}

	fn format_peer_id(&self, peer: PeerId) -> String {
		peernet::PeerId::from_bytes(&peer).unwrap().to_string()
	}

	fn send(&self, peer: PeerId, msg: Message) -> Pin<Box<dyn Future<Output = Result<()>> + Send>> {
		let endpoint = self.endpoint.clone();
		async move {
			let peer = peernet::PeerId::from_bytes(&peer)?;
			endpoint.notify::<TssProtocol>(peer, &msg).await
		}
		.boxed()
	}
}