From 0d745ee1207df4990ab92082ec33c75f5777757e Mon Sep 17 00:00:00 2001 From: Parag Jain Date: Thu, 28 Apr 2016 18:50:28 -0500 Subject: [PATCH] Basic authorization support in Druid (#2424) - Introduce `AuthorizationInfo` interface, specific implementations of which would be provided by extensions - If the `druid.auth.enabled` is set to `true` then the `isAuthorized` method of `AuthorizationInfo` will be called to perform authorization checks - `AuthorizationInfo` object will be created in the servlet filters of specific extension and will be passed as a request attribute with attribute name as `AuthConfig.DRUID_AUTH_TOKEN` - As per the scope of this PR, all resources that needs to be secured are divided into 3 types - `DATASOURCE`, `CONFIG` and `STATE`. For any type of resource, possible actions are - `READ` or `WRITE` - Specific ResourceFilters are used to perform auth checks for all endpoints that corresponds to a specific resource type. This prevents duplication of logic and need to inject HttpServletRequest inside each endpoint. For example - `DatasourceResourceFilter` is used for endpoints where the datasource information is present after "datasources" segment in the request Path such as `/druid/coordinator/v1/datasources/`, `/druid/coordinator/v1/metadata/datasources/`, `/druid/v2/datasources/` - `RulesResourceFilter` is used where the datasource information is present after "rules" segment in the request Path such as `/druid/coordinator/v1/rules/` - `TaskResourceFilter` is used for endpoints is used where the datasource information is present after "task" segment in the request Path such as `druid/indexer/v1/task` - `ConfigResourceFilter` is used for endpoints like `/druid/coordinator/v1/config`, `/druid/indexer/v1/worker`, `/druid/worker/v1` etc - `StateResourceFilter` is used for endpoints like `/druid/broker/v1/loadstatus`, `/druid/coordinator/v1/leader`, `/druid/coordinator/v1/loadqueue`, `/druid/coordinator/v1/rules` etc - For endpoints where a list of resources is returned like `/druid/coordinator/v1/datasources`, `/druid/indexer/v1/completeTasks` etc. the list is filtered to return only the resources to which the requested user has access. In these cases, `HttpServletRequest` instance needs to be injected in the endpoint method. Note - JAX-RS specification provides an interface called `SecurityContext`. However, we did not use this but provided our own interface `AuthorizationInfo` mainly because it provides more flexibility. For example, `SecurityContext` has a method called `isUserInRole(String role)` which would be used for auth checks and if used then the mapping of what roles can access what resource needs to be modeled inside Druid either using some convention or some other means which is not very flexible as Druid has dynamic resources like datasources. Fixes #2355 with PR #2424 --- .../overlord/http/OverlordResource.java | 195 +++++++- .../http/security/TaskResourceFilter.java | 123 +++++ .../indexing/worker/http/WorkerResource.java | 9 + .../overlord/http/OverlordResourceTest.java | 465 ++++++------------ .../indexing/overlord/http/OverlordTest.java | 413 ++++++++++++++++ .../security/SecurityResourceFilterTest.java | 146 ++++++ .../druid/guice/security/DruidAuthModule.java | 44 ++ .../druid/initialization/Initialization.java | 3 + .../EventReceiverFirehoseFactory.java | 2 +- .../io/druid/server/ClientInfoResource.java | 52 +- .../java/io/druid/server/QueryManager.java | 22 +- .../java/io/druid/server/QueryResource.java | 58 ++- .../java/io/druid/server/StatusResource.java | 3 + .../io/druid/server/http/BrokerResource.java | 3 + .../CoordinatorDynamicConfigsResource.java | 7 +- .../server/http/CoordinatorResource.java | 3 + .../server/http/DatasourcesResource.java | 34 +- .../druid/server/http/HistoricalResource.java | 3 + .../druid/server/http/IntervalsResource.java | 53 +- .../druid/server/http/InventoryViewUtils.java | 48 +- .../druid/server/http/MetadataResource.java | 103 +++- .../io/druid/server/http/RulesResource.java | 11 +- .../io/druid/server/http/ServersResource.java | 3 + .../io/druid/server/http/TiersResource.java | 3 + .../http/security/AbstractResourceFilter.java | 89 ++++ .../http/security/ConfigResourceFilter.java | 85 ++++ .../security/DatasourceResourceFilter.java | 110 +++++ .../http/security/RulesResourceFilter.java | 106 ++++ .../http/security/StateResourceFilter.java | 97 ++++ .../metrics/EventReceiverFirehoseMonitor.java | 2 - .../java/io/druid/server/security/Access.java | 51 ++ .../java/io/druid/server/security/Action.java | 26 + .../io/druid/server/security/AuthConfig.java | 85 ++++ .../server/security/AuthorizationInfo.java | 44 ++ .../io/druid/server/security/Resource.java | 69 +++ .../druid/server/security/ResourceType.java | 27 + .../druid/server/ClientInfoResourceTest.java | 3 +- .../io/druid/server/QueryResourceTest.java | 288 ++++++++++- .../server/http/DatasourcesResourceTest.java | 89 +++- .../server/http/IntervalsResourceTest.java | 30 +- .../druid/server/http/RulesResourceTest.java | 4 +- .../security/ResourceFilterTestHelper.java | 245 +++++++++ .../security/SecurityResourceFilterTest.java | 134 +++++ 43 files changed, 2980 insertions(+), 410 deletions(-) create mode 100644 indexing-service/src/main/java/io/druid/indexing/overlord/http/security/TaskResourceFilter.java create mode 100644 indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordTest.java create mode 100644 indexing-service/src/test/java/io/druid/indexing/overlord/http/security/SecurityResourceFilterTest.java create mode 100644 server/src/main/java/io/druid/guice/security/DruidAuthModule.java create mode 100644 server/src/main/java/io/druid/server/http/security/AbstractResourceFilter.java create mode 100644 server/src/main/java/io/druid/server/http/security/ConfigResourceFilter.java create mode 100644 server/src/main/java/io/druid/server/http/security/DatasourceResourceFilter.java create mode 100644 server/src/main/java/io/druid/server/http/security/RulesResourceFilter.java create mode 100644 server/src/main/java/io/druid/server/http/security/StateResourceFilter.java create mode 100644 server/src/main/java/io/druid/server/security/Access.java create mode 100644 server/src/main/java/io/druid/server/security/Action.java create mode 100644 server/src/main/java/io/druid/server/security/AuthConfig.java create mode 100644 server/src/main/java/io/druid/server/security/AuthorizationInfo.java create mode 100644 server/src/main/java/io/druid/server/security/Resource.java create mode 100644 server/src/main/java/io/druid/server/security/ResourceType.java create mode 100644 server/src/test/java/io/druid/server/http/security/ResourceFilterTestHelper.java create mode 100644 server/src/test/java/io/druid/server/http/security/SecurityResourceFilterTest.java diff --git a/indexing-service/src/main/java/io/druid/indexing/overlord/http/OverlordResource.java b/indexing-service/src/main/java/io/druid/indexing/overlord/http/OverlordResource.java index 706036e5e6f2..4ef7d5246db5 100644 --- a/indexing-service/src/main/java/io/druid/indexing/overlord/http/OverlordResource.java +++ b/indexing-service/src/main/java/io/druid/indexing/overlord/http/OverlordResource.java @@ -22,6 +22,10 @@ import com.fasterxml.jackson.annotation.JsonValue; import com.google.common.base.Function; import com.google.common.base.Optional; +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.Collections2; +import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; @@ -30,7 +34,9 @@ import com.google.common.io.ByteSource; import com.google.common.util.concurrent.SettableFuture; import com.google.inject.Inject; +import com.metamx.common.Pair; import com.metamx.common.logger.Logger; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.audit.AuditInfo; import io.druid.audit.AuditManager; import io.druid.common.config.JacksonConfigManager; @@ -46,8 +52,17 @@ import io.druid.indexing.overlord.TaskStorageQueryAdapter; import io.druid.indexing.overlord.WorkerTaskRunner; import io.druid.indexing.overlord.autoscaling.ScalingStats; +import io.druid.indexing.overlord.http.security.TaskResourceFilter; import io.druid.indexing.overlord.setup.WorkerBehaviorConfig; import io.druid.metadata.EntryExistsException; +import io.druid.server.http.security.ConfigResourceFilter; +import io.druid.server.http.security.StateResourceFilter; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; import io.druid.tasklogs.TaskLogStreamer; import io.druid.timeline.DataSegment; import org.joda.time.DateTime; @@ -63,11 +78,13 @@ import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; +import javax.ws.rs.WebApplicationException; import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import java.io.IOException; import java.util.Collection; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -85,6 +102,7 @@ public class OverlordResource private final TaskLogStreamer taskLogStreamer; private final JacksonConfigManager configManager; private final AuditManager auditManager; + private final AuthConfig authConfig; private AtomicReference workerConfigRef = null; @@ -94,7 +112,8 @@ public OverlordResource( TaskStorageQueryAdapter taskStorageQueryAdapter, TaskLogStreamer taskLogStreamer, JacksonConfigManager configManager, - AuditManager auditManager + AuditManager auditManager, + AuthConfig authConfig ) throws Exception { this.taskMaster = taskMaster; @@ -102,14 +121,35 @@ public OverlordResource( this.taskLogStreamer = taskLogStreamer; this.configManager = configManager; this.auditManager = auditManager; + this.authConfig = authConfig; } @POST @Path("/task") @Consumes(MediaType.APPLICATION_JSON) @Produces(MediaType.APPLICATION_JSON) - public Response taskPost(final Task task) + public Response taskPost( + final Task task, + @Context final HttpServletRequest req + ) { + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final String dataSource = task.getDataSource(); + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + Access authResult = authorizationInfo.isAuthorized( + new Resource(dataSource, ResourceType.DATASOURCE), + Action.WRITE + ); + if (!authResult.isAllowed()) { + return Response.status(Response.Status.FORBIDDEN).header("Access-Check-Result", authResult).build(); + } + } + return asLeaderWith( taskMaster.getTaskQueue(), new Function() @@ -133,6 +173,7 @@ public Response apply(TaskQueue taskQueue) @GET @Path("/leader") + @ResourceFilters(StateResourceFilter.class) @Produces(MediaType.APPLICATION_JSON) public Response getLeader() { @@ -142,6 +183,7 @@ public Response getLeader() @GET @Path("/task/{taskid}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(TaskResourceFilter.class) public Response getTaskPayload(@PathParam("taskid") String taskid) { return optionalTaskResponse(taskid, "payload", taskStorageQueryAdapter.getTask(taskid)); @@ -150,6 +192,7 @@ public Response getTaskPayload(@PathParam("taskid") String taskid) @GET @Path("/task/{taskid}/status") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(TaskResourceFilter.class) public Response getTaskStatus(@PathParam("taskid") String taskid) { return optionalTaskResponse(taskid, "status", taskStorageQueryAdapter.getStatus(taskid)); @@ -158,6 +201,7 @@ public Response getTaskStatus(@PathParam("taskid") String taskid) @GET @Path("/task/{taskid}/segments") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(TaskResourceFilter.class) public Response getTaskSegments(@PathParam("taskid") String taskid) { final Set segments = taskStorageQueryAdapter.getInsertedSegments(taskid); @@ -167,6 +211,7 @@ public Response getTaskSegments(@PathParam("taskid") String taskid) @POST @Path("/task/{taskid}/shutdown") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(TaskResourceFilter.class) public Response doShutdown(@PathParam("taskid") final String taskid) { return asLeaderWith( @@ -186,6 +231,7 @@ public Response apply(TaskQueue taskQueue) @GET @Path("/worker") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(ConfigResourceFilter.class) public Response getWorkerConfig() { if (workerConfigRef == null) { @@ -199,11 +245,12 @@ public Response getWorkerConfig() @POST @Path("/worker") @Consumes(MediaType.APPLICATION_JSON) + @ResourceFilters(ConfigResourceFilter.class) public Response setWorkerConfig( final WorkerBehaviorConfig workerBehaviorConfig, @HeaderParam(AuditManager.X_DRUID_AUTHOR) @DefaultValue("") final String author, @HeaderParam(AuditManager.X_DRUID_COMMENT) @DefaultValue("") final String comment, - @Context HttpServletRequest req + @Context final HttpServletRequest req ) { if (!configManager.set( @@ -222,6 +269,7 @@ public Response setWorkerConfig( @GET @Path("/worker/history") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(ConfigResourceFilter.class) public Response getWorkerConfigHistory( @QueryParam("interval") final String interval, @QueryParam("count") final Integer count @@ -258,6 +306,7 @@ public Response getWorkerConfigHistory( @POST @Path("/action") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response doAction(final TaskActionHolder holder) { return asLeaderWith( @@ -292,7 +341,7 @@ public Response apply(TaskActionClient taskActionClient) @GET @Path("/waitingTasks") @Produces(MediaType.APPLICATION_JSON) - public Response getWaitingTasks() + public Response getWaitingTasks(@Context final HttpServletRequest req) { return workItemsResponse( new Function>() @@ -302,7 +351,38 @@ public Collection apply(TaskRunner taskRunner) { // A bit roundabout, but works as a way of figuring out what tasks haven't been handed // off to the runner yet: - final List activeTasks = taskStorageQueryAdapter.getActiveTasks(); + final List allActiveTasks = taskStorageQueryAdapter.getActiveTasks(); + final List activeTasks; + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final Map, Access> resourceAccessMap = new HashMap<>(); + final AuthorizationInfo authorizationInfo = + (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + activeTasks = ImmutableList.copyOf( + Iterables.filter( + allActiveTasks, + new Predicate() + { + @Override + public boolean apply(Task input) + { + Resource resource = new Resource(input.getDataSource(), ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + ) + ); + } else { + activeTasks = allActiveTasks; + } final Set runnersKnownTasks = Sets.newHashSet( Iterables.transform( taskRunner.getKnownTasks(), @@ -346,7 +426,7 @@ public TaskLocation getLocation() @GET @Path("/pendingTasks") @Produces(MediaType.APPLICATION_JSON) - public Response getPendingTasks() + public Response getPendingTasks(@Context final HttpServletRequest req) { return workItemsResponse( new Function>() @@ -354,7 +434,13 @@ public Response getPendingTasks() @Override public Collection apply(TaskRunner taskRunner) { - return taskRunner.getPendingTasks(); + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + return securedTaskRunnerWorkItem(taskRunner.getPendingTasks(), req); + } else { + return taskRunner.getPendingTasks(); + } + } } ); @@ -363,7 +449,7 @@ public Collection apply(TaskRunner taskRunner) @GET @Path("/runningTasks") @Produces(MediaType.APPLICATION_JSON) - public Response getRunningTasks() + public Response getRunningTasks(@Context final HttpServletRequest req) { return workItemsResponse( new Function>() @@ -371,7 +457,12 @@ public Response getRunningTasks() @Override public Collection apply(TaskRunner taskRunner) { - return taskRunner.getRunningTasks(); + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + return securedTaskRunnerWorkItem(taskRunner.getRunningTasks(), req); + } else { + return taskRunner.getRunningTasks(); + } } } ); @@ -380,10 +471,50 @@ public Collection apply(TaskRunner taskRunner) @GET @Path("/completeTasks") @Produces(MediaType.APPLICATION_JSON) - public Response getCompleteTasks() + public Response getCompleteTasks(@Context final HttpServletRequest req) { + final List recentlyFinishedTasks; + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final Map, Access> resourceAccessMap = new HashMap<>(); + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + recentlyFinishedTasks = ImmutableList.copyOf( + Iterables.filter( + taskStorageQueryAdapter.getRecentlyFinishedTaskStatuses(), + new Predicate() + { + @Override + public boolean apply(TaskStatus input) + { + final String taskId = input.getId(); + final Optional optionalTask = taskStorageQueryAdapter.getTask(taskId); + if (!optionalTask.isPresent()) { + throw new WebApplicationException( + Response.serverError().entity( + String.format("No task information found for task with id: [%s]", taskId) + ).build() + ); + } + Resource resource = new Resource(optionalTask.get().getDataSource(), ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + ) + ); + } else { + recentlyFinishedTasks = taskStorageQueryAdapter.getRecentlyFinishedTaskStatuses(); + } + final List completeTasks = Lists.transform( - taskStorageQueryAdapter.getRecentlyFinishedTaskStatuses(), + recentlyFinishedTasks, new Function() { @Override @@ -406,6 +537,7 @@ public TaskResponseObject apply(TaskStatus taskStatus) @GET @Path("/workers") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response getWorkers() { return asLeaderWith( @@ -435,6 +567,7 @@ public Response apply(TaskRunner taskRunner) @GET @Path("/scaling") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response getScalingState() { // Don't use asLeaderWith, since we want to return 200 instead of 503 when missing an autoscaler. @@ -449,6 +582,7 @@ public Response getScalingState() @GET @Path("/task/{taskid}/log") @Produces("text/plain") + @ResourceFilters(TaskResourceFilter.class) public Response doGetLog( @PathParam("taskid") final String taskid, @QueryParam("offset") @DefaultValue("0") final long offset @@ -528,6 +662,45 @@ private Response asLeaderWith(Optional x, Function f) } } + private Collection securedTaskRunnerWorkItem( + Collection collectionToFilter, + HttpServletRequest req + ) + { + final Map, Access> resourceAccessMap = new HashMap<>(); + final AuthorizationInfo authorizationInfo = + (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + return Collections2.filter( + collectionToFilter, + new Predicate() + { + @Override + public boolean apply(TaskRunnerWorkItem input) + { + final String taskId = input.getTaskId(); + final Optional optionalTask = taskStorageQueryAdapter.getTask(taskId); + if (!optionalTask.isPresent()) { + throw new WebApplicationException( + Response.serverError().entity( + String.format("No task information found for task with id: [%s]", taskId) + ).build() + ); + } + Resource resource = new Resource(optionalTask.get().getDataSource(), ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + ); + } + static class TaskResponseObject { private final String id; diff --git a/indexing-service/src/main/java/io/druid/indexing/overlord/http/security/TaskResourceFilter.java b/indexing-service/src/main/java/io/druid/indexing/overlord/http/security/TaskResourceFilter.java new file mode 100644 index 000000000000..0866658c08a7 --- /dev/null +++ b/indexing-service/src/main/java/io/druid/indexing/overlord/http/security/TaskResourceFilter.java @@ -0,0 +1,123 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.indexing.overlord.http.security; + +import com.google.common.base.Optional; +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.inject.Inject; +import com.sun.jersey.spi.container.ContainerRequest; +import io.druid.indexing.common.task.Task; +import io.druid.indexing.overlord.TaskStorageQueryAdapter; +import io.druid.server.http.security.AbstractResourceFilter; +import io.druid.server.security.Access; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.PathSegment; +import javax.ws.rs.core.Response; +import java.util.List; + +/** + * Use this ResourceFilter when the datasource information is present after "task" segment in the request Path + * Here are some example paths where this filter is used - + * - druid/indexer/v1/task/{taskid}/... + * Note - DO NOT use this filter at MiddleManager resources as TaskStorageQueryAdapter cannot be injected there + */ +public class TaskResourceFilter extends AbstractResourceFilter +{ + private final TaskStorageQueryAdapter taskStorageQueryAdapter; + + @Inject + public TaskResourceFilter(TaskStorageQueryAdapter taskStorageQueryAdapter, AuthConfig authConfig) { + super(authConfig); + this.taskStorageQueryAdapter = taskStorageQueryAdapter; + } + + @Override + public ContainerRequest filter(ContainerRequest request) + { + if (getAuthConfig().isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final String taskId = Preconditions.checkNotNull( + request.getPathSegments() + .get( + Iterables.indexOf( + request.getPathSegments(), + new Predicate() + { + @Override + public boolean apply(PathSegment input) + { + return input.getPath().equals("task"); + } + } + ) + 1 + ).getPath() + ); + + Optional taskOptional = taskStorageQueryAdapter.getTask(taskId); + if (!taskOptional.isPresent()) { + throw new WebApplicationException( + Response.status(Response.Status.BAD_REQUEST) + .entity(String.format("Cannot find any task with id: [%s]", taskId)) + .build() + ); + } + final String dataSourceName = Preconditions.checkNotNull(taskOptional.get().getDataSource()); + + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) getReq().getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + final Access authResult = authorizationInfo.isAuthorized( + new Resource(dataSourceName, ResourceType.DATASOURCE), + getAction(request) + ); + if (!authResult.isAllowed()) { + throw new WebApplicationException(Response.status(Response.Status.FORBIDDEN) + .entity( + String.format("Access-Check-Result: %s", authResult.toString()) + ) + .build()); + } + } + + return request; + } + + @Override + public boolean isApplicable(String requestPath) + { + List applicablePaths = ImmutableList.of("druid/indexer/v1/task/"); + for (String path : applicablePaths) { + if(requestPath.startsWith(path) && !requestPath.equals(path)) { + return true; + } + } + return false; + } +} diff --git a/indexing-service/src/main/java/io/druid/indexing/worker/http/WorkerResource.java b/indexing-service/src/main/java/io/druid/indexing/worker/http/WorkerResource.java index 9bb3bdc44b67..49641462e912 100644 --- a/indexing-service/src/main/java/io/druid/indexing/worker/http/WorkerResource.java +++ b/indexing-service/src/main/java/io/druid/indexing/worker/http/WorkerResource.java @@ -27,10 +27,13 @@ import com.google.common.io.ByteSource; import com.google.inject.Inject; import com.metamx.common.logger.Logger; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.indexing.overlord.TaskRunner; import io.druid.indexing.overlord.TaskRunnerWorkItem; import io.druid.indexing.worker.Worker; import io.druid.indexing.worker.WorkerCuratorCoordinator; +import io.druid.server.http.security.ConfigResourceFilter; +import io.druid.server.http.security.StateResourceFilter; import io.druid.tasklogs.TaskLogStreamer; import javax.ws.rs.DefaultValue; @@ -73,6 +76,7 @@ public WorkerResource( @POST @Path("/disable") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(ConfigResourceFilter.class) public Response doDisable() { try { @@ -93,6 +97,7 @@ public Response doDisable() @POST @Path("/enable") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(ConfigResourceFilter.class) public Response doEnable() { try { @@ -107,6 +112,7 @@ public Response doEnable() @GET @Path("/enabled") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response isEnabled() { try { @@ -122,6 +128,7 @@ public Response isEnabled() @GET @Path("/tasks") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response getTasks() { try { @@ -149,6 +156,7 @@ public String apply(TaskRunnerWorkItem input) @POST @Path("/task/{taskid}/shutdown") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response doShutdown(@PathParam("taskid") String taskid) { try { @@ -164,6 +172,7 @@ public Response doShutdown(@PathParam("taskid") String taskid) @GET @Path("/task/{taskid}/log") @Produces("text/plain") + @ResourceFilters(StateResourceFilter.class) public Response doGetLog( @PathParam("taskid") String taskid, @QueryParam("offset") @DefaultValue("0") long offset diff --git a/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordResourceTest.java b/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordResourceTest.java index 5ef4fd3c8c03..173bd905c37a 100644 --- a/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordResourceTest.java +++ b/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordResourceTest.java @@ -22,379 +22,226 @@ import com.google.common.base.Function; import com.google.common.base.Optional; import com.google.common.collect.ImmutableList; -import com.google.common.collect.ImmutableMap; import com.google.common.collect.Lists; import com.google.common.util.concurrent.ListenableFuture; -import com.google.common.util.concurrent.MoreExecutors; -import com.metamx.common.Pair; -import com.metamx.common.guava.CloseQuietly; -import com.metamx.emitter.EmittingLogger; -import com.metamx.emitter.service.ServiceEmitter; -import io.druid.concurrent.Execs; -import io.druid.curator.PotentiallyGzippedCompressionProvider; -import io.druid.curator.discovery.NoopServiceAnnouncer; import io.druid.indexing.common.TaskLocation; import io.druid.indexing.common.TaskStatus; -import io.druid.indexing.common.actions.TaskActionClientFactory; -import io.druid.indexing.common.config.TaskStorageConfig; +import io.druid.indexing.common.TaskToolbox; +import io.druid.indexing.common.actions.TaskActionClient; +import io.druid.indexing.common.task.AbstractTask; import io.druid.indexing.common.task.NoopTask; import io.druid.indexing.common.task.Task; -import io.druid.indexing.overlord.HeapMemoryTaskStorage; -import io.druid.indexing.overlord.TaskLockbox; import io.druid.indexing.overlord.TaskMaster; import io.druid.indexing.overlord.TaskRunner; -import io.druid.indexing.overlord.TaskRunnerFactory; -import io.druid.indexing.overlord.TaskRunnerListener; import io.druid.indexing.overlord.TaskRunnerWorkItem; -import io.druid.indexing.overlord.TaskStorage; import io.druid.indexing.overlord.TaskStorageQueryAdapter; -import io.druid.indexing.overlord.autoscaling.ScalingStats; -import io.druid.indexing.overlord.config.TaskQueueConfig; -import io.druid.server.DruidNode; -import io.druid.server.initialization.IndexerZkConfig; -import io.druid.server.initialization.ZkPathsConfig; -import io.druid.server.metrics.NoopServiceEmitter; -import org.apache.curator.framework.CuratorFramework; -import org.apache.curator.framework.CuratorFrameworkFactory; -import org.apache.curator.retry.RetryOneTime; -import org.apache.curator.test.TestingServer; -import org.apache.curator.test.Timing; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; import org.easymock.EasyMock; -import org.joda.time.Period; import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; -import javax.annotation.Nullable; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.core.Response; -import java.util.ArrayList; import java.util.Collection; import java.util.List; -import java.util.Map; -import java.util.concurrent.Callable; -import java.util.concurrent.ConcurrentHashMap; -import java.util.concurrent.CountDownLatch; -import java.util.concurrent.Executor; -import java.util.concurrent.atomic.AtomicBoolean; public class OverlordResourceTest { - private static final TaskLocation TASK_LOCATION = new TaskLocation("dummy", 1000); - - private TestingServer server; - private Timing timing; - private CuratorFramework curator; - private TaskMaster taskMaster; - private TaskLockbox taskLockbox; - private TaskStorage taskStorage; - private TaskActionClientFactory taskActionClientFactory; - private CountDownLatch announcementLatch; - private DruidNode druidNode; private OverlordResource overlordResource; - private CountDownLatch[] taskCompletionCountDownLatches; - private CountDownLatch[] runTaskCountDownLatches; - - private void setupServerAndCurator() throws Exception - { - server = new TestingServer(); - timing = new Timing(); - curator = CuratorFrameworkFactory - .builder() - .connectString(server.getConnectString()) - .sessionTimeoutMs(timing.session()) - .connectionTimeoutMs(timing.connection()) - .retryPolicy(new RetryOneTime(1)) - .compressionProvider(new PotentiallyGzippedCompressionProvider(true)) - .build(); - } - - private void tearDownServerAndCurator() - { - CloseQuietly.close(curator); - CloseQuietly.close(server); - } + private TaskMaster taskMaster; + private TaskStorageQueryAdapter tsqa; + private HttpServletRequest req; + private TaskRunner taskRunner; @Before public void setUp() throws Exception { - taskLockbox = EasyMock.createStrictMock(TaskLockbox.class); - taskLockbox.syncFromStorage(); - EasyMock.expectLastCall().atLeastOnce(); - taskLockbox.add(EasyMock.anyObject()); - EasyMock.expectLastCall().atLeastOnce(); - taskLockbox.remove(EasyMock.anyObject()); - EasyMock.expectLastCall().atLeastOnce(); - - // for second Noop Task directly added to deep storage. - taskLockbox.add(EasyMock.anyObject()); - EasyMock.expectLastCall().atLeastOnce(); - taskLockbox.remove(EasyMock.anyObject()); - EasyMock.expectLastCall().atLeastOnce(); - - taskActionClientFactory = EasyMock.createStrictMock(TaskActionClientFactory.class); - EasyMock.expect(taskActionClientFactory.create(EasyMock.anyObject())) - .andReturn(null).anyTimes(); - EasyMock.replay(taskLockbox, taskActionClientFactory); + taskRunner = EasyMock.createMock(TaskRunner.class); + taskMaster = EasyMock.createStrictMock(TaskMaster.class); + tsqa = EasyMock.createStrictMock(TaskStorageQueryAdapter.class); + req = EasyMock.createStrictMock(HttpServletRequest.class); + + EasyMock.expect(taskMaster.getTaskRunner()).andReturn( + Optional.of(taskRunner) + ).anyTimes(); + + overlordResource = new OverlordResource( + taskMaster, + tsqa, + null, + null, + null, + new AuthConfig(true) + ); - taskStorage = new HeapMemoryTaskStorage(new TaskStorageConfig(null)); - runTaskCountDownLatches = new CountDownLatch[2]; - runTaskCountDownLatches[0] = new CountDownLatch(1); - runTaskCountDownLatches[1] = new CountDownLatch(1); - taskCompletionCountDownLatches = new CountDownLatch[2]; - taskCompletionCountDownLatches[0] = new CountDownLatch(1); - taskCompletionCountDownLatches[1] = new CountDownLatch(1); - announcementLatch = new CountDownLatch(1); - IndexerZkConfig indexerZkConfig = new IndexerZkConfig(new ZkPathsConfig(), null, null, null, null, null); - setupServerAndCurator(); - curator.start(); - curator.blockUntilConnected(); - curator.create().creatingParentsIfNeeded().forPath(indexerZkConfig.getLeaderLatchPath()); - druidNode = new DruidNode("hey", "what", 1234); - ServiceEmitter serviceEmitter = new NoopServiceEmitter(); - taskMaster = new TaskMaster( - new TaskQueueConfig(null, new Period(1), null, new Period(10)), - taskLockbox, - taskStorage, - taskActionClientFactory, - druidNode, - indexerZkConfig, - new TaskRunnerFactory() - { - @Override - public MockTaskRunner build() - { - return new MockTaskRunner(runTaskCountDownLatches, taskCompletionCountDownLatches); - } - }, - curator, - new NoopServiceAnnouncer() + EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN)).andReturn( + new AuthorizationInfo() { @Override - public void announce(DruidNode node) + public Access isAuthorized( + Resource resource, Action action + ) { - announcementLatch.countDown(); + if (resource.getName().equals("allow")) { + return new Access(true); + } else { + return new Access(false); + } } - }, - serviceEmitter + } ); - EmittingLogger.registerEmitter(serviceEmitter); } - @Test(timeout = 2000L) - public void testOverlordResource() throws Exception + @Test + public void testSecuredGetWaitingTask() throws Exception { - // basic task master lifecycle test - taskMaster.start(); - announcementLatch.await(); - while (!taskMaster.isLeading()) { - // I believe the control will never reach here and thread will never sleep but just to be on safe side - Thread.sleep(10); - } - Assert.assertEquals(taskMaster.getLeader(), druidNode.getHostAndPort()); - // Test Overlord resource stuff - overlordResource = new OverlordResource(taskMaster, new TaskStorageQueryAdapter(taskStorage), null, null, null); - Response response = overlordResource.getLeader(); - Assert.assertEquals(druidNode.getHostAndPort(), response.getEntity()); - - final String taskId_0 = "0"; - NoopTask task_0 = new NoopTask(taskId_0, 0, 0, null, null, null); - response = overlordResource.taskPost(task_0); - Assert.assertEquals(200, response.getStatus()); - Assert.assertEquals(ImmutableMap.of("task", taskId_0), response.getEntity()); + EasyMock.expect(tsqa.getActiveTasks()).andReturn( + ImmutableList.of( + getTaskWithIdAndDatasource("id_1", "allow"), + getTaskWithIdAndDatasource("id_2", "allow"), + getTaskWithIdAndDatasource("id_3", "deny"), + getTaskWithIdAndDatasource("id_4", "deny") + ) + ).once(); + + EasyMock.>expect(taskRunner.getKnownTasks()).andReturn( + ImmutableList.of( + new MockTaskRunnerWorkItem("id_1", null), + new MockTaskRunnerWorkItem("id_4", null) + ) + ); - // Duplicate task - should fail - response = overlordResource.taskPost(task_0); - Assert.assertEquals(400, response.getStatus()); + EasyMock.replay(taskRunner, taskMaster, tsqa, req); - // Task payload for task_0 should be present in taskStorage - response = overlordResource.getTaskPayload(taskId_0); - Assert.assertEquals(task_0, ((Map) response.getEntity()).get("payload")); + List responseObjects = (List) overlordResource.getWaitingTasks(req) + .getEntity(); + Assert.assertEquals(1, responseObjects.size()); + Assert.assertEquals("id_2", responseObjects.get(0).toJson().get("id")); + } - // Task not present in taskStorage - should fail - response = overlordResource.getTaskPayload("whatever"); - Assert.assertEquals(404, response.getStatus()); + @Test + public void testSecuredGetCompleteTasks() + { + List tasksIds = ImmutableList.of("id_1", "id_2", "id_3"); + EasyMock.expect(tsqa.getRecentlyFinishedTaskStatuses()).andReturn( + Lists.transform( + tasksIds, + new Function() + { + @Override + public TaskStatus apply(String input) + { + return TaskStatus.success(input); + } + } + ) + ).once(); + + EasyMock.expect(tsqa.getTask(tasksIds.get(0))).andReturn( + Optional.of(getTaskWithIdAndDatasource(tasksIds.get(0), "deny")) + ).once(); + EasyMock.expect(tsqa.getTask(tasksIds.get(1))).andReturn( + Optional.of(getTaskWithIdAndDatasource(tasksIds.get(1), "allow")) + ).once(); + EasyMock.expect(tsqa.getTask(tasksIds.get(2))).andReturn( + Optional.of(getTaskWithIdAndDatasource(tasksIds.get(2), "allow")) + ).once(); + EasyMock.replay(taskRunner, taskMaster, tsqa, req); + + List responseObjects = (List) overlordResource.getCompleteTasks(req) + .getEntity(); + + Assert.assertEquals(2, responseObjects.size()); + Assert.assertEquals(tasksIds.get(1), responseObjects.get(0).toJson().get("id")); + Assert.assertEquals(tasksIds.get(2), responseObjects.get(1).toJson().get("id")); + } - // Task status of the submitted task should be running - response = overlordResource.getTaskStatus(taskId_0); - Assert.assertEquals(taskId_0, ((Map) response.getEntity()).get("task")); - Assert.assertEquals( - TaskStatus.running(taskId_0).getStatusCode(), - ((TaskStatus) ((Map) response.getEntity()).get("status")).getStatusCode() + @Test + public void testSecuredGetRunningTasks() + { + List tasksIds = ImmutableList.of("id_1", "id_2"); + EasyMock.>expect(taskRunner.getRunningTasks()).andReturn( + ImmutableList.of( + new MockTaskRunnerWorkItem(tasksIds.get(0), null), + new MockTaskRunnerWorkItem(tasksIds.get(1), null) + ) ); + EasyMock.expect(tsqa.getTask(tasksIds.get(0))).andReturn( + Optional.of(getTaskWithIdAndDatasource(tasksIds.get(0), "deny")) + ).once(); + EasyMock.expect(tsqa.getTask(tasksIds.get(1))).andReturn( + Optional.of(getTaskWithIdAndDatasource(tasksIds.get(1), "allow")) + ).once(); - // Simulate completion of task_0 - taskCompletionCountDownLatches[Integer.parseInt(taskId_0)].countDown(); - // Wait for taskQueue to handle success status of task_0 - waitForTaskStatus(taskId_0, TaskStatus.Status.SUCCESS); - - // Manually insert task in taskStorage - // Verifies sync from storage - final String taskId_1 = "1"; - NoopTask task_1 = new NoopTask(taskId_1, 0, 0, null, null, null); - taskStorage.insert(task_1, TaskStatus.running(taskId_1)); - // Wait for task runner to run task_1 - runTaskCountDownLatches[Integer.parseInt(taskId_1)].await(); + EasyMock.replay(taskRunner, taskMaster, tsqa, req); - response = overlordResource.getRunningTasks(); - // 1 task that was manually inserted should be in running state - Assert.assertEquals(1, (((List) response.getEntity()).size())); - final OverlordResource.TaskResponseObject taskResponseObject = ((List) response - .getEntity()).get(0); - Assert.assertEquals(taskId_1, taskResponseObject.toJson().get("id")); - Assert.assertEquals(TASK_LOCATION, taskResponseObject.toJson().get("location")); + List responseObjects = (List) overlordResource.getRunningTasks(req) + .getEntity(); - // Simulate completion of task_1 - taskCompletionCountDownLatches[Integer.parseInt(taskId_1)].countDown(); - // Wait for taskQueue to handle success status of task_1 - waitForTaskStatus(taskId_1, TaskStatus.Status.SUCCESS); - - // should return number of tasks which are not in running state - response = overlordResource.getCompleteTasks(); - Assert.assertEquals(2, (((List) response.getEntity()).size())); - taskMaster.stop(); - Assert.assertFalse(taskMaster.isLeading()); - EasyMock.verify(taskLockbox, taskActionClientFactory); + Assert.assertEquals(1, responseObjects.size()); + Assert.assertEquals(tasksIds.get(1), responseObjects.get(0).toJson().get("id")); } - /* Wait until the task with given taskId has the given Task Status - * These method will not timeout until the condition is met so calling method should ensure timeout - * This method also assumes that the task with given taskId is present - * */ - private void waitForTaskStatus(String taskId, TaskStatus.Status status) throws InterruptedException + @Test + public void testSecuredTaskPost() { - while (true) { - Response response = overlordResource.getTaskStatus(taskId); - if (status.equals(((TaskStatus) ((Map) response.getEntity()).get("status")).getStatusCode())) { - break; - } - Thread.sleep(10); - } + EasyMock.replay(taskRunner, taskMaster, tsqa, req); + Task task = NoopTask.create(); + Response response = overlordResource.taskPost(task, req); + Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), response.getStatus()); } @After - public void tearDown() throws Exception + public void tearDown() { - tearDownServerAndCurator(); + EasyMock.verify(taskRunner, taskMaster, tsqa, req); } - public static class MockTaskRunner implements TaskRunner + private Task getTaskWithIdAndDatasource(String id, String datasource) { - private CountDownLatch[] completionLatches; - private CountDownLatch[] runLatches; - private ConcurrentHashMap taskRunnerWorkItems; - private List runningTasks; - private final AtomicBoolean started = new AtomicBoolean(false); - - public MockTaskRunner(CountDownLatch[] runLatches, CountDownLatch[] completionLatches) - { - this.runLatches = runLatches; - this.completionLatches = completionLatches; - this.taskRunnerWorkItems = new ConcurrentHashMap<>(); - this.runningTasks = new ArrayList<>(); - } - - @Override - public List>> restore() + return new AbstractTask(id, datasource, null) { - return ImmutableList.of(); - } - - public void registerListener(TaskRunnerListener listener, Executor executor) - { - // Overlord doesn't call this method - throw new UnsupportedOperationException(); - } - - @Override - public synchronized ListenableFuture run(final Task task) - { - final String taskId = task.getId(); - ListenableFuture future = MoreExecutors.listeningDecorator( - Execs.singleThreaded( - "noop_test_task_exec_%s" - ) - ).submit( - new Callable() - { - @Override - public TaskStatus call() throws Exception - { - // adding of task to list of runningTasks should be done before count down as - // getRunningTasks may not include the task for which latch has been counted down - // Count down to let know that task is actually running - // this is equivalent of getting process holder to run task in ForkingTaskRunner - runningTasks.add(taskId); - runLatches[Integer.parseInt(taskId)].countDown(); - // Wait for completion count down - completionLatches[Integer.parseInt(taskId)].await(); - taskRunnerWorkItems.remove(taskId); - runningTasks.remove(taskId); - return TaskStatus.success(taskId); - } - } - ); - TaskRunnerWorkItem taskRunnerWorkItem = new TaskRunnerWorkItem(taskId, future) + @Override + public String getType() { - @Override - public TaskLocation getLocation() - { - return TASK_LOCATION; - } - }; - taskRunnerWorkItems.put(taskId, taskRunnerWorkItem); - return future; - } - - @Override - public void shutdown(String taskid) {} - - @Override - public synchronized Collection getRunningTasks() - { - List runningTaskList = Lists.transform( - runningTasks, - new Function() - { - @Nullable - @Override - public TaskRunnerWorkItem apply(String input) - { - return taskRunnerWorkItems.get(input); - } - } - ); - return runningTaskList; - } - - @Override - public Collection getPendingTasks() - { - return ImmutableList.of(); - } + return null; + } - @Override - public Collection getKnownTasks() - { - return taskRunnerWorkItems.values(); - } + @Override + public boolean isReady(TaskActionClient taskActionClient) throws Exception + { + return false; + } - @Override - public Optional getScalingStats() - { - return Optional.absent(); - } + @Override + public TaskStatus run(TaskToolbox toolbox) throws Exception + { + return null; + } + }; + } - @Override - public void start() + private static class MockTaskRunnerWorkItem extends TaskRunnerWorkItem + { + public MockTaskRunnerWorkItem( + String taskId, + ListenableFuture result + ) { - started.set(true); + super(taskId, result); } @Override - public void stop() + public TaskLocation getLocation() { - started.set(false); + return null; } } + } diff --git a/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordTest.java b/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordTest.java new file mode 100644 index 000000000000..16df2895f323 --- /dev/null +++ b/indexing-service/src/test/java/io/druid/indexing/overlord/http/OverlordTest.java @@ -0,0 +1,413 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.indexing.overlord.http; + +import com.google.common.base.Function; +import com.google.common.base.Optional; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; +import com.google.common.collect.Lists; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; +import com.metamx.common.Pair; +import com.metamx.common.guava.CloseQuietly; +import com.metamx.emitter.EmittingLogger; +import com.metamx.emitter.service.ServiceEmitter; +import io.druid.concurrent.Execs; +import io.druid.curator.PotentiallyGzippedCompressionProvider; +import io.druid.curator.discovery.NoopServiceAnnouncer; +import io.druid.indexing.common.TaskLocation; +import io.druid.indexing.common.TaskStatus; +import io.druid.indexing.common.actions.TaskActionClientFactory; +import io.druid.indexing.common.config.TaskStorageConfig; +import io.druid.indexing.common.task.NoopTask; +import io.druid.indexing.common.task.Task; +import io.druid.indexing.overlord.HeapMemoryTaskStorage; +import io.druid.indexing.overlord.TaskLockbox; +import io.druid.indexing.overlord.TaskMaster; +import io.druid.indexing.overlord.TaskRunner; +import io.druid.indexing.overlord.TaskRunnerFactory; +import io.druid.indexing.overlord.TaskRunnerListener; +import io.druid.indexing.overlord.TaskRunnerWorkItem; +import io.druid.indexing.overlord.TaskStorage; +import io.druid.indexing.overlord.TaskStorageQueryAdapter; +import io.druid.indexing.overlord.autoscaling.ScalingStats; +import io.druid.indexing.overlord.config.TaskQueueConfig; +import io.druid.server.DruidNode; +import io.druid.server.initialization.IndexerZkConfig; +import io.druid.server.initialization.ZkPathsConfig; +import io.druid.server.metrics.NoopServiceEmitter; +import io.druid.server.security.AuthConfig; +import org.apache.curator.framework.CuratorFramework; +import org.apache.curator.framework.CuratorFrameworkFactory; +import org.apache.curator.retry.RetryOneTime; +import org.apache.curator.test.TestingServer; +import org.apache.curator.test.Timing; +import org.easymock.EasyMock; +import org.joda.time.Period; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; + +import javax.annotation.Nullable; +import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.core.Response; +import java.util.ArrayList; +import java.util.Collection; +import java.util.List; +import java.util.Map; +import java.util.concurrent.Callable; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executor; + +public class OverlordTest +{ + private static final TaskLocation TASK_LOCATION = new TaskLocation("dummy", 1000); + + private TestingServer server; + private Timing timing; + private CuratorFramework curator; + private TaskMaster taskMaster; + private TaskLockbox taskLockbox; + private TaskStorage taskStorage; + private TaskActionClientFactory taskActionClientFactory; + private CountDownLatch announcementLatch; + private DruidNode druidNode; + private OverlordResource overlordResource; + private CountDownLatch[] taskCompletionCountDownLatches; + private CountDownLatch[] runTaskCountDownLatches; + private HttpServletRequest req; + + private void setupServerAndCurator() throws Exception + { + server = new TestingServer(); + timing = new Timing(); + curator = CuratorFrameworkFactory + .builder() + .connectString(server.getConnectString()) + .sessionTimeoutMs(timing.session()) + .connectionTimeoutMs(timing.connection()) + .retryPolicy(new RetryOneTime(1)) + .compressionProvider(new PotentiallyGzippedCompressionProvider(true)) + .build(); + } + + private void tearDownServerAndCurator() + { + CloseQuietly.close(curator); + CloseQuietly.close(server); + } + + @Before + public void setUp() throws Exception + { + req = EasyMock.createStrictMock(HttpServletRequest.class); + taskLockbox = EasyMock.createStrictMock(TaskLockbox.class); + taskLockbox.syncFromStorage(); + EasyMock.expectLastCall().atLeastOnce(); + taskLockbox.add(EasyMock.anyObject()); + EasyMock.expectLastCall().atLeastOnce(); + taskLockbox.remove(EasyMock.anyObject()); + EasyMock.expectLastCall().atLeastOnce(); + + // for second Noop Task directly added to deep storage. + taskLockbox.add(EasyMock.anyObject()); + EasyMock.expectLastCall().atLeastOnce(); + taskLockbox.remove(EasyMock.anyObject()); + EasyMock.expectLastCall().atLeastOnce(); + + taskActionClientFactory = EasyMock.createStrictMock(TaskActionClientFactory.class); + EasyMock.expect(taskActionClientFactory.create(EasyMock.anyObject())) + .andReturn(null).anyTimes(); + EasyMock.replay(taskLockbox, taskActionClientFactory); + + taskStorage = new HeapMemoryTaskStorage(new TaskStorageConfig(null)); + runTaskCountDownLatches = new CountDownLatch[2]; + runTaskCountDownLatches[0] = new CountDownLatch(1); + runTaskCountDownLatches[1] = new CountDownLatch(1); + taskCompletionCountDownLatches = new CountDownLatch[2]; + taskCompletionCountDownLatches[0] = new CountDownLatch(1); + taskCompletionCountDownLatches[1] = new CountDownLatch(1); + announcementLatch = new CountDownLatch(1); + IndexerZkConfig indexerZkConfig = new IndexerZkConfig(new ZkPathsConfig(), null, null, null, null, null); + setupServerAndCurator(); + curator.start(); + curator.blockUntilConnected(); + curator.create().creatingParentsIfNeeded().forPath(indexerZkConfig.getLeaderLatchPath()); + druidNode = new DruidNode("hey", "what", 1234); + ServiceEmitter serviceEmitter = new NoopServiceEmitter(); + taskMaster = new TaskMaster( + new TaskQueueConfig(null, new Period(1), null, new Period(10)), + taskLockbox, + taskStorage, + taskActionClientFactory, + druidNode, + indexerZkConfig, + new TaskRunnerFactory() + { + @Override + public MockTaskRunner build() + { + return new MockTaskRunner(runTaskCountDownLatches, taskCompletionCountDownLatches); + } + }, + curator, + new NoopServiceAnnouncer() + { + @Override + public void announce(DruidNode node) + { + announcementLatch.countDown(); + } + }, + serviceEmitter + ); + EmittingLogger.registerEmitter(serviceEmitter); + } + + @Test(timeout = 2000L) + public void testOverlordRun() throws Exception + { + // basic task master lifecycle test + taskMaster.start(); + announcementLatch.await(); + while (!taskMaster.isLeading()) { + // I believe the control will never reach here and thread will never sleep but just to be on safe side + Thread.sleep(10); + } + Assert.assertEquals(taskMaster.getLeader(), druidNode.getHostAndPort()); + // Test Overlord resource stuff + overlordResource = new OverlordResource( + taskMaster, + new TaskStorageQueryAdapter(taskStorage), + null, + null, + null, + new AuthConfig() + ); + Response response = overlordResource.getLeader(); + Assert.assertEquals(druidNode.getHostAndPort(), response.getEntity()); + + final String taskId_0 = "0"; + NoopTask task_0 = new NoopTask(taskId_0, 0, 0, null, null, null); + response = overlordResource.taskPost(task_0, req); + Assert.assertEquals(200, response.getStatus()); + Assert.assertEquals(ImmutableMap.of("task", taskId_0), response.getEntity()); + + // Duplicate task - should fail + response = overlordResource.taskPost(task_0, req); + Assert.assertEquals(400, response.getStatus()); + + // Task payload for task_0 should be present in taskStorage + response = overlordResource.getTaskPayload(taskId_0); + Assert.assertEquals(task_0, ((Map) response.getEntity()).get("payload")); + + // Task not present in taskStorage - should fail + response = overlordResource.getTaskPayload("whatever"); + Assert.assertEquals(404, response.getStatus()); + + // Task status of the submitted task should be running + response = overlordResource.getTaskStatus(taskId_0); + Assert.assertEquals(taskId_0, ((Map) response.getEntity()).get("task")); + Assert.assertEquals( + TaskStatus.running(taskId_0).getStatusCode(), + ((TaskStatus) ((Map) response.getEntity()).get("status")).getStatusCode() + ); + + // Simulate completion of task_0 + taskCompletionCountDownLatches[Integer.parseInt(taskId_0)].countDown(); + // Wait for taskQueue to handle success status of task_0 + waitForTaskStatus(taskId_0, TaskStatus.Status.SUCCESS); + + // Manually insert task in taskStorage + // Verifies sync from storage + final String taskId_1 = "1"; + NoopTask task_1 = new NoopTask(taskId_1, 0, 0, null, null, null); + taskStorage.insert(task_1, TaskStatus.running(taskId_1)); + // Wait for task runner to run task_1 + runTaskCountDownLatches[Integer.parseInt(taskId_1)].await(); + + response = overlordResource.getRunningTasks(req); + // 1 task that was manually inserted should be in running state + Assert.assertEquals(1, (((List) response.getEntity()).size())); + final OverlordResource.TaskResponseObject taskResponseObject = ((List) response + .getEntity()).get(0); + Assert.assertEquals(taskId_1, taskResponseObject.toJson().get("id")); + Assert.assertEquals(TASK_LOCATION, taskResponseObject.toJson().get("location")); + + // Simulate completion of task_1 + taskCompletionCountDownLatches[Integer.parseInt(taskId_1)].countDown(); + // Wait for taskQueue to handle success status of task_1 + waitForTaskStatus(taskId_1, TaskStatus.Status.SUCCESS); + + // should return number of tasks which are not in running state + response = overlordResource.getCompleteTasks(req); + Assert.assertEquals(2, (((List) response.getEntity()).size())); + taskMaster.stop(); + Assert.assertFalse(taskMaster.isLeading()); + EasyMock.verify(taskLockbox, taskActionClientFactory); + } + + /* Wait until the task with given taskId has the given Task Status + * These method will not timeout until the condition is met so calling method should ensure timeout + * This method also assumes that the task with given taskId is present + * */ + private void waitForTaskStatus(String taskId, TaskStatus.Status status) throws InterruptedException + { + while (true) { + Response response = overlordResource.getTaskStatus(taskId); + if (status.equals(((TaskStatus) ((Map) response.getEntity()).get("status")).getStatusCode())) { + break; + } + Thread.sleep(10); + } + } + + @After + public void tearDown() throws Exception + { + tearDownServerAndCurator(); + } + + public static class MockTaskRunner implements TaskRunner + { + private CountDownLatch[] completionLatches; + private CountDownLatch[] runLatches; + private ConcurrentHashMap taskRunnerWorkItems; + private List runningTasks; + + public MockTaskRunner(CountDownLatch[] runLatches, CountDownLatch[] completionLatches) + { + this.runLatches = runLatches; + this.completionLatches = completionLatches; + this.taskRunnerWorkItems = new ConcurrentHashMap<>(); + this.runningTasks = new ArrayList<>(); + } + + @Override + public List>> restore() + { + return ImmutableList.of(); + } + + @Override + public void registerListener(TaskRunnerListener listener, Executor executor) + { + // Overlord doesn't call this method + throw new UnsupportedOperationException(); + } + + @Override + public void stop() + { + // Do nothing + } + + @Override + public synchronized ListenableFuture run(final Task task) + { + final String taskId = task.getId(); + ListenableFuture future = MoreExecutors.listeningDecorator( + Execs.singleThreaded( + "noop_test_task_exec_%s" + ) + ).submit( + new Callable() + { + @Override + public TaskStatus call() throws Exception + { + // adding of task to list of runningTasks should be done before count down as + // getRunningTasks may not include the task for which latch has been counted down + // Count down to let know that task is actually running + // this is equivalent of getting process holder to run task in ForkingTaskRunner + runningTasks.add(taskId); + if (runLatches != null) { + runLatches[Integer.parseInt(taskId)].countDown(); + } + // Wait for completion count down + if (completionLatches != null) { + completionLatches[Integer.parseInt(taskId)].await(); + } + taskRunnerWorkItems.remove(taskId); + runningTasks.remove(taskId); + return TaskStatus.success(taskId); + } + } + ); + TaskRunnerWorkItem taskRunnerWorkItem = new TaskRunnerWorkItem(taskId, future) + { + @Override + public TaskLocation getLocation() + { + return TASK_LOCATION; + } + }; + taskRunnerWorkItems.put(taskId, taskRunnerWorkItem); + return future; + } + + @Override + public void shutdown(String taskid) {} + + @Override + public synchronized Collection getRunningTasks() + { + return Lists.transform( + runningTasks, + new Function() + { + @Nullable + @Override + public TaskRunnerWorkItem apply(String input) + { + return taskRunnerWorkItems.get(input); + } + } + ); + } + + @Override + public Collection getPendingTasks() + { + return ImmutableList.of(); + } + + @Override + public Collection getKnownTasks() + { + return taskRunnerWorkItems.values(); + } + + @Override + public Optional getScalingStats() + { + return Optional.absent(); + } + + @Override + public void start() + { + //Do nothing + } + } +} diff --git a/indexing-service/src/test/java/io/druid/indexing/overlord/http/security/SecurityResourceFilterTest.java b/indexing-service/src/test/java/io/druid/indexing/overlord/http/security/SecurityResourceFilterTest.java new file mode 100644 index 000000000000..a0aa98458cf1 --- /dev/null +++ b/indexing-service/src/test/java/io/druid/indexing/overlord/http/security/SecurityResourceFilterTest.java @@ -0,0 +1,146 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.indexing.overlord.http.security; + +import com.google.common.base.Optional; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.inject.Injector; +import com.sun.jersey.spi.container.ResourceFilter; +import io.druid.indexing.common.task.NoopTask; +import io.druid.indexing.common.task.Task; +import io.druid.indexing.overlord.TaskStorageQueryAdapter; +import io.druid.indexing.overlord.http.OverlordResource; +import io.druid.indexing.worker.http.WorkerResource; +import io.druid.server.http.security.AbstractResourceFilter; +import io.druid.server.http.security.ResourceFilterTestHelper; +import org.easymock.EasyMock; +import org.junit.After; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.Response; +import java.util.Collection; + +@RunWith(Parameterized.class) +public class SecurityResourceFilterTest extends ResourceFilterTestHelper +{ + + @Parameterized.Parameters + public static Collection data() + { + return ImmutableList.copyOf( + Iterables.concat( + getRequestPaths(OverlordResource.class, ImmutableList.>of(TaskStorageQueryAdapter.class)), + getRequestPaths(WorkerResource.class) + ) + ); + } + + private final String requestPath; + private final String requestMethod; + private final ResourceFilter resourceFilter; + private final Injector injector; + private final Task noopTask = new NoopTask(null, 0, 0, null, null, null); + + private static boolean mockedOnce; + private TaskStorageQueryAdapter tsqa; + + public SecurityResourceFilterTest( + String requestPath, + String requestMethod, + ResourceFilter resourceFilter, + Injector injector + ) + { + this.requestPath = requestPath; + this.requestMethod = requestMethod; + this.resourceFilter = resourceFilter; + this.injector = injector; + } + + @Before + public void setUp() throws Exception + { + if (resourceFilter instanceof TaskResourceFilter && !mockedOnce) { + // Since we are creating the mocked tsqa object only once and getting that object from Guice here therefore + // if the mockedOnce check is not done then we will call EasyMock.expect and EasyMock.replay on the mocked object + // multiple times and it will throw exceptions + tsqa = injector.getInstance(TaskStorageQueryAdapter.class); + EasyMock.expect(tsqa.getTask(EasyMock.anyString())).andReturn(Optional.of(noopTask)).anyTimes(); + EasyMock.replay(tsqa); + mockedOnce = true; + } + setUp(resourceFilter); + } + + @Test + public void testDatasourcesResourcesFilteringAccess() + { + setUpMockExpectations(requestPath, true, requestMethod); + EasyMock.expect(request.getEntity(Task.class)).andReturn(noopTask).anyTimes(); + // As request object is a strict mock the ordering of expected calls matters + // therefore adding the expectation below again as getEntity is called before getMethod + EasyMock.expect(request.getMethod()).andReturn(requestMethod).anyTimes(); + EasyMock.replay(req, request, authorizationInfo); + resourceFilter.getRequestFilter().filter(request); + Assert.assertTrue(((AbstractResourceFilter) resourceFilter.getRequestFilter()).isApplicable(requestPath)); + } + + @Test(expected = WebApplicationException.class) + public void testDatasourcesResourcesFilteringNoAccess() + { + setUpMockExpectations(requestPath, false, requestMethod); + EasyMock.expect(request.getEntity(Task.class)).andReturn(noopTask).anyTimes(); + EasyMock.expect(request.getMethod()).andReturn(requestMethod).anyTimes(); + EasyMock.replay(req, request, authorizationInfo); + Assert.assertTrue(((AbstractResourceFilter) resourceFilter.getRequestFilter()).isApplicable(requestPath)); + try { + resourceFilter.getRequestFilter().filter(request); + } + catch (WebApplicationException e) { + Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), e.getResponse().getStatus()); + throw e; + } + } + + @Test + public void testDatasourcesResourcesFilteringBadPath() + { + final String badRequestPath = requestPath.replaceAll("\\w+", "droid"); + EasyMock.expect(request.getPath()).andReturn(badRequestPath).anyTimes(); + EasyMock.replay(req, request, authorizationInfo); + Assert.assertFalse(((AbstractResourceFilter) resourceFilter.getRequestFilter()).isApplicable(badRequestPath)); + } + + @After + public void tearDown() + { + EasyMock.verify(req, request, authorizationInfo); + if (tsqa != null) { + EasyMock.verify(tsqa); + } + } + +} diff --git a/server/src/main/java/io/druid/guice/security/DruidAuthModule.java b/server/src/main/java/io/druid/guice/security/DruidAuthModule.java new file mode 100644 index 000000000000..e89c8ca23679 --- /dev/null +++ b/server/src/main/java/io/druid/guice/security/DruidAuthModule.java @@ -0,0 +1,44 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.guice.security; + +import com.fasterxml.jackson.databind.Module; +import com.google.inject.Binder; +import io.druid.guice.JsonConfigProvider; +import io.druid.initialization.DruidModule; +import io.druid.server.security.AuthConfig; + +import java.util.Collections; +import java.util.List; + +public class DruidAuthModule implements DruidModule +{ + @Override + public List getJacksonModules() + { + return Collections.emptyList(); + } + + @Override + public void configure(Binder binder) + { + JsonConfigProvider.bind(binder, "druid.auth", AuthConfig.class); + } +} diff --git a/server/src/main/java/io/druid/initialization/Initialization.java b/server/src/main/java/io/druid/initialization/Initialization.java index 0bfc8c0c7bdd..0752036575e2 100644 --- a/server/src/main/java/io/druid/initialization/Initialization.java +++ b/server/src/main/java/io/druid/initialization/Initialization.java @@ -57,6 +57,7 @@ import io.druid.guice.annotations.Json; import io.druid.guice.annotations.Smile; import io.druid.guice.http.HttpClientModule; +import io.druid.guice.security.DruidAuthModule; import io.druid.metadata.storage.derby.DerbyMetadataStorageDruidModule; import io.druid.server.initialization.EmitterModule; import io.druid.server.initialization.jetty.JettyServerModule; @@ -318,7 +319,9 @@ public static Injector makeInjectorWithModules(final Injector baseInjector, Iter { final ModuleList defaultModules = new ModuleList(baseInjector); defaultModules.addModules( + // New modules should be added after Log4jShutterDownerModule new Log4jShutterDownerModule(), + new DruidAuthModule(), new LifecycleModule(), EmitterModule.class, HttpClientModule.global(), diff --git a/server/src/main/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseFactory.java b/server/src/main/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseFactory.java index 415c5c9101bf..ff6cb39e8e23 100644 --- a/server/src/main/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseFactory.java +++ b/server/src/main/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseFactory.java @@ -143,7 +143,7 @@ public class EventReceiverFirehose implements ChatHandler, Firehose, EventReceiv public EventReceiverFirehose(MapInputRowParser parser) { - this.buffer = new ArrayBlockingQueue(bufferSize); + this.buffer = new ArrayBlockingQueue<>(bufferSize); this.parser = parser; } diff --git a/server/src/main/java/io/druid/server/ClientInfoResource.java b/server/src/main/java/io/druid/server/ClientInfoResource.java index e3a653fe3716..9b800b891d79 100644 --- a/server/src/main/java/io/druid/server/ClientInfoResource.java +++ b/server/src/main/java/io/druid/server/ClientInfoResource.java @@ -19,13 +19,17 @@ package io.druid.server; +import com.google.common.base.Predicate; +import com.google.common.collect.Collections2; import com.google.common.collect.ImmutableMap; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Maps; import com.google.common.collect.Sets; import com.google.inject.Inject; +import com.metamx.common.Pair; import com.metamx.common.logger.Logger; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.client.DruidDataSource; import io.druid.client.DruidServer; import io.druid.client.FilteredServerInventoryView; @@ -34,6 +38,13 @@ import io.druid.client.selector.ServerSelector; import io.druid.query.TableDataSource; import io.druid.query.metadata.SegmentMetadataQueryConfig; +import io.druid.server.http.security.DatasourceResourceFilter; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; import io.druid.timeline.DataSegment; import io.druid.timeline.TimelineLookup; import io.druid.timeline.TimelineObjectHolder; @@ -41,14 +52,17 @@ import org.joda.time.DateTime; import org.joda.time.Interval; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.GET; import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import java.util.Collections; import java.util.Comparator; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; @@ -67,18 +81,21 @@ public class ClientInfoResource private FilteredServerInventoryView serverInventoryView; private TimelineServerView timelineServerView; private SegmentMetadataQueryConfig segmentMetadataQueryConfig; + private final AuthConfig authConfig; @Inject public ClientInfoResource( FilteredServerInventoryView serverInventoryView, TimelineServerView timelineServerView, - SegmentMetadataQueryConfig segmentMetadataQueryConfig + SegmentMetadataQueryConfig segmentMetadataQueryConfig, + AuthConfig authConfig ) { this.serverInventoryView = serverInventoryView; this.timelineServerView = timelineServerView; this.segmentMetadataQueryConfig = (segmentMetadataQueryConfig == null) ? new SegmentMetadataQueryConfig() : segmentMetadataQueryConfig; + this.authConfig = authConfig; } private Map> getSegmentsForDatasources() @@ -98,14 +115,41 @@ private Map> getSegmentsForDatasources() @GET @Produces(MediaType.APPLICATION_JSON) - public Iterable getDataSources() + public Iterable getDataSources(@Context final HttpServletRequest request) { - return getSegmentsForDatasources().keySet(); + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final Map, Access> resourceAccessMap = new HashMap<>(); + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) request.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + return Collections2.filter( + getSegmentsForDatasources().keySet(), + new Predicate() + { + @Override + public boolean apply(String input) + { + Resource resource = new Resource(input, ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + ); + } else { + return getSegmentsForDatasources().keySet(); + } } @GET @Path("/{dataSourceName}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Map getDatasource( @PathParam("dataSourceName") String dataSourceName, @QueryParam("interval") String interval, @@ -193,6 +237,7 @@ public int compare(Interval o1, Interval o2) @GET @Path("/{dataSourceName}/dimensions") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Iterable getDatasourceDimensions( @PathParam("dataSourceName") String dataSourceName, @QueryParam("interval") String interval @@ -225,6 +270,7 @@ public Iterable getDatasourceDimensions( @GET @Path("/{dataSourceName}/metrics") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Iterable getDatasourceMetrics( @PathParam("dataSourceName") String dataSourceName, @QueryParam("interval") String interval diff --git a/server/src/main/java/io/druid/server/QueryManager.java b/server/src/main/java/io/druid/server/QueryManager.java index 3e2b3b510791..49252c8c0ad6 100644 --- a/server/src/main/java/io/druid/server/QueryManager.java +++ b/server/src/main/java/io/druid/server/QueryManager.java @@ -27,20 +27,28 @@ import io.druid.query.Query; import io.druid.query.QueryWatcher; +import java.util.List; import java.util.Set; public class QueryManager implements QueryWatcher { - final SetMultimap queries; + + private final SetMultimap queries; + private final SetMultimap queryDatasources; public QueryManager() { this.queries = Multimaps.synchronizedSetMultimap( HashMultimap.create() ); + this.queryDatasources = Multimaps.synchronizedSetMultimap( + HashMultimap.create() + ); } - public boolean cancelQuery(String id) { + public boolean cancelQuery(String id) + { + queryDatasources.removeAll(id); Set futures = queries.removeAll(id); boolean success = true; for (ListenableFuture future : futures) { @@ -52,7 +60,9 @@ public boolean cancelQuery(String id) { public void registerQuery(Query query, final ListenableFuture future) { final String id = query.getId(); + final List datasources = query.getDataSource().getNames(); queries.put(id, future); + queryDatasources.putAll(id, datasources); future.addListener( new Runnable() { @@ -60,9 +70,17 @@ public void registerQuery(Query query, final ListenableFuture future) public void run() { queries.remove(id, future); + for (String datasource : datasources) { + queryDatasources.remove(id, datasource); + } } }, MoreExecutors.sameThreadExecutor() ); } + + public Set getQueryDatasources(final String queryId) + { + return queryDatasources.get(queryId); + } } diff --git a/server/src/main/java/io/druid/server/QueryResource.java b/server/src/main/java/io/druid/server/QueryResource.java index 0b9ac2b0fa50..63e37e338da3 100644 --- a/server/src/main/java/io/druid/server/QueryResource.java +++ b/server/src/main/java/io/druid/server/QueryResource.java @@ -22,11 +22,13 @@ import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.databind.ObjectWriter; import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes; +import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; import com.google.common.collect.MapMaker; import com.google.common.io.CountingOutputStream; import com.google.inject.Inject; +import com.metamx.common.ISE; import com.metamx.common.guava.Sequence; import com.metamx.common.guava.Sequences; import com.metamx.common.guava.Yielder; @@ -42,6 +44,12 @@ import io.druid.query.QuerySegmentWalker; import io.druid.server.initialization.ServerConfig; import io.druid.server.log.RequestLogger; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; import org.joda.time.DateTime; import javax.servlet.http.HttpServletRequest; @@ -61,6 +69,7 @@ import java.io.InputStream; import java.io.OutputStream; import java.util.Map; +import java.util.Set; import java.util.UUID; /** @@ -81,6 +90,7 @@ public class QueryResource private final ServiceEmitter emitter; private final RequestLogger requestLogger; private final QueryManager queryManager; + private final AuthConfig authConfig; @Inject public QueryResource( @@ -90,7 +100,8 @@ public QueryResource( QuerySegmentWalker texasRanger, ServiceEmitter emitter, RequestLogger requestLogger, - QueryManager queryManager + QueryManager queryManager, + AuthConfig authConfig ) { this.config = config; @@ -100,16 +111,39 @@ public QueryResource( this.emitter = emitter; this.requestLogger = requestLogger; this.queryManager = queryManager; + this.authConfig = authConfig; } @DELETE @Path("{id}") @Produces(MediaType.APPLICATION_JSON) - public Response getServer(@PathParam("id") String queryId) + public Response getServer(@PathParam("id") String queryId, @Context final HttpServletRequest req) { if (log.isDebugEnabled()) { log.debug("Received cancel request for query [%s]", queryId); } + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + Set datasources = queryManager.getQueryDatasources(queryId); + if (datasources == null) { + log.warn("QueryId [%s] not registered with QueryManager, cannot cancel", queryId); + } else { + for (String dataSource : datasources) { + Access authResult = authorizationInfo.isAuthorized( + new Resource(dataSource, ResourceType.DATASOURCE), + Action.WRITE + ); + if (!authResult.isAllowed()) { + return Response.status(Response.Status.FORBIDDEN).header("Access-Check-Result", authResult).build(); + } + } + } + } queryManager.cancelQuery(queryId); return Response.status(Response.Status.ACCEPTED).build(); } @@ -120,7 +154,7 @@ public Response getServer(@PathParam("id") String queryId) public Response doPost( InputStream in, @QueryParam("pretty") String pretty, - @Context final HttpServletRequest req // used only to get request content-type and remote address + @Context final HttpServletRequest req // used to get request content-type, remote address and AuthorizationInfo ) throws IOException { final long start = System.currentTimeMillis(); @@ -160,6 +194,24 @@ public Response doPost( log.debug("Got query [%s]", query); } + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + AuthorizationInfo authorizationInfo = (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + if (authorizationInfo != null) { + for (String dataSource : query.getDataSource().getNames()) { + Access authResult = authorizationInfo.isAuthorized( + new Resource(dataSource, ResourceType.DATASOURCE), + Action.READ + ); + if (!authResult.isAllowed()) { + return Response.status(Response.Status.FORBIDDEN).header("Access-Check-Result", authResult).build(); + } + } + } else { + throw new ISE("WTF?! Security is enabled but no authorization info found in the request"); + } + } + final Map responseContext = new MapMaker().makeMap(); final Sequence res = query.run(texasRanger, responseContext); final Sequence results; diff --git a/server/src/main/java/io/druid/server/StatusResource.java b/server/src/main/java/io/druid/server/StatusResource.java index f5012daafec7..edbd65b4fdb7 100644 --- a/server/src/main/java/io/druid/server/StatusResource.java +++ b/server/src/main/java/io/druid/server/StatusResource.java @@ -21,8 +21,10 @@ import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonProperty; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.initialization.DruidModule; import io.druid.initialization.Initialization; +import io.druid.server.http.security.StateResourceFilter; import javax.ws.rs.GET; import javax.ws.rs.Path; @@ -35,6 +37,7 @@ /** */ @Path("/status") +@ResourceFilters(StateResourceFilter.class) public class StatusResource { @GET diff --git a/server/src/main/java/io/druid/server/http/BrokerResource.java b/server/src/main/java/io/druid/server/http/BrokerResource.java index 7e9701a39b7c..7adc968e402b 100644 --- a/server/src/main/java/io/druid/server/http/BrokerResource.java +++ b/server/src/main/java/io/druid/server/http/BrokerResource.java @@ -21,7 +21,9 @@ import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.client.BrokerServerView; +import io.druid.server.http.security.StateResourceFilter; import javax.ws.rs.GET; import javax.ws.rs.Path; @@ -30,6 +32,7 @@ import javax.ws.rs.core.Response; @Path("/druid/broker/v1") +@ResourceFilters(StateResourceFilter.class) public class BrokerResource { private final BrokerServerView brokerServerView; diff --git a/server/src/main/java/io/druid/server/http/CoordinatorDynamicConfigsResource.java b/server/src/main/java/io/druid/server/http/CoordinatorDynamicConfigsResource.java index 0d955b915bf2..c4e572a15a5c 100644 --- a/server/src/main/java/io/druid/server/http/CoordinatorDynamicConfigsResource.java +++ b/server/src/main/java/io/druid/server/http/CoordinatorDynamicConfigsResource.java @@ -19,15 +19,15 @@ package io.druid.server.http; +import com.google.common.collect.ImmutableMap; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.audit.AuditInfo; import io.druid.audit.AuditManager; import io.druid.common.config.JacksonConfigManager; import io.druid.server.coordinator.CoordinatorDynamicConfig; - +import io.druid.server.http.security.ConfigResourceFilter; import org.joda.time.Interval; -import com.google.common.collect.ImmutableMap; - import javax.inject.Inject; import javax.servlet.http.HttpServletRequest; import javax.ws.rs.Consumes; @@ -45,6 +45,7 @@ /** */ @Path("/druid/coordinator/v1/config") +@ResourceFilters(ConfigResourceFilter.class) public class CoordinatorDynamicConfigsResource { private final JacksonConfigManager manager; diff --git a/server/src/main/java/io/druid/server/http/CoordinatorResource.java b/server/src/main/java/io/druid/server/http/CoordinatorResource.java index ac13e9ec22f9..20f6805dae12 100644 --- a/server/src/main/java/io/druid/server/http/CoordinatorResource.java +++ b/server/src/main/java/io/druid/server/http/CoordinatorResource.java @@ -24,8 +24,10 @@ import com.google.common.collect.ImmutableMap; import com.google.common.collect.Maps; import com.google.inject.Inject; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.server.coordinator.DruidCoordinator; import io.druid.server.coordinator.LoadQueuePeon; +import io.druid.server.http.security.StateResourceFilter; import io.druid.timeline.DataSegment; import javax.ws.rs.GET; @@ -38,6 +40,7 @@ /** */ @Path("/druid/coordinator/v1") +@ResourceFilters(StateResourceFilter.class) public class CoordinatorResource { private final DruidCoordinator coordinator; diff --git a/server/src/main/java/io/druid/server/http/DatasourcesResource.java b/server/src/main/java/io/druid/server/http/DatasourcesResource.java index 8aa035f96694..274e03492c5a 100644 --- a/server/src/main/java/io/druid/server/http/DatasourcesResource.java +++ b/server/src/main/java/io/druid/server/http/DatasourcesResource.java @@ -31,6 +31,7 @@ import com.metamx.common.guava.Comparators; import com.metamx.common.guava.FunctionalIterable; import com.metamx.common.logger.Logger; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.client.CoordinatorServerView; import io.druid.client.DruidDataSource; import io.druid.client.DruidServer; @@ -39,6 +40,9 @@ import io.druid.client.indexing.IndexingServiceClient; import io.druid.metadata.MetadataSegmentManager; import io.druid.query.TableDataSource; +import io.druid.server.http.security.DatasourceResourceFilter; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; import io.druid.timeline.DataSegment; import io.druid.timeline.TimelineLookup; import io.druid.timeline.TimelineObjectHolder; @@ -47,6 +51,7 @@ import org.joda.time.Interval; import javax.annotation.Nullable; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.Consumes; import javax.ws.rs.DELETE; import javax.ws.rs.GET; @@ -55,6 +60,7 @@ import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import java.util.Comparator; @@ -73,28 +79,38 @@ public class DatasourcesResource private final CoordinatorServerView serverInventoryView; private final MetadataSegmentManager databaseSegmentManager; private final IndexingServiceClient indexingServiceClient; + private final AuthConfig authConfig; @Inject public DatasourcesResource( CoordinatorServerView serverInventoryView, MetadataSegmentManager databaseSegmentManager, - @Nullable IndexingServiceClient indexingServiceClient + @Nullable IndexingServiceClient indexingServiceClient, + AuthConfig authConfig ) { this.serverInventoryView = serverInventoryView; this.databaseSegmentManager = databaseSegmentManager; this.indexingServiceClient = indexingServiceClient; + this.authConfig = authConfig; } @GET @Produces(MediaType.APPLICATION_JSON) public Response getQueryableDataSources( @QueryParam("full") String full, - @QueryParam("simple") String simple + @QueryParam("simple") String simple, + @Context final HttpServletRequest req ) { Response.ResponseBuilder builder = Response.ok(); - final Set datasources = InventoryViewUtils.getDataSources(serverInventoryView); + final Set datasources = authConfig.isEnabled() ? + InventoryViewUtils.getSecuredDataSources( + serverInventoryView, + (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN) + ) : + InventoryViewUtils.getDataSources(serverInventoryView); + if (full != null) { return builder.entity(datasources).build(); } else if (simple != null) { @@ -135,12 +151,14 @@ public String apply(DruidDataSource dataSource) @GET @Path("/{dataSourceName}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getTheDataSource( @PathParam("dataSourceName") final String dataSourceName, @QueryParam("full") final String full ) { DruidDataSource dataSource = getDataSource(dataSourceName); + if (dataSource == null) { return Response.noContent().build(); } @@ -155,6 +173,7 @@ public Response getTheDataSource( @POST @Path("/{dataSourceName}") @Consumes(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response enableDataSource( @PathParam("dataSourceName") final String dataSourceName ) @@ -175,6 +194,7 @@ public Response enableDataSource( @DELETE @Deprecated @Path("/{dataSourceName}") + @ResourceFilters(DatasourceResourceFilter.class) @Produces(MediaType.APPLICATION_JSON) public Response deleteDataSource( @PathParam("dataSourceName") final String dataSourceName, @@ -253,6 +273,7 @@ public Response deleteDataSourceSpecificInterval( @GET @Path("/{dataSourceName}/intervals") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getSegmentDataSourceIntervals( @PathParam("dataSourceName") String dataSourceName, @QueryParam("simple") String simple, @@ -313,6 +334,7 @@ public Response getSegmentDataSourceIntervals( @GET @Path("/{dataSourceName}/intervals/{interval}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getSegmentDataSourceSpecificInterval( @PathParam("dataSourceName") String dataSourceName, @PathParam("interval") String interval, @@ -380,6 +402,7 @@ public Response getSegmentDataSourceSpecificInterval( @GET @Path("/{dataSourceName}/segments") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getSegmentDataSourceSegments( @PathParam("dataSourceName") String dataSourceName, @QueryParam("full") String full @@ -413,6 +436,7 @@ public Object apply(DataSegment segment) @GET @Path("/{dataSourceName}/segments/{segmentId}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getSegmentDataSourceSegment( @PathParam("dataSourceName") String dataSourceName, @PathParam("segmentId") String segmentId @@ -436,6 +460,7 @@ public Response getSegmentDataSourceSegment( @DELETE @Path("/{dataSourceName}/segments/{segmentId}") + @ResourceFilters(DatasourceResourceFilter.class) public Response deleteDatasourceSegment( @PathParam("dataSourceName") String dataSourceName, @PathParam("segmentId") String segmentId @@ -451,6 +476,7 @@ public Response deleteDatasourceSegment( @POST @Path("/{dataSourceName}/segments/{segmentId}") @Consumes(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response enableDatasourceSegment( @PathParam("dataSourceName") String dataSourceName, @PathParam("segmentId") String segmentId @@ -466,6 +492,7 @@ public Response enableDatasourceSegment( @GET @Path("/{dataSourceName}/tiers") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getSegmentDataSourceTiers( @PathParam("dataSourceName") String dataSourceName ) @@ -624,6 +651,7 @@ private Map> getSimpleDatasource(String dataSourceNa @GET @Path("/{dataSourceName}/intervals/{interval}/serverview") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getSegmentDataSourceSpecificInterval( @PathParam("dataSourceName") String dataSourceName, @PathParam("interval") String interval, diff --git a/server/src/main/java/io/druid/server/http/HistoricalResource.java b/server/src/main/java/io/druid/server/http/HistoricalResource.java index 4680cf29c6c5..bc77ce0fc056 100644 --- a/server/src/main/java/io/druid/server/http/HistoricalResource.java +++ b/server/src/main/java/io/druid/server/http/HistoricalResource.java @@ -20,7 +20,9 @@ package io.druid.server.http; import com.google.common.collect.ImmutableMap; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.server.coordination.ZkCoordinator; +import io.druid.server.http.security.StateResourceFilter; import javax.inject.Inject; import javax.ws.rs.GET; @@ -30,6 +32,7 @@ import javax.ws.rs.core.Response; @Path("/druid/historical/v1") +@ResourceFilters(StateResourceFilter.class) public class HistoricalResource { private final ZkCoordinator coordinator; diff --git a/server/src/main/java/io/druid/server/http/IntervalsResource.java b/server/src/main/java/io/druid/server/http/IntervalsResource.java index 103330fc50ab..29c8a1f4f86f 100644 --- a/server/src/main/java/io/druid/server/http/IntervalsResource.java +++ b/server/src/main/java/io/druid/server/http/IntervalsResource.java @@ -25,14 +25,18 @@ import com.metamx.common.guava.Comparators; import io.druid.client.DruidDataSource; import io.druid.client.InventoryView; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; import io.druid.timeline.DataSegment; import org.joda.time.Interval; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.GET; import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import java.util.Comparator; @@ -45,35 +49,43 @@ public class IntervalsResource { private final InventoryView serverInventoryView; + private final AuthConfig authConfig; @Inject public IntervalsResource( - InventoryView serverInventoryView + InventoryView serverInventoryView, + AuthConfig authConfig ) { this.serverInventoryView = serverInventoryView; + this.authConfig = authConfig; } @GET @Produces(MediaType.APPLICATION_JSON) - public Response getIntervals() + public Response getIntervals(@Context final HttpServletRequest req) { - final Comparator comparator = Comparators.inverse(Comparators.intervalsByStartThenEnd()); - final Set datasources = InventoryViewUtils.getDataSources(serverInventoryView); - - final Map>> retVal = Maps.newTreeMap(comparator); - for (DruidDataSource dataSource : datasources) { - for (DataSegment dataSegment : dataSource.getSegments()) { - Map> interval = retVal.get(dataSegment.getInterval()); - if (interval == null) { - Map> tmp = Maps.newHashMap(); - retVal.put(dataSegment.getInterval(), tmp); - } - setProperties(retVal, dataSource, dataSegment); + final Comparator comparator = Comparators.inverse(Comparators.intervalsByStartThenEnd()); + final Set datasources = authConfig.isEnabled() ? + InventoryViewUtils.getSecuredDataSources( + serverInventoryView, + (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN) + ) : + InventoryViewUtils.getDataSources(serverInventoryView); + + final Map>> retVal = Maps.newTreeMap(comparator); + for (DruidDataSource dataSource : datasources) { + for (DataSegment dataSegment : dataSource.getSegments()) { + Map> interval = retVal.get(dataSegment.getInterval()); + if (interval == null) { + Map> tmp = Maps.newHashMap(); + retVal.put(dataSegment.getInterval(), tmp); } + setProperties(retVal, dataSource, dataSegment); } + } - return Response.ok(retVal).build(); + return Response.ok(retVal).build(); } @GET @@ -82,13 +94,20 @@ public Response getIntervals() public Response getSpecificIntervals( @PathParam("interval") String interval, @QueryParam("simple") String simple, - @QueryParam("full") String full + @QueryParam("full") String full, + @Context final HttpServletRequest req ) { final Interval theInterval = new Interval(interval.replace("_", "/")); - final Set datasources = InventoryViewUtils.getDataSources(serverInventoryView); + final Set datasources = authConfig.isEnabled() ? + InventoryViewUtils.getSecuredDataSources( + serverInventoryView, + (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN) + ) : + InventoryViewUtils.getDataSources(serverInventoryView); final Comparator comparator = Comparators.inverse(Comparators.intervalsByStartThenEnd()); + if (full != null) { final Map>> retVal = Maps.newTreeMap(comparator); for (DruidDataSource dataSource : datasources) { diff --git a/server/src/main/java/io/druid/server/http/InventoryViewUtils.java b/server/src/main/java/io/druid/server/http/InventoryViewUtils.java index df39f5e70c13..62cb5109eadb 100644 --- a/server/src/main/java/io/druid/server/http/InventoryViewUtils.java +++ b/server/src/main/java/io/druid/server/http/InventoryViewUtils.java @@ -20,18 +20,30 @@ package io.druid.server.http; import com.google.common.base.Function; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableSet; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.common.collect.Sets; +import com.metamx.common.ISE; +import com.metamx.common.Pair; import io.druid.client.DruidDataSource; import io.druid.client.DruidServer; import io.druid.client.InventoryView; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; import java.util.Comparator; +import java.util.HashMap; +import java.util.Map; import java.util.Set; import java.util.TreeSet; -public class InventoryViewUtils { +public class InventoryViewUtils +{ public static Set getDataSources(InventoryView serverInventoryView) { @@ -64,4 +76,38 @@ public Iterable apply(DruidServer input) ); return dataSources; } + + public static Set getSecuredDataSources( + InventoryView inventoryView, + final AuthorizationInfo authorizationInfo + ) + { + if (authorizationInfo == null) { + throw new ISE("Invalid to call a secured method with null AuthorizationInfo!!"); + } else { + final Map, Access> resourceAccessMap = new HashMap<>(); + return ImmutableSet.copyOf( + Iterables.filter( + getDataSources(inventoryView), + new Predicate() + { + @Override + public boolean apply(DruidDataSource input) + { + Resource resource = new Resource(input.getName(), ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + ) + ); + } + } } diff --git a/server/src/main/java/io/druid/server/http/MetadataResource.java b/server/src/main/java/io/druid/server/http/MetadataResource.java index 294165402f32..e480121b8b9f 100644 --- a/server/src/main/java/io/druid/server/http/MetadataResource.java +++ b/server/src/main/java/io/druid/server/http/MetadataResource.java @@ -20,26 +20,42 @@ package io.druid.server.http; import com.google.common.base.Function; +import com.google.common.base.Predicate; +import com.google.common.collect.Collections2; import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.inject.Inject; +import com.metamx.common.Pair; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.client.DruidDataSource; import io.druid.indexing.overlord.IndexerMetadataStorageCoordinator; import io.druid.metadata.MetadataSegmentManager; +import io.druid.server.http.security.DatasourceResourceFilter; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; import io.druid.timeline.DataSegment; import org.joda.time.Interval; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.GET; import javax.ws.rs.POST; import javax.ws.rs.Path; import javax.ws.rs.PathParam; import javax.ws.rs.Produces; import javax.ws.rs.QueryParam; +import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; import java.io.IOException; +import java.util.Collection; import java.util.Collections; +import java.util.HashMap; import java.util.List; +import java.util.Map; /** */ @@ -48,15 +64,18 @@ public class MetadataResource { private final MetadataSegmentManager metadataSegmentManager; private final IndexerMetadataStorageCoordinator metadataStorageCoordinator; + private final AuthConfig authConfig; @Inject public MetadataResource( MetadataSegmentManager metadataSegmentManager, - IndexerMetadataStorageCoordinator metadataStorageCoordinator + IndexerMetadataStorageCoordinator metadataStorageCoordinator, + AuthConfig authConfig ) { this.metadataSegmentManager = metadataSegmentManager; this.metadataStorageCoordinator = metadataStorageCoordinator; + this.authConfig = authConfig; } @GET @@ -64,20 +83,88 @@ public MetadataResource( @Produces(MediaType.APPLICATION_JSON) public Response getDatabaseDataSources( @QueryParam("full") String full, - @QueryParam("includeDisabled") String includeDisabled + @QueryParam("includeDisabled") String includeDisabled, + @Context final HttpServletRequest req ) { Response.ResponseBuilder builder = Response.status(Response.Status.OK); + + final Collection druidDataSources; + if (authConfig.isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final Map, Access> resourceAccessMap = new HashMap<>(); + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) req.getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + if (includeDisabled != null) { + return builder.entity( + Collections2.filter( + metadataSegmentManager.getAllDatasourceNames(), + new Predicate() + { + @Override + public boolean apply(String input) + { + Resource resource = new Resource(input, ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + )).build(); + } else { + druidDataSources = + Collections2.filter( + metadataSegmentManager.getInventory(), + new Predicate() + { + @Override + public boolean apply(DruidDataSource input) + { + Resource resource = new Resource(input.getName(), ResourceType.DATASOURCE); + Action action = Action.READ; + Pair key = new Pair<>(resource, action); + if (resourceAccessMap.containsKey(key)) { + return resourceAccessMap.get(key).isAllowed(); + } else { + Access access = authorizationInfo.isAuthorized(key.lhs, key.rhs); + resourceAccessMap.put(key, access); + return access.isAllowed(); + } + } + } + ); + } + } else { + druidDataSources = metadataSegmentManager.getInventory(); + } + if (includeDisabled != null) { - return builder.entity(metadataSegmentManager.getAllDatasourceNames()).build(); + return builder.entity( + Collections2.transform( + druidDataSources, + new Function() + { + @Override + public String apply(DruidDataSource input) + { + return input.getName(); + } + } + ) + ).build(); } if (full != null) { - return builder.entity(metadataSegmentManager.getInventory()).build(); + return builder.entity(druidDataSources).build(); } List dataSourceNames = Lists.newArrayList( Iterables.transform( - metadataSegmentManager.getInventory(), + druidDataSources, new Function() { @Override @@ -97,6 +184,7 @@ public String apply(DruidDataSource dataSource) @GET @Path("/datasources/{dataSourceName}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getDatabaseSegmentDataSource( @PathParam("dataSourceName") final String dataSourceName ) @@ -112,6 +200,7 @@ public Response getDatabaseSegmentDataSource( @GET @Path("/datasources/{dataSourceName}/segments") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getDatabaseSegmentDataSourceSegments( @PathParam("dataSourceName") String dataSourceName, @QueryParam("full") String full @@ -145,13 +234,14 @@ public String apply(DataSegment segment) @POST @Path("/datasources/{dataSourceName}/segments") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getDatabaseSegmentDataSourceSegments( @PathParam("dataSourceName") String dataSourceName, @QueryParam("full") String full, List intervals ) { - List segments = null; + List segments; try { segments = metadataStorageCoordinator.getUsedSegmentsForIntervals(dataSourceName, intervals); } @@ -182,6 +272,7 @@ public String apply(DataSegment segment) @GET @Path("/datasources/{dataSourceName}/segments/{segmentId}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(DatasourceResourceFilter.class) public Response getDatabaseSegmentDataSourceSegment( @PathParam("dataSourceName") String dataSourceName, @PathParam("segmentId") String segmentId diff --git a/server/src/main/java/io/druid/server/http/RulesResource.java b/server/src/main/java/io/druid/server/http/RulesResource.java index fdacb228ea63..1d93d61df7d8 100644 --- a/server/src/main/java/io/druid/server/http/RulesResource.java +++ b/server/src/main/java/io/druid/server/http/RulesResource.java @@ -21,13 +21,14 @@ import com.google.common.collect.ImmutableMap; import com.google.inject.Inject; - +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.audit.AuditEntry; import io.druid.audit.AuditInfo; import io.druid.audit.AuditManager; import io.druid.metadata.MetadataRuleManager; import io.druid.server.coordinator.rules.Rule; - +import io.druid.server.http.security.RulesResourceFilter; +import io.druid.server.http.security.StateResourceFilter; import org.joda.time.Interval; import javax.servlet.http.HttpServletRequest; @@ -43,7 +44,6 @@ import javax.ws.rs.core.Context; import javax.ws.rs.core.MediaType; import javax.ws.rs.core.Response; - import java.util.List; /** @@ -66,6 +66,7 @@ public RulesResource( @GET @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response getRules() { return Response.ok(databaseRuleManager.getAllRules()).build(); @@ -74,6 +75,7 @@ public Response getRules() @GET @Path("/{dataSourceName}") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(RulesResourceFilter.class) public Response getDatasourceRules( @PathParam("dataSourceName") final String dataSourceName, @QueryParam("full") final String full @@ -91,6 +93,7 @@ public Response getDatasourceRules( @POST @Path("/{dataSourceName}") @Consumes(MediaType.APPLICATION_JSON) + @ResourceFilters(RulesResourceFilter.class) public Response setDatasourceRules( @PathParam("dataSourceName") final String dataSourceName, final List rules, @@ -112,6 +115,7 @@ public Response setDatasourceRules( @GET @Path("/{dataSourceName}/history") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(RulesResourceFilter.class) public Response getDatasourceRuleHistory( @PathParam("dataSourceName") final String dataSourceName, @QueryParam("interval") final String interval, @@ -131,6 +135,7 @@ public Response getDatasourceRuleHistory( @GET @Path("/history") @Produces(MediaType.APPLICATION_JSON) + @ResourceFilters(StateResourceFilter.class) public Response getDatasourceRuleHistory( @QueryParam("interval") final String interval, @QueryParam("count") final Integer count diff --git a/server/src/main/java/io/druid/server/http/ServersResource.java b/server/src/main/java/io/druid/server/http/ServersResource.java index 33665fda81d2..70308eb8ebb0 100644 --- a/server/src/main/java/io/druid/server/http/ServersResource.java +++ b/server/src/main/java/io/druid/server/http/ServersResource.java @@ -25,8 +25,10 @@ import com.google.common.collect.Iterables; import com.google.common.collect.Lists; import com.google.inject.Inject; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.client.DruidServer; import io.druid.client.InventoryView; +import io.druid.server.http.security.StateResourceFilter; import io.druid.timeline.DataSegment; import javax.ws.rs.GET; @@ -41,6 +43,7 @@ /** */ @Path("/druid/coordinator/v1/servers") +@ResourceFilters(StateResourceFilter.class) public class ServersResource { private static Map makeSimpleServer(DruidServer input) diff --git a/server/src/main/java/io/druid/server/http/TiersResource.java b/server/src/main/java/io/druid/server/http/TiersResource.java index 6990dae2839a..db9189e56e5c 100644 --- a/server/src/main/java/io/druid/server/http/TiersResource.java +++ b/server/src/main/java/io/druid/server/http/TiersResource.java @@ -28,9 +28,11 @@ import com.google.common.collect.Table; import com.google.inject.Inject; import com.metamx.common.MapUtils; +import com.sun.jersey.spi.container.ResourceFilters; import io.druid.client.DruidDataSource; import io.druid.client.DruidServer; import io.druid.client.InventoryView; +import io.druid.server.http.security.StateResourceFilter; import io.druid.timeline.DataSegment; import org.joda.time.Interval; @@ -47,6 +49,7 @@ /** */ @Path("/druid/coordinator/v1/tiers") +@ResourceFilters(StateResourceFilter.class) public class TiersResource { private final InventoryView serverInventoryView; diff --git a/server/src/main/java/io/druid/server/http/security/AbstractResourceFilter.java b/server/src/main/java/io/druid/server/http/security/AbstractResourceFilter.java new file mode 100644 index 000000000000..a8a1fb4cb4e1 --- /dev/null +++ b/server/src/main/java/io/druid/server/http/security/AbstractResourceFilter.java @@ -0,0 +1,89 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.inject.Inject; +import com.sun.jersey.spi.container.ContainerRequest; +import com.sun.jersey.spi.container.ContainerRequestFilter; +import com.sun.jersey.spi.container.ContainerResponseFilter; +import com.sun.jersey.spi.container.ResourceFilter; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.core.Context; + +public abstract class AbstractResourceFilter implements ResourceFilter, ContainerRequestFilter +{ + //https://jsr311.java.net/nonav/releases/1.1/spec/spec3.html#x3-520005 + @Context + private HttpServletRequest req; + + private final AuthConfig authConfig; + + @Inject + public AbstractResourceFilter(AuthConfig authConfig) + { + this.authConfig = authConfig; + } + + @Override + public ContainerRequestFilter getRequestFilter() + { + return this; + } + + @Override + public ContainerResponseFilter getResponseFilter() + { + return null; + } + + public HttpServletRequest getReq() + { + return req; + } + + public AuthConfig getAuthConfig() + { + return authConfig; + } + + public AbstractResourceFilter setReq(HttpServletRequest req) + { + this.req = req; + return this; + } + + protected Action getAction(ContainerRequest request) + { + Action action; + switch (request.getMethod()) { + case "GET": + case "HEAD": + action = Action.READ; + break; + default: + action = Action.WRITE; + } + return action; + } + + public abstract boolean isApplicable(String requestPath); +} diff --git a/server/src/main/java/io/druid/server/http/security/ConfigResourceFilter.java b/server/src/main/java/io/druid/server/http/security/ConfigResourceFilter.java new file mode 100644 index 000000000000..61fc28f16269 --- /dev/null +++ b/server/src/main/java/io/druid/server/http/security/ConfigResourceFilter.java @@ -0,0 +1,85 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.common.base.Preconditions; +import com.google.inject.Inject; +import com.sun.jersey.spi.container.ContainerRequest; +import io.druid.server.security.Access; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.Response; + +/** + * Use this ResourceFilter at end points where Druid Cluster configuration is read or written + * Here are some example paths where this filter is used - + * - druid/worker/v1 + * - druid/indexer/v1 + * - druid/coordinator/v1/config + * Note - Currently the resource name for all end points is set to "CONFIG" however if more fine grained access control + * is required the resource name can be set to specific config properties. + */ +public class ConfigResourceFilter extends AbstractResourceFilter +{ + @Inject + public ConfigResourceFilter(AuthConfig authConfig) + { + super(authConfig); + } + + @Override + public ContainerRequest filter(ContainerRequest request) + { + if (getAuthConfig().isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final String resourceName = "CONFIG"; + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) getReq().getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + + final Access authResult = authorizationInfo.isAuthorized( + new Resource(resourceName, ResourceType.CONFIG), + getAction(request) + ); + if (!authResult.isAllowed()) { + throw new WebApplicationException( + Response.status(Response.Status.FORBIDDEN) + .entity(String.format("Access-Check-Result: %s", authResult.toString())) + .build() + ); + } + } + return request; + } + + @Override + public boolean isApplicable(String requestPath) + { + return requestPath.startsWith("druid/worker/v1") || + requestPath.startsWith("druid/indexer/v1") || + requestPath.startsWith("druid/coordinator/v1/config"); + } +} diff --git a/server/src/main/java/io/druid/server/http/security/DatasourceResourceFilter.java b/server/src/main/java/io/druid/server/http/security/DatasourceResourceFilter.java new file mode 100644 index 000000000000..ccbeab866008 --- /dev/null +++ b/server/src/main/java/io/druid/server/http/security/DatasourceResourceFilter.java @@ -0,0 +1,110 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.inject.Inject; +import com.sun.jersey.spi.container.ContainerRequest; +import io.druid.server.security.Access; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.PathSegment; +import javax.ws.rs.core.Response; +import java.util.List; + +/** + * Use this ResourceFilter when the datasource information is present after "datasources" segment in the request Path + * Here are some example paths where this filter is used - + * - druid/coordinator/v1/datasources/{dataSourceName}/... + * - druid/coordinator/v1/metadata/datasources/{dataSourceName}/... + * - druid/v2/datasources/{dataSourceName}/... + */ +public class DatasourceResourceFilter extends AbstractResourceFilter +{ + @Inject + public DatasourceResourceFilter(AuthConfig authConfig) + { + super(authConfig); + } + + @Override + public ContainerRequest filter(ContainerRequest request) + { + if (getAuthConfig().isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final String dataSourceName = request.getPathSegments() + .get( + Iterables.indexOf( + request.getPathSegments(), + new Predicate() + { + @Override + public boolean apply(PathSegment input) + { + return input.getPath().equals("datasources"); + } + } + ) + 1 + ).getPath(); + Preconditions.checkNotNull(dataSourceName); + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) getReq().getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + final Access authResult = authorizationInfo.isAuthorized( + new Resource(dataSourceName, ResourceType.DATASOURCE), + getAction(request) + ); + if (!authResult.isAllowed()) { + throw new WebApplicationException( + Response.status(Response.Status.FORBIDDEN) + .entity(String.format("Access-Check-Result: %s", authResult.toString())) + .build() + ); + } + } + + return request; + } + + @Override + public boolean isApplicable(String requestPath) + { + List applicablePaths = ImmutableList.of( + "druid/coordinator/v1/datasources/", + "druid/coordinator/v1/metadata/datasources/", + "druid/v2/datasources/" + ); + for (String path : applicablePaths) { + if(requestPath.startsWith(path) && !requestPath.equals(path)) { + return true; + } + } + return false; + } +} diff --git a/server/src/main/java/io/druid/server/http/security/RulesResourceFilter.java b/server/src/main/java/io/druid/server/http/security/RulesResourceFilter.java new file mode 100644 index 000000000000..0e87fab200fe --- /dev/null +++ b/server/src/main/java/io/druid/server/http/security/RulesResourceFilter.java @@ -0,0 +1,106 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.common.base.Preconditions; +import com.google.common.base.Predicate; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.inject.Inject; +import com.sun.jersey.spi.container.ContainerRequest; +import io.druid.server.security.Access; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.PathSegment; +import javax.ws.rs.core.Response; +import java.util.List; + + +/** + * Use this ResourceFilter when the datasource information is present after "rules" segment in the request Path + * Here are some example paths where this filter is used - + * - druid/coordinator/v1/rules/ + * */ + +public class RulesResourceFilter extends AbstractResourceFilter +{ + @Inject + public RulesResourceFilter(AuthConfig authConfig) + { + super(authConfig); + } + + @Override + public ContainerRequest filter(ContainerRequest request) + { + if (getAuthConfig().isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final String dataSourceName = request.getPathSegments() + .get( + Iterables.indexOf( + request.getPathSegments(), + new Predicate() + { + @Override + public boolean apply(PathSegment input) + { + return input.getPath().equals("rules"); + } + } + ) + 1 + ).getPath(); + Preconditions.checkNotNull(dataSourceName); + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) getReq().getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + final Access authResult = authorizationInfo.isAuthorized( + new Resource(dataSourceName, ResourceType.DATASOURCE), + getAction(request) + ); + if (!authResult.isAllowed()) { + throw new WebApplicationException( + Response.status(Response.Status.FORBIDDEN) + .entity(String.format("Access-Check-Result: %s", authResult.toString())) + .build() + ); + } + } + + return request; + } + + @Override + public boolean isApplicable(String requestPath) + { + List applicablePaths = ImmutableList.of("druid/coordinator/v1/rules/"); + for (String path : applicablePaths) { + if(requestPath.startsWith(path) && !requestPath.equals(path)) { + return true; + } + } + return false; + } +} diff --git a/server/src/main/java/io/druid/server/http/security/StateResourceFilter.java b/server/src/main/java/io/druid/server/http/security/StateResourceFilter.java new file mode 100644 index 000000000000..b4d9d40195f4 --- /dev/null +++ b/server/src/main/java/io/druid/server/http/security/StateResourceFilter.java @@ -0,0 +1,97 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.common.base.Preconditions; +import com.google.inject.Inject; +import com.sun.jersey.spi.container.ContainerRequest; +import io.druid.server.security.Access; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import io.druid.server.security.ResourceType; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.Response; + +/** + * Use this ResourceFilter at end points where Druid Cluster State is read or written + * Here are some example paths where this filter is used - + * - druid/broker/v1 + * - druid/coordinator/v1 + * - druid/historical/v1 + * - druid/indexer/v1 + * - druid/coordinator/v1/rules + * - druid/coordinator/v1/tiers + * - druid/worker/v1 + * - druid/coordinator/v1/servers + * - status + * Note - Currently the resource name for all end points is set to "STATE" however if more fine grained access control + * is required the resource name can be set to specific state properties. + */ +public class StateResourceFilter extends AbstractResourceFilter +{ + @Inject + public StateResourceFilter(AuthConfig authConfig) + { + super(authConfig); + } + + @Override + public ContainerRequest filter(ContainerRequest request) + { + if (getAuthConfig().isEnabled()) { + // This is an experimental feature, see - https://github.com/druid-io/druid/pull/2424 + final String resourceName = "STATE"; + final AuthorizationInfo authorizationInfo = (AuthorizationInfo) getReq().getAttribute(AuthConfig.DRUID_AUTH_TOKEN); + Preconditions.checkNotNull( + authorizationInfo, + "Security is enabled but no authorization info found in the request" + ); + + final Access authResult = authorizationInfo.isAuthorized( + new Resource(resourceName, ResourceType.STATE), + getAction(request) + ); + if (!authResult.isAllowed()) { + throw new WebApplicationException( + Response.status(Response.Status.FORBIDDEN) + .entity(String.format("Access-Check-Result: %s", authResult.toString())) + .build() + ); + } + } + + return request; + } + + public boolean isApplicable(String requestPath) + { + return requestPath.startsWith("druid/broker/v1") || + requestPath.startsWith("druid/coordinator/v1") || + requestPath.startsWith("druid/historical/v1") || + requestPath.startsWith("druid/indexer/v1") || + requestPath.startsWith("druid/coordinator/v1/rules") || + requestPath.startsWith("druid/coordinator/v1/tiers") || + requestPath.startsWith("druid/worker/v1") || + requestPath.startsWith("druid/coordinator/v1/servers") || + requestPath.startsWith("status"); + } +} diff --git a/server/src/main/java/io/druid/server/metrics/EventReceiverFirehoseMonitor.java b/server/src/main/java/io/druid/server/metrics/EventReceiverFirehoseMonitor.java index a0ad9b765b19..66fd4c1a6fda 100644 --- a/server/src/main/java/io/druid/server/metrics/EventReceiverFirehoseMonitor.java +++ b/server/src/main/java/io/druid/server/metrics/EventReceiverFirehoseMonitor.java @@ -28,11 +28,9 @@ import com.metamx.metrics.KeyedDiff; import com.metamx.metrics.MonitorUtils; import io.druid.query.DruidMetrics; -import io.druid.segment.realtime.firehose.EventReceiverFirehoseFactory; import java.util.Map; import java.util.Properties; -import java.util.concurrent.atomic.AtomicLong; public class EventReceiverFirehoseMonitor extends AbstractMonitor { diff --git a/server/src/main/java/io/druid/server/security/Access.java b/server/src/main/java/io/druid/server/security/Access.java new file mode 100644 index 000000000000..a70e579f3a4c --- /dev/null +++ b/server/src/main/java/io/druid/server/security/Access.java @@ -0,0 +1,51 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.security; + +public class Access +{ + private final boolean allowed; + private String message; + + public Access(boolean allowed) { + this(allowed, ""); + } + + public Access(boolean allowed, String message) { + this.allowed = allowed; + this.message = message; + } + + public boolean isAllowed() { + return allowed; + } + + public Access setMessage(String message) + { + this.message = message; + return this; + } + + @Override + public String toString() + { + return String.format("Allowed:%s, Message:%s", allowed, message); + } +} diff --git a/server/src/main/java/io/druid/server/security/Action.java b/server/src/main/java/io/druid/server/security/Action.java new file mode 100644 index 000000000000..2b7606b58dd8 --- /dev/null +++ b/server/src/main/java/io/druid/server/security/Action.java @@ -0,0 +1,26 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.security; + +public enum Action +{ + READ, + WRITE +} diff --git a/server/src/main/java/io/druid/server/security/AuthConfig.java b/server/src/main/java/io/druid/server/security/AuthConfig.java new file mode 100644 index 000000000000..8ade4ce6c415 --- /dev/null +++ b/server/src/main/java/io/druid/server/security/AuthConfig.java @@ -0,0 +1,85 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.security; + +import com.fasterxml.jackson.annotation.JsonCreator; +import com.fasterxml.jackson.annotation.JsonProperty; + +public class AuthConfig +{ + /** + * Use this String as the attribute name for the request attribute to pass {@link AuthorizationInfo} + * from the servlet filter to the jersey resource + * */ + public static final String DRUID_AUTH_TOKEN = "Druid-Auth-Token"; + + public AuthConfig() { + this(false); + } + + @JsonCreator + public AuthConfig( + @JsonProperty("enabled") boolean enabled + ){ + this.enabled = enabled; + } + /** + * If druid.auth.enabled is set to true then an implementation of AuthorizationInfo + * must be provided and it must be set as a request attribute possibly inside the servlet filter + * injected in the filter chain using your own extension + * */ + @JsonProperty + private final boolean enabled; + + public boolean isEnabled() + { + return enabled; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + AuthConfig that = (AuthConfig) o; + + return enabled == that.enabled; + + } + + @Override + public int hashCode() + { + return (enabled ? 1 : 0); + } + + @Override + public String toString() + { + return "AuthConfig{" + + "enabled=" + enabled + + '}'; + } +} diff --git a/server/src/main/java/io/druid/server/security/AuthorizationInfo.java b/server/src/main/java/io/druid/server/security/AuthorizationInfo.java new file mode 100644 index 000000000000..31097a935477 --- /dev/null +++ b/server/src/main/java/io/druid/server/security/AuthorizationInfo.java @@ -0,0 +1,44 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.security; + +/** + * This interface should be used to store as well as process Authorization Information + * An extension can be used to inject servlet filter which will create objects of this type + * and set it as a request attribute with attribute name as {@link AuthConfig#DRUID_AUTH_TOKEN}. + * In the jersey resources if the authorization is enabled depending on {@link AuthConfig#enabled} + * the {@link #isAuthorized(Resource, Action)} method will be used to perform authorization checks + * */ +public interface AuthorizationInfo +{ + /** + * Perform authorization checks for the given {@link Resource} and {@link Action}. + * resource and action objects should be instantiated depending on + * the specific endPoint where the check is being performed. + * Modeling Principal and specific way of performing authorization checks is + * entirely implementation dependent. + * + * @param resource information about resource that is being accessed + * @param action action to be performed on the resource + * @return a {@link Access} object having {@link Access#allowed} set to true if authorized otherwise set to false + * and optionally {@link Access#message} set to appropriate message + * */ + Access isAuthorized(Resource resource, Action action); +} diff --git a/server/src/main/java/io/druid/server/security/Resource.java b/server/src/main/java/io/druid/server/security/Resource.java new file mode 100644 index 000000000000..d3c74fb52899 --- /dev/null +++ b/server/src/main/java/io/druid/server/security/Resource.java @@ -0,0 +1,69 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.security; + +public class Resource +{ + private final String name; + private final ResourceType type; + + public Resource(String name, ResourceType type) + { + this.name = name; + this.type = type; + } + + public String getName() + { + return name; + } + + public ResourceType getType() + { + return type; + } + + @Override + public boolean equals(Object o) + { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + + Resource resource = (Resource) o; + + if (!name.equals(resource.name)) { + return false; + } + return type == resource.type; + + } + + @Override + public int hashCode() + { + int result = name.hashCode(); + result = 31 * result + type.hashCode(); + return result; + } +} diff --git a/server/src/main/java/io/druid/server/security/ResourceType.java b/server/src/main/java/io/druid/server/security/ResourceType.java new file mode 100644 index 000000000000..818bf9ca947d --- /dev/null +++ b/server/src/main/java/io/druid/server/security/ResourceType.java @@ -0,0 +1,27 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.security; + +public enum ResourceType +{ + DATASOURCE, + CONFIG, + STATE +} diff --git a/server/src/test/java/io/druid/server/ClientInfoResourceTest.java b/server/src/test/java/io/druid/server/ClientInfoResourceTest.java index a81938a7284f..1436ab2534b2 100644 --- a/server/src/test/java/io/druid/server/ClientInfoResourceTest.java +++ b/server/src/test/java/io/druid/server/ClientInfoResourceTest.java @@ -47,6 +47,7 @@ import io.druid.client.selector.ServerSelector; import io.druid.query.TableDataSource; import io.druid.query.metadata.SegmentMetadataQueryConfig; +import io.druid.server.security.AuthConfig; import io.druid.timeline.DataSegment; import io.druid.timeline.VersionedIntervalTimeline; import io.druid.timeline.partition.NumberedShardSpec; @@ -411,7 +412,7 @@ private ClientInfoResource getResourceTestHelper( SegmentMetadataQueryConfig segmentMetadataQueryConfig ) { - return new ClientInfoResource(serverInventoryView, timelineServerView, segmentMetadataQueryConfig) + return new ClientInfoResource(serverInventoryView, timelineServerView, segmentMetadataQueryConfig, new AuthConfig()) { @Override protected DateTime getCurrentTime() diff --git a/server/src/test/java/io/druid/server/QueryResourceTest.java b/server/src/test/java/io/druid/server/QueryResourceTest.java index ed2b3f1091f4..dabd6b575e8c 100644 --- a/server/src/test/java/io/druid/server/QueryResourceTest.java +++ b/server/src/test/java/io/druid/server/QueryResourceTest.java @@ -20,9 +20,13 @@ package io.druid.server; import com.fasterxml.jackson.databind.ObjectMapper; +import com.google.common.base.Throwables; +import com.google.common.util.concurrent.ListenableFuture; +import com.google.common.util.concurrent.MoreExecutors; import com.metamx.common.guava.Sequence; import com.metamx.common.guava.Sequences; import com.metamx.emitter.service.ServiceEmitter; +import io.druid.concurrent.Execs; import io.druid.jackson.DefaultObjectMapper; import io.druid.query.Query; import io.druid.query.QueryRunner; @@ -31,9 +35,15 @@ import io.druid.server.initialization.ServerConfig; import io.druid.server.log.NoopRequestLogger; import io.druid.server.metrics.NoopServiceEmitter; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; import org.easymock.EasyMock; import org.joda.time.Interval; import org.joda.time.Period; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.BeforeClass; @@ -45,6 +55,8 @@ import java.io.ByteArrayInputStream; import java.io.IOException; import java.util.Map; +import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; /** * @@ -97,6 +109,9 @@ public QueryRunner getQueryRunnerForSegments( private static final ServiceEmitter noopServiceEmitter = new NoopServiceEmitter(); + private QueryResource queryResource; + private QueryManager queryManager; + @BeforeClass public static void staticSetup() { @@ -106,9 +121,19 @@ public static void staticSetup() @Before public void setup() { - EasyMock.expect(testServletRequest.getContentType()).andReturn(MediaType.APPLICATION_JSON); + EasyMock.expect(testServletRequest.getContentType()).andReturn(MediaType.APPLICATION_JSON).anyTimes(); EasyMock.expect(testServletRequest.getRemoteAddr()).andReturn("localhost").anyTimes(); - EasyMock.replay(testServletRequest); + queryManager = new QueryManager(); + queryResource = new QueryResource( + serverConfig, + jsonMapper, + jsonMapper, + testSegmentWalker, + new NoopServiceEmitter(), + new NoopRequestLogger(), + queryManager, + new AuthConfig() + ); } private static final String simpleTimeSeriesQuery = "{\n" @@ -129,42 +154,273 @@ public void setup() @Test public void testGoodQuery() throws IOException { - QueryResource queryResource = new QueryResource( + EasyMock.replay(testServletRequest); + Response response = queryResource.doPost( + new ByteArrayInputStream(simpleTimeSeriesQuery.getBytes("UTF-8")), + null /*pretty*/, + testServletRequest + ); + Assert.assertNotNull(response); + } + + @Test + public void testBadQuery() throws IOException + { + EasyMock.replay(testServletRequest); + Response response = queryResource.doPost( + new ByteArrayInputStream("Meka Leka Hi Meka Hiney Ho".getBytes("UTF-8")), + null /*pretty*/, + testServletRequest + ); + Assert.assertNotNull(response); + Assert.assertEquals(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), response.getStatus()); + } + + @Test + public void testSecuredQuery() throws Exception + { + EasyMock.expect(testServletRequest.getAttribute(EasyMock.anyString())).andReturn( + new AuthorizationInfo() + { + @Override + public Access isAuthorized( + Resource resource, Action action + ) + { + if (resource.getName().equals("allow")) { + return new Access(true); + } else { + return new Access(false); + } + } + } + ).times(2); + EasyMock.replay(testServletRequest); + + queryResource = new QueryResource( serverConfig, jsonMapper, jsonMapper, testSegmentWalker, new NoopServiceEmitter(), new NoopRequestLogger(), - new QueryManager() + queryManager, + new AuthConfig(true) ); - Response respone = queryResource.doPost( + + Response response = queryResource.doPost( new ByteArrayInputStream(simpleTimeSeriesQuery.getBytes("UTF-8")), null /*pretty*/, testServletRequest ); - Assert.assertNotNull(respone); + Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), response.getStatus()); + + response = queryResource.doPost( + new ByteArrayInputStream("{\"queryType\":\"timeBoundary\", \"dataSource\":\"allow\"}".getBytes("UTF-8")), + null /*pretty*/, + testServletRequest + ); + + Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()); + } - @Test - public void testBadQuery() throws IOException + @Test(timeout = 60_000L) + public void testSecuredGetServer() throws Exception + { + final CountDownLatch waitForCancellationLatch = new CountDownLatch(1); + final CountDownLatch waitFinishLatch = new CountDownLatch(2); + final CountDownLatch startAwaitLatch = new CountDownLatch(1); + final CountDownLatch cancelledCountDownLatch = new CountDownLatch(1); + + EasyMock.expect(testServletRequest.getAttribute(EasyMock.anyString())).andReturn( + new AuthorizationInfo() + { + @Override + public Access isAuthorized( + Resource resource, Action action + ) + { + // READ action corresponds to the query + // WRITE corresponds to cancellation of query + if (action.equals(Action.READ)) { + try { + waitForCancellationLatch.await(); + } + catch (InterruptedException e) { + // When the query is cancelled the control will reach here, + // countdown the latch and rethrow the exception so that error response is returned for the query + cancelledCountDownLatch.countDown(); + Throwables.propagate(e); + } + return new Access(true); + } else { + return new Access(true); + } + } + } + ).times(2); + EasyMock.replay(testServletRequest); + + queryResource = new QueryResource( + serverConfig, + jsonMapper, + jsonMapper, + testSegmentWalker, + new NoopServiceEmitter(), + new NoopRequestLogger(), + queryManager, + new AuthConfig(true) + ); + + final String queryString = "{\"queryType\":\"timeBoundary\", \"dataSource\":\"allow\"," + + "\"context\":{\"queryId\":\"id_1\"}}"; + ObjectMapper mapper = new DefaultObjectMapper(); + Query query = mapper.readValue(queryString, Query.class); + + ListenableFuture future = MoreExecutors.listeningDecorator( + Execs.singleThreaded("test_query_resource_%s") + ).submit( + new Runnable() + { + @Override + public void run() + { + try { + startAwaitLatch.countDown(); + Response response = queryResource.doPost( + new ByteArrayInputStream(queryString.getBytes("UTF-8")), + null, + testServletRequest + ); + + Assert.assertEquals(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), response.getStatus()); + } + catch (IOException e) { + Throwables.propagate(e); + } + waitFinishLatch.countDown(); + } + } + ); + + queryManager.registerQuery(query, future); + startAwaitLatch.await(); + + Executors.newSingleThreadExecutor().submit( + new Runnable() + { + @Override + public void run() + { + Response response = queryResource.getServer("id_1", testServletRequest); + Assert.assertEquals(Response.Status.ACCEPTED.getStatusCode(), response.getStatus()); + waitForCancellationLatch.countDown(); + waitFinishLatch.countDown(); + } + } + ); + waitFinishLatch.await(); + cancelledCountDownLatch.await(); + } + + @Test(timeout = 60_000L) + public void testDenySecuredGetServer() throws Exception { + final CountDownLatch waitForCancellationLatch = new CountDownLatch(1); + final CountDownLatch waitFinishLatch = new CountDownLatch(2); + final CountDownLatch startAwaitLatch = new CountDownLatch(1); + + EasyMock.expect(testServletRequest.getAttribute(EasyMock.anyString())).andReturn( + new AuthorizationInfo() + { + @Override + public Access isAuthorized( + Resource resource, Action action + ) + { + // READ action corresponds to the query + // WRITE corresponds to cancellation of query + if (action.equals(Action.READ)) { + try { + waitForCancellationLatch.await(); + } + catch (InterruptedException e) { + Throwables.propagate(e); + } + return new Access(true); + } else { + // Deny access to cancel the query + return new Access(false); + } + } + } + ).times(2); + EasyMock.replay(testServletRequest); - QueryResource queryResource = new QueryResource( + queryResource = new QueryResource( serverConfig, jsonMapper, jsonMapper, testSegmentWalker, new NoopServiceEmitter(), new NoopRequestLogger(), - new QueryManager() + queryManager, + new AuthConfig(true) ); - Response respone = queryResource.doPost( - new ByteArrayInputStream("Meka Leka Hi Meka Hiney Ho".getBytes("UTF-8")), - null /*pretty*/, - testServletRequest + + final String queryString = "{\"queryType\":\"timeBoundary\", \"dataSource\":\"allow\"," + + "\"context\":{\"queryId\":\"id_1\"}}"; + ObjectMapper mapper = new DefaultObjectMapper(); + Query query = mapper.readValue(queryString, Query.class); + + ListenableFuture future = MoreExecutors.listeningDecorator( + Execs.singleThreaded("test_query_resource_%s") + ).submit( + new Runnable() + { + @Override + public void run() + { + try { + startAwaitLatch.countDown(); + Response response = queryResource.doPost( + new ByteArrayInputStream(queryString.getBytes("UTF-8")), + null, + testServletRequest + ); + Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()); + } + catch (IOException e) { + Throwables.propagate(e); + } + waitFinishLatch.countDown(); + } + } + ); + + queryManager.registerQuery(query, future); + startAwaitLatch.await(); + + Executors.newSingleThreadExecutor().submit( + new Runnable() + { + @Override + public void run() + { + Response response = queryResource.getServer("id_1", testServletRequest); + Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), response.getStatus()); + waitForCancellationLatch.countDown(); + waitFinishLatch.countDown(); + } + } ); - Assert.assertNotNull(respone); - Assert.assertEquals(Response.Status.INTERNAL_SERVER_ERROR.getStatusCode(), respone.getStatus()); + waitFinishLatch.await(); + } + + @After + public void tearDown() + { + EasyMock.verify(testServletRequest); } } diff --git a/server/src/test/java/io/druid/server/http/DatasourcesResourceTest.java b/server/src/test/java/io/druid/server/http/DatasourcesResourceTest.java index 51f5cbb88527..71147cdaa7bb 100644 --- a/server/src/test/java/io/druid/server/http/DatasourcesResourceTest.java +++ b/server/src/test/java/io/druid/server/http/DatasourcesResourceTest.java @@ -25,6 +25,11 @@ import io.druid.client.DruidDataSource; import io.druid.client.DruidServer; import io.druid.client.indexing.IndexingServiceClient; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; import io.druid.timeline.DataSegment; import org.easymock.EasyMock; import org.joda.time.Interval; @@ -32,6 +37,7 @@ import org.junit.Before; import org.junit.Test; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.core.Response; import java.util.ArrayList; import java.util.HashMap; @@ -47,10 +53,12 @@ public class DatasourcesResourceTest private DruidServer server; private List listDataSources; private List dataSegmentList; + private HttpServletRequest request; @Before public void setUp() { + request = EasyMock.createStrictMock(HttpServletRequest.class); inventoryView = EasyMock.createStrictMock(CoordinatorServerView.class); server = EasyMock.createStrictMock(DruidServer.class); dataSegmentList = new ArrayList<>(); @@ -94,8 +102,12 @@ public void setUp() ) ); listDataSources = new ArrayList<>(); - listDataSources.add(new DruidDataSource("datasource1", new HashMap()).addSegment("part1", dataSegmentList.get(0))); - listDataSources.add(new DruidDataSource("datasource2", new HashMap()).addSegment("part1", dataSegmentList.get(1))); + listDataSources.add( + new DruidDataSource("datasource1", new HashMap()).addSegment("part1", dataSegmentList.get(0)) + ); + listDataSources.add( + new DruidDataSource("datasource2", new HashMap()).addSegment("part1", dataSegmentList.get(1)) + ); } @Test @@ -108,8 +120,8 @@ public void testGetFullQueryableDataSources() throws Exception ImmutableList.of(server) ).atLeastOnce(); EasyMock.replay(inventoryView, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); - Response response = datasourcesResource.getQueryableDataSources("full", null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); + Response response = datasourcesResource.getQueryableDataSources("full", null, request); Set result = (Set) response.getEntity(); DruidDataSource[] resultantDruidDataSources = new DruidDataSource[result.size()]; result.toArray(resultantDruidDataSources); @@ -117,7 +129,7 @@ public void testGetFullQueryableDataSources() throws Exception Assert.assertEquals(2, resultantDruidDataSources.length); Assert.assertArrayEquals(listDataSources.toArray(), resultantDruidDataSources); - response = datasourcesResource.getQueryableDataSources(null, null); + response = datasourcesResource.getQueryableDataSources(null, null, request); List result1 = (List) response.getEntity(); Assert.assertEquals(200, response.getStatus()); Assert.assertEquals(2, result1.size()); @@ -126,6 +138,53 @@ public void testGetFullQueryableDataSources() throws Exception EasyMock.verify(inventoryView, server); } + @Test + public void testSecuredGetFullQueryableDataSources() throws Exception + { + EasyMock.expect(server.getDataSources()).andReturn( + ImmutableList.of(listDataSources.get(0), listDataSources.get(1)) + ).atLeastOnce(); + EasyMock.expect(inventoryView.getInventory()).andReturn( + ImmutableList.of(server) + ).atLeastOnce(); + EasyMock.expect(request.getAttribute(AuthConfig.DRUID_AUTH_TOKEN)).andReturn( + new AuthorizationInfo() + { + @Override + public Access isAuthorized( + Resource resource, Action action + ) + { + if (resource.getName().equals("datasource1")) { + return new Access(true); + } else { + return new Access(false); + } + } + } + ).times(2); + EasyMock.replay(inventoryView, server, request); + + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig(true)); + Response response = datasourcesResource.getQueryableDataSources("full", null, request); + Set result = (Set) response.getEntity(); + DruidDataSource[] resultantDruidDataSources = new DruidDataSource[result.size()]; + result.toArray(resultantDruidDataSources); + + Assert.assertEquals(200, response.getStatus()); + Assert.assertEquals(1, resultantDruidDataSources.length); + Assert.assertArrayEquals(listDataSources.subList(0, 1).toArray(), resultantDruidDataSources); + + response = datasourcesResource.getQueryableDataSources(null, null, request); + List result1 = (List) response.getEntity(); + + Assert.assertEquals(200, response.getStatus()); + Assert.assertEquals(1, result1.size()); + Assert.assertTrue(result1.contains("datasource1")); + + EasyMock.verify(inventoryView, server, request); + } + @Test public void testGetSimpleQueryableDataSources() throws Exception { @@ -145,8 +204,8 @@ public void testGetSimpleQueryableDataSources() throws Exception ).atLeastOnce(); EasyMock.replay(inventoryView, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); - Response response = datasourcesResource.getQueryableDataSources(null, "simple"); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); + Response response = datasourcesResource.getQueryableDataSources(null, "simple", request); Assert.assertEquals(200, response.getStatus()); List> results = (List>) response.getEntity(); int index = 0; @@ -172,7 +231,7 @@ public void testFullGetTheDataSource() throws Exception ).atLeastOnce(); EasyMock.replay(inventoryView, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); Response response = datasourcesResource.getTheDataSource("datasource1", "full"); DruidDataSource result = (DruidDataSource) response.getEntity(); Assert.assertEquals(200, response.getStatus()); @@ -189,7 +248,7 @@ public void testNullGetTheDataSource() throws Exception ).atLeastOnce(); EasyMock.replay(inventoryView, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); Assert.assertEquals(204, datasourcesResource.getTheDataSource("none", null).getStatus()); EasyMock.verify(inventoryView, server); } @@ -211,7 +270,7 @@ public void testSimpleGetTheDataSource() throws Exception ).atLeastOnce(); EasyMock.replay(inventoryView, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); Response response = datasourcesResource.getTheDataSource("datasource1", null); Assert.assertEquals(200, response.getStatus()); Map> result = (Map>) response.getEntity(); @@ -250,7 +309,7 @@ public void testSimpleGetTheDataSourceManyTiers() throws Exception ).atLeastOnce(); EasyMock.replay(inventoryView, server, server2, server3); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); Response response = datasourcesResource.getTheDataSource("datasource1", null); Assert.assertEquals(200, response.getStatus()); Map> result = (Map>) response.getEntity(); @@ -281,7 +340,7 @@ public void testGetSegmentDataSourceIntervals() List expectedIntervals = new ArrayList<>(); expectedIntervals.add(new Interval("2010-01-22T00:00:00.000Z/2010-01-23T00:00:00.000Z")); expectedIntervals.add(new Interval("2010-01-01T00:00:00.000Z/2010-01-02T00:00:00.000Z")); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); Response response = datasourcesResource.getSegmentDataSourceIntervals("invalidDataSource", null, null); Assert.assertEquals(response.getEntity(), null); @@ -328,7 +387,7 @@ public void testGetSegmentDataSourceSpecificInterval() ).atLeastOnce(); EasyMock.replay(inventoryView); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, null, new AuthConfig()); Response response = datasourcesResource.getSegmentDataSourceSpecificInterval( "invalidDataSource", "2010-01-01/P1D", @@ -395,7 +454,7 @@ public void testDeleteDataSourceSpecificInterval() throws Exception EasyMock.expectLastCall().once(); EasyMock.replay(indexingServiceClient, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, indexingServiceClient); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, indexingServiceClient, new AuthConfig()); Response response = datasourcesResource.deleteDataSourceSpecificInterval("datasource1", interval); Assert.assertEquals(200, response.getStatus()); @@ -407,7 +466,7 @@ public void testDeleteDataSourceSpecificInterval() throws Exception public void testDeleteDataSource() { IndexingServiceClient indexingServiceClient = EasyMock.createStrictMock(IndexingServiceClient.class); EasyMock.replay(indexingServiceClient, server); - DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, indexingServiceClient); + DatasourcesResource datasourcesResource = new DatasourcesResource(inventoryView, null, indexingServiceClient, new AuthConfig()); Response response = datasourcesResource.deleteDataSource("datasource", "true", "???"); Assert.assertEquals(400, response.getStatus()); Assert.assertNotNull(response.getEntity()); diff --git a/server/src/test/java/io/druid/server/http/IntervalsResourceTest.java b/server/src/test/java/io/druid/server/http/IntervalsResourceTest.java index b77842bff8dd..4fb50795c85a 100644 --- a/server/src/test/java/io/druid/server/http/IntervalsResourceTest.java +++ b/server/src/test/java/io/druid/server/http/IntervalsResourceTest.java @@ -22,13 +22,16 @@ import com.google.common.collect.ImmutableList; import io.druid.client.DruidServer; import io.druid.client.InventoryView; +import io.druid.server.security.AuthConfig; import io.druid.timeline.DataSegment; import org.easymock.EasyMock; import org.joda.time.Interval; +import org.junit.After; import org.junit.Assert; import org.junit.Before; import org.junit.Test; +import javax.servlet.http.HttpServletRequest; import javax.ws.rs.core.Response; import java.util.ArrayList; import java.util.List; @@ -40,12 +43,15 @@ public class IntervalsResourceTest private InventoryView inventoryView; private DruidServer server; private List dataSegmentList; + private HttpServletRequest request; @Before public void setUp() { inventoryView = EasyMock.createStrictMock(InventoryView.class); server = EasyMock.createStrictMock(DruidServer.class); + request = EasyMock.createStrictMock(HttpServletRequest.class); + dataSegmentList = new ArrayList<>(); dataSegmentList.add( new DataSegment( @@ -103,9 +109,9 @@ public void testGetIntervals() List expectedIntervals = new ArrayList<>(); expectedIntervals.add(new Interval("2010-01-01T00:00:00.000Z/2010-01-02T00:00:00.000Z")); expectedIntervals.add(new Interval("2010-01-22T00:00:00.000Z/2010-01-23T00:00:00.000Z")); - IntervalsResource intervalsResource = new IntervalsResource(inventoryView); + IntervalsResource intervalsResource = new IntervalsResource(inventoryView, new AuthConfig()); - Response response = intervalsResource.getIntervals(); + Response response = intervalsResource.getIntervals(request); TreeMap>> actualIntervals = (TreeMap) response.getEntity(); Assert.assertEquals(2, actualIntervals.size()); Assert.assertEquals(expectedIntervals.get(1), actualIntervals.firstKey()); @@ -117,7 +123,6 @@ public void testGetIntervals() Assert.assertEquals(5L, actualIntervals.get(expectedIntervals.get(0)).get("datasource2").get("size")); Assert.assertEquals(1, actualIntervals.get(expectedIntervals.get(0)).get("datasource2").get("count")); - EasyMock.verify(inventoryView); } @Test @@ -130,16 +135,15 @@ public void testSimpleGetSpecificIntervals() List expectedIntervals = new ArrayList<>(); expectedIntervals.add(new Interval("2010-01-01T00:00:00.000Z/2010-01-02T00:00:00.000Z")); - IntervalsResource intervalsResource = new IntervalsResource(inventoryView); + IntervalsResource intervalsResource = new IntervalsResource(inventoryView, new AuthConfig()); - Response response = intervalsResource.getSpecificIntervals("2010-01-01T00:00:00.000Z/P1D", "simple", null); + Response response = intervalsResource.getSpecificIntervals("2010-01-01T00:00:00.000Z/P1D", "simple", null, request); Map> actualIntervals = (Map) response.getEntity(); Assert.assertEquals(1, actualIntervals.size()); Assert.assertTrue(actualIntervals.containsKey(expectedIntervals.get(0))); Assert.assertEquals(25L, actualIntervals.get(expectedIntervals.get(0)).get("size")); Assert.assertEquals(2, actualIntervals.get(expectedIntervals.get(0)).get("count")); - EasyMock.verify(inventoryView); } @Test @@ -152,9 +156,9 @@ public void testFullGetSpecificIntervals() List expectedIntervals = new ArrayList<>(); expectedIntervals.add(new Interval("2010-01-01T00:00:00.000Z/2010-01-02T00:00:00.000Z")); - IntervalsResource intervalsResource = new IntervalsResource(inventoryView); + IntervalsResource intervalsResource = new IntervalsResource(inventoryView, new AuthConfig()); - Response response = intervalsResource.getSpecificIntervals("2010-01-01T00:00:00.000Z/P1D", null, "full"); + Response response = intervalsResource.getSpecificIntervals("2010-01-01T00:00:00.000Z/P1D", null, "full", request); TreeMap>> actualIntervals = (TreeMap) response.getEntity(); Assert.assertEquals(1, actualIntervals.size()); Assert.assertEquals(expectedIntervals.get(0), actualIntervals.firstKey()); @@ -163,7 +167,6 @@ public void testFullGetSpecificIntervals() Assert.assertEquals(5L, actualIntervals.get(expectedIntervals.get(0)).get("datasource2").get("size")); Assert.assertEquals(1, actualIntervals.get(expectedIntervals.get(0)).get("datasource2").get("count")); - EasyMock.verify(inventoryView); } @Test @@ -174,14 +177,19 @@ public void testGetSpecificIntervals() ).atLeastOnce(); EasyMock.replay(inventoryView); - IntervalsResource intervalsResource = new IntervalsResource(inventoryView); + IntervalsResource intervalsResource = new IntervalsResource(inventoryView, new AuthConfig()); - Response response = intervalsResource.getSpecificIntervals("2010-01-01T00:00:00.000Z/P1D", null, null); + Response response = intervalsResource.getSpecificIntervals("2010-01-01T00:00:00.000Z/P1D", null, null, request); Map actualIntervals = (Map) response.getEntity(); Assert.assertEquals(2, actualIntervals.size()); Assert.assertEquals(25L, actualIntervals.get("size")); Assert.assertEquals(2, actualIntervals.get("count")); + } + + @After + public void tearDown() { EasyMock.verify(inventoryView); } + } diff --git a/server/src/test/java/io/druid/server/http/RulesResourceTest.java b/server/src/test/java/io/druid/server/http/RulesResourceTest.java index 283026f82cf6..d153397cee95 100644 --- a/server/src/test/java/io/druid/server/http/RulesResourceTest.java +++ b/server/src/test/java/io/druid/server/http/RulesResourceTest.java @@ -20,12 +20,10 @@ package io.druid.server.http; import com.google.common.collect.ImmutableList; - import io.druid.audit.AuditEntry; import io.druid.audit.AuditInfo; import io.druid.audit.AuditManager; import io.druid.metadata.MetadataRuleManager; - import org.easymock.EasyMock; import org.joda.time.DateTime; import org.joda.time.Interval; @@ -34,7 +32,6 @@ import org.junit.Test; import javax.ws.rs.core.Response; - import java.util.List; import java.util.Map; @@ -255,4 +252,5 @@ public void testGetAllDatasourcesRuleHistoryWithWrongCount() EasyMock.verify(auditManager); } + } diff --git a/server/src/test/java/io/druid/server/http/security/ResourceFilterTestHelper.java b/server/src/test/java/io/druid/server/http/security/ResourceFilterTestHelper.java new file mode 100644 index 000000000000..ae317314b21e --- /dev/null +++ b/server/src/test/java/io/druid/server/http/security/ResourceFilterTestHelper.java @@ -0,0 +1,245 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.common.base.Function; +import com.google.common.base.Predicate; +import com.google.common.collect.Collections2; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.inject.Binder; +import com.google.inject.Guice; +import com.google.inject.Injector; +import com.google.inject.Key; +import com.google.inject.Module; +import com.sun.jersey.spi.container.ContainerRequest; +import com.sun.jersey.spi.container.ResourceFilter; +import com.sun.jersey.spi.container.ResourceFilters; +import io.druid.server.security.Access; +import io.druid.server.security.Action; +import io.druid.server.security.AuthConfig; +import io.druid.server.security.AuthorizationInfo; +import io.druid.server.security.Resource; +import org.easymock.EasyMock; + +import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.DELETE; +import javax.ws.rs.GET; +import javax.ws.rs.POST; +import javax.ws.rs.Path; +import javax.ws.rs.core.MultivaluedMap; +import javax.ws.rs.core.PathSegment; +import java.lang.reflect.Method; +import java.util.Arrays; +import java.util.Collection; +import java.util.Collections; +import java.util.List; + +public class ResourceFilterTestHelper +{ + public HttpServletRequest req; + public AuthorizationInfo authorizationInfo; + public ContainerRequest request; + + public void setUp(ResourceFilter resourceFilter) throws Exception + { + req = EasyMock.createStrictMock(HttpServletRequest.class); + request = EasyMock.createStrictMock(ContainerRequest.class); + authorizationInfo = EasyMock.createStrictMock(AuthorizationInfo.class); + + // Memory barrier + synchronized (this) { + ((AbstractResourceFilter) resourceFilter).setReq(req); + } + } + + public void setUpMockExpectations( + String requestPath, + boolean authCheckResult, + String requestMethod + ) + { + EasyMock.expect(request.getPath()).andReturn(requestPath).anyTimes(); + EasyMock.expect(request.getPathSegments()).andReturn( + ImmutableList.copyOf( + Iterables.transform( + Arrays.asList(requestPath.split("/")), + new Function() + { + @Override + public PathSegment apply(final String input) + { + return new PathSegment() + { + @Override + public String getPath() + { + return input; + } + + @Override + public MultivaluedMap getMatrixParameters() + { + return null; + } + }; + } + } + ) + ) + ).anyTimes(); + EasyMock.expect(request.getMethod()).andReturn(requestMethod).anyTimes(); + EasyMock.expect(req.getAttribute(EasyMock.anyString())).andReturn(authorizationInfo).atLeastOnce(); + EasyMock.expect(authorizationInfo.isAuthorized( + EasyMock.anyObject(Resource.class), + EasyMock.anyObject(Action.class) + )).andReturn( + new Access(authCheckResult) + ).atLeastOnce(); + + } + + public static Collection getRequestPaths(final Class clazz) + { + return getRequestPaths(clazz, ImmutableList.>of(), ImmutableList.>of()); + } + + public static Collection getRequestPaths( + final Class clazz, + final Iterable> mockableInjections + ) + { + return getRequestPaths(clazz, mockableInjections, ImmutableList.>of()); + } + + public static Collection getRequestPaths( + final Class clazz, + final Iterable> mockableInjections, + final Iterable> mockableKeys + ) + { + return getRequestPaths(clazz, mockableInjections, mockableKeys, ImmutableList.of()); + } + + // Feeds in an array of [ PathName, MethodName, ResourceFilter , Injector] + public static Collection getRequestPaths( + final Class clazz, + final Iterable> mockableInjections, + final Iterable> mockableKeys, + final Iterable injectedObjs + ) + { + final Injector injector = Guice.createInjector( + new Module() + { + @Override + public void configure(Binder binder) + { + for (Class clazz : mockableInjections) { + binder.bind(clazz).toInstance(EasyMock.createNiceMock(clazz)); + } + for (Object obj : injectedObjs) { + binder.bind((Class) obj.getClass()).toInstance(obj); + } + for (Key key : mockableKeys) { + binder.bind((Key) key).toInstance(EasyMock.createNiceMock(key.getTypeLiteral().getRawType())); + } + binder.bind(AuthConfig.class).toInstance(new AuthConfig(true)); + } + } + ); + final String basepath = ((Path) clazz.getAnnotation(Path.class)).value().substring(1); //Ignore the first "/" + final List> baseResourceFilters = + clazz.getAnnotation(ResourceFilters.class) == null ? Collections.>emptyList() : + ImmutableList.copyOf(((ResourceFilters) clazz.getAnnotation(ResourceFilters.class)).value()); + + return ImmutableList.copyOf( + Iterables.concat( + // Step 3 - Merge all the Objects arrays for each endpoints + Iterables.transform( + // Step 2 - + // For each endpoint, make an Object array containing + // - Request Path like "druid/../../.." + // - Request Method like "GET" or "POST" or "DELETE" + // - Resource Filter instance for the endpoint + Iterables.filter( + // Step 1 - + // Filter out non resource endpoint methods + // and also the endpoints that does not have any + // ResourceFilters applied to them + ImmutableList.copyOf(clazz.getDeclaredMethods()), + new Predicate() + { + @Override + public boolean apply(Method input) + { + return input.getAnnotation(GET.class) != null + || input.getAnnotation(POST.class) != null + || input.getAnnotation(DELETE.class) != null + && (input.getAnnotation(ResourceFilters.class) != null + || !baseResourceFilters.isEmpty()); + } + } + ), + new Function>() + { + @Override + public Collection apply(final Method method) + { + final List> resourceFilters = + method.getAnnotation(ResourceFilters.class) == null ? baseResourceFilters : + ImmutableList.copyOf(method.getAnnotation(ResourceFilters.class).value()); + + return Collections2.transform( + resourceFilters, + new Function, Object[]>() + { + @Override + public Object[] apply(Class input) + { + if (method.getAnnotation(Path.class) != null) { + return new Object[]{ + String.format("%s%s", basepath, method.getAnnotation(Path.class).value()), + input.getAnnotation(GET.class) == null ? (method.getAnnotation(DELETE.class) == null + ? "POST" + : "DELETE") : "GET", + injector.getInstance(input), + injector + }; + } else { + return new Object[]{ + basepath, + input.getAnnotation(GET.class) == null ? (method.getAnnotation(DELETE.class) == null + ? "POST" + : "DELETE") : "GET", + injector.getInstance(input), + injector + }; + } + } + } + ); + } + } + ) + ) + ); + } +} diff --git a/server/src/test/java/io/druid/server/http/security/SecurityResourceFilterTest.java b/server/src/test/java/io/druid/server/http/security/SecurityResourceFilterTest.java new file mode 100644 index 000000000000..4a7cd0de8258 --- /dev/null +++ b/server/src/test/java/io/druid/server/http/security/SecurityResourceFilterTest.java @@ -0,0 +1,134 @@ +/* + * Licensed to Metamarkets Group Inc. (Metamarkets) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. Metamarkets licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package io.druid.server.http.security; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Iterables; +import com.google.inject.Injector; +import com.sun.jersey.spi.container.ResourceFilter; +import io.druid.server.ClientInfoResource; +import io.druid.server.QueryResource; +import io.druid.server.StatusResource; +import io.druid.server.http.BrokerResource; +import io.druid.server.http.CoordinatorDynamicConfigsResource; +import io.druid.server.http.CoordinatorResource; +import io.druid.server.http.DatasourcesResource; +import io.druid.server.http.HistoricalResource; +import io.druid.server.http.IntervalsResource; +import io.druid.server.http.MetadataResource; +import io.druid.server.http.RulesResource; +import io.druid.server.http.ServersResource; +import io.druid.server.http.TiersResource; +import org.easymock.EasyMock; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; + +import javax.ws.rs.WebApplicationException; +import javax.ws.rs.core.Response; +import java.util.Collection; + +@RunWith(Parameterized.class) +public class SecurityResourceFilterTest extends ResourceFilterTestHelper +{ + @Parameterized.Parameters + public static Collection data() + { + return ImmutableList.copyOf( + Iterables.concat( + getRequestPaths(CoordinatorResource.class), + getRequestPaths(DatasourcesResource.class), + getRequestPaths(BrokerResource.class), + getRequestPaths(HistoricalResource.class), + getRequestPaths(IntervalsResource.class), + getRequestPaths(MetadataResource.class), + getRequestPaths(RulesResource.class), + getRequestPaths(ServersResource.class), + getRequestPaths(TiersResource.class), + getRequestPaths(ClientInfoResource.class), + getRequestPaths(CoordinatorDynamicConfigsResource.class), + getRequestPaths(QueryResource.class), + getRequestPaths(StatusResource.class) + ) + ); + } + + private final String requestPath; + private final String requestMethod; + private final ResourceFilter resourceFilter; + private final Injector injector; + + public SecurityResourceFilterTest( + String requestPath, + String requestMethod, + ResourceFilter resourceFilter, + Injector injector + ) + { + this.requestPath = requestPath; + this.requestMethod = requestMethod; + this.resourceFilter = resourceFilter; + this.injector = injector; + } + + @Before + public void setUp() throws Exception + { + setUp(resourceFilter); + } + + @Test + public void testDatasourcesResourcesFilteringAccess() + { + setUpMockExpectations(requestPath, true, requestMethod); + EasyMock.replay(req, request, authorizationInfo); + Assert.assertTrue(((AbstractResourceFilter) resourceFilter.getRequestFilter()).isApplicable(requestPath)); + resourceFilter.getRequestFilter().filter(request); + EasyMock.verify(req, request, authorizationInfo); + } + + @Test(expected = WebApplicationException.class) + public void testDatasourcesResourcesFilteringNoAccess() + { + setUpMockExpectations(requestPath, false, requestMethod); + EasyMock.replay(req, request, authorizationInfo); + //Assert.assertTrue(((AbstractResourceFilter) resourceFilter.getRequestFilter()).isApplicable(requestPath)); + try { + resourceFilter.getRequestFilter().filter(request); + } + catch (WebApplicationException e) { + Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), e.getResponse().getStatus()); + throw e; + } + EasyMock.verify(req, request, authorizationInfo); + } + + @Test + public void testDatasourcesResourcesFilteringBadPath() + { + EasyMock.replay(req, request, authorizationInfo); + final String badRequestPath = requestPath.replaceAll("\\w+", "droid"); + Assert.assertFalse(((AbstractResourceFilter) resourceFilter.getRequestFilter()).isApplicable(badRequestPath)); + EasyMock.verify(req, request, authorizationInfo); + } + +}