From 91a823103e234e40753af343f15f37d6099c8013 Mon Sep 17 00:00:00 2001 From: Kevin Axel Date: Tue, 9 Apr 2024 20:33:25 +0800 Subject: [PATCH] feat: add spawn and try_current for madsim-tokio (#197) * add spawn for Handle and add dummy interface for building Signed-off-by: Kevin Axel * remove mutex Signed-off-by: Kevin Axel * cargo check Signed-off-by: Kevin Axel * cargo check Signed-off-by: Kevin Axel * Update madsim-tokio/src/sim/runtime.rs * Update madsim/src/sim/runtime/mod.rs --------- Signed-off-by: Kevin Axel Co-authored-by: Runji Wang --- madsim-etcd-client/src/service.rs | 1 - madsim-macros/src/service.rs | 1 - madsim-tokio/src/sim/runtime.rs | 5 ++++ madsim/src/sim/net/endpoint.rs | 10 ++++--- madsim/src/sim/net/rpc.rs | 2 +- madsim/src/sim/net/tcp/listener.rs | 2 +- madsim/src/sim/net/tcp/stream.rs | 34 +++++++++++++++------ madsim/src/sim/net/unix/stream.rs | 7 ++++- madsim/src/sim/runtime/mod.rs | 47 ++++++++++++++++++++++++++++-- madsim/src/sim/task/mod.rs | 6 ++-- madsim/src/sim/time/interval.rs | 11 ++++--- madsim/src/sim/time/sleep.rs | 2 +- 12 files changed, 99 insertions(+), 29 deletions(-) diff --git a/madsim-etcd-client/src/service.rs b/madsim-etcd-client/src/service.rs index 5b8d9d9d..9d7faf09 100644 --- a/madsim-etcd-client/src/service.rs +++ b/madsim-etcd-client/src/service.rs @@ -6,7 +6,6 @@ use spin::Mutex; use std::collections::btree_map::Entry; use std::collections::{btree_map::Range, BTreeMap, HashMap, HashSet}; use std::sync::Arc; -use std::time::Duration; use tokio::sync::mpsc; #[derive(Debug)] diff --git a/madsim-macros/src/service.rs b/madsim-macros/src/service.rs index 18af9b1d..b6b5cc7b 100644 --- a/madsim-macros/src/service.rs +++ b/madsim-macros/src/service.rs @@ -2,7 +2,6 @@ use darling::FromMeta; use proc_macro::TokenStream as TokenStream1; use proc_macro2::TokenStream; use quote::quote; -use std::convert::TryFrom; use syn::{spanned::Spanned, *}; pub fn service(_args: TokenStream1, input: TokenStream1) -> TokenStream1 { diff --git a/madsim-tokio/src/sim/runtime.rs b/madsim-tokio/src/sim/runtime.rs index 23fae84e..9fec218c 100644 --- a/madsim-tokio/src/sim/runtime.rs +++ b/madsim-tokio/src/sim/runtime.rs @@ -1,3 +1,4 @@ +pub use madsim::runtime::Handle; use madsim::task::{AbortHandle, JoinHandle}; use spin::Mutex; use std::{future::Future, io}; @@ -72,6 +73,10 @@ impl Runtime { handle } + pub fn block_on(&self, _future: F) -> F::Output { + unimplemented!("blocking the current thread is not allowed in madsim"); + } + pub fn enter(&self) -> EnterGuard<'_> { // Madsim runtime is entered by default. No-op here. EnterGuard(self) diff --git a/madsim/src/sim/net/endpoint.rs b/madsim/src/sim/net/endpoint.rs index 9650b77f..b1d8b646 100644 --- a/madsim/src/sim/net/endpoint.rs +++ b/madsim/src/sim/net/endpoint.rs @@ -1,5 +1,5 @@ use super::{IpProtocol::Udp, *}; -use futures_util::{Stream, StreamExt}; +use futures_util::Stream; use std::{ fmt, pin::Pin, @@ -48,7 +48,8 @@ impl Endpoint { Ok(self.guard.addr) } - /// Returns the socket address of the remote peer this socket was connected to. + /// Returns the socket address of the remote peer this socket was connected + /// to. pub fn peer_addr(&self) -> io::Result { (self.peer.lock()) .ok_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "not connected")) @@ -98,8 +99,9 @@ impl Endpoint { self.send_to(peer, tag, buf).await } - /// Receives a single datagram message on the socket from the remote address to which it is connected. - /// On success, returns the number of bytes read. + /// Receives a single datagram message on the socket from the remote address + /// to which it is connected. On success, returns the number of bytes + /// read. pub async fn recv(&self, tag: u64, buf: &mut [u8]) -> io::Result { let peer = self.peer_addr()?; let (len, from) = self.recv_from(tag, buf).await?; diff --git a/madsim/src/sim/net/rpc.rs b/madsim/src/sim/net/rpc.rs index db7e1b02..290ea6d0 100644 --- a/madsim/src/sim/net/rpc.rs +++ b/madsim/src/sim/net/rpc.rs @@ -67,7 +67,7 @@ pub use bytes::Bytes; use futures_util::FutureExt; #[doc(no_inline)] pub use serde::{de::DeserializeOwned, Deserialize, Serialize}; -use std::{any::Any, future::Future}; +use std::future::Future; /// A RPC request. pub trait Request: Serialize + DeserializeOwned + Any + Send + Sync { diff --git a/madsim/src/sim/net/tcp/listener.rs b/madsim/src/sim/net/tcp/listener.rs index d166e937..f6db764a 100644 --- a/madsim/src/sim/net/tcp/listener.rs +++ b/madsim/src/sim/net/tcp/listener.rs @@ -1,4 +1,4 @@ -use std::{fmt, io::Result, net::SocketAddr, sync::Arc}; +use std::{fmt, io::Result}; use tracing::instrument; use crate::net::{IpProtocol::Tcp, *}; diff --git a/madsim/src/sim/net/tcp/stream.rs b/madsim/src/sim/net/tcp/stream.rs index cae3cb50..8a87fefd 100644 --- a/madsim/src/sim/net/tcp/stream.rs +++ b/madsim/src/sim/net/tcp/stream.rs @@ -1,17 +1,11 @@ -use crate::{ - net::{IpProtocol::Tcp, *}, - plugin, -}; -use bytes::{Buf, Bytes, BytesMut}; -use futures_util::StreamExt; +use crate::net::{IpProtocol::Tcp, *}; +use bytes::{Buf, BufMut, BytesMut}; #[cfg(unix)] use std::os::unix::io::{AsRawFd, RawFd}; use std::{ fmt, io::Result, - net::SocketAddr, pin::Pin, - sync::Arc, task::{Context, Poll}, }; use tokio::io::{AsyncRead, AsyncWrite, ReadBuf}; @@ -106,6 +100,27 @@ impl TcpStream { pub fn peer_addr(&self) -> Result { Ok(self.peer) } + + /// Tries to read data from the stream into the provided buffer, advancing + /// the buffer's internal cursor, returning how many bytes were read. + /// + /// Receives any pending data from the socket but does not wait for new data + /// to arrive. On success, returns the number of bytes read. Because + /// `try_read_buf()` is non-blocking, the buffer does not have to be stored + /// by the async task and can exist entirely on the stack. + pub fn try_read_buf(&mut self, buf: &mut B) -> io::Result { + // read the buffer if not empty + if !self.read_buf.is_empty() { + let len = self.read_buf.len().min(buf.remaining_mut()); + buf.put_slice(&self.read_buf[..len]); + self.read_buf.advance(len); + return Ok(len); + } + Err(io::Error::new( + io::ErrorKind::WouldBlock, + "read buffer is empty", + )) + } } #[cfg(unix)] @@ -129,7 +144,8 @@ impl AsyncRead for TcpStream { return Poll::Ready(Ok(())); } // otherwise wait on channel - match self.rx.poll_next_unpin(cx) { + let poll_res = { self.rx.poll_next_unpin(cx) }; + match poll_res { Poll::Pending => Poll::Pending, Poll::Ready(Some(data)) => { self.read_buf = *data.downcast::().unwrap(); diff --git a/madsim/src/sim/net/unix/stream.rs b/madsim/src/sim/net/unix/stream.rs index a17e209b..9f539415 100644 --- a/madsim/src/sim/net/unix/stream.rs +++ b/madsim/src/sim/net/unix/stream.rs @@ -1,5 +1,6 @@ +use bytes::BufMut; use std::{ - io::Result, + io::{self, Result}, os::unix::{ io::{AsRawFd, RawFd}, net::SocketAddr, @@ -43,6 +44,10 @@ impl UnixStream { { todo!(); } + + pub fn try_read_buf(&mut self, buf: &mut B) -> io::Result { + unimplemented!(); + } } impl AsRawFd for UnixStream { diff --git a/madsim/src/sim/runtime/mod.rs b/madsim/src/sim/runtime/mod.rs index 100f39ba..57b4a805 100644 --- a/madsim/src/sim/runtime/mod.rs +++ b/madsim/src/sim/runtime/mod.rs @@ -98,7 +98,8 @@ impl Runtime { self.handle.create_node() } - /// Run a future to completion on the runtime. This is the runtime’s entry point. + /// Run a future to completion on the runtime. This is the runtime’s entry + /// point. /// /// This runs the given future on the current thread until it is complete. /// @@ -212,6 +213,10 @@ pub struct Handle { /// A collection of simulators. pub(crate) type Simulators = Mutex>>; +/// `TryCurrentError` indicates there is no runtime has been started +#[derive(Debug)] +pub struct TryCurrentError; + impl Handle { /// Returns a [`Handle`] view over the currently running [`Runtime`]. /// @@ -226,6 +231,40 @@ impl Handle { context::current(|h| h.clone()) } + /// Returns a [`Handle`] view over the currently running [`Runtime`] + /// + /// Returns an error if no Runtime has been started + /// + /// Contrary to `current`, this never panics + pub fn try_current() -> Result { + context::try_current(|h| h.clone()).ok_or(TryCurrentError) + } + + /// spawn a task + pub fn spawn(&self, future: F) -> JoinHandle + where + F: Future + Send + 'static, + F::Output: Send + 'static, + { + self.create_node() + .name("spawn a task") + .build() + .spawn(future) + } + + /// spawn a blocking task + #[deprecated( + since = "0.3.0", + note = "blocking function is not allowed in simulation" + )] + pub fn spawn_blocking(&self, _f: F) -> JoinHandle + where + F: FnOnce() -> R + Send + 'static, + R: Send + 'static, + { + unimplemented!("blocking function is not allowed in simulation") + } + /// Returns the random seed of the current runtime. /// /// ``` @@ -281,7 +320,8 @@ impl Handle { self.task.get_node(id).map(|task| NodeHandle { task }) } - /// Returns a view that lets you get information about how the runtime is performing. + /// Returns a view that lets you get information about how the runtime is + /// performing. pub fn metrics(&self) -> RuntimeMetrics { RuntimeMetrics { task: self.task.clone(), @@ -347,7 +387,8 @@ impl<'a> NodeBuilder<'a> { self } - /// Automatically restart the node when it panics with a message containing the given string. + /// Automatically restart the node when it panics with a message containing + /// the given string. pub fn restart_on_panic_matching(mut self, msg: impl Into) -> Self { self.restart_on_panic_matching.push(msg.into()); self diff --git a/madsim/src/sim/task/mod.rs b/madsim/src/sim/task/mod.rs index 2537ec2c..4a163b08 100644 --- a/madsim/src/sim/task/mod.rs +++ b/madsim/src/sim/task/mod.rs @@ -318,7 +318,7 @@ impl Deref for Executor { #[derive(Clone)] #[doc(hidden)] pub struct TaskHandle { - sender: mpsc::Sender, + pub(crate) sender: mpsc::Sender, nodes: Arc>>, next_node_id: Arc, /// Info of the main node. @@ -564,8 +564,8 @@ impl ToNodeId for &T { /// A handle to spawn tasks on a node. #[derive(Clone)] pub struct Spawner { - sender: mpsc::Sender, - info: Arc, + pub(crate) sender: mpsc::Sender, + pub(crate) info: Arc, } /// A handle to spawn tasks on a node. diff --git a/madsim/src/sim/time/interval.rs b/madsim/src/sim/time/interval.rs index 643a7ab0..4caf6592 100644 --- a/madsim/src/sim/time/interval.rs +++ b/madsim/src/sim/time/interval.rs @@ -30,9 +30,9 @@ use crate::time::{sleep_until, Duration, Instant, Sleep}; use futures_util::future::poll_fn; use futures_util::ready; +use std::future::Future; use std::pin::Pin; use std::task::{Context, Poll}; -use std::{convert::TryInto, future::Future}; /// Creates new [`Interval`] that yields with interval of `period`. pub fn interval(period: Duration) -> Interval { @@ -62,14 +62,17 @@ fn internal_interval_at(start: Instant, period: Duration) -> Interval { pub enum MissedTickBehavior { /// Ticks as fast as possible until caught up. Burst, - /// Tick at multiples of `period` from when `tick` was called, rather than from `start`. + /// Tick at multiples of `period` from when `tick` was called, rather than + /// from `start`. Delay, - /// Skips missed ticks and tick on the next multiple of `period` from `start`. + /// Skips missed ticks and tick on the next multiple of `period` from + /// `start`. Skip, } impl MissedTickBehavior { - /// If a tick is missed, this method is called to determine when the next tick should happen. + /// If a tick is missed, this method is called to determine when the next + /// tick should happen. fn next_timeout(&self, timeout: Instant, now: Instant, period: Duration) -> Instant { match self { Self::Burst => timeout + period, diff --git a/madsim/src/sim/time/sleep.rs b/madsim/src/sim/time/sleep.rs index 7b83d606..3009ec1c 100644 --- a/madsim/src/sim/time/sleep.rs +++ b/madsim/src/sim/time/sleep.rs @@ -1,5 +1,5 @@ use super::*; -use std::{fmt, future::Future, pin::Pin, task::Poll}; +use std::{fmt, pin::Pin, task::Poll}; /// Waits until `duration` has elapsed. pub fn sleep(duration: Duration) -> Sleep {