Skip to content

Commit

Permalink
Use nimbus-oauth2-oidc instead of scribe-java
Browse files Browse the repository at this point in the history
Nimbus OAuth 2.0 SDK, unlike scribe-java, includes first-class support
for OpenID connect extensions. Also, the API looks like a better fit for
Trino. Instead of proving the service layer with methods such as:
`getToken` `oauth2-oidc-sdk` provides a convenient way to build
a request which then can be easily translated to Airlift's HttpRequest.
  • Loading branch information
lukasz-walkiewicz authored and kokosing committed May 9, 2022
1 parent 49348fd commit 725cbb2
Show file tree
Hide file tree
Showing 17 changed files with 700 additions and 735 deletions.
20 changes: 10 additions & 10 deletions core/trino-main/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -207,16 +207,6 @@
<artifactId>oshi-core</artifactId>
</dependency>

<dependency>
<groupId>com.github.scribejava</groupId>
<artifactId>scribejava-apis</artifactId>
</dependency>

<dependency>
<groupId>com.github.scribejava</groupId>
<artifactId>scribejava-core</artifactId>
</dependency>

<dependency>
<groupId>com.google.code.findbugs</groupId>
<artifactId>jsr305</artifactId>
Expand All @@ -237,6 +227,16 @@
<artifactId>guice</artifactId>
</dependency>

<dependency>
<groupId>com.nimbusds</groupId>
<artifactId>nimbus-jose-jwt</artifactId>
</dependency>

<dependency>
<groupId>com.nimbusds</groupId>
<artifactId>oauth2-oidc-sdk</artifactId>
</dependency>

<dependency>
<groupId>com.teradata</groupId>
<artifactId>re2j-td</artifactId>
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,137 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.server.security.oauth2;

import com.google.common.collect.ImmutableMultimap;
import com.nimbusds.jose.util.Resource;
import com.nimbusds.oauth2.sdk.ParseException;
import com.nimbusds.oauth2.sdk.http.HTTPRequest;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;
import io.airlift.http.client.HttpClient;
import io.airlift.http.client.Request;
import io.airlift.http.client.Response;
import io.airlift.http.client.ResponseHandler;
import io.airlift.http.client.ResponseHandlerUtils;
import io.airlift.http.client.StringResponseHandler;

import javax.inject.Inject;
import javax.ws.rs.core.UriBuilder;

import java.io.IOException;
import java.net.URISyntaxException;
import java.net.URL;

import static com.google.common.net.HttpHeaders.CONTENT_TYPE;
import static com.nimbusds.oauth2.sdk.http.HTTPRequest.Method.DELETE;
import static com.nimbusds.oauth2.sdk.http.HTTPRequest.Method.GET;
import static com.nimbusds.oauth2.sdk.http.HTTPRequest.Method.POST;
import static com.nimbusds.oauth2.sdk.http.HTTPRequest.Method.PUT;
import static io.airlift.http.client.Request.Builder.prepareGet;
import static io.airlift.http.client.StaticBodyGenerator.createStaticBodyGenerator;
import static io.airlift.http.client.StringResponseHandler.createStringResponseHandler;
import static java.lang.String.format;
import static java.nio.charset.StandardCharsets.UTF_8;
import static java.util.Objects.requireNonNull;

public class NimbusAirliftHttpClient
implements NimbusHttpClient
{
private final HttpClient httpClient;

@Inject
public NimbusAirliftHttpClient(@ForOAuth2 HttpClient httpClient)
{
this.httpClient = requireNonNull(httpClient, "httpClient is null");
}

@Override
public Resource retrieveResource(URL url)
throws IOException
{
try {
StringResponseHandler.StringResponse response = httpClient.execute(
prepareGet().setUri(url.toURI()).build(),
createStringResponseHandler());
return new Resource(response.getBody(), response.getHeader(CONTENT_TYPE));
}
catch (URISyntaxException e) {
throw new RuntimeException(e);
}
}

@Override
public <T> T execute(com.nimbusds.oauth2.sdk.Request nimbusRequest, Parser<T> parser)
{
HTTPRequest httpRequest = nimbusRequest.toHTTPRequest();
HTTPRequest.Method method = httpRequest.getMethod();

Request.Builder request = new Request.Builder()
.setMethod(method.name())
.setFollowRedirects(httpRequest.getFollowRedirects());

UriBuilder url = UriBuilder.fromUri(httpRequest.getURI());
if (method.equals(GET) || method.equals(DELETE)) {
httpRequest.getQueryParameters().forEach((key, value) -> url.queryParam(key, value.toArray()));
}

url.fragment(httpRequest.getFragment());

request.setUri(url.build());

ImmutableMultimap.Builder<String, String> headers = ImmutableMultimap.builder();
httpRequest.getHeaderMap().forEach(headers::putAll);
request.addHeaders(headers.build());

if (method.equals(POST) || method.equals(PUT)) {
String query = httpRequest.getQuery();
if (query != null) {
request.setBodyGenerator(createStaticBodyGenerator(httpRequest.getQuery(), UTF_8));
}
}
return httpClient.execute(request.build(), new NimbusResponseHandler<>(parser));
}

public static class NimbusResponseHandler<T>
implements ResponseHandler<T, RuntimeException>
{
private final StringResponseHandler handler = createStringResponseHandler();
private final Parser<T> parser;

public NimbusResponseHandler(Parser<T> parser)
{
this.parser = requireNonNull(parser, "parser is null");
}

@Override
public T handleException(Request request, Exception exception)
{
throw ResponseHandlerUtils.propagate(request, exception);
}

@Override
public T handle(Request request, Response response)
{
StringResponseHandler.StringResponse stringResponse = handler.handle(request, response);
HTTPResponse nimbusResponse = new HTTPResponse(response.getStatusCode());
response.getHeaders().asMap().forEach((name, values) -> nimbusResponse.setHeader(name.toString(), values.toArray(new String[0])));
nimbusResponse.setContent(stringResponse.getBody());
try {
return parser.parse(nimbusResponse);
}
catch (ParseException e) {
throw new RuntimeException(format("Unable to parse response status=[%d], body=[%s]", stringResponse.getStatusCode(), stringResponse.getBody()), e);
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
/*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package io.trino.server.security.oauth2;

import com.nimbusds.jose.util.ResourceRetriever;
import com.nimbusds.oauth2.sdk.ParseException;
import com.nimbusds.oauth2.sdk.Request;
import com.nimbusds.oauth2.sdk.http.HTTPResponse;

public interface NimbusHttpClient
extends ResourceRetriever
{
<T> T execute(Request nimbusRequest, Parser<T> parser);

interface Parser<T>
{
T parse(HTTPResponse response)
throws ParseException;
}
}
Loading

0 comments on commit 725cbb2

Please sign in to comment.