diff --git a/README.md b/README.md index 40a1d94..d94e453 100644 --- a/README.md +++ b/README.md @@ -1,11 +1,10 @@ -mongostore -========== +# mongostore -[Gorilla's Session](http://www.gorillatoolkit.org/pkg/sessions) store implementation with MongoDB +[Gorilla's Session](http://www.gorillatoolkit.org/pkg/sessions) store implementation with MongoDB official driver ## Requirements -Depends on the [mgo](https://labix.org/v2/mgo) library. +Depends on the [mongo-driver](https://docs.mongodb.com/ecosystem/drivers/go) library. ## Installation @@ -16,17 +15,18 @@ Depends on the [mgo](https://labix.org/v2/mgo) library. Available on [godoc.org](http://www.godoc.org/github.com/kidstuff/mongostore). ### Example + ```go func foo(rw http.ResponseWriter, req *http.Request) { // Fetch new store. - dbsess, err := mgo.Dial("localhost") + ctx, _ := context.WithTimeout(context.Background(), 10*time.Second) + client, err := mongo.Connect(ctx, options.Client().ApplyURI("mongodb://localhost:27017")) if err != nil { panic(err) } - defer dbsess.Close() + defer client.Disconnect(ctx) - store := mongostore.NewMongoStore(dbsess.DB("test").C("test_session"), 3600, true, - []byte("secret-key")) + store := mongostore.NewMongoStore(client.Database("test").Collection("test_session"), 3600, true,[]byte("secret-key")) // Get a session. session, err := store.Get(req, "session-key") diff --git a/mgostore_test.go b/mgostore_test.go index 4fdf41d..43289d1 100644 --- a/mgostore_test.go +++ b/mgostore_test.go @@ -7,13 +7,15 @@ package mongostore import ( + "context" "encoding/gob" "net/http" "net/http/httptest" "testing" - "github.com/globalsign/mgo" "github.com/gorilla/sessions" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" ) type FlashMessage struct { @@ -36,13 +38,13 @@ func TestMongoStore(t *testing.T) { // license that can be found in the LICENSE file. // Round 1 ---------------------------------------------------------------- - dbsess, err := mgo.Dial("localhost") + client, err := mongo.Connect(context.Background(), options.Client().ApplyURI("localhost")) if err != nil { panic(err) } - defer dbsess.Close() + defer client.Disconnect(context.Background()) - store := NewMongoStore(dbsess.DB("test").C("test_session"), 3600, true, + store := NewMongoStore(client.Database("test").Collection("test_session"), 3600, true, []byte("secret-key")) req, _ = http.NewRequest("GET", "http://localhost:8080/", nil) diff --git a/mongostore.go b/mongostore.go index 7087bff..3ba6a88 100644 --- a/mongostore.go +++ b/mongostore.go @@ -5,23 +5,28 @@ package mongostore import ( + "context" "errors" + "fmt" "net/http" "time" - "github.com/globalsign/mgo" - "github.com/globalsign/mgo/bson" "github.com/gorilla/securecookie" "github.com/gorilla/sessions" + "go.mongodb.org/mongo-driver/bson" + "go.mongodb.org/mongo-driver/bson/primitive" + "go.mongodb.org/mongo-driver/mongo" + "go.mongodb.org/mongo-driver/mongo/options" ) +// ErrInvalidID ... var ( - ErrInvalidId = errors.New("mgostore: invalid session id") + ErrInvalidID = errors.New("store: invalid session id") ) // Session object store in MongoDB type Session struct { - Id bson.ObjectId `bson:"_id,omitempty"` + ID *primitive.ObjectID `bson:"_id,omitempty"` Data string Modified time.Time } @@ -31,32 +36,35 @@ type MongoStore struct { Codecs []securecookie.Codec Options *sessions.Options Token TokenGetSeter - coll *mgo.Collection + coll *mongo.Collection } // NewMongoStore returns a new MongoStore. // Set ensureTTL to true let the database auto-remove expired object by maxAge. -func NewMongoStore(c *mgo.Collection, maxAge int, ensureTTL bool, +func NewMongoStore(c *mongo.Collection, maxAge int32, ensureTTL bool, keyPairs ...[]byte) *MongoStore { store := &MongoStore{ Codecs: securecookie.CodecsFromPairs(keyPairs...), Options: &sessions.Options{ Path: "/", - MaxAge: maxAge, + MaxAge: int(maxAge), }, Token: &CookieToken{}, coll: c, } - store.MaxAge(maxAge) + store.MaxAge(int(maxAge)) if ensureTTL { - c.EnsureIndex(mgo.Index{ - Key: []string{"modified"}, - Background: true, - Sparse: true, - ExpireAfter: time.Duration(maxAge) * time.Second, - }) + opts := options.Index() + opts.SetBackground(true) + opts.SetSparse(true) + opts.SetExpireAfterSeconds(maxAge) + idx := mongo.IndexModel{Keys: []string{"modified"}, Options: opts} + _, err := c.Indexes().CreateOne(context.Background(), idx) + if err != nil { + fmt.Println("Error occurred while creating index", err) + } } return store @@ -108,7 +116,7 @@ func (m *MongoStore) Save(r *http.Request, w http.ResponseWriter, } if session.ID == "" { - session.ID = bson.NewObjectId().Hex() + session.ID = primitive.NewObjectID().Hex() } if err := m.upsert(session); err != nil { @@ -140,13 +148,14 @@ func (m *MongoStore) MaxAge(age int) { } func (m *MongoStore) load(session *sessions.Session) error { - if !bson.IsObjectIdHex(session.ID) { - return ErrInvalidId + idObject, err := primitive.ObjectIDFromHex(session.ID) + if err != nil { + return ErrInvalidID } s := Session{} - err := m.coll.FindId(bson.ObjectIdHex(session.ID)).One(&s) - if err != nil { + errFind := m.coll.FindOne(context.Background(), bson.M{"_id": idObject}).Decode(&s) + if errFind != nil { return err } @@ -159,8 +168,9 @@ func (m *MongoStore) load(session *sessions.Session) error { } func (m *MongoStore) upsert(session *sessions.Session) error { - if !bson.IsObjectIdHex(session.ID) { - return ErrInvalidId + idObject, err := primitive.ObjectIDFromHex(session.ID) + if err != nil { + return ErrInvalidID } var modified time.Time @@ -180,23 +190,45 @@ func (m *MongoStore) upsert(session *sessions.Session) error { } s := Session{ - Id: bson.ObjectIdHex(session.ID), + ID: &idObject, Data: encoded, Modified: modified, } - _, err = m.coll.UpsertId(s.Id, &s) + sMap := bson.M{ + "$set": s, + } + + count, errFind := m.coll.CountDocuments(context.Background(), bson.M{"_id": idObject}) if err != nil { - return err + return errFind + } + + if count > 0 { + _, errUpdate := m.coll.UpdateOne(context.Background(), bson.M{"_id": idObject}, sMap) + if errUpdate != nil { + return errUpdate + } + } else { + _, errInsert := m.coll.InsertOne(context.Background(), s) + if errInsert != nil { + return errInsert + } } return nil } func (m *MongoStore) delete(session *sessions.Session) error { - if !bson.IsObjectIdHex(session.ID) { - return ErrInvalidId + idObject, err := primitive.ObjectIDFromHex(session.ID) + if err != nil { + return ErrInvalidID + } + + _, delError := m.coll.DeleteOne(context.Background(), bson.M{"_id": idObject}) + if delError != nil { + return delError } - return m.coll.RemoveId(bson.ObjectIdHex(session.ID)) + return nil }