Skip to content

Commit 9dce90a

Browse files
committed
Polishing.
Merge List values in Update. Add tests, reformat code. See: #4918 Original pull request: #4921
1 parent 64271c0 commit 9dce90a

File tree

4 files changed

+75
-29
lines changed

4 files changed

+75
-29
lines changed

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/BasicUpdate.java

+32-9
Original file line numberDiff line numberDiff line change
@@ -15,16 +15,21 @@
1515
*/
1616
package org.springframework.data.mongodb.core.query;
1717

18+
import java.util.ArrayList;
1819
import java.util.Collections;
1920
import java.util.LinkedHashMap;
2021
import java.util.List;
2122
import java.util.Map;
23+
import java.util.function.BiFunction;
2224

2325
import org.bson.Document;
26+
2427
import org.springframework.lang.Nullable;
2528
import org.springframework.util.ClassUtils;
2629

2730
/**
31+
* {@link Document}-based {@link Update} variant.
32+
*
2833
* @author Thomas Risberg
2934
* @author John Brisbin
3035
* @author Oliver Gierke
@@ -36,12 +41,10 @@ public class BasicUpdate extends Update {
3641
private final Document updateObject;
3742

3843
public BasicUpdate(String updateString) {
39-
super();
40-
this.updateObject = Document.parse(updateString);
44+
this(Document.parse(updateString));
4145
}
4246

4347
public BasicUpdate(Document updateObject) {
44-
super();
4548
this.updateObject = updateObject;
4649
}
4750

@@ -89,7 +92,17 @@ public Update pull(String key, @Nullable Object value) {
8992

9093
@Override
9194
public Update pullAll(String key, Object[] values) {
92-
setOperationValue("$pullAll", key, List.of(values));
95+
setOperationValue("$pullAll", key, List.of(values), (o, o2) -> {
96+
97+
if (o instanceof List<?> prev && o2 instanceof List<?> currentValue) {
98+
List<Object> merged = new ArrayList<>(prev.size() + currentValue.size());
99+
merged.addAll(prev);
100+
merged.addAll(currentValue);
101+
return merged;
102+
}
103+
104+
return o2;
105+
});
93106
return this;
94107
}
95108

@@ -109,21 +122,31 @@ public Document getUpdateObject() {
109122
return updateObject;
110123
}
111124

112-
void setOperationValue(String operator, String key, Object value) {
125+
void setOperationValue(String operator, String key, @Nullable Object value) {
126+
setOperationValue(operator, key, value, (o, o2) -> o2);
127+
}
128+
129+
void setOperationValue(String operator, String key, @Nullable Object value,
130+
BiFunction<Object, Object, Object> mergeFunction) {
113131

114132
if (!updateObject.containsKey(operator)) {
115133
updateObject.put(operator, Collections.singletonMap(key, value));
116134
} else {
117-
Object existingValue = updateObject.get(operator);
118-
if (existingValue instanceof Map<?, ?> existing) {
135+
Object o = updateObject.get(operator);
136+
if (o instanceof Map<?, ?> existing) {
119137
Map<Object, Object> target = new LinkedHashMap<>(existing);
120-
target.put(key, value);
138+
139+
if (target.containsKey(key)) {
140+
target.put(key, mergeFunction.apply(target.get(key), value));
141+
} else {
142+
target.put(key, value);
143+
}
121144
updateObject.put(operator, target);
122145
} else {
123146
throw new IllegalStateException(
124147
"Cannot add ['%s' : { '%s' : ... }]. Operator already exists with value of type [%s] which is not suitable for appending"
125148
.formatted(operator, key,
126-
existingValue != null ? ClassUtils.getShortName(existingValue.getClass()) : "null"));
149+
o != null ? ClassUtils.getShortName(o.getClass()) : "null"));
127150
}
128151
}
129152
}

spring-data-mongodb/src/main/java/org/springframework/data/mongodb/core/query/Update.java

+4-6
Original file line numberDiff line numberDiff line change
@@ -447,13 +447,11 @@ protected void addMultiFieldOperation(String operator, String key, @Nullable Obj
447447
if (existingValue == null) {
448448
keyValueMap = new Document();
449449
this.modifierOps.put(operator, keyValueMap);
450+
} else if (existingValue instanceof Document document) {
451+
keyValueMap = document;
450452
} else {
451-
if (existingValue instanceof Document document) {
452-
keyValueMap = document;
453-
} else {
454-
throw new InvalidDataAccessApiUsageException(
455-
"Modifier Operations should be a LinkedHashMap but was " + existingValue.getClass());
456-
}
453+
throw new InvalidDataAccessApiUsageException(
454+
"Modifier Operations should be a LinkedHashMap but was " + existingValue.getClass());
457455
}
458456

459457
keyValueMap.put(key, value);

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/core/query/BasicUpdateUnitTests.java

+21-4
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
* you may not use this file except in compliance with the License.
66
* You may obtain a copy of the License at
77
*
8-
* http://www.apache.org/licenses/LICENSE-2.0
8+
* https://www.apache.org/licenses/LICENSE-2.0
99
*
1010
* Unless required by applicable law or agreed to in writing, software
1111
* distributed under the License is distributed on an "AS IS" BASIS,
@@ -15,9 +15,12 @@
1515
*/
1616
package org.springframework.data.mongodb.core.query;
1717

