Skip to content

Commit

Permalink
feat: add spawn and try_current for madsim-tokio (#197)
Browse files Browse the repository at this point in the history
* add spawn for Handle and add dummy interface for building

Signed-off-by: Kevin Axel <[email protected]>

* remove mutex

Signed-off-by: Kevin Axel <[email protected]>

* cargo check

Signed-off-by: Kevin Axel <[email protected]>

* cargo check

Signed-off-by: Kevin Axel <[email protected]>

* Update madsim-tokio/src/sim/runtime.rs

* Update madsim/src/sim/runtime/mod.rs

---------

Signed-off-by: Kevin Axel <[email protected]>
Co-authored-by: Runji Wang <[email protected]>
  • Loading branch information
KveinAxel and wangrunji0408 authored Apr 9, 2024
1 parent ff89182 commit 91a8231
Show file tree
Hide file tree
Showing 12 changed files with 99 additions and 29 deletions.
1 change: 0 additions & 1 deletion madsim-etcd-client/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ use spin::Mutex;
use std::collections::btree_map::Entry;
use std::collections::{btree_map::Range, BTreeMap, HashMap, HashSet};
use std::sync::Arc;
use std::time::Duration;
use tokio::sync::mpsc;

#[derive(Debug)]
Expand Down
1 change: 0 additions & 1 deletion madsim-macros/src/service.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@ use darling::FromMeta;
use proc_macro::TokenStream as TokenStream1;
use proc_macro2::TokenStream;
use quote::quote;
use std::convert::TryFrom;
use syn::{spanned::Spanned, *};

pub fn service(_args: TokenStream1, input: TokenStream1) -> TokenStream1 {
Expand Down
5 changes: 5 additions & 0 deletions madsim-tokio/src/sim/runtime.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
pub use madsim::runtime::Handle;
use madsim::task::{AbortHandle, JoinHandle};
use spin::Mutex;
use std::{future::Future, io};
Expand Down Expand Up @@ -72,6 +73,10 @@ impl Runtime {
handle
}

pub fn block_on<F: Future>(&self, _future: F) -> F::Output {
unimplemented!("blocking the current thread is not allowed in madsim");
}

pub fn enter(&self) -> EnterGuard<'_> {
// Madsim runtime is entered by default. No-op here.
EnterGuard(self)
Expand Down
10 changes: 6 additions & 4 deletions madsim/src/sim/net/endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::{IpProtocol::Udp, *};
use futures_util::{Stream, StreamExt};
use futures_util::Stream;
use std::{
fmt,
pin::Pin,
Expand Down Expand Up @@ -48,7 +48,8 @@ impl Endpoint {
Ok(self.guard.addr)
}

/// Returns the socket address of the remote peer this socket was connected to.
/// Returns the socket address of the remote peer this socket was connected
/// to.
pub fn peer_addr(&self) -> io::Result<SocketAddr> {
(self.peer.lock())
.ok_or_else(|| io::Error::new(io::ErrorKind::NotConnected, "not connected"))
Expand Down Expand Up @@ -98,8 +99,9 @@ impl Endpoint {
self.send_to(peer, tag, buf).await
}

/// Receives a single datagram message on the socket from the remote address to which it is connected.
/// On success, returns the number of bytes read.
/// Receives a single datagram message on the socket from the remote address
/// to which it is connected. On success, returns the number of bytes
/// read.
pub async fn recv(&self, tag: u64, buf: &mut [u8]) -> io::Result<usize> {
let peer = self.peer_addr()?;
let (len, from) = self.recv_from(tag, buf).await?;
Expand Down
2 changes: 1 addition & 1 deletion madsim/src/sim/net/rpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,7 @@ pub use bytes::Bytes;
use futures_util::FutureExt;
#[doc(no_inline)]
pub use serde::{de::DeserializeOwned, Deserialize, Serialize};
use std::{any::Any, future::Future};
use std::future::Future;

/// A RPC request.
pub trait Request: Serialize + DeserializeOwned + Any + Send + Sync {
Expand Down
2 changes: 1 addition & 1 deletion madsim/src/sim/net/tcp/listener.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
use std::{fmt, io::Result, net::SocketAddr, sync::Arc};
use std::{fmt, io::Result};
use tracing::instrument;

use crate::net::{IpProtocol::Tcp, *};
Expand Down
34 changes: 25 additions & 9 deletions madsim/src/sim/net/tcp/stream.rs
Original file line number Diff line number Diff line change
@@ -1,17 +1,11 @@
use crate::{
net::{IpProtocol::Tcp, *},
plugin,
};
use bytes::{Buf, Bytes, BytesMut};
use futures_util::StreamExt;
use crate::net::{IpProtocol::Tcp, *};
use bytes::{Buf, BufMut, BytesMut};
#[cfg(unix)]
use std::os::unix::io::{AsRawFd, RawFd};
use std::{
fmt,
io::Result,
net::SocketAddr,
pin::Pin,
sync::Arc,
task::{Context, Poll},
};
use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
Expand Down Expand Up @@ -106,6 +100,27 @@ impl TcpStream {
pub fn peer_addr(&self) -> Result<SocketAddr> {
Ok(self.peer)
}

/// Tries to read data from the stream into the provided buffer, advancing
/// the buffer's internal cursor, returning how many bytes were read.
///
/// Receives any pending data from the socket but does not wait for new data
/// to arrive. On success, returns the number of bytes read. Because
/// `try_read_buf()` is non-blocking, the buffer does not have to be stored
/// by the async task and can exist entirely on the stack.
pub fn try_read_buf<B: BufMut>(&mut self, buf: &mut B) -> io::Result<usize> {
// read the buffer if not empty
if !self.read_buf.is_empty() {
let len = self.read_buf.len().min(buf.remaining_mut());
buf.put_slice(&self.read_buf[..len]);
self.read_buf.advance(len);
return Ok(len);
}
Err(io::Error::new(
io::ErrorKind::WouldBlock,
"read buffer is empty",
))
}
}

#[cfg(unix)]
Expand All @@ -129,7 +144,8 @@ impl AsyncRead for TcpStream {
return Poll::Ready(Ok(()));
}
// otherwise wait on channel
match self.rx.poll_next_unpin(cx) {
let poll_res = { self.rx.poll_next_unpin(cx) };
match poll_res {
Poll::Pending => Poll::Pending,
Poll::Ready(Some(data)) => {
self.read_buf = *data.downcast::<Bytes>().unwrap();
Expand Down
7 changes: 6 additions & 1 deletion madsim/src/sim/net/unix/stream.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
use bytes::BufMut;
use std::{
io::Result,
io::{self, Result},
os::unix::{
io::{AsRawFd, RawFd},
net::SocketAddr,
Expand Down Expand Up @@ -43,6 +44,10 @@ impl UnixStream {
{
todo!();
}

pub fn try_read_buf<B: BufMut>(&mut self, buf: &mut B) -> io::Result<usize> {
unimplemented!();
}
}

impl AsRawFd for UnixStream {
Expand Down
47 changes: 44 additions & 3 deletions madsim/src/sim/runtime/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -98,7 +98,8 @@ impl Runtime {
self.handle.create_node()
}

/// Run a future to completion on the runtime. This is the runtime’s entry point.
/// Run a future to completion on the runtime. This is the runtime’s entry
/// point.
///
/// This runs the given future on the current thread until it is complete.
///
Expand Down Expand Up @@ -212,6 +213,10 @@ pub struct Handle {
/// A collection of simulators.
pub(crate) type Simulators = Mutex<HashMap<TypeId, Arc<dyn plugin::Simulator>>>;

/// `TryCurrentError` indicates there is no runtime has been started
#[derive(Debug)]
pub struct TryCurrentError;

impl Handle {
/// Returns a [`Handle`] view over the currently running [`Runtime`].
///
Expand All @@ -226,6 +231,40 @@ impl Handle {
context::current(|h| h.clone())
}

/// Returns a [`Handle`] view over the currently running [`Runtime`]
///
/// Returns an error if no Runtime has been started
///
/// Contrary to `current`, this never panics
pub fn try_current() -> Result<Self, TryCurrentError> {
context::try_current(|h| h.clone()).ok_or(TryCurrentError)
}

/// spawn a task
pub fn spawn<F>(&self, future: F) -> JoinHandle<F::Output>
where
F: Future + Send + 'static,
F::Output: Send + 'static,
{
self.create_node()
.name("spawn a task")
.build()
.spawn(future)
}

/// spawn a blocking task
#[deprecated(
since = "0.3.0",
note = "blocking function is not allowed in simulation"
)]
pub fn spawn_blocking<F, R>(&self, _f: F) -> JoinHandle<R>
where
F: FnOnce() -> R + Send + 'static,
R: Send + 'static,
{
unimplemented!("blocking function is not allowed in simulation")
}

/// Returns the random seed of the current runtime.
///
/// ```
Expand Down Expand Up @@ -281,7 +320,8 @@ impl Handle {
self.task.get_node(id).map(|task| NodeHandle { task })
}

/// Returns a view that lets you get information about how the runtime is performing.
/// Returns a view that lets you get information about how the runtime is
/// performing.
pub fn metrics(&self) -> RuntimeMetrics {
RuntimeMetrics {
task: self.task.clone(),
Expand Down Expand Up @@ -347,7 +387,8 @@ impl<'a> NodeBuilder<'a> {
self
}

/// Automatically restart the node when it panics with a message containing the given string.
/// Automatically restart the node when it panics with a message containing
/// the given string.
pub fn restart_on_panic_matching(mut self, msg: impl Into<String>) -> Self {
self.restart_on_panic_matching.push(msg.into());
self
Expand Down
6 changes: 3 additions & 3 deletions madsim/src/sim/task/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -318,7 +318,7 @@ impl Deref for Executor {
#[derive(Clone)]
#[doc(hidden)]
pub struct TaskHandle {
sender: mpsc::Sender<Runnable>,
pub(crate) sender: mpsc::Sender<Runnable>,
nodes: Arc<Mutex<HashMap<NodeId, Node>>>,
next_node_id: Arc<AtomicU64>,
/// Info of the main node.
Expand Down Expand Up @@ -564,8 +564,8 @@ impl<T: ToNodeId> ToNodeId for &T {
/// A handle to spawn tasks on a node.
#[derive(Clone)]
pub struct Spawner {
sender: mpsc::Sender<Runnable>,
info: Arc<NodeInfo>,
pub(crate) sender: mpsc::Sender<Runnable>,
pub(crate) info: Arc<NodeInfo>,
}

/// A handle to spawn tasks on a node.
Expand Down
11 changes: 7 additions & 4 deletions madsim/src/sim/time/interval.rs
Original file line number Diff line number Diff line change
Expand Up @@ -30,9 +30,9 @@ use crate::time::{sleep_until, Duration, Instant, Sleep};
use futures_util::future::poll_fn;
use futures_util::ready;

use std::future::Future;
use std::pin::Pin;
use std::task::{Context, Poll};
use std::{convert::TryInto, future::Future};

/// Creates new [`Interval`] that yields with interval of `period`.
pub fn interval(period: Duration) -> Interval {
Expand Down Expand Up @@ -62,14 +62,17 @@ fn internal_interval_at(start: Instant, period: Duration) -> Interval {
pub enum MissedTickBehavior {
/// Ticks as fast as possible until caught up.
Burst,
/// Tick at multiples of `period` from when `tick` was called, rather than from `start`.
/// Tick at multiples of `period` from when `tick` was called, rather than
/// from `start`.
Delay,
/// Skips missed ticks and tick on the next multiple of `period` from `start`.
/// Skips missed ticks and tick on the next multiple of `period` from
/// `start`.
Skip,
}

impl MissedTickBehavior {
/// If a tick is missed, this method is called to determine when the next tick should happen.
/// If a tick is missed, this method is called to determine when the next
/// tick should happen.
fn next_timeout(&self, timeout: Instant, now: Instant, period: Duration) -> Instant {
match self {
Self::Burst => timeout + period,
Expand Down
2 changes: 1 addition & 1 deletion madsim/src/sim/time/sleep.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::*;
use std::{fmt, future::Future, pin::Pin, task::Poll};
use std::{fmt, pin::Pin, task::Poll};

/// Waits until `duration` has elapsed.
pub fn sleep(duration: Duration) -> Sleep {
Expand Down

0 comments on commit 91a8231

Please sign in to comment.