Skip to content

Commit

Permalink
[websocket] Allow registering websocket adapters (#3622)
Browse files Browse the repository at this point in the history
* [WebSocket] Allow register websocket handlers

Signed-off-by: Miguel Álvarez <miguelwork92@gmail.com>
  • Loading branch information
GiviMAD authored Jun 10, 2023
1 parent be74889 commit e3396c9
Show file tree
Hide file tree
Showing 13 changed files with 449 additions and 216 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
*
* SPDX-License-Identifier: EPL-2.0
*/
package org.openhab.core.io.rest.auth.internal;
package org.openhab.core.io.rest.auth;

import java.security.Principal;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
*
* SPDX-License-Identifier: EPL-2.0
*/
package org.openhab.core.io.rest.auth.internal;
package org.openhab.core.io.rest.auth;

import java.io.IOException;
import java.net.InetAddress;
Expand Down Expand Up @@ -53,6 +53,7 @@
import org.openhab.core.config.core.ConfigurableService;
import org.openhab.core.io.rest.JSONResponse;
import org.openhab.core.io.rest.RESTConstants;
import org.openhab.core.io.rest.auth.internal.*;
import org.osgi.framework.Constants;
import org.osgi.service.component.annotations.Activate;
import org.osgi.service.component.annotations.Component;
Expand All @@ -77,7 +78,8 @@
* @author Miguel Álvarez - Add trusted networks for implicit user role
*/
@PreMatching
@Component(configurationPid = "org.openhab.restauth", property = Constants.SERVICE_PID + "=org.openhab.restauth")
@Component(configurationPid = "org.openhab.restauth", property = Constants.SERVICE_PID
+ "=org.openhab.restauth", service = AuthFilter.class)
@ConfigurableService(category = "system", label = "API Security", description_uri = AuthFilter.CONFIG_URI)
@JaxrsExtension
@JaxrsApplicationSelect("(" + JaxrsWhiteboardConstants.JAX_RS_NAME + "=" + RESTConstants.JAX_RS_NAME + ")")
Expand Down Expand Up @@ -232,56 +234,80 @@ private SecurityContext authenticateBasicAuth(String credentialString) throws Au
public void filter(@Nullable ContainerRequestContext requestContext) throws IOException {
if (requestContext != null) {
try {
String altTokenHeader = requestContext.getHeaderString(ALT_AUTH_HEADER);
if (altTokenHeader != null) {
requestContext.setSecurityContext(authenticateBearerToken(altTokenHeader));
return;
SecurityContext sc = getSecurityContext(servletRequest, false);
if (sc != null) {
requestContext.setSecurityContext(sc);
}
} catch (AuthenticationException e) {
logger.warn("Unauthorized API request from {}: {}", getClientIp(servletRequest), e.getMessage());
requestContext.abortWith(JSONResponse.createErrorResponse(Status.UNAUTHORIZED, "Invalid credentials"));
}
}
}

String authHeader = requestContext.getHeaderString(HttpHeaders.AUTHORIZATION);
if (authHeader != null) {
String[] authParts = authHeader.split(" ");
if (authParts.length == 2) {
String authType = authParts[0];
String authValue = authParts[1];
if ("Bearer".equalsIgnoreCase(authType)) {
requestContext.setSecurityContext(authenticateBearerToken(authValue));
return;
} else if ("Basic".equalsIgnoreCase(authType)) {
String[] decodedCredentials = new String(Base64.getDecoder().decode(authValue), "UTF-8")
.split(":");
if (decodedCredentials.length > 2) {
throw new AuthenticationException("Invalid Basic authentication credential format");
}
switch (decodedCredentials.length) {
case 1:
requestContext.setSecurityContext(authenticateBearerToken(decodedCredentials[0]));
break;
case 2:
if (!allowBasicAuth) {
throw new AuthenticationException(
"Basic authentication with username/password is not allowed");
}
requestContext.setSecurityContext(authenticateBasicAuth(authValue));
}
}
public @Nullable SecurityContext getSecurityContext(HttpServletRequest request, boolean allowQueryToken)
throws AuthenticationException, IOException {
String altTokenHeader = request.getHeader(ALT_AUTH_HEADER);
if (altTokenHeader != null) {
return authenticateBearerToken(altTokenHeader);
}
String authHeader = request.getHeader(HttpHeaders.AUTHORIZATION);
String authType = null;
String authValue = null;
boolean authFromQuery = false;
if (authHeader != null) {
String[] authParts = authHeader.split(" ");
if (authParts.length == 2) {
authType = authParts[0];
authValue = authParts[1];
}
} else if (allowQueryToken) {
Map<String, String[]> parameterMap = request.getParameterMap();
String[] accessToken = parameterMap.get("accessToken");
if (accessToken != null && accessToken.length > 0) {
authValue = accessToken[0];
authFromQuery = true;
}
}
if (authValue != null) {
if (authFromQuery) {
try {
return authenticateBearerToken(authValue);
} catch (AuthenticationException e) {
if (allowBasicAuth) {
return authenticateBasicAuth(authValue);
}
} else if (isImplicitUserRole(requestContext)) {
requestContext.setSecurityContext(new AnonymousUserSecurityContext());
}
} catch (AuthenticationException e) {
logger.warn("Unauthorized API request from {}: {}", getClientIp(requestContext), e.getMessage());
requestContext.abortWith(JSONResponse.createErrorResponse(Status.UNAUTHORIZED, "Invalid credentials"));
} else if ("Bearer".equalsIgnoreCase(authType)) {
return authenticateBearerToken(authValue);
} else if ("Basic".equalsIgnoreCase(authType)) {
String[] decodedCredentials = new String(Base64.getDecoder().decode(authValue), "UTF-8").split(":");
if (decodedCredentials.length > 2) {
throw new AuthenticationException("Invalid Basic authentication credential format");
}
switch (decodedCredentials.length) {
case 1:
return authenticateBearerToken(decodedCredentials[0]);
case 2:
if (!allowBasicAuth) {
throw new AuthenticationException(
"Basic authentication with username/password is not allowed");
}
return authenticateBasicAuth(authValue);
}
}
} else if (isImplicitUserRole(servletRequest)) {
return new AnonymousUserSecurityContext();
}
return null;
}

private boolean isImplicitUserRole(ContainerRequestContext requestContext) {
private boolean isImplicitUserRole(HttpServletRequest request) {
if (implicitUserRole) {
return true;
}
try {
byte[] clientAddress = InetAddress.getByName(getClientIp(requestContext)).getAddress();
byte[] clientAddress = InetAddress.getByName(getClientIp(request)).getAddress();
return trustedNetworks.stream().anyMatch(networkCIDR -> networkCIDR.isInRange(clientAddress));
} catch (IOException e) {
logger.debug("Error validating trusted networks: {}", e.getMessage());
Expand All @@ -303,8 +329,8 @@ private List<CIDR> parseTrustedNetworks(String value) {
return cidrList;
}

private String getClientIp(ContainerRequestContext requestContext) throws UnknownHostException {
String ipForwarded = Objects.requireNonNullElse(requestContext.getHeaderString("x-forwarded-for"), "");
private String getClientIp(HttpServletRequest request) throws UnknownHostException {
String ipForwarded = Objects.requireNonNullElse(request.getHeader("x-forwarded-for"), "");
String clientIp = ipForwarded.split(",")[0];
return clientIp.isBlank() ? servletRequest.getRemoteAddr() : clientIp;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ public class ExpiringUserSecurityContextCache {

private int calls = 0;

ExpiringUserSecurityContextCache(long expirationTime) {
public ExpiringUserSecurityContextCache(long expirationTime) {
this.keepPeriod = expirationTime;
entryMap = new LinkedHashMap<>() {
private static final long serialVersionUID = -1220310861591070462L;
Expand All @@ -48,7 +48,7 @@ protected boolean removeEldestEntry(Map.@Nullable Entry<String, Entry> eldest) {
};
}

synchronized @Nullable UserSecurityContext get(String key) {
public synchronized @Nullable UserSecurityContext get(String key) {
calls++;
if (calls >= CLEANUP_FREQUENCY) {
new HashSet<>(entryMap.keySet()).forEach(k -> getEntry(k));
Expand All @@ -61,11 +61,11 @@ protected boolean removeEldestEntry(Map.@Nullable Entry<String, Entry> eldest) {
return null;
}

synchronized void put(String key, UserSecurityContext value) {
public synchronized void put(String key, UserSecurityContext value) {
entryMap.put(key, new Entry(System.currentTimeMillis(), value));
}

synchronized void clear() {
public synchronized void clear() {
entryMap.clear();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
*
* SPDX-License-Identifier: EPL-2.0
*/
package org.openhab.core.io.rest.auth.internal;
package org.openhab.core.io.rest.auth;

import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.*;
Expand All @@ -32,6 +32,7 @@
import org.mockito.junit.jupiter.MockitoSettings;
import org.mockito.quality.Strictness;
import org.openhab.core.auth.UserRegistry;
import org.openhab.core.io.rest.auth.internal.JwtHelper;

/**
* The {@link AuthFilterTest} is a
Expand Down Expand Up @@ -79,7 +80,7 @@ public void noImplicitUserRoleDeniesAccess() throws IOException {
public void trustedNetworkAllowsAccessIfForwardedHeaderMatches() throws IOException {
authFilter.activate(Map.of(AuthFilter.CONFIG_IMPLICIT_USER_ROLE, false, AuthFilter.CONFIG_TRUSTED_NETWORKS,
"192.168.1.0/24"));
when(containerRequestContext.getHeaderString("x-forwarded-for")).thenReturn("192.168.1.100");
when(servletRequest.getHeader("x-forwarded-for")).thenReturn("192.168.1.100");
authFilter.filter(containerRequestContext);

verify(containerRequestContext).setSecurityContext(any());
Expand All @@ -89,7 +90,7 @@ public void trustedNetworkAllowsAccessIfForwardedHeaderMatches() throws IOExcept
public void trustedNetworkDeniesAccessIfForwardedHeaderDoesNotMatch() throws IOException {
authFilter.activate(Map.of(AuthFilter.CONFIG_IMPLICIT_USER_ROLE, false, AuthFilter.CONFIG_TRUSTED_NETWORKS,
"192.168.1.0/24"));
when(containerRequestContext.getHeaderString("x-forwarded-for")).thenReturn("192.168.2.100");
when(servletRequest.getHeader("x-forwarded-for")).thenReturn("192.168.2.100");
authFilter.filter(containerRequestContext);

verify(containerRequestContext, never()).setSecurityContext(any());
Expand Down
5 changes: 5 additions & 0 deletions bundles/org.openhab.core.io.websocket/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,11 @@
<artifactId>org.openhab.core</artifactId>
<version>${project.version}</version>
</dependency>
<dependency>
<groupId>org.openhab.core.bundles</groupId>
<artifactId>org.openhab.core.io.rest.auth</artifactId>
<version>${project.version}</version>
</dependency>
</dependencies>

</project>
Original file line number Diff line number Diff line change
@@ -0,0 +1,136 @@
/**
* Copyright (c) 2010-2023 Contributors to the openHAB project
*
* See the NOTICE file(s) distributed with this work for additional
* information.
*
* This program and the accompanying materials are made available under the
* terms of the Eclipse Public License 2.0 which is available at
* http://www.eclipse.org/legal/epl-2.0
*
* SPDX-License-Identifier: EPL-2.0
*/
package org.openhab.core.io.websocket;

import java.io.IOException;
import java.util.HashMap;
import java.util.Map;

import javax.servlet.Servlet;
import javax.servlet.ServletException;

import org.eclipse.jdt.annotation.NonNullByDefault;
import org.eclipse.jdt.annotation.Nullable;
import org.eclipse.jetty.websocket.server.WebSocketServerFactory;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeRequest;
import org.eclipse.jetty.websocket.servlet.ServletUpgradeResponse;
import org.eclipse.jetty.websocket.servlet.WebSocketCreator;
import org.eclipse.jetty.websocket.servlet.WebSocketServlet;
import org.eclipse.jetty.websocket.servlet.WebSocketServletFactory;
import org.openhab.core.auth.AuthenticationException;
import org.openhab.core.auth.Role;
import org.openhab.core.io.rest.auth.AuthFilter;
import org.osgi.service.component.annotations.Activate;
import org.osgi.service.component.annotations.Component;
import org.osgi.service.component.annotations.Reference;
import org.osgi.service.component.annotations.ReferenceCardinality;
import org.osgi.service.component.annotations.ReferencePolicy;
import org.osgi.service.http.NamespaceException;
import org.osgi.service.http.whiteboard.propertytypes.HttpWhiteboardServletName;
import org.osgi.service.http.whiteboard.propertytypes.HttpWhiteboardServletPattern;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
* The {@link CommonWebSocketServlet} provides the servlet for WebSocket connections
*
* @author Jan N. Klug - Initial contribution
* @author Miguel Álvarez Díez - Refactor into a common servlet
*/
@NonNullByDefault
@HttpWhiteboardServletName(CommonWebSocketServlet.SERVLET_PATH)
@HttpWhiteboardServletPattern(CommonWebSocketServlet.SERVLET_PATH + "/*")
@Component(immediate = true, service = { Servlet.class })
public class CommonWebSocketServlet extends WebSocketServlet {
private static final long serialVersionUID = 1L;

public static final String SERVLET_PATH = "/ws";

public static final String DEFAULT_ADAPTER_ID = EventWebSocketAdapter.ADAPTER_ID;

private final Map<String, WebSocketAdapter> connectionHandlers = new HashMap<>();
private final AuthFilter authFilter;

@SuppressWarnings("unused")
private @Nullable WebSocketServerFactory importNeeded;

@Activate
public CommonWebSocketServlet(@Reference AuthFilter authFilter) throws ServletException, NamespaceException {
this.authFilter = authFilter;
}

@Override
public void configure(@NonNullByDefault({}) WebSocketServletFactory webSocketServletFactory) {
webSocketServletFactory.getPolicy().setIdleTimeout(10000);
webSocketServletFactory.setCreator(new CommonWebSocketCreator());
}

@Reference(cardinality = ReferenceCardinality.MULTIPLE, policy = ReferencePolicy.DYNAMIC)
protected void addWebSocketAdapter(WebSocketAdapter wsAdapter) {
this.connectionHandlers.put(wsAdapter.getId(), wsAdapter);
}

protected void removeWebSocketAdapter(WebSocketAdapter wsAdapter) {
this.connectionHandlers.remove(wsAdapter.getId());
}

private class CommonWebSocketCreator implements WebSocketCreator {
private final Logger logger = LoggerFactory.getLogger(CommonWebSocketCreator.class);

@Override
public @Nullable Object createWebSocket(@Nullable ServletUpgradeRequest servletUpgradeRequest,
@Nullable ServletUpgradeResponse servletUpgradeResponse) {
if (servletUpgradeRequest == null || servletUpgradeResponse == null) {
return null;
}
if (isAuthorizedRequest(servletUpgradeRequest)) {
String requestPath = servletUpgradeRequest.getRequestURI().getPath();
String pathPrefix = SERVLET_PATH + "/";
boolean useDefaultAdapter = requestPath.equals(pathPrefix) || !requestPath.startsWith(pathPrefix);
WebSocketAdapter wsAdapter;
if (!useDefaultAdapter) {
String adapterId = requestPath.substring(pathPrefix.length());
wsAdapter = connectionHandlers.get(adapterId);
if (wsAdapter == null) {
logger.warn("Missing WebSocket adapter for path {}", adapterId);
return null;
}
} else {
wsAdapter = connectionHandlers.get(DEFAULT_ADAPTER_ID);
if (wsAdapter == null) {
logger.warn("Default WebSocket adapter is missing");
return null;
}
}
logger.debug("New connection handled by {}", wsAdapter.getId());
return wsAdapter.createWebSocket(servletUpgradeRequest, servletUpgradeResponse);
} else {
logger.warn("Unauthenticated request to create a websocket from {}.",
servletUpgradeRequest.getRemoteAddress());
}
return null;
}

private boolean isAuthorizedRequest(ServletUpgradeRequest servletUpgradeRequest) {
try {
var securityContext = authFilter.getSecurityContext(servletUpgradeRequest.getHttpServletRequest(),
true);
return securityContext != null
&& (securityContext.isUserInRole(Role.USER) || securityContext.isUserInRole(Role.ADMIN));
} catch (AuthenticationException | IOException e) {
logger.warn("Error handling WebSocket authorization", e);
return false;
}
}
}
}
Loading

0 comments on commit e3396c9

Please sign in to comment.