#![cfg_attr(not(feature = "std"), no_std)]
#![allow(clippy::manual_inspect)]
#![doc = simple_mermaid::mermaid!("../docs/elections_flow.mmd")]
#![doc = simple_mermaid::mermaid!("../docs/elections_flow_2.mmd")]
pub use pallet::*;
#[cfg(feature = "runtime-benchmarks")]
mod benchmarking;
#[cfg(test)]
mod mock;
#[cfg(test)]
mod tests;
#[polkadot_sdk::frame_support::pallet]
pub mod pallet {
	use polkadot_sdk::{frame_support, frame_system, sp_std};
	use frame_support::pallet_prelude::*;
	use frame_system::pallet_prelude::*;
	use sp_std::vec;
	use sp_std::vec::Vec;
	use time_primitives::{
		AccountId, ElectionsInterface, MembersInterface, NetworkId, NetworksInterface,
		ShardsInterface,
	};
	pub trait WeightInfo {
		fn try_elect_shards(b: u32) -> Weight;
	}
	impl WeightInfo for () {
		fn try_elect_shards(_: u32) -> Weight {
			Weight::default()
		}
	}
	#[pallet::pallet]
	#[pallet::without_storage_info]
	pub struct Pallet<T>(_);
	#[pallet::config]
	pub trait Config: polkadot_sdk::frame_system::Config<AccountId = AccountId> {
		type RuntimeEvent: From<Event<Self>>
			+ IsType<<Self as polkadot_sdk::frame_system::Config>::RuntimeEvent>;
		type WeightInfo: WeightInfo;
		type Shards: ShardsInterface;
		type Members: MembersInterface;
		type Networks: NetworksInterface;
		#[pallet::constant]
		type MaxElectionsPerBlock: Get<u32>;
	}
	#[pallet::storage]
	pub type NetworkCounter<T: Config> = StorageValue<_, u32, ValueQuery>;
	#[pallet::storage]
	pub type Unassigned<T: Config> =
		StorageMap<_, Blake2_128Concat, NetworkId, Vec<AccountId>, ValueQuery>;
	#[pallet::event]
	pub enum Event<T: Config> {}
	#[pallet::hooks]
	impl<T: Config> Hooks<BlockNumberFor<T>> for Pallet<T> {
		fn on_initialize(_: BlockNumberFor<T>) -> Weight {
			log::info!("on_initialize begin");
			let mut num_elections = 0u32;
			let networks = T::Networks::networks();
			let net_counter0 = NetworkCounter::<T>::get();
			let (mut net_counter, mut all_nets_elected) = (net_counter0, false);
			while num_elections < T::MaxElectionsPerBlock::get() {
				let Some(next_network) = networks.get(net_counter as usize) else {
					net_counter = 0;
					break;
				};
				let elected = Self::try_elect_shards(
					*next_network,
					T::MaxElectionsPerBlock::get().saturating_sub(num_elections),
				);
				num_elections = num_elections.saturating_add(elected);
				net_counter = (net_counter + 1) % networks.len() as u32;
				if net_counter == net_counter0 {
					all_nets_elected = true;
					break;
				}
			}
			if !all_nets_elected {
				NetworkCounter::<T>::put(net_counter);
			} log::info!("on_initialize end");
			T::WeightInfo::try_elect_shards(num_elections)
		}
	}
	impl<T: Config> ElectionsInterface for Pallet<T> {
		type MaxElectionsPerBlock = T::MaxElectionsPerBlock;
		fn shard_offline(network: NetworkId, members: Vec<AccountId>) {
			let mut batch = Vec::new();
			for member in members {
				if T::Members::is_member_online(&member) {
					batch.push(member.clone());
				}
			}
			Unassigned::<T>::mutate(network, |unassigned| {
				unassigned.extend(batch);
				unassigned.sort_by(|a, b| a.cmp(b).reverse());
			});
		}
		fn member_online(member: &AccountId, network: NetworkId) {
			if !T::Shards::is_shard_member(member) {
				Unassigned::<T>::mutate(network, |members| {
					members.push(member.clone());
					members.sort_by(|a, b| a.cmp(b).reverse());
				});
			}
			T::Shards::member_online(member, network);
		}
		fn members_offline(members: Vec<AccountId>, network: NetworkId) {
			Unassigned::<T>::mutate(network, |unassigned| {
				unassigned.retain(|m| !members.contains(m));
			});
			T::Shards::members_offline(members);
		}
	}
	impl<T: Config> Pallet<T> {
		pub(crate) fn try_elect_shards(network: NetworkId, max_elections: u32) -> u32 {
			let shard_size = T::Networks::shard_size(network);
			let shard_threshold = T::Networks::shard_threshold(network);
			let mut unassigned = Unassigned::<T>::get(network);
			let num_elected =
				sp_std::cmp::min((unassigned.len() as u32) / shard_size as u32, max_elections)
					* shard_size as u32;
			let mut members = Vec::with_capacity(num_elected as usize);
			members.extend(unassigned.drain(..(num_elected as usize)));
			let mut num_elections = 0u32;
			for (i, next_shard) in members.chunks(shard_size as usize).enumerate() {
				if T::Shards::create_shard(network, next_shard.to_vec(), shard_threshold).is_err() {
					unassigned
						.extend(members.chunks(shard_size as usize).skip(i).flatten().cloned());
					break;
				} else {
					num_elections += 1;
				}
			}
			Unassigned::<T>::insert(network, unassigned);
			num_elections
		}
	}
}