diff --git a/.changelog/unreleased/bug-fixes/2010-p2p-pex-shutdown.md b/.changelog/unreleased/bug-fixes/2010-p2p-pex-shutdown.md new file mode 100644 index 00000000000..e913d7b2371 --- /dev/null +++ b/.changelog/unreleased/bug-fixes/2010-p2p-pex-shutdown.md @@ -0,0 +1 @@ +- `[p2p/pex]` gracefully shutdown Reactor ([\#2010](https://github.com/cometbft/cometbft/pull/2010)) diff --git a/p2p/pex/addrbook.go b/p2p/pex/addrbook.go index c48d6cc4a91..6ece794f3b5 100644 --- a/p2p/pex/addrbook.go +++ b/p2p/pex/addrbook.go @@ -154,26 +154,23 @@ func (a *addrBook) init() { // OnStart implements Service. func (a *addrBook) OnStart() error { - if err := a.BaseService.OnStart(); err != nil { - return err - } a.loadFromFile(a.filePath) - // wg.Add to ensure that any invocation of .Wait() - // later on will wait for saveRoutine to terminate. a.wg.Add(1) go a.saveRoutine() return nil } -// OnStop implements Service. -func (a *addrBook) OnStop() { - a.BaseService.OnStop() -} - -func (a *addrBook) Wait() { +// Stop overrides Service.Stop(). +func (a *addrBook) Stop() error { + // Closes the Service.Quit() channel. + // This enables a.saveRoutine() to quit. + if err := a.BaseService.Stop(); err != nil { + return err + } a.wg.Wait() + return nil } func (a *addrBook) FilePath() string { @@ -491,17 +488,16 @@ func (a *addrBook) saveRoutine() { defer a.wg.Done() saveFileTicker := time.NewTicker(dumpAddressInterval) -out: + defer saveFileTicker.Stop() for { select { case <-saveFileTicker.C: - a.saveToFile(a.filePath) + a.Save() case <-a.Quit(): - break out + a.Save() + return } } - saveFileTicker.Stop() - a.saveToFile(a.filePath) } //---------------------------------------------------------- diff --git a/p2p/pex/pex_reactor.go b/p2p/pex/pex_reactor.go index aa8e7cdba5b..57eb02827f2 100644 --- a/p2p/pex/pex_reactor.go +++ b/p2p/pex/pex_reactor.go @@ -84,6 +84,7 @@ type Reactor struct { book AddrBook config *ReactorConfig ensurePeersPeriod time.Duration // TODO: should go in the config + peersRoutineWg sync.WaitGroup // maps to prevent abuse requestsSent *cmap.CMap // ID->struct{}: unanswered send requests @@ -156,6 +157,7 @@ func (r *Reactor) OnStart() error { r.seedAddrs = seedAddrs + r.peersRoutineWg.Add(1) // Check if this node should run // in seed/crawler mode if r.config.SeedMode { @@ -166,11 +168,16 @@ func (r *Reactor) OnStart() error { return nil } -// OnStop implements BaseService. -func (r *Reactor) OnStop() { +// Stop overrides `Service.Stop()`. +func (r *Reactor) Stop() error { + if err := r.BaseReactor.Stop(); err != nil { + return err + } if err := r.book.Stop(); err != nil { - r.Logger.Error("Error stopping address book", "err", err) + return fmt.Errorf("can't stop address book: %w", err) } + r.peersRoutineWg.Wait() + return nil } // GetChannels implements Reactor. @@ -414,6 +421,8 @@ func (r *Reactor) SetEnsurePeersPeriod(d time.Duration) { // Ensures that sufficient peers are connected. (continuous). func (r *Reactor) ensurePeersRoutine() { + defer r.peersRoutineWg.Done() + var ( seed = cmtrand.NewRand() jitter = seed.Int63n(r.ensurePeersPeriod.Nanoseconds()) @@ -432,12 +441,14 @@ func (r *Reactor) ensurePeersRoutine() { // fire periodically ticker := time.NewTicker(r.ensurePeersPeriod) + defer ticker.Stop() for { select { case <-ticker.C: r.ensurePeers() + case <-r.book.Quit(): + return case <-r.Quit(): - ticker.Stop() return } } @@ -475,6 +486,10 @@ func (r *Reactor) ensurePeers() { maxAttempts := numToDial * 3 for i := 0; i < maxAttempts && len(toDial) < numToDial; i++ { + if !r.IsRunning() || !r.book.IsRunning() { + return + } + try := r.book.PickAddress(newBias) if try == nil { continue @@ -650,6 +665,8 @@ func (r *Reactor) AttemptsToDial(addr *p2p.NetAddress) int { // Seed/Crawler Mode causes this node to quickly disconnect // from peers, except other seed nodes. func (r *Reactor) crawlPeersRoutine() { + defer r.peersRoutineWg.Done() + // If we have any seed nodes, consult them first if len(r.seedAddrs) > 0 { r.dialSeeds() @@ -660,13 +677,15 @@ func (r *Reactor) crawlPeersRoutine() { // Fire periodically ticker := time.NewTicker(crawlPeerPeriod) - + defer ticker.Stop() for { select { case <-ticker.C: r.attemptDisconnects() r.crawlPeers(r.book.GetSelection()) r.cleanupCrawlPeerInfos() + case <-r.book.Quit(): + return case <-r.Quit(): return }