18-
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
18+
import static org.assertj.core.api.Assertions.*;
19+
import static org.springframework.data.mongodb.test.util.Assertions.*;
1920
import static org.springframework.data.mongodb.test.util.Assertions.assertThat;
2021

22+
import java.util.Arrays;
23+
import java.util.List;
2124
import java.util.function.Function;
2225
import java.util.stream.Stream;
2326

@@ -27,12 +30,16 @@
2730
import org.junit.jupiter.params.provider.Arguments;
2831
import org.junit.jupiter.params.provider.CsvSource;
2932
import org.junit.jupiter.params.provider.MethodSource;
33+
3034
import org.springframework.data.mongodb.core.query.Update.Position;
3135

3236
/**
37+
* Unit tests for {@link BasicUpdate}.
38+
*
3339
* @author Christoph Strobl
40+
* @author Mark Paluch
3441
*/
35-
public class BasicUpdateUnitTests {
42+
class BasicUpdateUnitTests {
3643

3744
@Test // GH-4918
3845
void setOperationValueShouldAppendsOpsCorrectly() {
@@ -80,8 +87,18 @@ void updateOpsShouldNotOverrideExistingValues(String operator, Function<BasicUpd
8087
.containsKey("%s.key-2".formatted(operator));
8188
}
8289

83-
static Stream<Arguments> updateOpArgs() {
90+
@Test // GH-4918
91+
void shouldNotOverridePullAll() {
8492

93+
Document source = Document.parse("{ '$pullAll' : { 'key-1' : ['value-1'] } }");
94+
Update update = new BasicUpdate(source).pullAll("key-1", new String[] { "value-2" }).pullAll("key-2",
95+
new String[] { "value-3" });
96+
97+
assertThat(update.getUpdateObject()).containsEntry("$pullAll.key-1", Arrays.asList("value-1", "value-2"))
98+
.containsEntry("$pullAll.key-2", List.of("value-3"));
99+
}
100+
101+
static Stream<Arguments> updateOpArgs() {
85102
return Stream.of( //
86103
Arguments.of("$set", (Function<BasicUpdate, Update>) update -> update.set("key-2", "value-2")),
87104
Arguments.of("$unset", (Function<BasicUpdate, Update>) update -> update.unset("key-2")),

spring-data-mongodb/src/test/java/org/springframework/data/mongodb/repository/VersionedPersonRepositoryIntegrationTests.java

+18-10
Original file line numberDiff line numberDiff line change
@@ -15,13 +15,14 @@
1515
*/
1616
package org.springframework.data.mongodb.repository;
1717

18-
import static org.assertj.core.api.Assertions.assertThat;
18+
import static org.assertj.core.api.Assertions.*;
1919

2020
import org.bson.Document;
2121
import org.bson.types.ObjectId;
2222
import org.junit.jupiter.api.BeforeEach;
2323
import org.junit.jupiter.api.Test;
2424
import org.junit.jupiter.api.extension.ExtendWith;
25+
2526
import org.springframework.beans.factory.annotation.Autowired;
2627
import org.springframework.context.annotation.ComponentScan.Filter;
2728
import org.springframework.context.annotation.Configuration;
@@ -40,12 +41,13 @@
4041
import com.mongodb.client.MongoClient;
4142

4243
/**
44+
* Integration tests for Repositories using optimistic locking.
45+
*
4346
* @author Christoph Strobl
44-
* @since 2025/03
4547
*/
4648
@ExtendWith({ MongoClientExtension.class, SpringExtension.class })
4749
@ContextConfiguration
48-
public class VersionedPersonRepositoryIntegrationTests {
50+
class VersionedPersonRepositoryIntegrationTests {
4951

5052
static @Client MongoClient mongoClient;
5153

@@ -70,14 +72,15 @@ public MongoClient mongoClient() {
7072

7173
@BeforeEach
7274
void beforeEach() {
73-
MongoTestUtils.flushCollection("versioned-person-tests", template.getCollectionName(VersionedPersonWithCounter.class),
74-
mongoClient);
75+
MongoTestUtils.flushCollection("versioned-person-tests",
76+
template.getCollectionName(VersionedPersonWithCounter.class), mongoClient);
7577
}
7678

7779
@Test // GH-4918
7880
void updatesVersionedTypeCorrectly() {
7981

80-
VersionedPerson person = template.insert(VersionedPersonWithCounter.class).one(new VersionedPersonWithCounter("Donald", "Duckling"));
82+
VersionedPerson person = template.insert(VersionedPersonWithCounter.class)
83+
.one(new VersionedPersonWithCounter("Donald", "Duckling"));
8184

8285
int updateCount = versionedPersonRepository.findAndSetFirstnameToLastnameByLastname(person.getLastname());
8386

@@ -93,7 +96,8 @@ void updatesVersionedTypeCorrectly() {
9396
@Test // GH-4918
9497
void updatesVersionedTypeCorrectlyWhenUpdateIsUsingInc() {
9598

96-
VersionedPerson person = template.insert(VersionedPersonWithCounter.class).one(new VersionedPersonWithCounter("Donald", "Duckling"));
99+
VersionedPerson person = template.insert(VersionedPersonWithCounter.class)
100+
.one(new VersionedPersonWithCounter("Donald", "Duckling"));
97101

98102
int updateCount = versionedPersonRepository.findAndIncCounterByLastname(person.getLastname());
99103

@@ -103,13 +107,15 @@ void updatesVersionedTypeCorrectlyWhenUpdateIsUsingInc() {
103107
return collection.find(new Document("_id", new ObjectId(person.getId()))).first();
104108
});
105109

106-
assertThat(document).containsEntry("lastname", "Duckling").containsEntry("version", 1L).containsEntry("counter", 42);
110+
assertThat(document).containsEntry("lastname", "Duckling").containsEntry("version", 1L).containsEntry("counter",
111+
42);
107112
}
108113

109114
@Test // GH-4918
110115
void updatesVersionedTypeCorrectlyWhenUpdateCoversVersionBump() {
111116

112-
VersionedPerson person = template.insert(VersionedPersonWithCounter.class).one(new VersionedPersonWithCounter("Donald", "Duckling"));
117+
VersionedPerson person = template.insert(VersionedPersonWithCounter.class)
118+
.one(new VersionedPersonWithCounter("Donald", "Duckling"));
113119

114120
int updateCount = versionedPersonRepository.findAndSetFirstnameToLastnameIncVersionByLastname(person.getLastname(),
115121
10);
@@ -123,7 +129,7 @@ void updatesVersionedTypeCorrectlyWhenUpdateCoversVersionBump() {
123129
assertThat(document).containsEntry("firstname", "Duckling").containsEntry("version", 10L);
124130
}
125131

126-
public interface VersionedPersonRepository extends CrudRepository<VersionedPersonWithCounter, String> {
132+
interface VersionedPersonRepository extends CrudRepository<VersionedPersonWithCounter, String> {
127133

128134
@Update("{ '$set': { 'firstname' : ?0 } }")
129135
int findAndSetFirstnameToLastnameByLastname(String lastname);
@@ -156,5 +162,7 @@ public int getCounter() {
156162
public void setCounter(int counter) {
157163
this.counter = counter;
158164
}
165+
159166
}
167+
160168
}

0 commit comments

Comments
 (0)