Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ISSUE #6968] fix grpc acl bug #6969

Merged
merged 5 commits into from
Jul 15, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -223,7 +223,7 @@ public static PlainAccessResource parse(GeneratedMessageV3 messageV3, Authentica
if (!request.hasGroup()) {
throw new AclException("Consumer heartbeat doesn't have group");
} else {
accessResource.addResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addGroupResourceAndPerm(request.getGroup(), Permission.SUB);
}
}
} else if (SendMessageRequest.getDescriptor().getFullName().equals(rpcFullName)) {
Expand All @@ -240,15 +240,15 @@ public static PlainAccessResource parse(GeneratedMessageV3 messageV3, Authentica
accessResource.addResourceAndPerm(topic, Permission.PUB);
} else if (ReceiveMessageRequest.getDescriptor().getFullName().equals(rpcFullName)) {
ReceiveMessageRequest request = (ReceiveMessageRequest) messageV3;
accessResource.addResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addGroupResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addResourceAndPerm(request.getMessageQueue().getTopic(), Permission.SUB);
} else if (AckMessageRequest.getDescriptor().getFullName().equals(rpcFullName)) {
AckMessageRequest request = (AckMessageRequest) messageV3;
accessResource.addResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addGroupResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addResourceAndPerm(request.getTopic(), Permission.SUB);
} else if (ForwardMessageToDeadLetterQueueRequest.getDescriptor().getFullName().equals(rpcFullName)) {
ForwardMessageToDeadLetterQueueRequest request = (ForwardMessageToDeadLetterQueueRequest) messageV3;
accessResource.addResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addGroupResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addResourceAndPerm(request.getTopic(), Permission.SUB);
} else if (EndTransactionRequest.getDescriptor().getFullName().equals(rpcFullName)) {
EndTransactionRequest request = (EndTransactionRequest) messageV3;
Expand All @@ -264,7 +264,7 @@ public static PlainAccessResource parse(GeneratedMessageV3 messageV3, Authentica
}
if (command.getSettings().hasSubscription()) {
Subscription subscription = command.getSettings().getSubscription();
accessResource.addResourceAndPerm(subscription.getGroup(), Permission.SUB);
accessResource.addGroupResourceAndPerm(subscription.getGroup(), Permission.SUB);
for (SubscriptionEntry entry : subscription.getSubscriptionsList()) {
accessResource.addResourceAndPerm(entry.getTopic(), Permission.SUB);
}
Expand All @@ -275,17 +275,17 @@ public static PlainAccessResource parse(GeneratedMessageV3 messageV3, Authentica
}
} else if (NotifyClientTerminationRequest.getDescriptor().getFullName().equals(rpcFullName)) {
NotifyClientTerminationRequest request = (NotifyClientTerminationRequest) messageV3;
accessResource.addResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addGroupResourceAndPerm(request.getGroup(), Permission.SUB);
} else if (QueryRouteRequest.getDescriptor().getFullName().equals(rpcFullName)) {
QueryRouteRequest request = (QueryRouteRequest) messageV3;
accessResource.addResourceAndPerm(request.getTopic(), Permission.ANY);
} else if (QueryAssignmentRequest.getDescriptor().getFullName().equals(rpcFullName)) {
QueryAssignmentRequest request = (QueryAssignmentRequest) messageV3;
accessResource.addResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addGroupResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addResourceAndPerm(request.getTopic(), Permission.SUB);
} else if (ChangeInvisibleDurationRequest.getDescriptor().getFullName().equals(rpcFullName)) {
ChangeInvisibleDurationRequest request = (ChangeInvisibleDurationRequest) messageV3;
accessResource.addResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addGroupResourceAndPerm(request.getGroup(), Permission.SUB);
accessResource.addResourceAndPerm(request.getTopic(), Permission.SUB);
}
} catch (Throwable t) {
Expand All @@ -299,6 +299,11 @@ private void addResourceAndPerm(Resource resource, byte permission) {
addResourceAndPerm(resourceName, permission);
}

private void addGroupResourceAndPerm(Resource resource, byte permission) {
String resourceName = NamespaceUtil.wrapNamespace(resource.getResourceNamespace(), resource.getName());
addResourceAndPerm(getRetryTopic(resourceName), permission);
}

public static PlainAccessResource build(PlainAccessConfig plainAccessConfig, RemoteAddressStrategy remoteAddressStrategy) {
PlainAccessResource plainAccessResource = new PlainAccessResource();
plainAccessResource.setAccessKey(plainAccessConfig.getAccessKey());
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,189 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.rocketmq.acl;

import java.io.File;
import java.io.IOException;
import java.nio.ByteBuffer;
import org.apache.rocketmq.acl.common.AclClientRPCHook;
import org.apache.rocketmq.acl.common.AclException;
import org.apache.rocketmq.acl.common.SessionCredentials;
import org.apache.rocketmq.acl.plain.AclTestHelper;
import org.apache.rocketmq.acl.plain.PlainAccessResource;
import org.apache.rocketmq.acl.plain.PlainAccessValidator;
import org.apache.rocketmq.remoting.exception.RemotingCommandException;
import org.apache.rocketmq.remoting.protocol.RemotingCommand;
import org.apache.rocketmq.remoting.protocol.RequestCode;
import org.apache.rocketmq.remoting.protocol.header.PullMessageRequestHeader;
import org.apache.rocketmq.remoting.protocol.header.SendMessageRequestHeader;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;

public class RemotingClientAccessTest {

private PlainAccessValidator plainAccessValidator;
private AclClientRPCHook aclClient;
private SessionCredentials sessionCredentials;

private File confHome;

private String clientAddress = "10.7.1.3";

@Before
public void init() throws IOException {
String folder = "access_acl_conf";
confHome = AclTestHelper.copyResources(folder, true);
System.setProperty("rocketmq.home.dir", confHome.getAbsolutePath());
System.setProperty("rocketmq.acl.plain.file", "/access_acl_conf/acl/plain_acl.yml".replace("/", File.separator));

plainAccessValidator = new PlainAccessValidator();
sessionCredentials = new SessionCredentials();
sessionCredentials.setAccessKey("rocketmq3");
sessionCredentials.setSecretKey("12345678");
aclClient = new AclClientRPCHook(sessionCredentials);
}

@After
public void cleanUp() {
AclTestHelper.recursiveDelete(confHome);
}

@Test(expected = AclException.class)
public void testProduceDenyTopic() {
SendMessageRequestHeader messageRequestHeader = new SendMessageRequestHeader();
messageRequestHeader.setTopic("topicD");
RemotingCommand remotingCommand = RemotingCommand.createRequestCommand(RequestCode.SEND_MESSAGE, messageRequestHeader);
aclClient.doBeforeRequest(clientAddress, remotingCommand);

ByteBuffer buf = remotingCommand.encodeHeader();
buf.getInt();
buf = ByteBuffer.allocate(buf.limit() - buf.position()).put(buf);
buf.position(0);
try {
PlainAccessResource accessResource = (PlainAccessResource) plainAccessValidator.parse(RemotingCommand.decode(buf), clientAddress);
plainAccessValidator.validate(accessResource);
} catch (RemotingCommandException e) {
e.printStackTrace();
Assert.fail("Should not throw IOException");
}
}

@Test
public void testProduceAuthorizedTopic() {
SendMessageRequestHeader messageRequestHeader = new SendMessageRequestHeader();
messageRequestHeader.setTopic("topicA");
RemotingCommand remotingCommand = RemotingCommand.createRequestCommand(RequestCode.SEND_MESSAGE, messageRequestHeader);
aclClient.doBeforeRequest(clientAddress, remotingCommand);

ByteBuffer buf = remotingCommand.encodeHeader();
buf.getInt();
buf = ByteBuffer.allocate(buf.limit() - buf.position()).put(buf);
buf.position(0);
try {
PlainAccessResource accessResource = (PlainAccessResource) plainAccessValidator.parse(RemotingCommand.decode(buf), clientAddress);
plainAccessValidator.validate(accessResource);
} catch (RemotingCommandException e) {
e.printStackTrace();
Assert.fail("Should not throw IOException");
}
}


@Test(expected = AclException.class)
public void testConsumeDenyTopic() {
PullMessageRequestHeader pullMessageRequestHeader = new PullMessageRequestHeader();
pullMessageRequestHeader.setTopic("topicD");
pullMessageRequestHeader.setConsumerGroup("groupB");
RemotingCommand remotingCommand = RemotingCommand.createRequestCommand(RequestCode.PULL_MESSAGE, pullMessageRequestHeader);
aclClient.doBeforeRequest("", remotingCommand);
ByteBuffer buf = remotingCommand.encodeHeader();
buf.getInt();
buf = ByteBuffer.allocate(buf.limit() - buf.position()).put(buf);
buf.position(0);
try {
PlainAccessResource accessResource = (PlainAccessResource) plainAccessValidator.parse(RemotingCommand.decode(buf), "123.4.5.6");
plainAccessValidator.validate(accessResource);
} catch (RemotingCommandException e) {
e.printStackTrace();
Assert.fail("Should not throw IOException");
}

}

@Test
public void testConsumeAuthorizedTopic() {
PullMessageRequestHeader pullMessageRequestHeader = new PullMessageRequestHeader();
pullMessageRequestHeader.setTopic("topicB");
pullMessageRequestHeader.setConsumerGroup("groupB");
RemotingCommand remotingCommand = RemotingCommand.createRequestCommand(RequestCode.PULL_MESSAGE, pullMessageRequestHeader);
aclClient.doBeforeRequest("", remotingCommand);
ByteBuffer buf = remotingCommand.encodeHeader();
buf.getInt();
buf = ByteBuffer.allocate(buf.limit() - buf.position()).put(buf);
buf.position(0);
try {
PlainAccessResource accessResource = (PlainAccessResource) plainAccessValidator.parse(RemotingCommand.decode(buf), "123.4.5.6");
plainAccessValidator.validate(accessResource);
} catch (RemotingCommandException e) {
e.printStackTrace();
Assert.fail("Should not throw IOException");
}
}

@Test(expected = AclException.class)
public void testConsumeInDeniedGroup() {
PullMessageRequestHeader pullMessageRequestHeader = new PullMessageRequestHeader();
pullMessageRequestHeader.setTopic("topicB");
pullMessageRequestHeader.setConsumerGroup("groupD");
RemotingCommand remotingCommand = RemotingCommand.createRequestCommand(RequestCode.PULL_MESSAGE, pullMessageRequestHeader);
aclClient.doBeforeRequest("", remotingCommand);
ByteBuffer buf = remotingCommand.encodeHeader();
buf.getInt();
buf = ByteBuffer.allocate(buf.limit() - buf.position()).put(buf);
buf.position(0);
try {
PlainAccessResource accessResource = (PlainAccessResource) plainAccessValidator.parse(RemotingCommand.decode(buf), "123.4.5.6");
plainAccessValidator.validate(accessResource);
} catch (RemotingCommandException e) {
e.printStackTrace();
Assert.fail("Should not throw IOException");
}
}

@Test
public void testConsumeInAuthorizedGroup() {
PullMessageRequestHeader pullMessageRequestHeader = new PullMessageRequestHeader();
pullMessageRequestHeader.setTopic("topicB");
pullMessageRequestHeader.setConsumerGroup("groupB");
RemotingCommand remotingCommand = RemotingCommand.createRequestCommand(RequestCode.PULL_MESSAGE, pullMessageRequestHeader);
aclClient.doBeforeRequest("", remotingCommand);
ByteBuffer buf = remotingCommand.encodeHeader();
buf.getInt();
buf = ByteBuffer.allocate(buf.limit() - buf.position()).put(buf);
buf.position(0);
try {
PlainAccessResource accessResource = (PlainAccessResource) plainAccessValidator.parse(RemotingCommand.decode(buf), "123.4.5.6");
plainAccessValidator.validate(accessResource);
} catch (RemotingCommandException e) {
e.printStackTrace();
Assert.fail("Should not throw IOException");
}
}

}
31 changes: 31 additions & 0 deletions acl/src/test/resources/access_acl_conf/acl/plain_acl.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

accounts:
- accessKey: rocketmq3
secretKey: 12345678
admin: false
defaultTopicPerm: DENY
defaultGroupPerm: DENY
topicPerms:
- topicA=PUB
- topicB=SUB
- topicC=PUB|SUB
- topicD=DENY
groupPerms:
- groupB=SUB
- groupC=PUB|SUB
- groupD=DENY

1 change: 0 additions & 1 deletion acl/src/test/resources/conf/acl/plain_acl.yml
Original file line number Diff line number Diff line change
Expand Up @@ -41,4 +41,3 @@ accounts:
whiteRemoteAddress: 192.168.1.*
# if it is admin, it could access all resources
admin: true

1 change: 1 addition & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,7 @@
<awaitility.version>4.1.0</awaitility.version>
<truth.version>0.30</truth.version>
<s3mock-junit4.version>2.11.0</s3mock-junit4.version>
<rocketmq-client-java.version>5.0.5</rocketmq-client-java.version>

<!-- Build plugin dependencies -->
<versions-maven-plugin.version>2.2</versions-maven-plugin.version>
Expand Down
17 changes: 15 additions & 2 deletions proxy/src/main/java/org/apache/rocketmq/proxy/ProxyStartup.java
Original file line number Diff line number Diff line change
Expand Up @@ -29,11 +29,14 @@
import org.apache.commons.cli.Option;
import org.apache.commons.cli.Options;
import org.apache.commons.lang3.StringUtils;
import org.apache.rocketmq.acl.AccessValidator;
import org.apache.rocketmq.acl.plain.PlainAccessValidator;
import org.apache.rocketmq.broker.BrokerController;
import org.apache.rocketmq.broker.BrokerStartup;
import org.apache.rocketmq.common.MixAll;
import org.apache.rocketmq.common.constant.LoggerName;
import org.apache.rocketmq.common.thread.ThreadPoolMonitor;
import org.apache.rocketmq.common.utils.ServiceProvider;
import org.apache.rocketmq.logging.org.slf4j.Logger;
import org.apache.rocketmq.logging.org.slf4j.LoggerFactory;
import org.apache.rocketmq.common.utils.AbstractStartAndShutdown;
Expand Down Expand Up @@ -75,16 +78,17 @@ public static void main(String[] args) {

MessagingProcessor messagingProcessor = createMessagingProcessor();

List<AccessValidator> accessValidators = loadAccessValidators();
// create grpcServer
GrpcServer grpcServer = GrpcServerBuilder.newBuilder(executor, ConfigurationManager.getProxyConfig().getGrpcServerPort())
.addService(createServiceProcessor(messagingProcessor))
.addService(ChannelzService.newInstance(100))
.addService(ProtoReflectionService.newInstance())
.configInterceptor()
.configInterceptor(accessValidators)
.build();
PROXY_START_AND_SHUTDOWN.appendStartAndShutdown(grpcServer);

RemotingProtocolServer remotingServer = new RemotingProtocolServer(messagingProcessor);
RemotingProtocolServer remotingServer = new RemotingProtocolServer(messagingProcessor, accessValidators);
PROXY_START_AND_SHUTDOWN.appendStartAndShutdown(remotingServer);

// start servers one by one.
Expand All @@ -109,6 +113,15 @@ public static void main(String[] args) {
log.info(new Date() + " rocketmq-proxy startup successfully");
}

protected static List<AccessValidator> loadAccessValidators() {
List<AccessValidator> accessValidators = ServiceProvider.load(AccessValidator.class);
if (accessValidators.isEmpty()) {
log.info("ServiceProvider loaded no AccessValidator, using default org.apache.rocketmq.acl.plain.PlainAccessValidator");
accessValidators.add(new PlainAccessValidator());
}
return accessValidators;
}

protected static void initConfiguration(CommandLineArgument commandLineArgument) throws Exception {
if (StringUtils.isNotBlank(commandLineArgument.getProxyConfigPath())) {
System.setProperty(Configuration.CONFIG_PATH_PROPERTY, commandLineArgument.getProxyConfigPath());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -28,9 +28,7 @@
import java.util.concurrent.ThreadPoolExecutor;
import java.util.concurrent.TimeUnit;
import org.apache.rocketmq.acl.AccessValidator;
import org.apache.rocketmq.acl.plain.PlainAccessValidator;
import org.apache.rocketmq.common.constant.LoggerName;
import org.apache.rocketmq.common.utils.ServiceProvider;
import org.apache.rocketmq.logging.org.slf4j.Logger;
import org.apache.rocketmq.logging.org.slf4j.LoggerFactory;
import org.apache.rocketmq.proxy.config.ConfigurationManager;
Expand Down Expand Up @@ -98,14 +96,8 @@ public GrpcServer build() {
return new GrpcServer(this.serverBuilder.build());
}

public GrpcServerBuilder configInterceptor() {
public GrpcServerBuilder configInterceptor(List<AccessValidator> accessValidators) {
// grpc interceptors, including acl, logging etc.
List<AccessValidator> accessValidators = ServiceProvider.load(AccessValidator.class);
if (accessValidators.isEmpty()) {
log.info("ServiceProvider loaded no AccessValidator, using default org.apache.rocketmq.acl.plain.PlainAccessValidator");
accessValidators.add(new PlainAccessValidator());
}

this.serverBuilder.intercept(new AuthenticationInterceptor(accessValidators));

this.serverBuilder
Expand Down
Loading
Loading