diff --git a/drivers/store/memory/cache.go b/drivers/store/memory/cache.go index e6e7bfc..361471f 100644 --- a/drivers/store/memory/cache.go +++ b/drivers/store/memory/cache.go @@ -147,3 +147,13 @@ func (cache *Cache) Clean() { } cache.mutex.Unlock() } + +// Reset changes the key's value and resets the expiration. +func (cache *Cache) Reset(key string, duration time.Duration) (int64, time.Time) { + cache.mutex.Lock() + delete(cache.counters, key) + cache.mutex.Unlock() + + expiration := time.Now().Add(duration).UnixNano() + return 0, time.Unix(0, expiration) +} diff --git a/drivers/store/memory/cache_test.go b/drivers/store/memory/cache_test.go index 9e7c79f..5973a80 100644 --- a/drivers/store/memory/cache_test.go +++ b/drivers/store/memory/cache_test.go @@ -94,3 +94,38 @@ func TestCacheGet(t *testing.T) { is.InEpsilon(deleted, expire.UnixNano(), epsilon) } + +func TestCacheReset(t *testing.T) { + is := require.New(t) + + key := "foobar" + cache := memory.NewCache(10 * time.Nanosecond) + duration := 50 * time.Millisecond + deleted := time.Now().Add(duration).UnixNano() + epsilon := 0.001 + + x, expire := cache.Get(key, duration) + is.Equal(int64(0), x) + is.InEpsilon(deleted, expire.UnixNano(), epsilon) + + x, expire = cache.Increment(key, 1, duration) + is.Equal(int64(1), x) + is.InEpsilon(deleted, expire.UnixNano(), epsilon) + + x, expire = cache.Increment(key, 1, duration) + is.Equal(int64(2), x) + is.InEpsilon(deleted, expire.UnixNano(), epsilon) + + x, expire = cache.Reset(key, duration) + is.Equal(int64(0), x) + is.InEpsilon(deleted, expire.UnixNano(), epsilon) + + x, expire = cache.Increment(key, 1, duration) + is.Equal(int64(1), x) + is.InEpsilon(deleted, expire.UnixNano(), epsilon) + + x, expire = cache.Increment(key, 1, duration) + is.Equal(int64(2), x) + is.InEpsilon(deleted, expire.UnixNano(), epsilon) + +} diff --git a/drivers/store/memory/store.go b/drivers/store/memory/store.go index 5b8f50b..db36ce1 100644 --- a/drivers/store/memory/store.go +++ b/drivers/store/memory/store.go @@ -54,3 +54,14 @@ func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (li lctx := common.GetContextFromState(now, rate, expiration, count) return lctx, nil } + +// Reset returns the limit for given identifier. +func (store *Store) Reset(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { + key = fmt.Sprintf("%s:%s", store.Prefix, key) + now := time.Now() + + count, expiration := store.cache.Reset(key, rate.Period) + + lctx := common.GetContextFromState(now, rate, expiration, count) + return lctx, nil +} diff --git a/drivers/store/redis/store.go b/drivers/store/redis/store.go index 7d7160f..31915e0 100644 --- a/drivers/store/redis/store.go +++ b/drivers/store/redis/store.go @@ -134,6 +134,35 @@ func (store *Store) Peek(ctx context.Context, key string, rate limiter.Rate) (li return lctx, nil } +// Reset returns the limit for given identifier which is set to zero. +func (store *Store) Reset(ctx context.Context, key string, rate limiter.Rate) (limiter.Context, error) { + key = fmt.Sprintf("%s:%s", store.Prefix, key) + now := time.Now() + + lctx := limiter.Context{} + onWatch := func(rtx *libredis.Tx) error { + + err := store.doResetValue(rtx, key) + if err != nil { + return err + } + + count := int64(0) + expiration := now.Add(rate.Period) + + lctx = common.GetContextFromState(now, rate, expiration, count) + return nil + } + + err := store.client.Watch(onWatch, key) + if err != nil { + err = errors.Wrapf(err, "limiter: cannot reset value for %s", key) + return limiter.Context{}, err + } + + return lctx, nil +} + // doPeekValue will execute peekValue with a retry mecanism (optimistic locking) until store.MaxRetry is reached. func (store *Store) doPeekValue(rtx *libredis.Tx, key string) (int64, time.Duration, error) { for i := 0; i < store.MaxRetry; i++ { @@ -251,6 +280,33 @@ func updateValue(rtx *libredis.Tx, key string, expiration time.Duration) (int64, } +// doResetValue will execute resetValue with a retry mecanism (optimistic locking) until store.MaxRetry is reached. +func (store *Store) doResetValue(rtx *libredis.Tx, key string) error { + for i := 0; i < store.MaxRetry; i++ { + err := resetValue(rtx, key) + if err == nil { + return nil + } + } + return errors.New("retry limit exceeded") +} + +// resetValue will try to reset the counter identified by given key. +func resetValue(rtx *libredis.Tx, key string) error { + deletion := rtx.Del(key) + + count, err := deletion.Result() + if err != nil { + return err + } + if count != 1 { + return errors.New("cannot delete key") + } + + return nil + +} + // ping checks if redis is alive. func (store *Store) ping() (bool, error) { cmd := store.client.Ping() diff --git a/drivers/store/tests/tests.go b/drivers/store/tests/tests.go index 4686395..a8c7c9b 100644 --- a/drivers/store/tests/tests.go +++ b/drivers/store/tests/tests.go @@ -2,7 +2,6 @@ package tests import ( "context" - "math" "sync" "testing" "time" @@ -22,39 +21,72 @@ func TestStoreSequentialAccess(t *testing.T, store limiter.Store) { Period: time.Minute, }) - for i := 1; i <= 6; i++ { + // Check counter increment. + { + for i := 1; i <= 6; i++ { - if i <= 3 { + if i <= 3 { - lctx, err := limiter.Peek(ctx, "foo") + lctx, err := limiter.Peek(ctx, "foo") + is.NoError(err) + is.NotZero(lctx) + is.Equal(int64(3-(i-1)), lctx.Remaining) + is.False(lctx.Reached) + + } + + lctx, err := limiter.Get(ctx, "foo") is.NoError(err) is.NotZero(lctx) - is.Equal(int64(3-(i-1)), lctx.Remaining) + if i <= 3 { + + is.Equal(int64(3), lctx.Limit) + is.Equal(int64(3-i), lctx.Remaining) + is.True((lctx.Reset - time.Now().Unix()) <= 60) + is.False(lctx.Reached) + + lctx, err = limiter.Peek(ctx, "foo") + is.NoError(err) + is.Equal(int64(3-i), lctx.Remaining) + is.False(lctx.Reached) + + } else { + + is.Equal(int64(3), lctx.Limit) + is.Equal(int64(0), lctx.Remaining) + is.True((lctx.Reset - time.Now().Unix()) <= 60) + is.True(lctx.Reached) + + } } + } - lctx, err := limiter.Get(ctx, "foo") + // Check counter reset. + { + lctx, err := limiter.Peek(ctx, "foo") is.NoError(err) is.NotZero(lctx) - if i <= 3 { - - is.Equal(int64(3), lctx.Limit) - is.Equal(int64(3-i), lctx.Remaining) - is.True(math.Ceil(time.Since(time.Unix(lctx.Reset, 0)).Seconds()) <= 60) + is.Equal(int64(3), lctx.Limit) + is.Equal(int64(0), lctx.Remaining) + is.True((lctx.Reset - time.Now().Unix()) <= 60) + is.True(lctx.Reached) - lctx, err = limiter.Peek(ctx, "foo") - is.NoError(err) - is.Equal(int64(3-i), lctx.Remaining) - - } else { + lctx, err = limiter.Reset(ctx, "foo") + is.NoError(err) + is.NotZero(lctx) - is.Equal(int64(3), lctx.Limit) - is.True(lctx.Remaining == 0) - is.True(math.Ceil(time.Since(time.Unix(lctx.Reset, 0)).Seconds()) <= 60) + lctx, err = limiter.Peek(ctx, "foo") + is.NoError(err) + is.NotZero(lctx) - } + is.Equal(int64(3), lctx.Limit) + is.Equal(int64(3), lctx.Remaining) + is.True((lctx.Reset - time.Now().Unix()) <= 60) + is.False(lctx.Reached) } + } // TestStoreConcurrentAccess verify that store works as expected with a concurrent access. diff --git a/limiter.go b/limiter.go index 5d372ff..753ed87 100644 --- a/limiter.go +++ b/limiter.go @@ -53,3 +53,8 @@ func (limiter *Limiter) Get(ctx context.Context, key string) (Context, error) { func (limiter *Limiter) Peek(ctx context.Context, key string) (Context, error) { return limiter.Store.Peek(ctx, key, limiter.Rate) } + +// Reset sets the limit for given identifier to zero. +func (limiter *Limiter) Reset(ctx context.Context, key string) (Context, error) { + return limiter.Store.Reset(ctx, key, limiter.Rate) +} diff --git a/store.go b/store.go index 890e84d..a9799d7 100644 --- a/store.go +++ b/store.go @@ -11,6 +11,8 @@ type Store interface { Get(ctx context.Context, key string, rate Rate) (Context, error) // Peek returns the limit for given identifier, without modification on current values. Peek(ctx context.Context, key string, rate Rate) (Context, error) + // Reset resets the limit to zero for given identifier. + Reset(ctx context.Context, key string, rate Rate) (Context, error) } // StoreOptions are options for store.