8000 jwks tweaks, vendor deps by FZambia · Pull Request #415 · centrifugal/centrifugo · GitHub
[go: up one dir, main page]
More Web Proxy on the site http://driver.im/
Skip to content

jwks tweaks, vendor deps #415

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 1 commit into from
Jan 19, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 4 additions & 9 deletions internal/jwks/cache.go
Original file line number Diff line number Diff line change
@@ -1,22 +1,17 @@
package jwks

import (
"context"
"errors"
)

var (
// ErrEmptyKeyID raises when input kid is empty.
ErrEmptyKeyID = errors.New("cache: empty kid")
// ErrCacheNotFound raises when cache value not found.
// ErrCacheNotFound returned when cache value not found.
ErrCacheNotFound = errors.New("cache: value not found")
// ErrInvalidValue raises when type conversion to JWK has been failed.
ErrInvalidValue = errors.New("cache: invalid value")
)

// Cache works with cache layer.
type Cache interface {
Add(ctx context.Context, key *JWK) error
Get(ctx context.Context, kid string) (*JWK, error)
Len(ctx context.Context) (int, error)
Add(key *JWK) error
Get(kid string) (*JWK, error)
Len() (int, error)
}
49 changes: 17 additions & 32 deletions internal/jwks/cache_ttl.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package jwks

