Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -17,21 +17,25 @@

package org.apache.uniffle.coordinator;

import java.util.Collections;
import java.util.Map;
import java.util.Set;

import com.google.common.collect.Sets;

public class AccessInfo {
private final String accessId;
private final Set<String> tags;
private final Map<String, String> extraProperties;

public AccessInfo(String accessId, Set<String> tags) {
public AccessInfo(String accessId, Set<String> tags, Map<String, String> extraProperties) {
this.accessId = accessId;
this.tags = tags;
this.extraProperties = extraProperties == null ? Collections.emptyMap() : extraProperties;
}

public AccessInfo(String accessId) {
this(accessId, Sets.newHashSet());
this(accessId, Sets.newHashSet(), Collections.emptyMap());
}

public String getAccessId() {
Expand All @@ -42,11 +46,16 @@ public Set<String> getTags() {
return tags;
}

public Map<String, String> getExtraProperties() {
return extraProperties;
}

@Override
public String toString() {
return "AccessInfo{"
+ "accessId='" + accessId + '\''
+ ", tags=" + tags
+ '}';
+ "accessId='" + accessId + '\''
+ ", tags=" + tags
+ ", extraProperties=" + extraProperties
+ '}';
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,12 @@ public void accessCluster(AccessClusterRequest request, StreamObserver<AccessClu
AccessClusterResponse response;
AccessManager accessManager = coordinatorServer.getAccessManager();

AccessInfo accessInfo = new AccessInfo(request.getAccessId(), Sets.newHashSet(request.getTagsList()));
AccessInfo accessInfo =
new AccessInfo(
request.getAccessId(),
Sets.newHashSet(request.getTagsList()),
request.getExtraPropertiesMap()
);
AccessCheckResult result = accessManager.handleAccessRequest(accessInfo);
if (!result.isSuccess()) {
statusCode = StatusCode.ACCESS_DENIED;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@

package org.apache.uniffle.coordinator;

import java.util.Collections;
import java.util.Random;

import com.google.common.collect.Sets;
Expand Down Expand Up @@ -66,7 +67,7 @@ public void test() throws Exception {
AccessManager accessManager = new AccessManager(conf, null, new Configuration());
assertTrue(accessManager.handleAccessRequest(
new AccessInfo(String.valueOf(new Random().nextInt()),
Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION)))
Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION), Collections.emptyMap()))
.isSuccess());
accessManager.close();
// test mock checkers
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,30 +18,98 @@
package org.apache.uniffle.test;

import java.io.FileWriter;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.File;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.TimeUnit;

import com.google.common.collect.Sets;
import com.google.common.util.concurrent.Uninterruptibles;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.io.TempDir;

import org.apache.uniffle.client.api.CoordinatorClient;
import org.apache.uniffle.client.factory.CoordinatorClientFactory;
import org.apache.uniffle.client.request.RssAccessClusterRequest;
import org.apache.uniffle.client.response.ResponseStatusCode;
import org.apache.uniffle.client.response.RssAccessClusterResponse;
import org.apache.uniffle.common.util.Constants;
import org.apache.uniffle.coordinator.AccessCheckResult;
import org.apache.uniffle.coordinator.AccessChecker;
import org.apache.uniffle.coordinator.AccessInfo;
import org.apache.uniffle.coordinator.AccessManager;
import org.apache.uniffle.coordinator.CoordinatorConf;
import org.apache.uniffle.server.ShuffleServer;
import org.apache.uniffle.server.ShuffleServerConf;
import org.junit.jupiter.api.io.TempDir;

import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertTrue;

public class AccessClusterTest extends CoordinatorTestBase {

public static class MockedAccessChecker implements AccessChecker {
final String key = "key";
final List<String> legalNames = Arrays.asList("v1", "v2", "v3");

public MockedAccessChecker(AccessManager accessManager) throws Exception {
// ignore
}

@Override
public AccessCheckResult check(AccessInfo accessInfo) {
Map<String, String> reservedData = accessInfo.getExtraProperties();
if (legalNames.contains(reservedData.get(key))) {
return new AccessCheckResult(true, "");
}
return new AccessCheckResult(false, "");
}

@Override
public void close() throws IOException {
// ignore.
}
}

@Test
public void testUsingCustomExtraProperties() throws Exception {
CoordinatorConf coordinatorConf = getCoordinatorConf();
coordinatorConf.setString(
"rss.coordinator.access.checkers",
"org.apache.uniffle.test.AccessClusterTest$MockedAccessChecker");
createCoordinatorServer(coordinatorConf);
startServers();
Uninterruptibles.sleepUninterruptibly(3, TimeUnit.SECONDS);

// case1: empty map
String accessID = "acessid";
RssAccessClusterRequest request = new RssAccessClusterRequest(
accessID, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION), 2000);
RssAccessClusterResponse response = coordinatorClient.accessCluster(request);
assertEquals(ResponseStatusCode.ACCESS_DENIED, response.getStatusCode());

// case2: illegal names
Map<String, String> extraProperties = new HashMap<>();
extraProperties.put("key", "illegalName");
request = new RssAccessClusterRequest(
accessID, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION), 2000, extraProperties);
response = coordinatorClient.accessCluster(request);
assertEquals(ResponseStatusCode.ACCESS_DENIED, response.getStatusCode());

// case3: legal names
extraProperties.clear();
extraProperties.put("key", "v1");
request = new RssAccessClusterRequest(
accessID, Sets.newHashSet(Constants.SHUFFLE_SERVER_VERSION), 2000, extraProperties);
response = coordinatorClient.accessCluster(request);
assertEquals(ResponseStatusCode.SUCCESS, response.getStatusCode());

shutdownServers();
}

@Test
public void test(@TempDir File tempDir) throws Exception {
File cfgFile = File.createTempFile("tmp", ".conf", tempDir);
Expand All @@ -57,14 +125,13 @@ public void test(@TempDir File tempDir) throws Exception {
coordinatorConf.setInteger("rss.coordinator.access.loadChecker.serverNum.threshold", 2);
coordinatorConf.setString("rss.coordinator.access.candidates.path", cfgFile.getAbsolutePath());
coordinatorConf.setString(
"rss.coordinator.access.checkers",
"org.apache.uniffle.coordinator.AccessCandidatesChecker,org.apache.uniffle.coordinator.AccessClusterLoadChecker");
"rss.coordinator.access.checkers",
"org.apache.uniffle.coordinator.AccessCandidatesChecker,org.apache.uniffle.coordinator.AccessClusterLoadChecker");
createCoordinatorServer(coordinatorConf);

ShuffleServerConf shuffleServerConf = getShuffleServerConf();
createShuffleServer(shuffleServerConf);
startServers();

Uninterruptibles.sleepUninterruptibly(3, TimeUnit.SECONDS);
String accessId = "111111";
RssAccessClusterRequest request = new RssAccessClusterRequest(
Expand Down Expand Up @@ -100,6 +167,7 @@ public void test(@TempDir File tempDir) throws Exception {
assertEquals(ResponseStatusCode.SUCCESS, response.getStatusCode());
assertTrue(response.getMessage().startsWith("SUCCESS"));
shuffleServer.stopServer();
shutdownServers();
}
}

Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,7 @@ public RssAccessClusterResponse accessCluster(RssAccessClusterRequest request) {
.newBuilder()
.setAccessId(request.getAccessId())
.addAllTags(request.getTags())
.putAllExtraProperties(request.getExtraProperties())
.build();
AccessClusterResponse rpcResponse;
try {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,18 +17,37 @@

package org.apache.uniffle.client.request;

import java.util.Collections;
import java.util.Map;
import java.util.Set;

public class RssAccessClusterRequest {

private final String accessId;
private final Set<String> tags;
private final int timeoutMs;
/**
* The map is to pass the extra data to the coordinator and to
* extend more pluggable {@code AccessCheckers} easily.
*/
private final Map<String, String> extraProperties;

public RssAccessClusterRequest(String accessId, Set<String> tags, int timeoutMs) {
this.accessId = accessId;
this.tags = tags;
this.timeoutMs = timeoutMs;
this.extraProperties = Collections.emptyMap();
}

public RssAccessClusterRequest(
String accessId,
Set<String> tags,
int timeoutMs,
Map<String, String> extraProperties) {
this.accessId = accessId;
this.tags = tags;
this.timeoutMs = timeoutMs;
this.extraProperties = extraProperties;
}

public String getAccessId() {
Expand All @@ -42,4 +61,8 @@ public Set<String> getTags() {
public int getTimeoutMs() {
return timeoutMs;
}

public Map<String, String> getExtraProperties() {
return extraProperties;
}
}
1 change: 1 addition & 0 deletions proto/src/main/proto/Rss.proto
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,7 @@ message CheckServiceAvailableResponse {
message AccessClusterRequest {
string accessId = 1;
repeated string tags = 2;
map<string, string> extraProperties = 3;
}

message AccessClusterResponse {
Expand Down