diff --git a/internal/notifier/notifier.go b/internal/notifier/notifier.go index 42d3a63a..b3e881b4 100644 --- a/internal/notifier/notifier.go +++ b/internal/notifier/notifier.go @@ -1,6 +1,7 @@ package notifier import ( + "cmp" "context" "errors" "fmt" @@ -78,7 +79,8 @@ type Notifier struct { baseservice.BaseService startstop.BaseStartStop - disableSleep bool // for tests only; disable sleep on exponential backoff + disableSleep bool // for tests only; disable sleep on exponential backoff + testPingInterval time.Duration // for tests only; override the 5s ping interval listener riverdriver.Listener notificationBuf chan *riverdriver.Notification testSignals notifierTestSignals @@ -345,6 +347,12 @@ func (n *Notifier) waitOnce(ctx context.Context) error { n.waitCancel() }) + // Save a reference to the parent context before creating the inner + // cancellable context. The inner context is cancelled by drainErrChan to + // interrupt WaitForNotification, but we still need a live context for the + // Ping health check afterward. + pingCtx := ctx + ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -382,7 +390,8 @@ func (n *Notifier) waitOnce(ctx context.Context) error { return nil } - needPingCtx, needPingCancel := context.WithTimeout(ctx, 5*time.Second) + pingInterval := cmp.Or(n.testPingInterval, 5*time.Second) + needPingCtx, needPingCancel := context.WithTimeout(ctx, pingInterval) defer needPingCancel() // * Wait for notifications @@ -397,8 +406,15 @@ func (n *Notifier) waitOnce(ctx context.Context) error { if err := drainErrChan(); err != nil { return err } - // Ping the conn to see if it's still alive - if err := n.listener.Ping(ctx); err != nil { + // Ping the conn to see if it's still alive. Use pingCtx (the parent + // context) because the inner ctx was cancelled by drainErrChan above + // to interrupt WaitForNotification. + // + // Note: Previously this used the (already cancelled) inner ctx, making + // the ping a no-op that always returned context.Canceled. With the fix, + // dead or flaky connections are now actively detected, which may trigger + // reconnections that were previously silently swallowed. + if err := n.listener.Ping(pingCtx); err != nil { return err } diff --git a/internal/notifier/notifier_test.go b/internal/notifier/notifier_test.go index 2c4bcccf..f9b44daa 100644 --- a/internal/notifier/notifier_test.go +++ b/internal/notifier/notifier_test.go @@ -531,6 +531,51 @@ func TestNotifier(t *testing.T) { require.EqualError(t, notifier.testSignals.BackoffError.WaitOrTimeout(), "error during wait") }) + t.Run("PingUsesNonCancelledContext", func(t *testing.T) { + t.Parallel() + + notifier, _ := setup(t, nil) + + // Use a very short ping interval so the test doesn't take 5 seconds. + notifier.testPingInterval = 50 * time.Millisecond + + var ( + pingCtxCancelled bool + pingCalled = make(chan struct{}) + pingOnce sync.Once + ) + + listenerMock := NewListenerMock(notifier.listener) + listenerMock.waitForNotificationFunc = func(ctx context.Context) (*riverdriver.Notification, error) { + // Block until the context is cancelled (which happens when + // drainErrChan runs after the ping interval elapses). + <-ctx.Done() + return nil, ctx.Err() + } + listenerMock.pingFunc = func(ctx context.Context) error { + pingOnce.Do(func() { + pingCtxCancelled = ctx.Err() != nil + close(pingCalled) + }) + return nil + } + notifier.listener = listenerMock + + start(t, notifier) + + notifier.testSignals.ListeningBegin.WaitOrTimeout() + + select { + case <-pingCalled: + case <-time.After(5 * time.Second): + require.FailNow(t, "Timed out waiting for Ping to be called") + } + + require.False(t, pingCtxCancelled, + "Ping should receive a non-cancelled context; the inner context is "+ + "cancelled to interrupt WaitForNotification, but Ping needs a live context") + }) + t.Run("StillFunctionalAfterMainLoopFailure", func(t *testing.T) { t.Parallel() @@ -584,6 +629,7 @@ type ListenerMock struct { connectFunc func(ctx context.Context) error listenFunc func(ctx context.Context, topic string) error + pingFunc func(ctx context.Context) error waitForNotificationFunc func(ctx context.Context) (*riverdriver.Notification, error) } @@ -593,6 +639,7 @@ func NewListenerMock(listener riverdriver.Listener) *ListenerMock { connectFunc: listener.Connect, listenFunc: listener.Listen, + pingFunc: listener.Ping, waitForNotificationFunc: listener.WaitForNotification, } } @@ -605,6 +652,10 @@ func (l *ListenerMock) Listen(ctx context.Context, topic string) error { return l.listenFunc(ctx, topic) } +func (l *ListenerMock) Ping(ctx context.Context) error { + return l.pingFunc(ctx) +} + func (l *ListenerMock) WaitForNotification(ctx context.Context) (*riverdriver.Notification, error) { return l.waitForNotificationFunc(ctx) }