1#![cfg_attr(feature = "docs", doc = "\n\nSee the [changelog][changelog] for a full release history.")]
4#![cfg_attr(feature = "docs", doc = "## Feature flags")]
5#![cfg_attr(feature = "docs", doc = document_features::document_features!())]
6#![cfg_attr(all(coverage_nightly, test), feature(coverage_attribute))]
56#![cfg_attr(docsrs, feature(doc_auto_cfg))]
57#![deny(missing_docs)]
58#![deny(unreachable_pub)]
59#![deny(clippy::mod_module_files)]
60#![deny(clippy::undocumented_unsafe_blocks)]
61#![deny(clippy::multiple_unsafe_ops_per_block)]
62
63use std::pin::Pin;
64use std::task::{Context, Poll};
65
66#[cfg(unix)]
67use tokio::signal::unix;
68#[cfg(unix)]
69pub use tokio::signal::unix::SignalKind as UnixSignalKind;
70
71#[cfg(feature = "bootstrap")]
72mod bootstrap;
73
74#[cfg(feature = "bootstrap")]
75pub use bootstrap::{SignalConfig, SignalSvc};
76
77#[derive(Debug, Clone, Copy, Eq)]
79pub enum SignalKind {
80 Interrupt,
82 Terminate,
84 #[cfg(windows)]
86 Windows(WindowsSignalKind),
87 #[cfg(unix)]
89 Unix(UnixSignalKind),
90}
91
92impl PartialEq for SignalKind {
93 fn eq(&self, other: &Self) -> bool {
94 #[cfg(unix)]
95 const INTERRUPT: UnixSignalKind = UnixSignalKind::interrupt();
96 #[cfg(unix)]
97 const TERMINATE: UnixSignalKind = UnixSignalKind::terminate();
98
99 match (self, other) {
100 #[cfg(windows)]
101 (
102 Self::Interrupt | Self::Windows(WindowsSignalKind::CtrlC),
103 Self::Interrupt | Self::Windows(WindowsSignalKind::CtrlC),
104 ) => true,
105 #[cfg(windows)]
106 (
107 Self::Terminate | Self::Windows(WindowsSignalKind::CtrlClose),
108 Self::Terminate | Self::Windows(WindowsSignalKind::CtrlClose),
109 ) => true,
110 #[cfg(windows)]
111 (Self::Windows(a), Self::Windows(b)) => a == b,
112 #[cfg(unix)]
113 (Self::Interrupt | Self::Unix(INTERRUPT), Self::Interrupt | Self::Unix(INTERRUPT)) => true,
114 #[cfg(unix)]
115 (Self::Terminate | Self::Unix(TERMINATE), Self::Terminate | Self::Unix(TERMINATE)) => true,
116 #[cfg(unix)]
117 (Self::Unix(a), Self::Unix(b)) => a == b,
118 _ => false,
119 }
120 }
121}
122
123#[cfg(unix)]
124impl From<UnixSignalKind> for SignalKind {
125 fn from(value: UnixSignalKind) -> Self {
126 match value {
127 kind if kind == UnixSignalKind::interrupt() => Self::Interrupt,
128 kind if kind == UnixSignalKind::terminate() => Self::Terminate,
129 kind => Self::Unix(kind),
130 }
131 }
132}
133
134#[cfg(unix)]
135impl PartialEq<UnixSignalKind> for SignalKind {
136 fn eq(&self, other: &UnixSignalKind) -> bool {
137 match self {
138 Self::Interrupt => other == &UnixSignalKind::interrupt(),
139 Self::Terminate => other == &UnixSignalKind::terminate(),
140 Self::Unix(kind) => kind == other,
141 }
142 }
143}
144
145#[cfg(windows)]
147#[derive(Debug, Clone, Copy, PartialEq, Eq)]
148pub enum WindowsSignalKind {
149 CtrlBreak,
151 CtrlC,
153 CtrlClose,
155 CtrlLogoff,
157 CtrlShutdown,
159}
160
161#[cfg(windows)]
162impl From<WindowsSignalKind> for SignalKind {
163 fn from(value: WindowsSignalKind) -> Self {
164 match value {
165 WindowsSignalKind::CtrlC => Self::Interrupt,
166 WindowsSignalKind::CtrlClose => Self::Terminate,
167 WindowsSignalKind::CtrlBreak => Self::Windows(value),
168 WindowsSignalKind::CtrlLogoff => Self::Windows(value),
169 WindowsSignalKind::CtrlShutdown => Self::Windows(value),
170 }
171 }
172}
173
174#[cfg(windows)]
175impl PartialEq<WindowsSignalKind> for SignalKind {
176 fn eq(&self, other: &WindowsSignalKind) -> bool {
177 match self {
178 Self::Interrupt => other == &WindowsSignalKind::CtrlC,
179 Self::Terminate => other == &WindowsSignalKind::CtrlClose,
180 Self::Windows(kind) => kind == other,
181 }
182 }
183}
184
185#[cfg(windows)]
186#[derive(Debug)]
187enum WindowsSignalValue {
188 CtrlBreak(tokio::signal::windows::CtrlBreak),
189 CtrlC(tokio::signal::windows::CtrlC),
190 CtrlClose(tokio::signal::windows::CtrlClose),
191 CtrlLogoff(tokio::signal::windows::CtrlLogoff),
192 CtrlShutdown(tokio::signal::windows::CtrlShutdown),
193 #[cfg(test)]
194 Mock(SignalKind, Pin<Box<tokio_stream::wrappers::BroadcastStream<SignalKind>>>),
195}
196
197#[cfg(windows)]
198impl WindowsSignalValue {
199 fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<Option<()>> {
200 #[cfg(test)]
201 use futures::Stream;
202
203 match self {
204 Self::CtrlBreak(signal) => signal.poll_recv(cx),
205 Self::CtrlC(signal) => signal.poll_recv(cx),
206 Self::CtrlClose(signal) => signal.poll_recv(cx),
207 Self::CtrlLogoff(signal) => signal.poll_recv(cx),
208 Self::CtrlShutdown(signal) => signal.poll_recv(cx),
209 #[cfg(test)]
210 Self::Mock(kind, receiver) => match receiver.as_mut().poll_next(cx) {
211 Poll::Ready(Some(Ok(recv))) if recv == *kind => Poll::Ready(Some(())),
212 Poll::Ready(Some(Ok(_))) => {
213 cx.waker().wake_by_ref();
214 Poll::Pending
215 }
216 Poll::Ready(v) => unreachable!("receiver should always have a value: {:?}", v),
217 Poll::Pending => {
218 cx.waker().wake_by_ref();
219 Poll::Pending
220 }
221 },
222 }
223 }
224}
225
226#[cfg(unix)]
227type Signal = unix::Signal;
228
229#[cfg(windows)]
230type Signal = WindowsSignalValue;
231
232impl SignalKind {
233 #[cfg(unix)]
234 fn listen(&self) -> Result<Signal, std::io::Error> {
235 match self {
236 Self::Interrupt => tokio::signal::unix::signal(UnixSignalKind::interrupt()),
237 Self::Terminate => tokio::signal::unix::signal(UnixSignalKind::terminate()),
238 Self::Unix(kind) => tokio::signal::unix::signal(*kind),
239 }
240 }
241
242 #[cfg(windows)]
243 fn listen(&self) -> Result<Signal, std::io::Error> {
244 #[cfg(test)]
245 if cfg!(test) {
246 return Ok(WindowsSignalValue::Mock(
247 *self,
248 Box::pin(tokio_stream::wrappers::BroadcastStream::new(tests::SignalMocker::subscribe())),
249 ));
250 }
251
252 match self {
253 Self::Interrupt | Self::Windows(WindowsSignalKind::CtrlC) => {
255 Ok(WindowsSignalValue::CtrlC(tokio::signal::windows::ctrl_c()?))
256 }
257 Self::Terminate | Self::Windows(WindowsSignalKind::CtrlClose) => {
259 Ok(WindowsSignalValue::CtrlClose(tokio::signal::windows::ctrl_close()?))
260 }
261 Self::Windows(WindowsSignalKind::CtrlBreak) => {
262 Ok(WindowsSignalValue::CtrlBreak(tokio::signal::windows::ctrl_break()?))
263 }
264 Self::Windows(WindowsSignalKind::CtrlLogoff) => {
265 Ok(WindowsSignalValue::CtrlLogoff(tokio::signal::windows::ctrl_logoff()?))
266 }
267 Self::Windows(WindowsSignalKind::CtrlShutdown) => {
268 Ok(WindowsSignalValue::CtrlShutdown(tokio::signal::windows::ctrl_shutdown()?))
269 }
270 }
271 }
272}
273
274#[derive(Debug)]
322#[must_use = "signal handlers must be used to wait for signals"]
323pub struct SignalHandler {
324 signals: Vec<(SignalKind, Signal)>,
325}
326
327impl Default for SignalHandler {
328 fn default() -> Self {
329 Self::new()
330 }
331}
332
333impl SignalHandler {
334 pub const fn new() -> Self {
336 Self { signals: Vec::new() }
337 }
338
339 pub fn with_signals<T: Into<SignalKind>>(signals: impl IntoIterator<Item = T>) -> Self {
341 let mut handler = Self::new();
342
343 for signal in signals {
344 handler = handler.with_signal(signal.into());
345 }
346
347 handler
348 }
349
350 pub fn with_signal(mut self, kind: impl Into<SignalKind>) -> Self {
354 self.add_signal(kind);
355 self
356 }
357
358 pub fn add_signal(&mut self, kind: impl Into<SignalKind>) -> &mut Self {
362 let kind = kind.into();
363 if self.signals.iter().any(|(k, _)| k == &kind) {
364 return self;
365 }
366
367 let signal = kind.listen().expect("failed to create signal");
368
369 self.signals.push((kind, signal));
370
371 self
372 }
373
374 pub async fn recv(&mut self) -> SignalKind {
378 self.await
379 }
380
381 pub fn poll_recv(&mut self, cx: &mut Context<'_>) -> Poll<SignalKind> {
384 for (kind, signal) in self.signals.iter_mut() {
385 if signal.poll_recv(cx).is_ready() {
386 return Poll::Ready(*kind);
387 }
388 }
389
390 Poll::Pending
391 }
392}
393
394impl std::future::Future for SignalHandler {
395 type Output = SignalKind;
396
397 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
398 self.poll_recv(cx)
399 }
400}
401
402#[cfg(feature = "docs")]
404#[scuffle_changelog::changelog]
405pub mod changelog {}
406
407#[cfg(test)]
408#[cfg_attr(coverage_nightly, coverage(off))]
409mod tests {
410 use std::time::Duration;
411
412 use scuffle_future_ext::FutureExt;
413
414 use crate::{SignalHandler, SignalKind};
415
416 #[cfg(windows)]
417 pub(crate) struct SignalMocker(tokio::sync::broadcast::Sender<SignalKind>);
418
419 #[cfg(windows)]
420 impl SignalMocker {
421 fn new() -> Self {
422 println!("new");
423 let (sender, _) = tokio::sync::broadcast::channel(100);
424 Self(sender)
425 }
426
427 fn raise(kind: SignalKind) {
428 println!("raising");
429 SIGNAL_MOCKER.with(|local| local.0.send(kind).unwrap());
430 }
431
432 pub(crate) fn subscribe() -> tokio::sync::broadcast::Receiver<SignalKind> {
433 println!("subscribing");
434 SIGNAL_MOCKER.with(|local| local.0.subscribe())
435 }
436 }
437
438 #[cfg(windows)]
439 thread_local! {
440 static SIGNAL_MOCKER: SignalMocker = SignalMocker::new();
441 }
442
443 #[cfg(windows)]
444 pub(crate) async fn raise_signal(kind: SignalKind) {
445 SignalMocker::raise(kind);
446 }
447
448 #[cfg(unix)]
449 pub(crate) async fn raise_signal(kind: SignalKind) {
450 unsafe {
452 libc::raise(match kind {
453 SignalKind::Interrupt => libc::SIGINT,
454 SignalKind::Terminate => libc::SIGTERM,
455 SignalKind::Unix(kind) => kind.as_raw_value(),
456 });
457 }
458 }
459
460 #[cfg(windows)]
461 #[tokio::test]
462 async fn signal_handler() {
463 use crate::WindowsSignalKind;
464
465 let mut handler = SignalHandler::with_signals([WindowsSignalKind::CtrlC, WindowsSignalKind::CtrlBreak]);
466
467 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
468
469 raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlC)).await;
470
471 let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
472
473 assert_eq!(recv, WindowsSignalKind::CtrlC, "expected CtrlC");
474
475 let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await;
476 assert!(recv.is_err(), "expected timeout");
477
478 raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlBreak)).await;
479
480 let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
481
482 assert_eq!(recv, WindowsSignalKind::CtrlBreak, "expected CtrlBreak");
483 }
484
485 #[cfg(windows)]
486 #[tokio::test]
487 async fn add_signal() {
488 use crate::WindowsSignalKind;
489
490 let mut handler = SignalHandler::new();
491
492 handler
493 .add_signal(WindowsSignalKind::CtrlC)
494 .add_signal(WindowsSignalKind::CtrlBreak)
495 .add_signal(WindowsSignalKind::CtrlC);
496
497 tokio::time::sleep(tokio::time::Duration::from_secs(1)).await;
498
499 raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlC)).await;
500
501 let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
502
503 assert_eq!(recv, WindowsSignalKind::CtrlC, "expected CtrlC");
504
505 raise_signal(SignalKind::Windows(WindowsSignalKind::CtrlBreak)).await;
506
507 let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
508
509 assert_eq!(recv, WindowsSignalKind::CtrlBreak, "expected CtrlBreak");
510 }
511
512 #[cfg(all(not(valgrind), unix))] #[tokio::test]
514 async fn signal_handler() {
515 use crate::UnixSignalKind;
516
517 let mut handler = SignalHandler::with_signals([UnixSignalKind::user_defined1()])
518 .with_signal(UnixSignalKind::user_defined2())
519 .with_signal(UnixSignalKind::user_defined1());
520
521 raise_signal(SignalKind::Unix(UnixSignalKind::user_defined1())).await;
522
523 let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
524
525 assert_eq!(recv, SignalKind::Unix(UnixSignalKind::user_defined1()), "expected SIGUSR1");
526
527 let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await;
529
530 assert!(recv.is_err(), "expected timeout");
531
532 raise_signal(SignalKind::Unix(UnixSignalKind::user_defined2())).await;
533
534 let recv = (&mut handler).with_timeout(Duration::from_millis(500)).await.unwrap();
536
537 assert_eq!(recv, UnixSignalKind::user_defined2(), "expected SIGUSR2");
538 }
539
540 #[cfg(all(not(valgrind), unix))] #[tokio::test]
542 async fn add_signal() {
543 use crate::UnixSignalKind;
544
545 let mut handler = SignalHandler::new();
546
547 handler
548 .add_signal(UnixSignalKind::user_defined1())
549 .add_signal(UnixSignalKind::user_defined2())
550 .add_signal(UnixSignalKind::user_defined2());
551
552 raise_signal(SignalKind::Unix(UnixSignalKind::user_defined1())).await;
553
554 let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
555
556 assert_eq!(recv, UnixSignalKind::user_defined1(), "expected SIGUSR1");
557
558 raise_signal(SignalKind::Unix(UnixSignalKind::user_defined2())).await;
559
560 let recv = handler.recv().with_timeout(Duration::from_millis(500)).await.unwrap();
561
562 assert_eq!(recv, UnixSignalKind::user_defined2(), "expected SIGUSR2");
563 }
564
565 #[cfg(not(valgrind))] #[tokio::test]
567 async fn no_signals() {
568 let mut handler = SignalHandler::default();
569
570 assert!(handler.recv().with_timeout(Duration::from_millis(500)).await.is_err());
572 }
573
574 #[cfg(windows)]
575 #[test]
576 fn signal_kind_eq() {
577 use crate::WindowsSignalKind;
578
579 assert_eq!(SignalKind::Interrupt, SignalKind::Windows(WindowsSignalKind::CtrlC));
580 assert_eq!(SignalKind::Terminate, SignalKind::Windows(WindowsSignalKind::CtrlClose));
581 assert_eq!(SignalKind::Windows(WindowsSignalKind::CtrlC), SignalKind::Interrupt);
582 assert_eq!(SignalKind::Windows(WindowsSignalKind::CtrlClose), SignalKind::Terminate);
583 assert_ne!(SignalKind::Interrupt, SignalKind::Terminate);
584 assert_eq!(
585 SignalKind::Windows(WindowsSignalKind::CtrlBreak),
586 SignalKind::Windows(WindowsSignalKind::CtrlBreak)
587 );
588 }
589
590 #[cfg(unix)]
591 #[test]
592 fn signal_kind_eq() {
593 use crate::UnixSignalKind;
594
595 assert_eq!(SignalKind::Interrupt, SignalKind::Unix(UnixSignalKind::interrupt()));
596 assert_eq!(SignalKind::Terminate, SignalKind::Unix(UnixSignalKind::terminate()));
597 assert_eq!(SignalKind::Unix(UnixSignalKind::interrupt()), SignalKind::Interrupt);
598 assert_eq!(SignalKind::Unix(UnixSignalKind::terminate()), SignalKind::Terminate);
599 assert_ne!(SignalKind::Interrupt, SignalKind::Terminate);
600 assert_eq!(
601 SignalKind::Unix(UnixSignalKind::user_defined1()),
602 SignalKind::Unix(UnixSignalKind::user_defined1())
603 );
604 }
605}