diff --git a/session.go b/session.go index bb3691ba5..dcfc30f44 100644 --- a/session.go +++ b/session.go @@ -3090,6 +3090,11 @@ func (c *Collection) Insert(docs ...interface{}) error { // http://www.mongodb.org/display/DOCS/Atomic+Operations // func (c *Collection) Update(selector interface{}, update interface{}) error { + _, err := c.UpdateWithChangeInfo(selector, update) + return err +} + +func (c *Collection) UpdateWithChangeInfo(selector interface{}, update interface{}) (info *ChangeInfo, err error) { if selector == nil { selector = bson.D{} } @@ -3099,10 +3104,13 @@ func (c *Collection) Update(selector interface{}, update interface{}) error { Update: update, } lerr, err := c.writeOp(&op, true) - if err == nil && lerr != nil && !lerr.UpdatedExisting { - return ErrNotFound + if err == nil && lerr != nil { + if !lerr.UpdatedExisting { + return info, ErrNotFound + } + info = &ChangeInfo{Updated: lerr.modified, Matched: lerr.N} } - return err + return info, err } // UpdateId is a convenience helper equivalent to: @@ -3247,14 +3255,22 @@ func (c *Collection) UpsertId(id interface{}, update interface{}) (info *ChangeI // http://www.mongodb.org/display/DOCS/Removing // func (c *Collection) Remove(selector interface{}) error { + _, err := c.RemoveWithChangeInfo(selector) + return err +} + +func (c *Collection) RemoveWithChangeInfo(selector interface{}) (info *ChangeInfo, err error) { if selector == nil { selector = bson.D{} } lerr, err := c.writeOp(&deleteOp{c.FullName, selector, 1, 1}, true) - if err == nil && lerr != nil && lerr.N == 0 { - return ErrNotFound + if err == nil && lerr != nil { + if lerr.N == 0 { + return info, ErrNotFound + } + info = &ChangeInfo{Removed: lerr.N, Matched: lerr.N} } - return err + return info, err } // RemoveId is a convenience helper equivalent to: diff --git a/session_test.go b/session_test.go index eaa8964f3..79bc64099 100644 --- a/session_test.go +++ b/session_test.go @@ -683,6 +683,46 @@ func (s *S) TestUpdate(c *C) { c.Assert(err, Equals, mgo.ErrNotFound) } +func (s *S) TestUpdateWithChangeInfo(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"k": n, "n": n}) + c.Assert(err, IsNil) + } + + // No changes is a no-op and shouldn't return an error. + info, err := coll.UpdateWithChangeInfo(M{"k": 42}, M{"$set": M{"n": 42}}) + c.Assert(err, IsNil) + c.Assert(info.Matched, Equals, 1) + c.Assert(info.Updated, Equals, 0) + c.Assert(info.Removed, Equals, 0) + c.Assert(info.UpsertedId, IsNil) + + info, err = coll.UpdateWithChangeInfo(M{"k": 42}, M{"$inc": M{"n": 1}}) + c.Assert(err, IsNil) + c.Assert(info.Matched, Equals, 1) + c.Assert(info.Updated, Equals, 1) + c.Assert(info.Removed, Equals, 0) + c.Assert(info.UpsertedId, IsNil) + + result := make(M) + err = coll.Find(M{"k": 42}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 43) + + info, err = coll.UpdateWithChangeInfo(M{"k": 47}, M{"k": 47, "n": 47}) + c.Assert(err, Equals, mgo.ErrNotFound) + + err = coll.Find(M{"k": 47}).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) +} + func (s *S) TestUpdateId(c *C) { session, err := mgo.Dial("localhost:40001") c.Assert(err, IsNil) @@ -986,6 +1026,39 @@ func (s *S) TestRemove(c *C) { c.Assert(result.N, Equals, 44) } +func (s *S) TestRemoveWithChangeInfo(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + info, err := coll.RemoveWithChangeInfo(M{"n": M{"$gt": 42}}) + c.Assert(err, IsNil) + c.Assert(info.Removed, Equals, 1) + c.Assert(info.Matched, Equals, 1) + c.Assert(info.Updated, Equals, 0) + c.Assert(info.UpsertedId, IsNil) + + result := &struct{ N int }{} + err = coll.Find(M{"n": 42}).One(result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 42) + + err = coll.Find(M{"n": 43}).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) + + err = coll.Find(M{"n": 44}).One(result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 44) +} + func (s *S) TestRemoveId(c *C) { session, err := mgo.Dial("localhost:40001") c.Assert(err, IsNil)