Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions fe/fe-core/src/main/java/org/apache/doris/analysis/LoadStmt.java
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@
import com.google.common.base.Joiner;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.checkerframework.checker.nullness.qual.Nullable;
Expand Down Expand Up @@ -522,7 +523,6 @@ public void analyze(Analyzer analyzer) throws UserException {
user = ConnectContext.get().getQualifiedUser();
}


private String getProviderFromEndpoint() {
Map<String, String> properties = brokerDesc.getProperties();
for (Map.Entry<String, String> entry : properties.entrySet()) {
Expand All @@ -534,13 +534,18 @@ private String getProviderFromEndpoint() {
return S3Properties.S3_PROVIDER;
}

private String getBucketFromFilePath(String filePath) throws Exception {
private Pair<String, String> getBucketAndObjectFromPath(String filePath) throws UserException {
String[] parts = filePath.split("\\/\\/");
if (parts.length < 2) {
throw new Exception("filePath is not valid");
throw new UserException("Invalid file path format: " + filePath);
}

String[] bucketAndObject = parts[1].split("\\/", 2);
if (bucketAndObject.length < 2) {
throw new UserException("Cannot extract bucket and object from path: " + filePath);
}
String buckt = parts[1].split("\\/")[0];
return buckt;

return Pair.of(bucketAndObject[0], bucketAndObject[1]);
}

public String getComment() {
Expand Down Expand Up @@ -601,7 +606,11 @@ public RedirectStatus getRedirectStatus() {
private void checkEndpoint(String endpoint) throws UserException {
HttpURLConnection connection = null;
try {
String urlStr = "http://" + endpoint;
String urlStr = endpoint;
// Add default protocol if not specified
if (!endpoint.startsWith("http://") && !endpoint.startsWith("https://")) {
urlStr = "http://" + endpoint;
}
SecurityChecker.getInstance().startSSRFChecking(urlStr);
URL url = new URL(urlStr);
connection = (HttpURLConnection) url.openConnection();
Expand Down Expand Up @@ -636,9 +645,6 @@ public void checkS3Param() throws UserException {
&& brokerDescProperties.containsKey(S3Properties.Env.SECRET_KEY)
&& brokerDescProperties.containsKey(S3Properties.Env.REGION)) {
String endpoint = brokerDescProperties.get(S3Properties.Env.ENDPOINT);
endpoint = endpoint.replaceFirst("^http://", "");
endpoint = endpoint.replaceFirst("^https://", "");
brokerDescProperties.put(S3Properties.Env.ENDPOINT, endpoint);
checkWhiteList(endpoint);
if (AzureProperties.checkAzureProviderPropertyExist(brokerDescProperties)) {
return;
Expand All @@ -649,6 +655,8 @@ public void checkS3Param() throws UserException {
}

public void checkWhiteList(String endpoint) throws UserException {
endpoint = endpoint.replaceFirst("^http://", "");
endpoint = endpoint.replaceFirst("^https://", "");
List<String> whiteList = new ArrayList<>(Arrays.asList(Config.s3_load_endpoint_white_list));
whiteList.removeIf(String::isEmpty);
if (!whiteList.isEmpty() && !whiteList.contains(endpoint)) {
Expand All @@ -667,15 +675,19 @@ private void checkAkSk() throws UserException {
for (DataDescription dataDescription : dataDescriptions) {
for (String filePath : dataDescription.getFilePaths()) {
curFile = filePath;
String bucket = getBucketFromFilePath(filePath);
Pair<String, String> pair = getBucketAndObjectFromPath(filePath);
String bucket = pair.getLeft();
String object = pair.getRight();
objectInfo = new ObjectInfo(ObjectStoreInfoPB.Provider.valueOf(provider.toUpperCase()),
brokerDescProperties.get(S3Properties.Env.ACCESS_KEY),
brokerDescProperties.get(S3Properties.Env.SECRET_KEY),
bucket, brokerDescProperties.get(S3Properties.Env.ENDPOINT),
brokerDescProperties.get(S3Properties.Env.REGION), "");
remote = RemoteBase.newInstance(objectInfo);
// Verify read permissions by calling headObject() on the S3 object.
// RemoteBase#headObject does not throw exception if key does not exist.
remote.headObject("1");
remote.headObject(object);
// Verify list permissions by calling listObjects() on the S3 bucket.
remote.listObjects(null);
remote.close();
}
Expand Down
Loading