Skip to content

Commit b6327ac

Browse files
committed
Add SpEL based selector to DefaultSubscriptionRegistry
Issue: SPR-12884
1 parent 86733a9 commit b6327ac

File tree

2 files changed

+264
-43
lines changed

2 files changed

+264
-43
lines changed

spring-messaging/src/main/java/org/springframework/messaging/simp/broker/DefaultSubscriptionRegistry.java

Lines changed: 217 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@
1616

1717
package org.springframework.messaging.simp.broker;
1818

19+
import static org.springframework.messaging.support.MessageHeaderAccessor.getAccessor;
20+
1921
import java.util.Collection;
2022
import java.util.HashSet;
2123
import java.util.LinkedHashMap;
@@ -24,8 +26,20 @@
2426
import java.util.Set;
2527
import java.util.concurrent.ConcurrentHashMap;
2628
import java.util.concurrent.ConcurrentMap;
27-
29+
import java.util.concurrent.CopyOnWriteArraySet;
30+
31+
import org.springframework.expression.AccessException;
32+
import org.springframework.expression.EvaluationContext;
33+
import org.springframework.expression.Expression;
34+
import org.springframework.expression.ExpressionParser;
35+
import org.springframework.expression.PropertyAccessor;
36+
import org.springframework.expression.TypedValue;
37+
import org.springframework.expression.spel.SpelEvaluationException;
38+
import org.springframework.expression.spel.standard.SpelExpressionParser;
39+
import org.springframework.expression.spel.support.StandardEvaluationContext;
2840
import org.springframework.messaging.Message;
41+
import org.springframework.messaging.MessageHeaders;
42+
import org.springframework.messaging.simp.SimpMessageHeaderAccessor;
2943
import org.springframework.util.AntPathMatcher;
3044
import org.springframework.util.Assert;
3145
import org.springframework.util.LinkedMultiValueMap;
@@ -34,7 +48,13 @@
3448

3549

3650
/**
37-
* A default, simple in-memory implementation of {@link SubscriptionRegistry}.
51+
* Implementation of {@link SubscriptionRegistry} that stores subscriptions
52+
* in memory and uses a {@link org.springframework.util.PathMatcher PathMatcher}
53+
* for matching destinations.
54+
*
55+
* <p>As of 4.2 this class supports a {@link #setSelectorHeaderName selector}
56+
* header on subscription messages with Spring EL expressions evaluated against
57+
* the headers to filter out messages in addition to destination matching.
3858
*
3959
* @author Rossen Stoyanchev
4060
* @author Sebastien Deleuze
@@ -51,6 +71,10 @@ public class DefaultSubscriptionRegistry extends AbstractSubscriptionRegistry {
5171

5272
private PathMatcher pathMatcher = new AntPathMatcher();
5373

74+
private String selectorHeaderName = "selector";
75+
76+
private ExpressionParser expressionParser = new SpelExpressionParser();
77+
5478
private final DestinationCache destinationCache = new DestinationCache();
5579

5680
private final SessionSubscriptionRegistry subscriptionRegistry = new SessionSubscriptionRegistry();
@@ -85,10 +109,52 @@ public PathMatcher getPathMatcher() {
85109
return this.pathMatcher;
86110
}
87111

112+
/**
113+
* Configure the name of a selector header that a subscription message can
114+
* have in order to filter messages based on their headers. The value of the
115+
* header can use Spring EL expressions against message headers.
116+
* <p>For example the following expression expects a header called "foo" to
117+
* have the value "bar":
118+
* <pre>
119+
* headers.foo == 'bar'
120+
* </pre>
121+
* <p>By default this is set to "selector".
122+
* @since 4.2
123+
*/
124+
public void setSelectorHeaderName(String selectorHeaderName) {
125+
Assert.notNull(selectorHeaderName);
126+
this.selectorHeaderName = selectorHeaderName;
127+
}
128+
129+
/**
130+
* Return the name for the selector header.
131+
*/
132+
public String getSelectorHeaderName() {
133+
return this.selectorHeaderName;
134+
}
135+
88136

