Skip to content

Commit

Permalink
Replace blacklistHeader to excludeHeaders in routing rule conf
Browse files Browse the repository at this point in the history
  • Loading branch information
Chaho12 authored and ebyhr committed Sep 24, 2024
1 parent df9bc20 commit 08979db
Show file tree
Hide file tree
Showing 4 changed files with 35 additions and 43 deletions.
6 changes: 3 additions & 3 deletions docs/routing-rules.md
Original file line number Diff line number Diff line change
Expand Up @@ -31,12 +31,12 @@ routingRules:
rulesConfigPath: "app/config/routing_rules.yml" # replace with actual path to your rules config file
rulesExternalConfiguration:
urlPath: https://router.example.com/gateway-rules # replace with your own API path
blacklistHeaders:
excludeHeaders:
- 'Authorization'
```
* Redirect URLs are not supported
* Optionally add headers to the `BlacklistHeaders` list to exclude requests with corresponding header values
* Optionally add headers to the `excludeHeaders` list to exclude requests with corresponding header values
from being sent in the POST request.

If there is error parsing the routing rules configuration file, an error is logged,
Expand All @@ -48,7 +48,7 @@ You can use an external service for processing your routing by setting the
`rulesType` to `EXTERNAL` and configuring the `rulesExternalConfiguration`.

Trino Gateway then sends all headers as a map in the body of a POST request to the external service.
Headers specified in `blacklistHeaders` are excluded. If `requestAnalyzerConfig.analyzeRequest` is set to `true`,
Headers specified in `excludeHeaders` are excluded. If `requestAnalyzerConfig.analyzeRequest` is set to `true`,
`TrinoRequestUser` and `TrinoQueryProperties` are also included.

Additionally, the following HTTP information is included:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
public class RulesExternalConfiguration
{
private String urlPath;
private List<String> blackListHeaders;
private List<String> excludeHeaders;

public String getUrlPath()
{
Expand All @@ -30,13 +30,13 @@ public void setUrlPath(String urlPath)
this.urlPath = urlPath;
}

public List<String> getBlackListHeaders()
public List<String> getExcludeHeaders()
{
return this.blackListHeaders;
return this.excludeHeaders;
}

public void setBlackListHeaders(List<String> blackListHeaders)
public void setExcludeHeaders(List<String> excludeHeaders)
{
this.blackListHeaders = blackListHeaders;
this.excludeHeaders = excludeHeaders;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ public class ExternalRoutingGroupSelector
implements RoutingGroupSelector
{
private static final Logger log = Logger.get(ExternalRoutingGroupSelector.class);
private final Set<String> blacklistHeaders;
private final Set<String> excludeHeaders;
private final URI uri;
private final HttpClient httpClient;
private final RequestAnalyzerConfig requestAnalyzerConfig;
Expand All @@ -61,10 +61,9 @@ public class ExternalRoutingGroupSelector
@VisibleForTesting
ExternalRoutingGroupSelector(RulesExternalConfiguration rulesExternalConfiguration, RequestAnalyzerConfig requestAnalyzerConfig)
{
Set<String> defaultBlacklistHeaders = ImmutableSet.of("Content-Length");
this.blacklistHeaders = ImmutableSet.<String>builder()
.addAll(defaultBlacklistHeaders)
.addAll(rulesExternalConfiguration.getBlackListHeaders())
this.excludeHeaders = ImmutableSet.<String>builder()
.add("Content-Length")
.addAll(rulesExternalConfiguration.getExcludeHeaders())
.build();

this.requestAnalyzerConfig = requestAnalyzerConfig;
Expand Down Expand Up @@ -142,8 +141,7 @@ private Multimap<String, String> getValidHeaders(HttpServletRequest servletReque
Multimap<String, String> headers = ArrayListMultimap.create();
for (String name : list(servletRequest.getHeaderNames())) {
for (String value : list(servletRequest.getHeaders(name))) {
// Add all headers to ListMultimap except those in blacklist
if (!blacklistHeaders.contains(name)) {
if (!excludeHeaders.contains(name)) {
headers.put(name, value);
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,6 @@
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
import org.junit.jupiter.api.extension.ExtendWith;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.MethodSource;
import org.mockito.ArgumentCaptor;
import org.mockito.Mockito;
import org.mockito.junit.jupiter.MockitoExtension;
Expand All @@ -39,13 +37,11 @@
import java.lang.reflect.Method;
import java.net.URI;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.Enumeration;
import java.util.List;
import java.util.Optional;
import java.util.stream.Stream;

import static io.airlift.http.client.JsonResponseHandler.createJsonResponseHandler;
import static io.airlift.http.client.Request.Builder.preparePost;
Expand Down Expand Up @@ -80,19 +76,19 @@ void initialize()
httpClient = Mockito.mock(HttpClient.class);
}

static Stream<RulesExternalConfiguration> provideRoutingRuleExternalConfig()
static RulesExternalConfiguration provideRoutingRuleExternalConfig()
{
RulesExternalConfiguration restConfig = new RulesExternalConfiguration();
restConfig.setUrlPath("http://localhost:8080/api/public/gateway_rules");
restConfig.setBlackListHeaders(new ArrayList<>(List.of("Authorization")));
return Stream.of(restConfig);
restConfig.setExcludeHeaders(List.of("Authorization"));
return restConfig;
}

@ParameterizedTest
@MethodSource("provideRoutingRuleExternalConfig")
void testByRoutingRulesExternalEngine(RulesExternalConfiguration rulesExternalConfiguration)
@Test
void testByRoutingRulesExternalEngine()
throws URISyntaxException
{
RulesExternalConfiguration rulesExternalConfiguration = provideRoutingRuleExternalConfig();
HttpServletRequest mockRequest = prepareMockRequest();

// Create a mock response
Expand Down Expand Up @@ -132,10 +128,10 @@ void testByRoutingRulesExternalEngine(RulesExternalConfiguration rulesExternalCo
assertThat(ROUTING_GROUP_REST_API_JSON_RESPONSE_HANDLER).isEqualTo(handlerCaptor.getValue());
}

@ParameterizedTest
@MethodSource("provideRoutingRuleExternalConfig")
void testApiFailure(RulesExternalConfiguration rulesExternalConfiguration)
@Test
void testApiFailure()
{
RulesExternalConfiguration rulesExternalConfiguration = provideRoutingRuleExternalConfig();
RoutingGroupSelector routingGroupSelector =
RoutingGroupSelector.byRoutingExternal(rulesExternalConfiguration, requestAnalyzerConfig);

Expand All @@ -162,38 +158,36 @@ void testApiFailure(RulesExternalConfiguration rulesExternalConfiguration)
@Test
void testNullUri()
{
RulesExternalConfiguration restConfig = new RulesExternalConfiguration();
restConfig.setBlackListHeaders(new ArrayList<>(List.of("Authorization")));
RulesExternalConfiguration rulesExternalConfiguration = provideRoutingRuleExternalConfig();
rulesExternalConfiguration.setUrlPath(null);

// Assert that a RuntimeException is thrown with message
assertThatThrownBy(() -> RoutingGroupSelector.byRoutingExternal(restConfig, requestAnalyzerConfig))
assertThatThrownBy(() -> RoutingGroupSelector.byRoutingExternal(rulesExternalConfiguration, requestAnalyzerConfig))
.isInstanceOf(RuntimeException.class)
.hasMessage("Invalid URL provided, using routing group header as default.");
}

@Test
void testBlackListHeader()
void testExcludeHeader()
throws IllegalAccessException, NoSuchMethodException, InvocationTargetException
{
// set custom RulesExternalConfiguration config
RulesExternalConfiguration restConfig = new RulesExternalConfiguration();
restConfig.setUrlPath("http://localhost:8080/api/public/gateway_rules");
restConfig.setBlackListHeaders(new ArrayList<>(List.of("test-blackList-header")));
RulesExternalConfiguration rulesExternalConfiguration = provideRoutingRuleExternalConfig();
rulesExternalConfiguration.setExcludeHeaders(List.of("test-exclude-header"));

RoutingGroupSelector routingGroupSelector =
RoutingGroupSelector.byRoutingExternal(restConfig, requestAnalyzerConfig);
RoutingGroupSelector.byRoutingExternal(rulesExternalConfiguration, requestAnalyzerConfig);

// Mock headers to be read by mockRequest
HttpServletRequest mockRequest = mock(HttpServletRequest.class);
List<String> customHeaderNames = List.of("test-blackList-header", "not-blacklisted-header");
List<String> customBlackListHeaderValues = List.of("test-blacklist-value");
List<String> customValidHeaderValues = List.of("not-blacklist-value");
List<String> customHeaderNames = List.of("test-exclude-header", "not-excluded-header");
List<String> customExcludeHeaderValues = List.of("test-excludeHeader-value");
List<String> customValidHeaderValues = List.of("not-excludeHeader-value");
Enumeration<String> headerNamesEnumeration = Collections.enumeration(customHeaderNames);
when(mockRequest.getHeaderNames()).thenReturn(headerNamesEnumeration);
when(mockRequest.getHeaders("test-blackList-header")).thenReturn(Collections.enumeration(customBlackListHeaderValues));
when(mockRequest.getHeaders("not-blacklisted-header")).thenReturn(Collections.enumeration(customValidHeaderValues));
when(mockRequest.getHeaders("test-exclude-header")).thenReturn(Collections.enumeration(customExcludeHeaderValues));
when(mockRequest.getHeaders("not-excluded-header")).thenReturn(Collections.enumeration(customValidHeaderValues));

// Use reflection to get valid headers after removing blacklist headers
// Use reflection to get valid headers after removing excludeHeaders headers
Method getValidHeaders = ExternalRoutingGroupSelector.class.getDeclaredMethod("getValidHeaders", HttpServletRequest.class);
getValidHeaders.setAccessible(true);

Expand Down

0 comments on commit 08979db

Please sign in to comment.