From 724d428df7764c668e185b534ad2a4ace499b906 Mon Sep 17 00:00:00 2001 From: Deo Date: Tue, 28 Feb 2017 12:50:41 +0800 Subject: [PATCH] support $arrayElemAt in aggregate $project support using False(not true) values in project field selection, e.g: {'$project': {'_id': False}} --- mongomock/collection.py | 16 +++++++++++----- tests/test__collection_api.py | 15 +++++++++++++++ 2 files changed, 26 insertions(+), 5 deletions(-) diff --git a/mongomock/collection.py b/mongomock/collection.py index cd87701d61..e35746334c 100644 --- a/mongomock/collection.py +++ b/mongomock/collection.py @@ -1267,7 +1267,9 @@ def aggregate(self, pipeline, **kwargs): '$avg', '$sum', '$stdDevPop', - '$stdDevSamp'] + '$stdDevSamp', + '$arrayElemAt' + ] boolean_operators = ['$and', '$or', '$not'] # noqa set_operators = [ # noqa '$setEquals', @@ -1431,6 +1433,11 @@ def _handle_project_operator(operator, values, doc_dict): " implemented in Mongomock" % len(values)) return min(_parse_expression(values[0], doc_dict), _parse_expression(values[1], doc_dict)) + elif operator == '$arrayElemAt': + key, index = values + array = _parse_basic_expression(key, doc_dict) + v = array[index] + return v else: raise NotImplementedError("Although '%s' is a valid project operator for the " "aggregation pipeline, it is currently not implemented " @@ -1588,10 +1595,9 @@ def _extend_collection(out_collection, field, expression): elif k == '$project': filter_list = ['_id'] for field, value in iteritems(v): - if field == '_id': - if value == 0: - filter_list.remove('_id') - if value != 0: + if field == '_id' and not value: + filter_list.remove('_id') + elif value: filter_list.append(field) out_collection = _extend_collection(out_collection, field, value) out_collection = [{k: v for (k, v) in x.items() if k in filter_list} diff --git a/tests/test__collection_api.py b/tests/test__collection_api.py index d6f1d4c784..9c59caa1ca 100644 --- a/tests/test__collection_api.py +++ b/tests/test__collection_api.py @@ -998,3 +998,18 @@ def test__bulk_write_delete_many(self): 'nModified': 0, 'nUpserted': 0, 'nMatched': 0, 'writeErrors': [], 'upserted': [], 'writeConcernErrors': [], 'nRemoved': 2, 'nInserted': 0}) + + def test__aggregate_project_array_element_at(self): + self.db.collection.insert_one({'_id': 1, 'arr': [2, 3]}) + actual = self.db.collection.aggregate([ + {'$match': {'_id': 1}}, + { + '$project': { + '_id': False, + 'a': { + '$arrayElemAt': ['$arr', 1] + } + } + } + ]) + self.assertEqual([{'a': 3}], list(actual))