89137
@Override
90-
protected void addSubscriptionInternal(String sessionId, String subsId, String destination, Message<?> message) {
91-
this.subscriptionRegistry.addSubscription(sessionId, subsId, destination);
138+
protected void addSubscriptionInternal(String sessionId, String subsId, String destination,
139+
Message<?> message) {
140+
141+
Expression expression = null;
142+
MessageHeaders headers = message.getHeaders();
143+
String selector = SimpMessageHeaderAccessor.getFirstNativeHeader(getSelectorHeaderName(), headers);
144+
if (selector != null) {
145+
try {
146+
expression = this.expressionParser.parseExpression(selector);
147+
if (logger.isTraceEnabled()) {
148+
logger.trace("Subscription selector: [" + selector + "]");
149+
}
150+
}
151+
catch (Throwable ex) {
152+
if (logger.isDebugEnabled()) {
153+
logger.debug("Failed to parse selector: " + selector, ex);
154+
}
155+
}
156+
}
157+
this.subscriptionRegistry.addSubscription(sessionId, subsId, destination, expression);
92158
this.destinationCache.updateAfterNewSubscription(destination, sessionId, subsId);
93159
}
94160

@@ -112,24 +178,64 @@ public void unregisterAllSubscriptions(String sessionId) {
112178
}
113179

114180
@Override
115-
protected MultiValueMap<String, String> findSubscriptionsInternal(String destination, Message<?> message) {
181+
protected MultiValueMap<String, String> findSubscriptionsInternal(String destination,
182+
Message<?> message) {
183+
116184
MultiValueMap<String, String> result = this.destinationCache.getSubscriptions(destination);
117185
if (result != null) {
118-
return result;
186+
return filterSubscriptions(result, message);
119187
}
120188
result = new LinkedMultiValueMap<String, String>();
121189
for (SessionSubscriptionInfo info : this.subscriptionRegistry.getAllSubscriptions()) {
122190
for (String destinationPattern : info.getDestinations()) {
123191
if (this.pathMatcher.match(destinationPattern, destination)) {
124-
for (String subscriptionId : info.getSubscriptions(destinationPattern)) {
125-
result.add(info.sessionId, subscriptionId);
192+
for (Subscription subscription : info.getSubscriptions(destinationPattern)) {
193+
result.add(info.sessionId, subscription.getId());
126194
}
127195
}
128196
}
129197
}
130198
if (!result.isEmpty()) {
131199
this.destinationCache.addSubscriptions(destination, result);
132200
}
201+
return filterSubscriptions(result, message);
202+
}
203+
204+
private MultiValueMap<String, String> filterSubscriptions(MultiValueMap<String, String> allMatches,
205+
Message<?> message) {
206+
207+
EvaluationContext context = null;
208+
MultiValueMap<String, String> result = new LinkedMultiValueMap<String, String>(allMatches.size());
209+
for (String sessionId : allMatches.keySet()) {
210+
for (String subId : allMatches.get(sessionId)) {
211+
SessionSubscriptionInfo info = this.subscriptionRegistry.getSubscriptions(sessionId);
212+
Subscription sub = info.getSubscription(subId);
213+
Expression expression = sub.getSelectorExpression();
214+
if (expression == null) {
215+
result.add(sessionId, subId);
216+
continue;
217+
}
218+
if (context == null) {
219+
context = new StandardEvaluationContext(message);
220+
context.getPropertyAccessors().add(new SimpMessageHeaderPropertyAccessor());
221+
}
222+
try {
223+
if (expression.getValue(context, boolean.class)) {
224+
result.add(sessionId, subId);
225+
}
226+
}
227+
catch (SpelEvaluationException ex) {
228+
if (logger.isDebugEnabled()) {
229+
logger.debug("Failed to evaluate selector: " + ex.getMessage());
230+
}
231+
}
232+
catch (Throwable ex) {
233+
if (logger.isDebugEnabled()) {
234+
logger.debug("Failed to evaluate selector.", ex);
235+
}
236+
}
237+
}
238+
}
133239
return result;
134240
}
135241

@@ -257,7 +363,9 @@ public Collection<SessionSubscriptionInfo> getAllSubscriptions() {
257363
return this.sessions.values();
258364
}
259365

260-
public SessionSubscriptionInfo addSubscription(String sessionId, String subscriptionId, String destination) {
366+
public SessionSubscriptionInfo addSubscription(String sessionId, String subscriptionId,
367+
String destination, Expression selectorExpression) {
368+
261369
SessionSubscriptionInfo info = this.sessions.get(sessionId);
262370
if (info == null) {
263371
info = new SessionSubscriptionInfo(sessionId);
@@ -266,7 +374,7 @@ public SessionSubscriptionInfo addSubscription(String sessionId, String subscrip
266374
info = value;
267375
}
268376
}
269-
info.addSubscription(destination, subscriptionId);
377+
info.addSubscription(destination, subscriptionId, selectorExpression);
270378
return info;
271379
}
272380

@@ -287,8 +395,9 @@ private static class SessionSubscriptionInfo {
287395

288396
private final String sessionId;
289397

290-
// destination -> subscriptionIds
291-
private final Map<String, Set<String>> subscriptions = new ConcurrentHashMap<String, Set<String>>(4);
398+
// destination -> subscriptions
399+
private final Map<String, Set<Subscription>> destinationLookup =
400+
new ConcurrentHashMap<String, Set<Subscription>>(4);
292401

293402
private final Object monitor = new Object();
294403

@@ -303,45 +412,124 @@ public String getSessionId() {
303412
}
304413

305414
public Set<String> getDestinations() {
306-
return this.subscriptions.keySet();
415+
return this.destinationLookup.keySet();
307416
}
308417

309-
public Set<String> getSubscriptions(String destination) {
310-
return this.subscriptions.get(destination);
418+
public Set<Subscription> getSubscriptions(String destination) {
419+
return this.destinationLookup.get(destination);
420+
}
421+
422+
public Subscription getSubscription(String subscriptionId) {
423+
for (String destination : this.destinationLookup.keySet()) {
424+
for (Subscription sub : this.destinationLookup.get(destination)) {
425+
if (sub.getId().equalsIgnoreCase(subscriptionId)) {
426+
return sub;
427+
}
428+
}
429+
}
430+
return null;
311431
}
312432

313-
public void addSubscription(String destination, String subscriptionId) {
314-
Set<String> subs = this.subscriptions.get(destination);
433+
public void addSubscription(String destination, String subscriptionId, Expression selectorExpression) {
434+
Set<Subscription> subs = this.destinationLookup.get(destination);
315435
if (subs == null) {
316436
synchronized (this.monitor) {
317-
subs = this.subscriptions.get(destination);
437+
subs = this.destinationLookup.get(destination);
318438
if (subs == null) {
319-
subs = new HashSet<String>(4);
320-
this.subscriptions.put(destination, subs);
439+
subs = new CopyOnWriteArraySet<Subscription>();
440+
this.destinationLookup.put(destination, subs);
321441
}
322442
}
323443
}
324-
subs.add(subscriptionId);
444+
subs.add(new Subscription(subscriptionId, selectorExpression));
325445
}
326446

327447
public String removeSubscription(String subscriptionId) {
328-
for (String destination : this.subscriptions.keySet()) {
329-
Set<String> subscriptionIds = this.subscriptions.get(destination);
330-
if (subscriptionIds.remove(subscriptionId)) {
331-
synchronized (this.monitor) {
332-
if (subscriptionIds.isEmpty()) {
333-
this.subscriptions.remove(destination);
448+
for (String destination : this.destinationLookup.keySet()) {
449+
Set<Subscription> subscriptions = this.destinationLookup.get(destination);
450+
for (Subscription sub : subscriptions) {
451+
if (sub.getId().equals(subscriptionId) && subscriptions.remove(sub)) {
452+
synchronized (this.monitor) {
453+
if (subscriptions.isEmpty()) {
454+
this.destinationLookup.remove(destination);
455+
}
334456
}
457+
return destination;
335458
}
336-
return destination;
337459
}
338460
}
339461
return null;
340462
}
341463

342464
@Override
343465
public String toString() {
344-
return "[sessionId=" + this.sessionId + ", subscriptions=" + this.subscriptions + "]";
466+
return "[sessionId=" + this.sessionId + ", subscriptions=" + this.destinationLookup + "]";
467+
}
468+
}
469+
470+
private static class Subscription {
471+
472+
private final String id;
473+
474+
private final Expression selectorExpression;
475+
476+
477+
public Subscription(String id, Expression selector) {
478+
this.id = id;
479+
this.selectorExpression = selector;
480+
}
481+
482+
483+
public String getId() {
484+
return this.id;
485+
}
486+
487+
public Expression getSelectorExpression() {
488+
return this.selectorExpression;
489+
}
490+
491+
@Override
492+
public String toString() {
493+
return "Subscription id='" + this.id;
494+
}
495+
}
496+
497+
private static class SimpMessageHeaderPropertyAccessor implements PropertyAccessor {
498+
499+
@Override
500+
public Class<?>[] getSpecificTargetClasses() {
501+
return new Class<?>[] {MessageHeaders.class};
502+
}
503+
504+
@Override
505+
public boolean canRead(EvaluationContext context, Object target, String name) {
506+
return true;
507+
}
508+
509+
@Override
510+
public TypedValue read(EvaluationContext context, Object target, String name) throws AccessException {
511+
MessageHeaders headers = (MessageHeaders) target;
512+
SimpMessageHeaderAccessor accessor = getAccessor(headers, SimpMessageHeaderAccessor.class);
513+
Object value;
514+
if ("destination".equalsIgnoreCase(name)) {
515+
value = accessor.getDestination();
516+
}
517+
else {
518+
value = accessor.getFirstNativeHeader(name);
519+
if (value == null) {
520+
value = headers.get(name);
521+
}
522+
}
523+
return new TypedValue(value);
524+
}
525+
526+
@Override
527+
public boolean canWrite(EvaluationContext context, Object target, String name) {
528+
return false;
529+
}
530+
531+
@Override
532+
public void write(EvaluationContext context, Object target, String name, Object value) {
345533
}
346534
}
347535

0 commit comments

Comments
 (0)