8000 feat: Config updater for the base provider by davidterpay · Pull Request #117 · skip-mev/connect · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
This repository was archived by the owner on Mar 24, 2025. It is now read-only.

feat: Config updater for the base provider #117

Merged
merged 7 commits into from
Feb 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 0 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -409,8 +409,6 @@ github.com/cosmos/cosmos-db v1.0.0 h1:EVcQZ+qYag7W6uorBKFPvX6gRjw6Uq2hIh4hCWjuQ0
github.com/cosmos/cosmos-db v1.0.0/go.mod h1:iBvi1TtqaedwLdcrZVYRSSCb6eSy61NLj4UNmdIgs0U=
github.com/cosmos/cosmos-proto v1.0.0-beta.3 h1:VitvZ1lPORTVxkmF2fAp3IiA61xVwArQYKXTdEcpW6o=
github.com/cosmos/cosmos-proto v1.0.0-beta.3/go.mod h1:t8IASdLaAq+bbHbjq4p960BvcTqtwuAxid3b/2rOD6I=
github.com/cosmos/cosmos-sdk v0.50.3 h1:zP0AXm54ws2t2qVWvcQhEYVafhOAREU2QL0gnbwjvXw=
github.com/cosmos/cosmos-sdk v0.50.3/go.mod h1:tlrkY1sntOt1q0OX/rqF0zRJtmXNoffAS6VFTcky+w8=
github.com/cosmos/cosmos-sdk v0.50.4-0.20240125183858-0abf94a334e3 h1:C5F34X1mAFNgybALt589GOi8OfySRA6zu6dfxvhyg2I=
github.com/cosmos/cosmos-sdk v0.50.4-0.20240125183858-0abf94a334e3/go.mod h1:0D9mrUy1eAUMQuvYzf2xvhEPk2ta9w7XH1zcYvyFiuM=
github.com/cosmos/go-bip39 v1.0.0 h1:pcomnQdrdH22njcAatO0yWojsUnCO3y2tNoV1cb6hHY=
Expand Down
3 changes: 3 additions & 0 deletions oracle/oracle_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,6 +127,7 @@ func (s *OracleTestSuite) TestStopWithContextCancel() {
provider2 := testutils.CreateWebSocketProviderWithGetResponses[oracletypes.CurrencyPair, *big.Int](
s.T(),
time.Second,
s.currencyPairs,
providerCfg2,
s.logger,
nil,
Expand Down Expand Up @@ -226,6 +227,7 @@ func (s *OracleTestSuite) TestStopWithContextDeadline() {
provider2 := testutils.CreateWebSocketProvide 8000 rWithGetResponses[oracletypes.CurrencyPair, *big.Int](
s.T(),
time.Second,
s.currencyPairs,
providerCfg2,
s.logger,
nil,
Expand Down Expand Up @@ -313,6 +315,7 @@ func (s *OracleTestSuite) TestStop() {
provider2 := testutils.CreateWebSocketProviderWithGetResponses[oracletypes.CurrencyPair, *big.Int](
s.T(),
time.Second,
s.currencyPairs,
providerCfg2,
s.logger,
nil,
Expand Down
1 change: 1 addition & 0 deletions oracle/providers_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -108,6 +108,7 @@ func (s *OracleTestSuite) TestProviders() {
provider2 := testutils.CreateWebSocketProviderWithGetResponses[oracletypes.CurrencyPair, *big.Int](
s.T(),
time.Second*2,
s.currencyPairs,
providerCfg2,
s.logger,
responses2,
Expand Down
30 changes: 30 additions & 0 deletions providers/base/config.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,30 @@
package base

import (
"context"

"go.uber.org/zap"
)

// listenOnConfigUpdater listens for updates from the config updater and updates the
// provider's internal configurations. This will trigger the provider to restart
// and is blocking until the context is cancelled.
func (p *Provider[K, V]) listenOnConfigUpdater(ctx context.Context) {
if p.updater == nil {
return
}

for {
select {
case <-ctx.Done():
p.logger.Info("stopping config client listener")
return
case ids := <-p.updater.GetIDs():
p.logger.Debug("received new ids", zap.Any("ids", ids))
p.SetIDs(ids)

// Signal the provider to restart.
p.restartCh <- struct{}{}
}
}
}
118 changes: 118 additions & 0 deletions providers/base/config_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
package base_test

import (
"context"
"math/big"
"testing"
"time"

"github.com/skip-mev/slinky/providers/base"
"github.com/skip-mev/slinky/providers/base/testutils"
oracletypes "github.com/skip-mev/slinky/x/oracle/types"
"github.com/stretchr/testify/require"
)

var (
btcusd = oracletypes.NewCurrencyPair("BITCOIN", "USD")
ethusd = oracletypes.NewCurrencyPair("ETHEREUM", "USD")
solusd = oracletypes.NewCurrencyPair("SOLANA", "USD")
)

func TestConfigUpdater(t *testing.T) {
t.Run("restart on IDs update with an API provider", func(t *testing.T) {
pairs := []oracletypes.CurrencyPair{btcusd}
updater := base.NewConfigUpdater[oracletypes.CurrencyPair]()
apiHandler := testutils.CreateAPIQueryHandlerWithGetResponses[oracletypes.CurrencyPair, *big.Int](
t,
logger,
nil,
)

provider, err := base.NewProvider[oracletypes.CurrencyPair, *big.Int](
base.WithName[oracletypes.CurrencyPair, *big.Int](apiCfg.Name),
base.WithAPIQueryHandler[oracletypes.CurrencyPair, *big.Int](apiHandler),
base.WithAPIConfig[oracletypes.CurrencyPair, *big.Int](apiCfg),
base.WithLogger[oracletypes.CurrencyPair, *big.Int](logger),
base.WithIDs[oracletypes.CurrencyPair, *big.Int](pairs),
base.WithConfigUpdater[oracletypes.CurrencyPair, *big.Int](updater),
)
require.NoError(t, err)

// Start the provider and run it for a few seconds.
ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second)
defer cancel()

errCh := make(chan error)
go func() {
errCh <- provider.Start(ctx)
}()

// The initial IDs should be the same as the provider's IDs.
ids := provider.GetIDs()
require.Equal(t, pairs, ids)

// Wait for a few seconds and update the IDs.
time.Sleep(2 * time.Second)
updated := []oracletypes.CurrencyPair{ethusd, solusd, btcusd}
updater.UpdateIDs(updated)

// Wait for the provider to restart.
time.Sleep(2 * time.Second)

// The IDs should be updated.
ids = provider.GetIDs()
require.Equal(t, updated, ids)

// Check that the provider exited without error.
require.Equal(t, context.DeadlineExceeded, <-errCh)
})

t.Run("restart on IDs update with a websocket provider", func(t *testing.T) {
pairs := []oracletypes.CurrencyPair{btcusd}
updater := base.NewConfigUpdater[oracletypes.CurrencyPair]()
wsHandler := testutils.CreateWebSocketQueryHandlerWithGetResponses[oracletypes.CurrencyPair, *big.Int](
t,
time.Second,
logger,
nil,
)

provider, err := base.NewProvider[oracletypes.CurrencyPair, *big.Int](
base.WithName[oracletypes.CurrencyPair, *big.Int](wsCfg.Name),
base.WithWebSocketQueryHandler[oracletypes.CurrencyPair, *big.Int](wsHandler),
base.WithWebSocketConfig[oracletypes.CurrencyPair, *big.Int](wsCfg),
base.WithLogger[oracletypes.CurrencyPair, *big.Int](logger),
base.WithIDs[oracletypes.CurrencyPair, *big.Int](pairs),
base.WithConfigUpdater[oracletypes.CurrencyPair, *big.Int](updater),
)
require.NoError(t, err)

// Start the provider and run it for a few seconds.
ctx, cancel := context.WithTimeout(context.Background(), 8*time.Second)
defer cancel()

errCh := make(chan error)
go func() {
errCh <- provider.Start(ctx)
}()

// The initial IDs should be the same as the provider's IDs.
ids := provider.GetIDs()
require.Equal(t, pairs, ids)

// Wait for a few seconds and update the IDs.
time.Sleep(2 * time.Second)
updated := []oracletypes.CurrencyPair{ethusd, solusd, btcusd}
updater.UpdateIDs(updated)

// Wait for the provider to restart.
time.Sleep(2 * time.Second)

// The IDs should be updated.
ids = provider.GetIDs()
require.Equal(t, updated, ids)

// Check that the provider exited without error.
require.Equal(t, context.DeadlineExceeded, <-errCh)
})
}
29 changes: 16 additions & 13 deletions providers/base/fetch.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,14 @@ import (
// fetch is the main blocker for the provider. It is responsible for fetching data from the
// data provider and updating the data.
func (p *Provider[K, V]) fetch(ctx context.Context) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

// responseCh is used to receive the response(s) from the query handler.
var responseCh chan providertypes.GetResponse[K, V]
switch {
case p.api != nil:
// The buffer size is set to the minimum of the number of IDs and the max number of queries.
// This is to ensure that the response channel does not block the query handler and that the
// query handler does not exceed the rate limit parameters of the provider.
responseCh = make(chan providertypes.GetResponse[K, V], math.Min(len(p.ids), p.apiCfg.MaxQueries))
responseCh = make(chan providertypes.GetResponse[K, V], math.Min(len(p.GetIDs()), p.apiCfg.MaxQueries))
case p.ws != nil:
// Otherwise, the buffer size is set to the max buffer size configured for the websocket.
responseCh = make(chan providertypes.GetResponse[K, V], p.wsCfg.MaxBufferSize)
Expand Down Expand Up @@ -68,7 +65,6 @@ func (p *Provider[K, V]) startAPI(ctx context.Context, responseCh chan<- provide
case <-ticker.C:
p.logger.Debug(
"attempting to fetch new data",
zap.Int("num_ids", len(p.ids)),
zap.Int("buffer_size", len(responseCh)),
)

Expand All @@ -80,7 +76,8 @@ func (p *Provider[K, V]) startAPI(ctx context.Context, responseCh chan<- provide
// attemptAPIDataUpdate tries to update data by fetching and parsing API data.
// It logs any errors encountered during the process.
func (p *Provider[K, V]) attemptAPIDataUpdate(ctx context.Context, responseCh chan<- providertypes.GetResponse[K, V]) {
if len(p.ids) == 0 {
ids := p.GetIDs()
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unclear what p.GetIDs() is?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

IDs are effectively currency pairs, the base provider just calls them IDs to have some generality.

if len(ids) == 0 {
p.logger.Debug("no ids to fetch")
return
}
Expand All @@ -96,8 +93,8 @@ func (p *Provider[K, V]) attemptAPIDataUpdate(ctx context.Context, responseCh ch
}()

// Start the query handler. The handler must respect the context timeout.
p.logger.Debug("starting query handler")
p.api.Query(ctx, p.ids, responseCh)
p.logger.Debug("starting query handler", zap.Int("num_ids", len(ids)))
p.api.Query(ctx, ids, responseCh)
}()
}

Expand All @@ -112,12 +109,18 @@ func (p *Provider[K, V]) startMultiplexWebsocket(ctx context.Context, responseCh
wg = errgroup.Group{}
)

ids := p.GetIDs()
if len(ids) == 0 {
p.logger.Debug("no ids to fetch")
return nil
}

// create sub handlers
// if len(ids) == 30 and MaxSubscriptionsPerConnection == 45
// 30 / 45 = 0 -> need one sub handler
if maxSubsPerConn > 0 {
// case where we will split ID's across sub handlers
numSubHandlers := (len(p.ids) / maxSubsPerConn) + 1
numSubHandlers := (len(ids) / maxSubsPerConn) + 1
wg.SetLimit(numSubHandlers)

// split ids
Expand All @@ -126,16 +129,16 @@ func (p *Provider[K, V]) startMultiplexWebsocket(ctx context.Context, responseCh
start := i
end := maxSubsPerConn * (i + 1)
if i+1 == numSubHandlers {
subIDs = p.ids[start:]
subIDs = ids[start:]
} else {
subIDs = p.ids[start:end]
subIDs = ids[start:end]
}

subTasks = append(subTasks, subIDs)
}
} else {
// case where there is 1 sub handler
subTasks = append(subTasks, p.ids)
subTasks = append(subTasks, ids)
wg.SetLimit(1)
}

Expand All @@ -158,7 +161,7 @@ func (p *Provider[K, V]) startWebSocket(ctx context.Context, subIDs []K, respons
p.logger.Info("provider stopped via context")
return ctx.Err()
default:
p.logger.Debug("starting websocket query handler")
p.logger.Debug("starting websocket query handler", zap.Int("num_ids", len(subIDs)))
if err := p.ws.Start(ctx, subIDs, responseCh); err != nil {
p.logger.Error("websocket query handler returned error", zap.Error(err))
}
Expand Down
12 changes: 12 additions & 0 deletions providers/base/options.go
Original file line number Diff line number Diff line change
Expand Up @@ -110,3 +110,15 @@ func WithMetrics[K providertypes.ResponseKey, V providertypes.ResponseValue](met
p.metrics = metrics
}
}

// WithConfigUpdater sets the ConfigUpdater for the provider. This can be used to update the provider asynchronously. Anytime
// the config is updated, the provider will restart with the new config.
func WithConfigUpdater[K providertypes.ResponseKey, V providertypes.ResponseValue](updater ConfigUpdater[K]) ProviderOption[K, V] {
return func(p *Provider[K, V]) {
if updater == nil {
panic("cannot set nil config updater")
}

p.updater = updater
}
}
Loading
0