Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 0 additions & 4 deletions Cargo.lock
Original file line number Diff line number Diff line change
Expand Up @@ -9005,7 +9005,6 @@ dependencies = [
"event-listener",
"guestmem",
"inspect",
"pal_async",
"plan9",
"task_control",
"tracing",
Expand All @@ -9025,7 +9024,6 @@ dependencies = [
"fs-err",
"guestmem",
"inspect",
"pal_async",
"sparse_mmap",
"task_control",
"tracing",
Expand All @@ -9052,14 +9050,12 @@ dependencies = [
"async-trait",
"event-listener",
"fuse",
"futures",
"guestmem",
"inspect",
"lx",
"lxutil",
"ntapi",
"pal",
"pal_async",
"parking_lot",
"task_control",
"tracing",
Expand Down
85 changes: 63 additions & 22 deletions support/task_control/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -517,43 +517,51 @@ impl<T: AsyncRun<S>, S: 'static + Send> TaskControl<T, S> {
.await
}

/// Stops the task, waiting for it to be cancelled.
/// Poll variant of [`stop`](Self::stop). Signals the task to stop and polls
/// for completion.
///
/// Returns true if the task was previously running. Returns false if the
/// task was not running, not inserted, or had already completed.
pub async fn stop(&mut self) -> bool {
/// Returns `Poll::Ready(true)` if the task was previously running and has
/// now stopped. Returns `Poll::Ready(false)` if the task was not running,
/// not inserted, or had already completed. Returns `Poll::Pending` if the
/// task has not yet stopped.
pub fn poll_stop(&mut self, cx: &mut Context<'_>) -> Poll<bool> {
match &mut self.inner {
Inner::WithState {
activity, shared, ..
} => match activity {
Activity::Running => {
let task_and_state = poll_fn(|cx| {
let mut shared = shared.lock();
shared.stop = true;
if shared.task_and_state.is_none() || !shared.calls.is_empty() {
shared.outer_waker = Some(cx.waker().clone());
let waker = shared.inner_waker.take();
drop(shared);
if let Some(waker) = waker {
waker.wake();
}
return Poll::Pending;
let mut shared = shared.lock();
shared.stop = true;
if shared.task_and_state.is_none() || !shared.calls.is_empty() {
shared.outer_waker = Some(cx.waker().clone());
let waker = shared.inner_waker.take();
drop(shared);
if let Some(waker) = waker {
waker.wake();
}
Poll::Ready(shared.task_and_state.take().unwrap())
})
.await;

return Poll::Pending;
}
let task_and_state = shared.task_and_state.take().unwrap();
drop(shared);
let done = task_and_state.done;
*activity = Activity::Stopped(task_and_state);
!done
Poll::Ready(!done)
}
_ => false,
_ => Poll::Ready(false),
},
Inner::NoState(_) => false,
Inner::NoState(_) => Poll::Ready(false),
Inner::Invalid => unreachable!(),
}
}

/// Stops the task, waiting for it to be cancelled.
///
/// Returns true if the task was previously running. Returns false if the
/// task was not running, not inserted, or had already completed.
pub async fn stop(&mut self) -> bool {
poll_fn(|cx| self.poll_stop(cx)).await
}

/// Removes the task state.
///
/// Panics if the task is not stopped.
Expand Down Expand Up @@ -655,4 +663,37 @@ mod tests {
assert!(t.stop().await);
assert_eq!(t.task_mut().0, 8);
}

#[async_test]
async fn test_poll_stop(driver: DefaultDriver) {
let mut t = TaskControl::new(Foo(5));

// poll_stop on a task without state returns Ready(false).
assert_eq!(
std::future::poll_fn(|cx| Poll::Ready(t.poll_stop(cx))).await,
Poll::Ready(false)
);

t.insert(&driver, "test", false);

// poll_stop on a stopped (not started) task returns Ready(false).
assert_eq!(
std::future::poll_fn(|cx| Poll::Ready(t.poll_stop(cx))).await,
Poll::Ready(false)
);

assert!(t.start());
yield_once().await;

// poll_stop drives the task to stop, equivalent to stop().await.
let result = std::future::poll_fn(|cx| t.poll_stop(cx)).await;
assert!(result); // was running
assert_eq!(t.task().0, 6);

// poll_stop after already stopped returns Ready(false).
assert_eq!(
std::future::poll_fn(|cx| Poll::Ready(t.poll_stop(cx))).await,
Poll::Ready(false)
);
}
}
10 changes: 9 additions & 1 deletion vm/devices/virtio/virtio/src/common.rs
Original file line number Diff line number Diff line change
Expand Up @@ -442,7 +442,15 @@ pub trait VirtioDevice: inspect::InspectMut + Send {
fn read_registers_u32(&self, offset: u16) -> u32;
fn write_registers_u32(&mut self, offset: u16, val: u32);
fn enable(&mut self, resources: Resources);
fn disable(&mut self);
/// Poll the device to complete a disable/reset operation.
///
/// This is called when the guest writes status=0 (device reset). The device
/// should stop workers and drain any in-flight IO. Returns `Poll::Ready(())`
/// when the disable is complete, or `Poll::Pending` if more work is needed.
///
/// Devices that don't need async cleanup can return `Poll::Ready(())`
/// immediately.
fn poll_disable(&mut self, cx: &mut Context<'_>) -> Poll<()>;
}

pub struct QueueResources {
Expand Down
20 changes: 6 additions & 14 deletions vm/devices/virtio/virtio/src/tests.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@ use guestmem::GuestMemoryBackingError;
use inspect::InspectMut;
use pal_async::DefaultDriver;
use pal_async::async_test;
use pal_async::task::Spawn;
use pal_async::timer::PolledTimer;
use pal_event::Event;
use parking_lot::Mutex;
Expand Down Expand Up @@ -846,20 +845,13 @@ impl VirtioDevice for TestDevice {
.collect();
}

fn disable(&mut self) {
if self.workers.is_empty() {
return;
}
fn poll_disable(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<()> {
self.exit_event.notify(usize::MAX);
let mut workers = self.workers.drain(..).collect::<Vec<_>>();
self.driver
.spawn("shutdown-test-virtio-queues".to_owned(), async move {
futures::future::join_all(workers.iter_mut().map(async |worker| {
worker.stop().await;
}))
.await;
})
.detach();
for worker in &mut self.workers {
std::task::ready!(worker.poll_stop(cx));
}
self.workers.clear();
std::task::Poll::Ready(())
}
}

Expand Down
59 changes: 50 additions & 9 deletions vm/devices/virtio/virtio/src/transport/mmio.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ use crate::spec::*;
use chipset_device::ChipsetDevice;
use chipset_device::io::IoResult;
use chipset_device::mmio::MmioIntercept;
use chipset_device::poll_device::PollDevice;
use device_emulators::ReadWriteRequestType;
use device_emulators::read_as_u32_chunks;
use device_emulators::write_as_u32_chunks;
Expand Down Expand Up @@ -53,6 +54,9 @@ pub struct VirtioMmioDevice {
#[inspect(skip)]
queues: Vec<QueueParams>,
device_status: VirtioDeviceStatus,
disabling: bool,
#[inspect(skip)]
poll_waker: Option<std::task::Waker>,
config_generation: u32,
#[inspect(skip)]
doorbells: VirtioDoorbells,
Expand Down Expand Up @@ -129,6 +133,8 @@ impl VirtioMmioDevice {
events,
queues,
device_status: VirtioDeviceStatus::new(),
disabling: false,
poll_waker: None,
config_generation: 0,
doorbells: VirtioDoorbells::new(doorbell_registration),
interrupt_state,
Expand All @@ -145,12 +151,6 @@ impl VirtioMmioDevice {
}
}

impl Drop for VirtioMmioDevice {
fn drop(&mut self) {
self.device.disable();
}
}

impl VirtioMmioDevice {
pub(crate) fn read_u32(&self, address: u64) -> u32 {
let offset = (address & 0xfff) as u16;
Expand Down Expand Up @@ -360,14 +360,31 @@ impl VirtioMmioDevice {
// Device status
112 => {
if val == 0 {
if self.disabling {
return;
}
let started = self.device_status.driver_ok();
self.device_status = VirtioDeviceStatus::new();
self.config_generation = 0;
if started {
self.doorbells.clear();
self.device.disable();
// Try the fast path: poll with a noop waker to see if
// the device can disable synchronously.
let waker = std::task::Waker::noop();
let mut cx = std::task::Context::from_waker(waker);
if self.device.poll_disable(&mut cx).is_pending() {
self.disabling = true;
// Wake the real poll waker so that poll_device will
// re-poll with a real waker, replacing the noop one.
if let Some(waker) = self.poll_waker.take() {
waker.wake();
}
return;
}
}
// Fast path: disable completed synchronously.
self.device_status = VirtioDeviceStatus::new();
self.interrupt_state.lock().update(false, !0);
return;
}

let new_status = VirtioDeviceStatus::from(val as u8);
Expand Down Expand Up @@ -480,14 +497,38 @@ impl ChangeDeviceState for VirtioMmioDevice {
async fn stop(&mut self) {}

async fn reset(&mut self) {
// TODO
if self.device_status.driver_ok() || self.disabling {
self.doorbells.clear();
std::future::poll_fn(|cx| self.device.poll_disable(cx)).await;
}
self.device_status = VirtioDeviceStatus::new();
self.disabling = false;
self.config_generation = 0;
self.interrupt_state.lock().update(false, !0);
}
}

impl PollDevice for VirtioMmioDevice {
fn poll_device(&mut self, cx: &mut std::task::Context<'_>) {
self.poll_waker = Some(cx.waker().clone());
if self.disabling {
if self.device.poll_disable(cx).is_ready() {
self.device_status = VirtioDeviceStatus::new();
self.disabling = false;
self.interrupt_state.lock().update(false, !0);
}
}
}
}

impl ChipsetDevice for VirtioMmioDevice {
fn supports_mmio(&mut self) -> Option<&mut dyn MmioIntercept> {
Some(self)
}

fn supports_poll_device(&mut self) -> Option<&mut dyn PollDevice> {
Some(self)
}
}

impl SaveRestore for VirtioMmioDevice {
Expand Down
Loading