8000 Add RateLimit middleware using TokenBucket algorithm by LaPetiteSouris · Pull Request #557 · flyteorg/flyteadmin · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content
This repository was archived by the owner on Oct 9, 2023. It is now read-only.

Add RateLimit middleware using TokenBucket algorithm #557

Open
wants to merge 8 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 10 additions & 1 deletion pkg/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,14 +80,23 @@ type ServerSecurityOptions struct {
// These are the Access-Control-Request-Headers that the server will respond to.
// By default, the server will allow Accept, Accept-Language, Content-Language, and Content-Type.
// DeprecatedUser this setting to add any additional headers which are needed
AllowedHeaders []string `json:"allowedHeaders"`
AllowedHeaders []string `json:"allowedHeaders"`
RateLimit RateLimitOptions `json:"rateLimit"`
}

type SslOptions struct {
CertificateFile string `json:"certificateFile"`
KeyFile string `json:"keyFile"`
}

// RateLimitOptions is a type to hold rate limit configuration options.
type RateLimitOptions struct {
Enabled bool `json:"enabled" pflag:",Controls whether rate limiting is enabled. If enabled, the rate limit is applied to all requests using the TokenBucket algorithm."`
RequestsPerSecond int `json:"requestsPerSecond" pflag:",The number of requests allowed per second."`
BurstSize int `json:"burstSize" pflag:",The number of requests allowed to burst. 0 implies the TokenBucket algorithm cannot hold any tokens."`
CleanupInterval config.Duration `json:"cleanupInterval" pflag:",The interval at which the rate limiter cleans up entries that have not been used for a certain period of time."`
}

