Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Auth for scan #1

Open
wants to merge 15 commits into
base: master
Choose a base branch
from
2 changes: 1 addition & 1 deletion client/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<parent>
<groupId>com.vesoft</groupId>
<artifactId>nebula</artifactId>
<version>3.0-SNAPSHOT</version>
<version>3.7.0-auth</version>
</parent>

<modelVersion>4.0.0</modelVersion>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,14 @@

import com.vesoft.nebula.HostAddr;
import com.vesoft.nebula.client.graph.data.HostAddress;
import com.vesoft.nebula.client.graph.data.ResultSet;
import com.vesoft.nebula.client.graph.data.SSLParam;
import com.vesoft.nebula.client.graph.data.ValueWrapper;
import com.vesoft.nebula.client.graph.exception.AuthFailedException;
import com.vesoft.nebula.client.graph.exception.ClientServerIncompatibleException;
import com.vesoft.nebula.client.graph.exception.IOErrorException;
import com.vesoft.nebula.client.graph.net.AuthResult;
import com.vesoft.nebula.client.graph.net.SyncConnection;
import com.vesoft.nebula.client.meta.MetaManager;
import com.vesoft.nebula.client.storage.scan.PartScanInfo;
import com.vesoft.nebula.client.storage.scan.ScanEdgeResultIterator;
Expand All @@ -19,10 +26,13 @@
import com.vesoft.nebula.storage.ScanVertexRequest;
import com.vesoft.nebula.storage.VertexProp;
import java.io.Serializable;
import java.io.UnsupportedEncodingException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
Expand All @@ -41,6 +51,13 @@ public class StorageClient implements Serializable {
private boolean enableSSL = false;
private SSLParam sslParam = null;

private String user = null;
private String password = null;

private String graphAddress = null;

// the write list for users with read permission
private Map<String, List<String>> spaceLabelWriteList = null;
private String version = null;

/**
Expand Down Expand Up @@ -94,6 +111,24 @@ public StorageClient(List<HostAddress> addresses, int timeout, int connectionRet
}
}

public StorageClient setUser(String user) {
this.user = user;
return this;
}

public StorageClient setPassword(String password) {
this.password = password;
return this;
}

public String getGraphAddress() {
return graphAddress;
}

public void setGraphAddress(String graphAddress) {
this.graphAddress = graphAddress;
}

public StorageClient setVersion(String version) {
this.version = version;
return this;
Expand All @@ -105,6 +140,7 @@ public StorageClient setVersion(String version) {
* @return true if connect successfully.
*/
public boolean connect() throws Exception {
authUser();
connection.open(addresses.get(0), timeout, enableSSL, sslParam);
StoragePoolConfig config = new StoragePoolConfig();
config.setEnableSSL(enableSSL);
Expand Down Expand Up @@ -561,6 +597,15 @@ private ScanVertexResultIterator scanVertex(String spaceName,
partScanInfoSet.add(new PartScanInfo(part, new HostAddress(leader.getHost(),
leader.getPort())));
}

// check the user permission after the 'getLeader',
// if the space is not exist, 'getLeader' can throw it first.
if (!checkWriteList(spaceName, tagName)) {
throw new IllegalArgumentException(
String.format("user %s has no read permission for %s.%s", user, spaceName,
tagName));
}

List<HostAddress> addrs = new ArrayList<>();
for (HostAddr addr : metaManager.listHosts()) {
addrs.add(new HostAddress(addr.getHost(), addr.getPort()));
Expand Down Expand Up @@ -1004,6 +1049,10 @@ private ScanEdgeResultIterator scanEdge(String spaceName,
long endTime,
boolean allowPartSuccess,
boolean allowReadFromFollower) {
if (!checkWriteList(spaceName, edgeName)) {
throw new IllegalArgumentException(
String.format("user has no read permission for %s.%s", spaceName, edgeName));
}
if (spaceName == null || spaceName.trim().isEmpty()) {
throw new IllegalArgumentException("space name is empty.");
}
Expand Down Expand Up @@ -1139,6 +1188,110 @@ private long getEdgeId(String spaceName, String edgeName) {
return metaManager.getEdge(spaceName, edgeName).getEdge_type();
}


/**
* auth user with graphd server, and get the space and labels WriteList with read permission
* for user
*/
private void authUser() throws AuthFailedException, IOErrorException,
ClientServerIncompatibleException, UnsupportedEncodingException {
if (user == null || password == null || graphAddress == null) {
throw new IllegalArgumentException(
"the user,password,graphAddress can not be null,"
+ " please config them first by setXXX()");
}
SyncConnection graphConnection = new SyncConnection();
String[] graphAddrAndPort = graphAddress.split(":");
if (graphAddrAndPort.length != 2) {
throw new IllegalArgumentException("the graph address is invalid.");
}
if (sslParam == null) {
graphConnection.open(new HostAddress(graphAddrAndPort[0].trim(),
Integer.valueOf(graphAddrAndPort[1].trim())), timeout, false,
new HashMap<>(),
version);
} else {
graphConnection.open(new HostAddress(graphAddrAndPort[0].trim(),
Integer.valueOf(graphAddrAndPort[1].trim())), timeout, sslParam, false,
new HashMap<>(),
version);
}
AuthResult authResult = graphConnection.authenticate(user, password);
long sessionId = authResult.getSessionId();

if (user.equals("root")) {
return;
}

spaceLabelWriteList = new HashMap<>();
ResultSet resultSet = new ResultSet(
graphConnection.execute(sessionId, "DESC USER " + user),
authResult.getTimezoneOffset());
if (!resultSet.isSucceeded()) {
throw new RuntimeException("get spaces for user " + user + " failed, "
+ resultSet.getErrorMessage());
}
if (resultSet.isEmpty()) {
throw new RuntimeException("there's no space for user " + user + " to have permission"
+ " to access.");
}

for (int i = 0; i < resultSet.getRows().size(); i++) {
List<ValueWrapper> values = resultSet.rowValues(i).values();
String role = values.get(0).asString();
String space = values.get(1).asString();
if (!role.equalsIgnoreCase("BASIC")) {
spaceLabelWriteList.put(space, null);
} else {
List<String> labels = new ArrayList<>();
// get the tags and edges that the user has read permission for
String showGrants = String.format("USE %s; show grants %s", space, user);
ResultSet userGrantResult = new ResultSet(graphConnection.execute(sessionId,
showGrants),
authResult.getTimezoneOffset());
if (!userGrantResult.isSucceeded()) {
throw new RuntimeException("get tags for user " + user
+ " failed, " + userGrantResult.getErrorMessage());
}
List<ValueWrapper> readTags = userGrantResult.colValues("READ(TAG)");
if (!readTags.isEmpty()) {
for (ValueWrapper v : readTags.get(0).asList()) {
labels.add(v.asString());
}
}
List<ValueWrapper> readEdges = userGrantResult.colValues("READ(EDGE)");
if (!readEdges.isEmpty()) {
for (ValueWrapper v : readEdges.get(0).asList()) {
labels.add(v.asString());
}
}
spaceLabelWriteList.put(space, labels);
}
}
}


/**
* check if the space and the label is in the WriteList
*
* @param spaceName space name
* @param label tag name or edge type name
* @return true if spaceName and label in the WriteList
*/
private boolean checkWriteList(String spaceName, String label) {
if (spaceLabelWriteList == null) {
return true;
}
if (!spaceLabelWriteList.containsKey(spaceName)) {
return false;
}
if (spaceLabelWriteList.get(spaceName) != null
&& !spaceLabelWriteList.get(spaceName).contains(label)) {
return false;
}
return true;
}

private static final int DEFAULT_LIMIT = 1000;
private static final long DEFAULT_START_TIME = 0;
private static final long DEFAULT_END_TIME = Long.MAX_VALUE;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
import com.vesoft.nebula.client.graph.data.CASignedSSLParam;
import com.vesoft.nebula.client.graph.data.HostAddress;
import com.vesoft.nebula.client.graph.data.SSLParam;
import com.vesoft.nebula.client.graph.data.SelfSignedSSLParam;
import com.vesoft.nebula.client.storage.data.EdgeRow;
import com.vesoft.nebula.client.storage.data.EdgeTableRow;
import com.vesoft.nebula.client.storage.data.VertexRow;
Expand All @@ -17,8 +16,6 @@
import com.vesoft.nebula.client.storage.scan.ScanEdgeResultIterator;
import com.vesoft.nebula.client.storage.scan.ScanVertexResult;
import com.vesoft.nebula.client.storage.scan.ScanVertexResultIterator;
import com.vesoft.nebula.client.util.ProcessUtil;
import java.io.IOException;
import java.io.UnsupportedEncodingException;
import java.util.Arrays;
import java.util.List;
Expand All @@ -43,6 +40,9 @@ public void before() {
assert (false);
}
client = new StorageClient(address);
client.setGraphAddress(ip + ":9669");
client.setUser("root");
client.setPassword("nebula");
}

@After
Expand All @@ -57,6 +57,9 @@ public void testStorageClientWithVersionInWhiteList() {
List<HostAddress> address = Arrays.asList(new HostAddress(ip, 9559));
StorageClient storageClient = new StorageClient(address);
try {
storageClient.setGraphAddress("127.0.0.1:9669");
storageClient.setUser("root");
storageClient.setPassword("nebula");
storageClient.setVersion("3.0.0");
assert (storageClient.connect());

Expand All @@ -73,6 +76,9 @@ public void testStorageClientWithVersionNotInWhiteList() {
List<HostAddress> address = Arrays.asList(new HostAddress(ip, 9559));
StorageClient storageClient = new StorageClient(address);
try {
storageClient.setGraphAddress("127.0.0.1:9669");
storageClient.setUser("root");
storageClient.setPassword("nebula");
storageClient.setVersion("INVALID_VERSION");
storageClient.connect();
assert false;
Expand Down Expand Up @@ -417,13 +423,18 @@ public void testCASignedSSL() {
"src/test/resources/ssl/client.crt",
"src/test/resources/ssl/client.key");
sslClient = new StorageClient(address, 1000, 1, 1, true, sslParam);
sslClient.setGraphAddress("127.0.0.1:8669");
sslClient.setUser("root");
sslClient.setPassword("nebula");
sslClient.setVersion("3.0.0");
sslClient.connect();

ScanVertexResultIterator resultIterator = sslClient.scanVertex(
"testStorageCA",
"person");
assertIterator(resultIterator);
} catch (Exception e) {
System.out.println("scan failed for cs ssl." + e.getMessage());
e.printStackTrace();
assert (false);
} finally {
Expand Down
2 changes: 1 addition & 1 deletion examples/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<parent>
<groupId>com.vesoft</groupId>
<artifactId>nebula</artifactId>
<version>3.0-SNAPSHOT</version>
<version>3.7.0-auth</version>
</parent>

<modelVersion>4.0.0</modelVersion>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,10 @@ public static void main(String[] args) {
// input params are the metad's ip and port
StorageClient client = new StorageClient("127.0.0.1", 9559);
try {
client.setGraphAddress("127.0.0.1:9669");
client.setUser("root");
client.setPassword("nebula");
client.setVersion("test");
client.connect();
} catch (Exception e) {
LOGGER.error("storage client connect error, ", e);
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
<groupId>com.vesoft</groupId>
<artifactId>nebula</artifactId>
<packaging>pom</packaging>
<version>3.0-SNAPSHOT</version>
<version>3.7.0-auth</version>

<properties>
<project.build.sourceEncoding>UTF-8</project.build.sourceEncoding>
Expand Down
Loading