Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
24 changes: 20 additions & 4 deletions internal/notifier/notifier.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package notifier

import (
"cmp"
"context"
"errors"
"fmt"
Expand Down Expand Up @@ -78,7 +79,8 @@ type Notifier struct {
baseservice.BaseService
startstop.BaseStartStop

disableSleep bool // for tests only; disable sleep on exponential backoff
disableSleep bool // for tests only; disable sleep on exponential backoff
testPingInterval time.Duration // for tests only; override the 5s ping interval
listener riverdriver.Listener
notificationBuf chan *riverdriver.Notification
testSignals notifierTestSignals
Expand Down Expand Up @@ -345,6 +347,12 @@ func (n *Notifier) waitOnce(ctx context.Context) error {
n.waitCancel()
})

// Save a reference to the parent context before creating the inner
// cancellable context. The inner context is cancelled by drainErrChan to
// interrupt WaitForNotification, but we still need a live context for the
// Ping health check afterward.
pingCtx := ctx

ctx, cancel := context.WithCancel(ctx)
defer cancel()

Expand Down Expand Up @@ -382,7 +390,8 @@ func (n *Notifier) waitOnce(ctx context.Context) error {
return nil
}

needPingCtx, needPingCancel := context.WithTimeout(ctx, 5*time.Second)
pingInterval := cmp.Or(n.testPingInterval, 5*time.Second)
needPingCtx, needPingCancel := context.WithTimeout(ctx, pingInterval)
defer needPingCancel()

// * Wait for notifications
Expand All @@ -397,8 +406,15 @@ func (n *Notifier) waitOnce(ctx context.Context) error {
if err := drainErrChan(); err != nil {
return err
}
// Ping the conn to see if it's still alive
if err := n.listener.Ping(ctx); err != nil {
// Ping the conn to see if it's still alive. Use pingCtx (the parent
// context) because the inner ctx was cancelled by drainErrChan above
// to interrupt WaitForNotification.
//
// Note: Previously this used the (already cancelled) inner ctx, making
// the ping a no-op that always returned context.Canceled. With the fix,
// dead or flaky connections are now actively detected, which may trigger
// reconnections that were previously silently swallowed.
if err := n.listener.Ping(pingCtx); err != nil {
return err
}

Expand Down
51 changes: 51 additions & 0 deletions internal/notifier/notifier_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -531,6 +531,51 @@ func TestNotifier(t *testing.T) {
require.EqualError(t, notifier.testSignals.BackoffError.WaitOrTimeout(), "error during wait")
})

t.Run("PingUsesNonCancelledContext", func(t *testing.T) {
t.Parallel()

notifier, _ := setup(t, nil)

// Use a very short ping interval so the test doesn't take 5 seconds.
notifier.testPingInterval = 50 * time.Millisecond

var (
pingCtxCancelled bool
pingCalled = make(chan struct{})
pingOnce sync.Once
)

listenerMock := NewListenerMock(notifier.listener)
listenerMock.waitForNotificationFunc = func(ctx context.Context) (*riverdriver.Notification, error) {
// Block until the context is cancelled (which happens when
// drainErrChan runs after the ping interval elapses).
<-ctx.Done()
return nil, ctx.Err()
}
listenerMock.pingFunc = func(ctx context.Context) error {
pingOnce.Do(func() {
pingCtxCancelled = ctx.Err() != nil
close(pingCalled)
})
return nil
}
notifier.listener = listenerMock

start(t, notifier)

notifier.testSignals.ListeningBegin.WaitOrTimeout()

select {
case <-pingCalled:
case <-time.After(5 * time.Second):
require.FailNow(t, "Timed out waiting for Ping to be called")
}

require.False(t, pingCtxCancelled,
"Ping should receive a non-cancelled context; the inner context is "+
"cancelled to interrupt WaitForNotification, but Ping needs a live context")
})

t.Run("StillFunctionalAfterMainLoopFailure", func(t *testing.T) {
t.Parallel()

Expand Down Expand Up @@ -584,6 +629,7 @@ type ListenerMock struct {

connectFunc func(ctx context.Context) error
listenFunc func(ctx context.Context, topic string) error
pingFunc func(ctx context.Context) error
waitForNotificationFunc func(ctx context.Context) (*riverdriver.Notification, error)
}

Expand All @@ -593,6 +639,7 @@ func NewListenerMock(listener riverdriver.Listener) *ListenerMock {

connectFunc: listener.Connect,
listenFunc: listener.Listen,
pingFunc: listener.Ping,
waitForNotificationFunc: listener.WaitForNotification,
}
}
Expand All @@ -605,6 +652,10 @@ func (l *ListenerMock) Listen(ctx context.Context, topic string) error {
return l.listenFunc(ctx, topic)
}

func (l *ListenerMock) Ping(ctx context.Context) error {
return l.pingFunc(ctx)
}

func (l *ListenerMock) WaitForNotification(ctx context.Context) (*riverdriver.Notification, error) {
return l.waitForNotificationFunc(ctx)
}
Expand Down
Loading