var defaultServerConfig = &ServerConfig{
HTTPPort: 8088,
Security: ServerSecurityOptions{
Expand Down
4 changes: 4 additions & 0 deletions pkg/config/serverconfig_flags.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

56 changes: 56 additions & 0 deletions pkg/config/serverconfig_flags_test.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

9 changes: 9 additions & 0 deletions pkg/server/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,17 @@ func newGRPCServer(ctx context.Context, pluginRegistry *plugins.Registry, cfg *c
auth.AuthenticationLoggingInterceptor,
middlewareInterceptors,
)
if cfg.Security.RateLimit.Enabled {
rateLimiter := plugins.NewRateLimiter(cfg.Security.RateLimit.RequestsPerSecond, cfg.Security.RateLimit.BurstSize, cfg.Security.RateLimit.CleanupInterval.Duration)
rateLimitInterceptors := plugins.RateLimiteInterceptor(*rateLimiter)
chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(chainedUnaryInterceptors, rateLimitInterceptors)
}
} else {
logger.Infof(ctx, "Creating gRPC server without authentication")
chainedUnaryInterceptors = grpcmiddleware.ChainUnaryServer(grpcprometheus.UnaryServerInterceptor)
if cfg.Security.RateLimit.Enabled {
logger.Warningf(ctx, "Rate limit is enabled but auth is not")
}
}

serverOpts := []grpc.ServerOption{
Expand Down Expand Up @@ -257,6 +265,7 @@ func serveGatewayInsecure(ctx context.Context, pluginRegistry *plugins.Registry,
}

oauth2ResourceServer = oauth2Provider

} else {
oauth2ResourceServer, err = authzserver.NewOAuth2ResourceServer(ctx, authCfg.AppAuth.ExternalAuthServer, authCfg.UserAuth.OpenID.BaseURL)
if err != nil {
Expand Down
116 changes: 116 additions & 0 deletions plugins/rate_limit.go
8000
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
package plugins

import (
"context"
"errors"
"fmt"
"sync"
"time"

auth "github.com/flyteorg/flyteadmin/auth"
"golang.org/x/time/rate"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
)

type RateLimitExceeded error

// accessRecords stores the rate limiter and the last access time
type accessRecords struct {
limiter *rate.Limiter
lastAccess time.Time
mutex *sync.Mutex
}

// LimiterStore stores the access records for each user
type LimiterStore struct {
// accessPerUser is a synchronized map of userID to accessRecords
accessPerUser *sync.Map
requestPerSec int
burstSize int
cleanupInterval time.Duration
}

// Allow takes a userID and returns an error if the user has exceeded the rate limit
func (l *LimiterStore) Allow(userID string) error {
accessRecord, _ := l.accessPerUser.LoadOrStore(userID, &accessRecords{
limiter: rate.NewLimiter(rate.Limit(l.requestPerSec), l.burstSize),
lastAccess: time.Now(),
mutex: &sync.Mutex{},
})
accessRecord.(*accessRecords).mutex.Lock()
defer accessRecord.(*accessRecords).mutex.Unlock()

accessRecord.(*accessRecords).lastAccess = time.Now()
l.accessPerUser.Store(userID, accessRecord)

if !accessRecord.(*accessRecords).limiter.Allow() {
return RateLimitExceeded(fmt.Errorf("rate limit exceeded"))
}
Comment on lines +48 to +50
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this happen before we modify the map?


return nil
}

// clean removes the access records for users who have not accessed the system for a while
func (l *LimiterStore) clean() {
l.accessPerUser.Range(func(key, value interface{}) bool {
value.(*accessRecords).mutex.Lock()
defer value.(*accessRecords).mutex.Unlock()
if time.Since(value.(*accessRecords).lastAccess) > l.cleanupInterval {
l.accessPerUser.Delete(key)
}
return true
})
}

func newRateLimitStore(requestPerSec int, burstSize int, cleanupInterval time.Duration) *LimiterStore {
l := &LimiterStore{
accessPerUser: &sync.Map{},
requestPerSec: requestPerSec,
burstSize: burstSize,
cleanupInterval: cleanupInterval,
}

go func() {
for {
time.Sleep(l.cleanupInterval)
l.clean()
}
}()

return l
}

// RateLimiter is a struct that implements the RateLimiter interface from grpc middleware
type RateLimiter struct {
limiter *LimiterStore
}

func (r *RateLimiter) Limit(ctx context.Context) error {
IdenCtx := auth.IdentityContextFromContext(ctx)
if IdenCtx.IsEmpty() {
return errors.New("no identity context found")
}
userID := IdenCtx.UserID()
if err := r.limiter.Allow(userID); err != nil {
return err
}
return nil
}

func NewRateLimiter(requestPerSec int, burstSize int, cleanupInterval time.Duration) *RateLimiter {
limiter := newRateLimitStore(requestPerSec, burstSize, cleanupInterval)
return &RateLimiter{limiter: limiter}
}

func RateLimiteInterceptor(limiter RateLimiter) grpc.UnaryServerInterceptor {
return func(ctx context.Context, req interface{}, info *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (
resp interface{}, err error) {
if err := limiter.Limit(ctx); err != nil {
return nil, status.Errorf(codes.ResourceExhausted, "rate limit exceeded")
}

return handler(ctx, req)
}
}
126 changes: 126 additions & 0 deletions plugins/rate_limit_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
package plugins

import (
"context"
"testing"
"time"

auth "github.com/flyteorg/flyteadmin/auth"
"github.com/stretchr/testify/assert"
)

func TestNewRateLimiter(t *testing.T) {
rlStore := newRateLimitStore(1, 1, time.Second)
assert.NotNil(t, rlStore)
}

func TestLimiterAllow(t *testing.T) {
rlStore := newRateLimitStore(1, 1, 10*time.Second)
assert.NoError(t, rlStore.Allow("hello"))
assert.Error(t, rlStore.Allow("hello"))
time.Sleep(time.Second)
assert.NoError(t, rlStore.Allow("hello"))
}

func TestLimiterAllowBurst(t *testing.T) {
rlStore := newRateLimitStore(1, 2, time.Second)
assert.NoError(t, rlStore.Allow("hello"))
assert.NoError(t, rlStore.Allow("hello"))
assert.Error(t, rlStore.Allow("hello"))
assert.NoError(t, rlStore.Allow("world"))
}

func TestLimiterClean(t *testing.T) {
rlStore := newRateLimitStore(1, 1, time.Second)
assert.NoError(t, rlStore.Allow("hello"))
assert.Error(t, rlStore.Allow("hello"))
time.Sleep(time.Second)
rlStore.clean()
assert.NoError(t, rlStore.Allow("hello"))
}

func TestLimiterAllowOnMultipleRequests(t *testing.T) {
rlStore := newRateLimitStore(1, 1, time.Second)
assert.NoError(t, rlStore.Allow("a"))
assert.NoError(t, rlStore.Allow("b"))
assert.NoError(t, rlStore.Allow("c"))
assert.Error(t, rlStore.Allow("a"))
assert.Error(t, rlStore.Allow("b"))

time.Sleep(time.Second)

assert.NoError(t, rlStore.Allow("a"))
assert.Error(t, rlStore.Allow("a"))
assert.NoError(t, rlStore.Allow("b"))
assert.Error(t, rlStore.Allow("b"))
assert.NoError(t, rlStore.Allow("c"))
}

func TestRateLimiterLimitPass(t *testing.T) {
rateLimit := NewRateLimiter(1, 1, time.Second)
assert.NotNil(t, rateLimit)

identityCtx, err := auth.NewIdentityContext("audience", "user1", "app1", time.Now(), nil, nil, nil)
assert.NoError(t, err)

ctx := context.WithValue(context.TODO(), auth.ContextKeyIdentityContext, identityCtx)
err = rateLimit.Limit(ctx)
assert.NoError(t, err)

}

func TestRateLimiterLimitStop(t *testing.T) {
rateLimit := NewRateLimiter(1, 1, time.Second)
assert.NotNil(t, rateLimit)

identityCtx, err := auth.NewIdentityContext("audience", "user1", "app1", time.Now(), nil, nil, nil)
assert.NoError(t, err)
ctx := context.WithValue(context.TODO(), auth.ContextKeyIdentityContext, identityCtx)
err = rateLimit.Limit(ctx)
assert.NoError(t, err)

err = rateLimit.Limit(ctx)
assert.Error(t, err)

}

func TestRateLimiterLimitWithoutUserIdentity(t *testing.T) {
rateLimit := NewRateLimiter(1, 1, time.Second)
assert.NotNil(t, rateLimit)

ctx := context.TODO()

err := rateLimit.Limit(ctx)
assert.Error(t, err)
}

func TestRateLimiterUpdateLastAccessTime(t *testing.T) {
rlStore := newRateLimitStore(2, 2, time.Second)
assert.NoError(t, rlStore.Allow("hello"))
// get last access time

accessRecord, _ := rlStore.accessPerUser.Load("hello")
accessRecord.(*accessRecords).mutex.Lock()
firstAccessTime := accessRecord.(*accessRecords).lastAccess
accessRecord.(*accessRecords).mutex.Unlock()

assert.NoError(t, rlStore.Allow("hello"))

accessRecord, _ = rlStore.accessPerUser.Load("hello")
accessRecord.(*accessRecords).mutex.Lock()
secondAccessTime := accessRecord.(*accessRecords).lastAccess
accessRecord.(*accessRecords).mutex.Unlock()

assert.True(t, secondAccessTime.After(firstAccessTime))

// Verify that the last access time is updated even when user is rate limited
assert.Error(t, rlStore.Allow("hello"))

accessRecord, _ = rlStore.accessPerUser.Load("hello")
accessRecord.(*accessRecords).mutex.Lock()
thirdAccessTime := accessRecord.(*accessRecords).lastAccess
accessRecord.(*accessRecords).mutex.Unlock()

assert.True(t, thirdAccessTime.After(secondAccessTime))

}
0