Skip to content

Commit

Permalink
Added ability to override default timeouts for OkHttpClient (#206)
Browse files Browse the repository at this point in the history
* Added ability to override default timeouts for OkHttpClient

* Update ability to override OkHttp connect, read, and write timeouts individually

* fix formatting on test params
  • Loading branch information
nolivermke authored and lbalmaceda committed Dec 3, 2018
1 parent 511f2e0 commit 34bc106
Show file tree
Hide file tree
Showing 6 changed files with 118 additions and 16 deletions.
43 changes: 42 additions & 1 deletion auth0/src/main/java/com/auth0/android/Auth0.java
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,9 @@ public class Auth0 {
private boolean oidcConformant;
private boolean loggingEnabled;
private boolean tls12Enforced;

private int connectTimeoutInSeconds;
private int readTimeoutInSeconds;
private int writeTimeoutInSeconds;
/**
* Creates a new Auth0 instance with the 'com_auth0_client_id' and 'com_auth0_domain' values
* defined in the project String resources file.
Expand Down Expand Up @@ -141,6 +143,21 @@ public Telemetry getTelemetry() {
return telemetry;
}

/**
* @return Auth0 request connectTimeoutInSeconds
*/
public int getConnectTimeoutInSeconds(){ return connectTimeoutInSeconds; }

/**
* @return Auth0 request readTimeoutInSeconds
*/
public int getReadTimeoutInSeconds(){ return readTimeoutInSeconds; }

/**
* @return Auth0 request writeTimeoutInSeconds
*/
public int getWriteTimeoutInSeconds(){ return writeTimeoutInSeconds; }

/**
* Setter for the Telemetry to send in every request to Auth0.
*
Expand Down Expand Up @@ -223,6 +240,30 @@ public void setTLS12Enforced(boolean enforced) {
tls12Enforced = enforced;
}

/**
* Override default connection timeout for requests
* @param timeout
*/
public void setConnectTimeoutInSeconds(int timeout){
this.connectTimeoutInSeconds = timeout;
}

/**
* Override default read timeout for requests
* @param timeout
*/
public void setReadTimeoutInSeconds(int timeout){
this.readTimeoutInSeconds = timeout;
}

/**
* Override default write timeout for requests
* @param timeout
*/
public void setWriteTimeoutInSeconds(int timeout){
this.writeTimeoutInSeconds = timeout;
}

private HttpUrl resolveConfiguration(@Nullable String configurationDomain, @NonNull HttpUrl domainUrl) {
HttpUrl url = ensureValidUrl(configurationDomain);
if (url == null) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,11 @@ public AuthenticationAPIClient(Context context) {

private AuthenticationAPIClient(Auth0 auth0, RequestFactory factory, OkHttpClientFactory clientFactory, Gson gson) {
this.auth0 = auth0;
this.client = clientFactory.createClient(auth0.isLoggingEnabled(), auth0.isTLS12Enforced());
this.client = clientFactory.createClient(auth0.isLoggingEnabled(),
auth0.isTLS12Enforced(),
auth0.getConnectTimeoutInSeconds(),
auth0.getReadTimeoutInSeconds(),
auth0.getWriteTimeoutInSeconds());
this.gson = gson;
this.factory = factory;
this.authErrorBuilder = new AuthenticationErrorBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -101,7 +101,11 @@ public UsersAPIClient(Context context, String token) {

private UsersAPIClient(Auth0 auth0, RequestFactory factory, OkHttpClientFactory clientFactory, Gson gson) {
this.auth0 = auth0;
client = clientFactory.createClient(auth0.isLoggingEnabled(), auth0.isTLS12Enforced());
client = clientFactory.createClient(auth0.isLoggingEnabled(),
auth0.isTLS12Enforced(),
auth0.getConnectTimeoutInSeconds(),
auth0.getReadTimeoutInSeconds(),
auth0.getWriteTimeoutInSeconds());
this.gson = gson;
this.factory = factory;
this.mgmtErrorBuilder = new ManagementErrorBuilder();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.TimeUnit;

import javax.net.ssl.SSLContext;

Expand All @@ -34,20 +35,32 @@ public class OkHttpClientFactory {
*
* @param loggingEnabled Enable logging in the created OkHttpClient.
* @param tls12Enforced Enforce TLS 1.2 in the created OkHttpClient on devices with API 16-21
* @param connectTimeout Override default connect timeout for OkHttpClient
* @param readTimeout Override default read timeout for OkHttpClient
* @param writeTimeout Override default write timeout for OkHttpClient
* @return new OkHttpClient instance created according to the parameters.
*/
public OkHttpClient createClient(boolean loggingEnabled, boolean tls12Enforced) {
return modifyClient(new OkHttpClient(), loggingEnabled, tls12Enforced);
public OkHttpClient createClient(boolean loggingEnabled, boolean tls12Enforced, int connectTimeout, int readTimeout, int writeTimeout) {
return modifyClient(new OkHttpClient(), loggingEnabled, tls12Enforced, connectTimeout, readTimeout, writeTimeout);
}

@VisibleForTesting
OkHttpClient modifyClient(OkHttpClient client, boolean loggingEnabled, boolean tls12Enforced) {
OkHttpClient modifyClient(OkHttpClient client, boolean loggingEnabled, boolean tls12Enforced, int connectTimeout, int readTimeout, int writeTimeout) {
if (loggingEnabled) {
enableLogging(client);
}
if (tls12Enforced) {
enforceTls12(client);
}
if(connectTimeout > 0){
client.setConnectTimeout(connectTimeout, TimeUnit.SECONDS);
}
if(readTimeout > 0){
client.setReadTimeout(readTimeout, TimeUnit.SECONDS);
}
if(writeTimeout > 0){
client.setWriteTimeout(writeTimeout, TimeUnit.SECONDS);
}
client.setProtocols(Arrays.asList(Protocol.HTTP_1_1, Protocol.SPDY_3));
return client;
}
Expand Down
24 changes: 24 additions & 0 deletions auth0/src/test/java/com/auth0/android/Auth0Test.java
Original file line number Diff line number Diff line change
Expand Up @@ -129,6 +129,30 @@ public void shouldNotHaveTLS12Enforced() throws Exception {
assertThat(auth0.isTLS12Enforced(), is(false));
}

@Test
public void shouldHaveConnectTimeout() throws Exception {
Auth0 auth0 = new Auth0(CLIENT_ID, DOMAIN);
auth0.setConnectTimeoutInSeconds(5);

assertThat(auth0.getConnectTimeoutInSeconds(), is(5));
}

@Test
public void shouldReadHaveTimeout() throws Exception {
Auth0 auth0 = new Auth0(CLIENT_ID, DOMAIN);
auth0.setReadTimeoutInSeconds(15);

assertThat(auth0.getReadTimeoutInSeconds(), is(15));
}

@Test
public void shouldHaveWriteTimeout() throws Exception {
Auth0 auth0 = new Auth0(CLIENT_ID, DOMAIN);
auth0.setWriteTimeoutInSeconds(20);

assertThat(auth0.getWriteTimeoutInSeconds(), is(20));
}

@Test
public void shouldNotHaveLoggingEnabledByDefault() throws Exception {
Auth0 auth0 = new Auth0(CLIENT_ID, DOMAIN);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -48,24 +48,40 @@ public void setUp() {
@Test
// Verify that there's no error when creating a new OkHttpClient instance
public void shouldCreateNewClient() {
factory.createClient(false, false);
factory.createClient(false, false, 0, 0, 0);
}

@Test
public void shouldNotUseHttp2Protocol() {
OkHttpClient client = factory.createClient(false, false);
OkHttpClient client = factory.createClient(false, false, 0, 0, 0);
//Doesn't use default protocols
assertThat(client.getProtocols(), is(notNullValue()));
assertThat(client.getProtocols().contains(Protocol.HTTP_1_1), is(true));
assertThat(client.getProtocols().contains(Protocol.SPDY_3), is(true));
assertThat(client.getProtocols().contains(Protocol.HTTP_2), is(false));
}

@Test
public void shouldUseDefaultTimeoutWhenTimeoutZero() {
OkHttpClient client = factory.createClient(false, false, 0, 0, 0);
assertThat(client.getConnectTimeout(), is(10000));
assertThat(client.getReadTimeout(), is(10000));
assertThat(client.getWriteTimeout(), is(10000));
}

@Test
public void shouldUsePassedInTimeout() {
OkHttpClient client = factory.createClient(false, false, 5, 15, 20);
assertThat(client.getConnectTimeout(), is(5000));
assertThat(client.getReadTimeout(), is(15000));
assertThat(client.getWriteTimeout(), is(20000));
}

@Test
@Config(sdk = 21)
public void shouldEnableLoggingTLS12Enforced() {
List list = generateInterceptorsMockList(mockClient);
OkHttpClient client = factory.modifyClient(mockClient, true, true);
OkHttpClient client = factory.modifyClient(mockClient, true, true, 0, 0, 0);
verifyLoggingEnabled(client, list);
verifyTLS12Enforced(client);
}
Expand All @@ -74,7 +90,7 @@ public void shouldEnableLoggingTLS12Enforced() {
@Config(sdk = 21)
public void shouldEnableLoggingTLS12NotEnforced() {
List list = generateInterceptorsMockList(mockClient);
OkHttpClient client = factory.modifyClient(mockClient, true, false);
OkHttpClient client = factory.modifyClient(mockClient, true, false, 0, 0, 0);
verifyLoggingEnabled(client, list);
verifyTLS12NotEnforced(client);
}
Expand All @@ -83,7 +99,7 @@ public void shouldEnableLoggingTLS12NotEnforced() {
@Config(sdk = 21)
public void shouldDisableLoggingTLS12Enforced() {
List list = generateInterceptorsMockList(mockClient);
OkHttpClient client = factory.modifyClient(mockClient, false, true);
OkHttpClient client = factory.modifyClient(mockClient, false, true, 0, 0, 0);
verifyLoggingDisabled(client, list);
verifyTLS12Enforced(client);
}
Expand All @@ -92,7 +108,7 @@ public void shouldDisableLoggingTLS12Enforced() {
@Config(sdk = 21)
public void shouldDisableLoggingTLS12NotEnforced() {
List list = generateInterceptorsMockList(mockClient);
OkHttpClient client = factory.modifyClient(mockClient, false, false);
OkHttpClient client = factory.modifyClient(mockClient, false, false, 0, 0, 0);
verifyLoggingDisabled(client, list);
verifyTLS12NotEnforced(client);
}
Expand All @@ -101,7 +117,7 @@ public void shouldDisableLoggingTLS12NotEnforced() {
@Config(sdk = 22)
public void shouldEnableLoggingTLS12Enforced_postLollipopTLS12NoEffect() {
List list = generateInterceptorsMockList(mockClient);
OkHttpClient client = factory.modifyClient(mockClient, true, true);
OkHttpClient client = factory.modifyClient(mockClient, true, true, 0, 0, 0);
verifyLoggingEnabled(client, list);
verifyTLS12NotEnforced(client);
}
Expand All @@ -110,7 +126,7 @@ public void shouldEnableLoggingTLS12Enforced_postLollipopTLS12NoEffect() {
@Config(sdk = 22)
public void shouldEnableLoggingTLS12NotEnforced_posLollipop() {
List list = generateInterceptorsMockList(mockClient);
OkHttpClient client = factory.modifyClient(mockClient, true, false);
OkHttpClient client = factory.modifyClient(mockClient, true, false, 0, 0, 0);
verifyLoggingEnabled(client, list);
verifyTLS12NotEnforced(client);
}
Expand All @@ -119,7 +135,7 @@ public void shouldEnableLoggingTLS12NotEnforced_posLollipop() {
@Config(sdk = 22)
public void shouldDisableLoggingTLS12Enforced_postLollipopTLS12NoEffect() {
List list = generateInterceptorsMockList(mockClient);
OkHttpClient client = factory.modifyClient(mockClient, false, true);
OkHttpClient client = factory.modifyClient(mockClient, false, true, 0, 0, 0);
verifyLoggingDisabled(client, list);
verifyTLS12NotEnforced(client);
}
Expand All @@ -128,7 +144,7 @@ public void shouldDisableLoggingTLS12Enforced_postLollipopTLS12NoEffect() {
@Config(sdk = 22)
public void shouldDisableLoggingTLS12NotEnforced_postLollipop() {
List list = generateInterceptorsMockList(mockClient);
OkHttpClient client = factory.modifyClient(mockClient, false, false);
OkHttpClient client = factory.modifyClient(mockClient, false, false, 0, 0, 0);
verifyLoggingDisabled(client, list);
verifyTLS12NotEnforced(client);
}
Expand Down

0 comments on commit 34bc106

Please sign in to comment.