diff --git a/Cargo.lock b/Cargo.lock index 0a31c69ff0..2ab00bc703 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -9005,7 +9005,6 @@ dependencies = [ "event-listener", "guestmem", "inspect", - "pal_async", "plan9", "task_control", "tracing", @@ -9025,7 +9024,6 @@ dependencies = [ "fs-err", "guestmem", "inspect", - "pal_async", "sparse_mmap", "task_control", "tracing", @@ -9052,14 +9050,12 @@ dependencies = [ "async-trait", "event-listener", "fuse", - "futures", "guestmem", "inspect", "lx", "lxutil", "ntapi", "pal", - "pal_async", "parking_lot", "task_control", "tracing", diff --git a/support/task_control/src/lib.rs b/support/task_control/src/lib.rs index 18b85eb992..8fc78c773c 100644 --- a/support/task_control/src/lib.rs +++ b/support/task_control/src/lib.rs @@ -517,43 +517,51 @@ impl, S: 'static + Send> TaskControl { .await } - /// Stops the task, waiting for it to be cancelled. + /// Poll variant of [`stop`](Self::stop). Signals the task to stop and polls + /// for completion. /// - /// Returns true if the task was previously running. Returns false if the - /// task was not running, not inserted, or had already completed. - pub async fn stop(&mut self) -> bool { + /// Returns `Poll::Ready(true)` if the task was previously running and has + /// now stopped. Returns `Poll::Ready(false)` if the task was not running, + /// not inserted, or had already completed. Returns `Poll::Pending` if the + /// task has not yet stopped. + pub fn poll_stop(&mut self, cx: &mut Context<'_>) -> Poll { match &mut self.inner { Inner::WithState { activity, shared, .. } => match activity { Activity::Running => { - let task_and_state = poll_fn(|cx| { - let mut shared = shared.lock(); - shared.stop = true; - if shared.task_and_state.is_none() || !shared.calls.is_empty() { - shared.outer_waker = Some(cx.waker().clone()); - let waker = shared.inner_waker.take(); - drop(shared); - if let Some(waker) = waker { - waker.wake(); - } - return Poll::Pending; + let mut shared = shared.lock(); + shared.stop = true; + if shared.task_and_state.is_none() || !shared.calls.is_empty() { + shared.outer_waker = Some(cx.waker().clone()); + let waker = shared.inner_waker.take(); + drop(shared); + if let Some(waker) = waker { + waker.wake(); } - Poll::Ready(shared.task_and_state.take().unwrap()) - }) - .await; - + return Poll::Pending; + } + let task_and_state = shared.task_and_state.take().unwrap(); + drop(shared); let done = task_and_state.done; *activity = Activity::Stopped(task_and_state); - !done + Poll::Ready(!done) } - _ => false, + _ => Poll::Ready(false), }, - Inner::NoState(_) => false, + Inner::NoState(_) => Poll::Ready(false), Inner::Invalid => unreachable!(), } } + /// Stops the task, waiting for it to be cancelled. + /// + /// Returns true if the task was previously running. Returns false if the + /// task was not running, not inserted, or had already completed. + pub async fn stop(&mut self) -> bool { + poll_fn(|cx| self.poll_stop(cx)).await + } + /// Removes the task state. /// /// Panics if the task is not stopped. @@ -655,4 +663,37 @@ mod tests { assert!(t.stop().await); assert_eq!(t.task_mut().0, 8); } + + #[async_test] + async fn test_poll_stop(driver: DefaultDriver) { + let mut t = TaskControl::new(Foo(5)); + + // poll_stop on a task without state returns Ready(false). + assert_eq!( + std::future::poll_fn(|cx| Poll::Ready(t.poll_stop(cx))).await, + Poll::Ready(false) + ); + + t.insert(&driver, "test", false); + + // poll_stop on a stopped (not started) task returns Ready(false). + assert_eq!( + std::future::poll_fn(|cx| Poll::Ready(t.poll_stop(cx))).await, + Poll::Ready(false) + ); + + assert!(t.start()); + yield_once().await; + + // poll_stop drives the task to stop, equivalent to stop().await. + let result = std::future::poll_fn(|cx| t.poll_stop(cx)).await; + assert!(result); // was running + assert_eq!(t.task().0, 6); + + // poll_stop after already stopped returns Ready(false). + assert_eq!( + std::future::poll_fn(|cx| Poll::Ready(t.poll_stop(cx))).await, + Poll::Ready(false) + ); + } } diff --git a/vm/devices/virtio/virtio/src/common.rs b/vm/devices/virtio/virtio/src/common.rs index d309e0514f..72784d2c9f 100644 --- a/vm/devices/virtio/virtio/src/common.rs +++ b/vm/devices/virtio/virtio/src/common.rs @@ -442,7 +442,15 @@ pub trait VirtioDevice: inspect::InspectMut + Send { fn read_registers_u32(&self, offset: u16) -> u32; fn write_registers_u32(&mut self, offset: u16, val: u32); fn enable(&mut self, resources: Resources); - fn disable(&mut self); + /// Poll the device to complete a disable/reset operation. + /// + /// This is called when the guest writes status=0 (device reset). The device + /// should stop workers and drain any in-flight IO. Returns `Poll::Ready(())` + /// when the disable is complete, or `Poll::Pending` if more work is needed. + /// + /// Devices that don't need async cleanup can return `Poll::Ready(())` + /// immediately. + fn poll_disable(&mut self, cx: &mut Context<'_>) -> Poll<()>; } pub struct QueueResources { diff --git a/vm/devices/virtio/virtio/src/tests.rs b/vm/devices/virtio/virtio/src/tests.rs index 7dee922185..c43d92be3d 100644 --- a/vm/devices/virtio/virtio/src/tests.rs +++ b/vm/devices/virtio/virtio/src/tests.rs @@ -33,7 +33,6 @@ use guestmem::GuestMemoryBackingError; use inspect::InspectMut; use pal_async::DefaultDriver; use pal_async::async_test; -use pal_async::task::Spawn; use pal_async::timer::PolledTimer; use pal_event::Event; use parking_lot::Mutex; @@ -846,20 +845,13 @@ impl VirtioDevice for TestDevice { .collect(); } - fn disable(&mut self) { - if self.workers.is_empty() { - return; - } + fn poll_disable(&mut self, cx: &mut std::task::Context<'_>) -> std::task::Poll<()> { self.exit_event.notify(usize::MAX); - let mut workers = self.workers.drain(..).collect::>(); - self.driver - .spawn("shutdown-test-virtio-queues".to_owned(), async move { - futures::future::join_all(workers.iter_mut().map(async |worker| { - worker.stop().await; - })) - .await; - }) - .detach(); + for worker in &mut self.workers { + std::task::ready!(worker.poll_stop(cx)); + } + self.workers.clear(); + std::task::Poll::Ready(()) } } diff --git a/vm/devices/virtio/virtio/src/transport/mmio.rs b/vm/devices/virtio/virtio/src/transport/mmio.rs index 2eaecef9d0..6e987a84a7 100644 --- a/vm/devices/virtio/virtio/src/transport/mmio.rs +++ b/vm/devices/virtio/virtio/src/transport/mmio.rs @@ -11,6 +11,7 @@ use crate::spec::*; use chipset_device::ChipsetDevice; use chipset_device::io::IoResult; use chipset_device::mmio::MmioIntercept; +use chipset_device::poll_device::PollDevice; use device_emulators::ReadWriteRequestType; use device_emulators::read_as_u32_chunks; use device_emulators::write_as_u32_chunks; @@ -53,6 +54,9 @@ pub struct VirtioMmioDevice { #[inspect(skip)] queues: Vec, device_status: VirtioDeviceStatus, + disabling: bool, + #[inspect(skip)] + poll_waker: Option, config_generation: u32, #[inspect(skip)] doorbells: VirtioDoorbells, @@ -129,6 +133,8 @@ impl VirtioMmioDevice { events, queues, device_status: VirtioDeviceStatus::new(), + disabling: false, + poll_waker: None, config_generation: 0, doorbells: VirtioDoorbells::new(doorbell_registration), interrupt_state, @@ -145,12 +151,6 @@ impl VirtioMmioDevice { } } -impl Drop for VirtioMmioDevice { - fn drop(&mut self) { - self.device.disable(); - } -} - impl VirtioMmioDevice { pub(crate) fn read_u32(&self, address: u64) -> u32 { let offset = (address & 0xfff) as u16; @@ -360,14 +360,31 @@ impl VirtioMmioDevice { // Device status 112 => { if val == 0 { + if self.disabling { + return; + } let started = self.device_status.driver_ok(); - self.device_status = VirtioDeviceStatus::new(); self.config_generation = 0; if started { self.doorbells.clear(); - self.device.disable(); + // Try the fast path: poll with a noop waker to see if + // the device can disable synchronously. + let waker = std::task::Waker::noop(); + let mut cx = std::task::Context::from_waker(waker); + if self.device.poll_disable(&mut cx).is_pending() { + self.disabling = true; + // Wake the real poll waker so that poll_device will + // re-poll with a real waker, replacing the noop one. + if let Some(waker) = self.poll_waker.take() { + waker.wake(); + } + return; + } } + // Fast path: disable completed synchronously. + self.device_status = VirtioDeviceStatus::new(); self.interrupt_state.lock().update(false, !0); + return; } let new_status = VirtioDeviceStatus::from(val as u8); @@ -480,7 +497,27 @@ impl ChangeDeviceState for VirtioMmioDevice { async fn stop(&mut self) {} async fn reset(&mut self) { - // TODO + if self.device_status.driver_ok() || self.disabling { + self.doorbells.clear(); + std::future::poll_fn(|cx| self.device.poll_disable(cx)).await; + } + self.device_status = VirtioDeviceStatus::new(); + self.disabling = false; + self.config_generation = 0; + self.interrupt_state.lock().update(false, !0); + } +} + +impl PollDevice for VirtioMmioDevice { + fn poll_device(&mut self, cx: &mut std::task::Context<'_>) { + self.poll_waker = Some(cx.waker().clone()); + if self.disabling { + if self.device.poll_disable(cx).is_ready() { + self.device_status = VirtioDeviceStatus::new(); + self.disabling = false; + self.interrupt_state.lock().update(false, !0); + } + } } } @@ -488,6 +525,10 @@ impl ChipsetDevice for VirtioMmioDevice { fn supports_mmio(&mut self) -> Option<&mut dyn MmioIntercept> { Some(self) } + + fn supports_poll_device(&mut self) -> Option<&mut dyn PollDevice> { + Some(self) + } } impl SaveRestore for VirtioMmioDevice { diff --git a/vm/devices/virtio/virtio/src/transport/pci.rs b/vm/devices/virtio/virtio/src/transport/pci.rs index dc7b4a8f04..87bf1ed66b 100644 --- a/vm/devices/virtio/virtio/src/transport/pci.rs +++ b/vm/devices/virtio/virtio/src/transport/pci.rs @@ -17,6 +17,7 @@ use chipset_device::io::IoResult; use chipset_device::mmio::MmioIntercept; use chipset_device::mmio::RegisterMmioIntercept; use chipset_device::pci::PciConfigSpace; +use chipset_device::poll_device::PollDevice; use device_emulators::ReadWriteRequestType; use device_emulators::read_as_u32_chunks; use device_emulators::write_as_u32_chunks; @@ -84,6 +85,9 @@ pub struct VirtioPciDevice { interrupt_status: Arc>, #[inspect(hex)] device_status: VirtioDeviceStatus, + disabling: bool, + #[inspect(skip)] + poll_waker: Option, config_generation: u32, config_space: ConfigSpaceType0Emulator, @@ -233,6 +237,8 @@ impl VirtioPciDevice { msix_vectors, interrupt_status: Arc::new(Mutex::new(0)), device_status: VirtioDeviceStatus::new(), + disabling: false, + poll_waker: None, config_generation: 0, interrupt_kind, config_space, @@ -400,14 +406,31 @@ impl VirtioPciDevice { self.queue_select = val >> 16; let val = val & 0xff; if val == 0 { + if self.disabling { + return; + } let started = self.device_status.driver_ok(); - self.device_status = VirtioDeviceStatus::new(); self.config_generation = 0; if started { self.doorbells.clear(); - self.device.disable(); + // Try the fast path: poll with a noop waker to see if + // the device can disable synchronously. + let waker = std::task::Waker::noop(); + let mut cx = std::task::Context::from_waker(waker); + if self.device.poll_disable(&mut cx).is_pending() { + self.disabling = true; + // Wake the real poll waker so that poll_device will + // re-poll with a real waker, replacing the noop one. + if let Some(waker) = self.poll_waker.take() { + waker.wake(); + } + return; + } } + // Fast path: disable completed synchronously. + self.device_status = VirtioDeviceStatus::new(); *self.interrupt_status.lock() = 0; + return; } let new_status = VirtioDeviceStatus::from(val as u8); @@ -558,13 +581,6 @@ impl VirtioPciDevice { } } -impl Drop for VirtioPciDevice { - fn drop(&mut self) { - // TODO conditionalize - self.device.disable(); - } -} - impl VirtioPciDevice { fn read_bar_u32(&mut self, bar: u8, offset: u16) -> u32 { match bar { @@ -599,7 +615,27 @@ impl ChangeDeviceState for VirtioPciDevice { async fn stop(&mut self) {} async fn reset(&mut self) { - // TODO + if self.device_status.driver_ok() || self.disabling { + self.doorbells.clear(); + std::future::poll_fn(|cx| self.device.poll_disable(cx)).await; + } + self.device_status = VirtioDeviceStatus::new(); + self.disabling = false; + self.config_generation = 0; + *self.interrupt_status.lock() = 0; + } +} + +impl PollDevice for VirtioPciDevice { + fn poll_device(&mut self, cx: &mut std::task::Context<'_>) { + self.poll_waker = Some(cx.waker().clone()); + if self.disabling { + if self.device.poll_disable(cx).is_ready() { + self.device_status = VirtioDeviceStatus::new(); + self.disabling = false; + *self.interrupt_status.lock() = 0; + } + } } } @@ -611,6 +647,10 @@ impl ChipsetDevice for VirtioPciDevice { fn supports_pci(&mut self) -> Option<&mut dyn PciConfigSpace> { Some(self) } + + fn supports_poll_device(&mut self) -> Option<&mut dyn PollDevice> { + Some(self) + } } impl SaveRestore for VirtioPciDevice { diff --git a/vm/devices/virtio/virtio_net/src/lib.rs b/vm/devices/virtio/virtio_net/src/lib.rs index eda8620d6a..c808274b4c 100644 --- a/vm/devices/virtio/virtio_net/src/lib.rs +++ b/vm/devices/virtio/virtio_net/src/lib.rs @@ -41,6 +41,7 @@ use pal_async::wait::PolledWait; use std::future::pending; use std::mem::offset_of; use std::sync::Arc; +use std::task::Context; use std::task::Poll; use task_control::AsyncRun; use task_control::InspectTaskMut; @@ -346,10 +347,11 @@ impl VirtioDevice for Device { self.coordinator.start(); } - fn disable(&mut self) { + fn poll_disable(&mut self, _cx: &mut Context<'_>) -> Poll<()> { if let Some(send) = self.coordinator_send.take() { send.send(CoordinatorMessage::Disable); } + Poll::Ready(()) } } diff --git a/vm/devices/virtio/virtio_p9/Cargo.toml b/vm/devices/virtio/virtio_p9/Cargo.toml index 0ef2f340aa..f3876a192e 100644 --- a/vm/devices/virtio/virtio_p9/Cargo.toml +++ b/vm/devices/virtio/virtio_p9/Cargo.toml @@ -20,7 +20,6 @@ task_control.workspace = true anyhow.workspace = true async-trait.workspace = true event-listener.workspace = true -pal_async.workspace = true tracing.workspace = true [lints] diff --git a/vm/devices/virtio/virtio_p9/src/lib.rs b/vm/devices/virtio/virtio_p9/src/lib.rs index cbf1b8951a..9c432d8e92 100644 --- a/vm/devices/virtio/virtio_p9/src/lib.rs +++ b/vm/devices/virtio/virtio_p9/src/lib.rs @@ -10,9 +10,11 @@ pub mod resolver; use async_trait::async_trait; use guestmem::GuestMemory; use inspect::InspectMut; -use pal_async::task::Spawn; use plan9::Plan9FileSystem; use std::sync::Arc; +use std::task::Context; +use std::task::Poll; +use std::task::ready; use task_control::TaskControl; use virtio::DeviceTraits; use virtio::Resources; @@ -132,16 +134,13 @@ impl VirtioDevice for VirtioPlan9Device { )); } - fn disable(&mut self) { - let Some(mut worker) = self.worker.take() else { - return; - }; + fn poll_disable(&mut self, cx: &mut Context<'_>) -> Poll<()> { self.exit_event.notify(usize::MAX); - self.driver - .spawn("shutdown-virtio-9p-queue".to_owned(), async move { - worker.stop().await; - }) - .detach(); + if let Some(worker) = &mut self.worker { + ready!(worker.poll_stop(cx)); + } + self.worker = None; + Poll::Ready(()) } } diff --git a/vm/devices/virtio/virtio_pmem/Cargo.toml b/vm/devices/virtio/virtio_pmem/Cargo.toml index c345876c4d..6649e952d9 100644 --- a/vm/devices/virtio/virtio_pmem/Cargo.toml +++ b/vm/devices/virtio/virtio_pmem/Cargo.toml @@ -14,8 +14,6 @@ inspect = { workspace = true, features = ["filepath"] } guestmem.workspace = true vmcore.workspace = true vm_resource.workspace = true - -pal_async.workspace = true sparse_mmap.workspace = true task_control.workspace = true diff --git a/vm/devices/virtio/virtio_pmem/src/lib.rs b/vm/devices/virtio/virtio_pmem/src/lib.rs index 75c7a14d42..5b90a6d8a6 100644 --- a/vm/devices/virtio/virtio_pmem/src/lib.rs +++ b/vm/devices/virtio/virtio_pmem/src/lib.rs @@ -10,9 +10,10 @@ use anyhow::Context; use async_trait::async_trait; use guestmem::GuestMemory; use inspect::InspectMut; -use pal_async::task::Spawn; use std::fs; use std::sync::Arc; +use std::task::Poll; +use std::task::ready; use task_control::TaskControl; use virtio::DeviceTraits; use virtio::DeviceTraitsSharedMemory; @@ -126,15 +127,13 @@ impl VirtioDevice for Device { }; } - fn disable(&mut self) { + fn poll_disable(&mut self, cx: &mut std::task::Context<'_>) -> Poll<()> { self.exit_event.notify(usize::MAX); - if let Some(mut worker) = self.worker.take() { - self.driver - .spawn("shutdown-virtio-pmem-queue".to_owned(), async move { - worker.stop().await; - }) - .detach(); + if let Some(worker) = &mut self.worker { + ready!(worker.poll_stop(cx)); } + self.worker = None; + Poll::Ready(()) } } diff --git a/vm/devices/virtio/virtiofs/Cargo.toml b/vm/devices/virtio/virtiofs/Cargo.toml index e54797a0f5..4d4af8e571 100644 --- a/vm/devices/virtio/virtiofs/Cargo.toml +++ b/vm/devices/virtio/virtiofs/Cargo.toml @@ -19,13 +19,11 @@ inspect.workspace = true lx.workspace = true lxutil.workspace = true pal.workspace = true -pal_async.workspace = true task_control.workspace = true anyhow.workspace = true async-trait.workspace = true event-listener.workspace = true -futures.workspace = true parking_lot.workspace = true tracing.workspace = true zerocopy.workspace = true diff --git a/vm/devices/virtio/virtiofs/src/virtio.rs b/vm/devices/virtio/virtiofs/src/virtio.rs index f9cd5ba4d8..f03fedb41c 100644 --- a/vm/devices/virtio/virtiofs/src/virtio.rs +++ b/vm/devices/virtio/virtiofs/src/virtio.rs @@ -7,10 +7,12 @@ use async_trait::async_trait; use guestmem::GuestMemory; use guestmem::MappedMemoryRegion; use inspect::InspectMut; -use pal_async::task::Spawn; use std::io; use std::io::Write; use std::sync::Arc; +use std::task::Context; +use std::task::Poll; +use std::task::ready; use task_control::TaskControl; use virtio::DeviceTraits; use virtio::DeviceTraitsSharedMemory; @@ -158,17 +160,13 @@ impl VirtioDevice for VirtioFsDevice { .collect(); } - fn disable(&mut self) { + fn poll_disable(&mut self, cx: &mut Context<'_>) -> Poll<()> { self.exit_event.notify(usize::MAX); - let mut workers = self.workers.drain(..).collect::>(); - self.driver - .spawn("shutdown-virtiofs-queues".to_owned(), async move { - futures::future::join_all(workers.iter_mut().map(async |worker| { - worker.stop().await; - })) - .await; - }) - .detach(); + for worker in &mut self.workers { + ready!(worker.poll_stop(cx)); + } + self.workers.clear(); + Poll::Ready(()) } }