import (
"context"
"sync"
"time"
)
Expand Down Expand Up @@ -31,10 +30,11 @@ func (i *item) expired() bool {

// TTLCache is a TTL bases in-memory cache.
type TTLCache struct {
mu sync.RWMutex
ttl time.Duration
stop chan struct{}
items map[string]*item
mu sync.RWMutex
ttl time.Duration
stop chan struct{}
stopOnce sync.Once
items map[string]*item
}

// NewTTLCache returns a new instance of ttl cache.
Expand Down Expand Up @@ -78,7 +78,7 @@ func (tc *TTLCache) run() {
}

// Add item into cache.
func (tc *TTLCache) Add(_ context.Context, key *JWK) error {
func (tc *TTLCache) Add(key *JWK) error {
tc.mu.Lock()
item := &item{data: key}
item.touch(tc.ttl)
Expand All @@ -88,7 +88,7 @@ func (tc *TTLCache) Add(_ context.Context, key *JWK) error {
}

// Get item by key.
func (tc *TTLCache) Get(_ context.Context, kid string) (*JWK, error) {
func (tc *TTLCache) Get(kid string) (*JWK, error) {
tc.mu.RLock()
item, ok := tc.items[kid]
if !ok || item.expired() {
Expand All @@ -100,40 +100,25 @@ func (tc *TTLCache) Get(_ context.Context, kid string) (*JWK, error) {
return item.data, nil
}

// Remove item by key.
func (tc *TTLCache) Remove(_ context.Context, kid string) error {
// Stop stops TTL cache.
func (tc *TTLCache) Stop() error {
tc.stopOnce.Do(func() {
close(tc.stop)
})
return nil
}

func (tc *TTLCache) remove(kid string) error {
tc.mu.Lock()
delete(tc.items, kid)
tc.mu.Unlock()
return nil
}

// Contains checks item on existence.
func (tc *TTLCache) Contains(_ context.Context, kid string) (bool, error) {
tc.mu.RLock()
_, ok := tc.items[kid]
tc.mu.RUnlock()
return ok, nil
}

// Len returns current size of cache.
func (tc *TTLCache) Len(_ context.Context) (int, error) {
func (tc *TTLCache) Len() (int, error) {
tc.mu.RLock()
n := len(tc.items)
tc.mu.RUnlock()
return n, nil
}

// Purge deletes all items.
func (tc *TTLCache) Purge(_ context.Context) error {
tc.mu.Lock()
tc.items = map[string]*item{}
tc.mu.Unlock()
return nil
}

// Stop cleanup process.
func (tc *TTLCache) Stop(_ context.Context) error {
tc.stop <- struct{}{}
return nil
}
171 changes: 22 additions & 149 deletions internal/jwks/cache_ttl_test.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package jwks

import (
"context"
"fmt"
"testing"
"time"
Expand Down Expand Up @@ -29,13 +28,11 @@ func TestTTLCacheAdd(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
ctx := context.Background()

cache := NewTTLCache(tc.TTL)
require.NotNil(t, cache)

for i := 0; i < tc.Ops; i++ {
require.NoError(t, cache.Add(ctx, &JWK{
require.NoError(t, cache.Add(&JWK{
Kid: fmt.Sprintf("key-%d", i+1),
Kty: "RSA",
Alg: "RS256",
Expand Down Expand Up @@ -78,13 +75,11 @@ func TestTTLCacheGet(t *testing.T) {

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
ctx := context.Background()

cache := NewTTLCache(5 * time.Minute)
require.NotNil(t, cache)
require.NoError(t, cache.Add(ctx, tc.Key))
require.NoError(t, cache.Add(tc.Key))

key, err := cache.Get(ctx, tc.Kid)
key, err := cache.Get(tc.Kid)
if tc.Error != nil {
require.Error(t, err)
require.ErrorIs(t, err, tc.Error)
Expand All @@ -98,178 +93,56 @@ func TestTTLCacheGet(t *testing.T) {

func TestTTLCacheRemove(t *testing.T) {
testCases := []struct {
Name string
Adds int
Dels int
Len int
Name string
NumAdd int
NumDelete int
Len int
}{
{
Name: "OK",
Adds: 75,
Dels: 50,
Len: 25,
Name: "OK",
NumAdd: 75,
NumDelete: 50,
Len: 25,
},
{
Name: "RemoveUntilEmpty",
Adds: 75,
Dels: 100,
Len: 0,
Name: "RemoveUntilEmpty",
NumAdd: 75,
NumDelete: 100,
Len: 0,
},
}

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
ctx := context.Background()

cache := NewTTLCache(5 * time.Minute)
require.NotNil(t, cache)

for i := 0; i < tc.Adds; i++ {
require.NoError(t, cache.Add(ctx, &JWK{
for i := 0; i < tc.NumAdd; i++ {
require.NoError(t, cache.Add(&JWK{
Kid: fmt.Sprintf("key-%d", i+1),
Kty: "RSA",
Alg: "RS256",
Use: "sig",
}))
}

for i := 0; i < tc.Dels; i++ {
for i := 0; i < tc.NumDelete; i++ {
kid := fmt.Sprintf("key-%d", i+1)
require.NoError(t, cache.Remove(ctx, kid))
require.NoError(t, cache.remove(kid))
}

n, err := cache.Len(ctx)
n, err := cache.Len()
require.NoError(t, err)
require.Equal(t, tc.Len, n)
})
}
}

func TestTTLCacheContains(t *testing.T) {
testCases := []struct {
Name string
Key *JWK
Kid string
Found bool
}{
{
Name: "OK",
Key: &JWK{
Kid: "202101",
Kty: "RSA",
Alg: "RS256",
Use: "sig",
},
Kid: "202101",
Found: true,
},
{
Name: "NotFound",
Key: &JWK{
Kid: "202101",
Kty: "RSA",
Alg: "RS256",
Use: "sig",
},
Kid: "202102",
Found: false,
},
}

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
ctx := context.Background()

cache := NewTTLCache(5 * time.Minute)
require.NotNil(t, cache)
require.NoError(t, cache.Add(ctx, tc.Key))

found, err := cache.Contains(ctx, tc.Kid)
require.NoError(t, err)

require.Equal(t, tc.Found, found)
})
}
}

func TestTTLCacheLen(t *testing.T) {
testCases := []struct {
Name string
Ops int
Len int
}{
{
Name: "OK",
Ops: 50,
Len: 50,
},
}

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
ctx := context.Background()

cache := NewTTLCache(5 * time.Second)
require.NotNil(t, cache)

for i := 0; i < tc.Ops; i++ {
require.NoError(t, cache.Add(ctx, &JWK{
Kid: fmt.Sprintf("key-%d", i+1),
Kty: "RSA",
Alg: "RS256",
Use: "sig",
}))
}

n, err := cache.Len(ctx)
require.NoError(t, err)
require.Equal(t, tc.Len, n)
})
}
}

func TestTTLCachePurge(t *testing.T) {
testCases := []struct {
Name string
Ops int
}{
{
Name: "OK",
Ops: 50,
},
}

for _, tc := range testCases {
t.Run(tc.Name, func(t *testing.T) {
ctx := context.Background()

cache := NewTTLCache(5 * time.Second)
require.NotNil(t, cache)

for i := 0; i < tc.Ops; i++ {
require.NoError(t, cache.Add(ctx, &JWK{
Kid: fmt.Sprintf("key-%d", i+1),
Kty: "RSA",
Alg: "RS256",
Use: "sig",
}))
}

require.NoError(t, cache.Purge(ctx))

n, err := cache.Len(ctx)
require.NoError(t, err)
require.Equal(t, 0, n)
})
}
}

func TestTTLCacheCleanup(t *testing.T) {
ctx := context.Background()
cache := NewTTLCache(1 * time.Millisecond)

for i := 0; i < 10; i++ {
require.NoError(t, cache.Add(ctx, &JWK{
require.NoError(t, cache.Add(&JWK{
Kid: fmt.Sprintf("key-%d", i+1),
Kty: "RSA",
Alg: "RS256",
Expand All @@ -279,7 +152,7 @@ func TestTTLCacheCleanup(t *testing.T) {

time.Sleep(2 * time.Second)

n, err := cache.Len(ctx)
n, err := cache.Len()
require.NoError(t, err)
require.Equal(t, 0, n)
}
Loading
0