Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -20,20 +20,41 @@
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.UnsupportedCallbackException;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.lang.reflect.Method;
import java.util.List;

/** For handling customized {@link Callback}. */
public interface CustomizedCallbackHandler {
class DefaultHandler implements CustomizedCallbackHandler{
@Override
public void handleCallback(List<Callback> callbacks, String username, char[] password)
public void handleCallbacks(List<Callback> callbacks, String username, char[] password)
throws UnsupportedCallbackException {
if (!callbacks.isEmpty()) {
throw new UnsupportedCallbackException(callbacks.get(0));
}
}
}

void handleCallback(List<Callback> callbacks, String name, char[] password)
static CustomizedCallbackHandler delegate(Object delegated) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As of last week, DynMethod exists in hadoop common to assist here. Have look to see if it would help or are your needs better. A key aspect is can extract IOEs from the invocation

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@steveloughran , thanks for the info! Since we need to back port this to some earlier versions, we won't use DynMethod at the moment.

final String methodName = "handleCallbacks";
final Class<?> clazz = delegated.getClass();
final Method method;
try {
method = clazz.getMethod(methodName, List.class, String.class, char[].class);
} catch (NoSuchMethodException e) {
throw new IllegalStateException("Failed to get method " + methodName + " from " + clazz, e);
}

return (callbacks, name, password) -> {
try {
method.invoke(delegated, callbacks, name, password);
} catch (IllegalAccessException | InvocationTargetException e) {
throw new IOException("Failed to invoke " + method, e);
}
};
}

void handleCallbacks(List<Callback> callbacks, String name, char[] password)
throws UnsupportedCallbackException, IOException;
}
Original file line number Diff line number Diff line change
Expand Up @@ -225,14 +225,20 @@ static final class SaslServerCallbackHandler
SaslServerCallbackHandler(Configuration conf, PasswordFunction passwordFunction) {
this.passwordFunction = passwordFunction;

final Class<? extends CustomizedCallbackHandler> clazz = conf.getClass(
final Class<?> clazz = conf.getClass(
HdfsClientConfigKeys.DFS_DATA_TRANSFER_SASL_CUSTOMIZEDCALLBACKHANDLER_CLASS_KEY,
CustomizedCallbackHandler.DefaultHandler.class, CustomizedCallbackHandler.class);
CustomizedCallbackHandler.DefaultHandler.class);
final Object callbackHandler;
try {
this.customizedCallbackHandler = clazz.newInstance();
callbackHandler = clazz.newInstance();
} catch (Exception e) {
throw new IllegalStateException("Failed to create a new instance of " + clazz, e);
}
if (callbackHandler instanceof CustomizedCallbackHandler) {
customizedCallbackHandler = (CustomizedCallbackHandler) callbackHandler;
} else {
customizedCallbackHandler = CustomizedCallbackHandler.delegate(callbackHandler);
}
}

@Override
Expand Down Expand Up @@ -271,7 +277,7 @@ public void handle(Callback[] callbacks) throws IOException,
if (unknownCallbacks != null) {
final String name = nc != null ? nc.getDefaultName() : null;
final char[] password = name != null ? passwordFunction.apply(name) : null;
customizedCallbackHandler.handleCallback(unknownCallbacks, name, password);
customizedCallbackHandler.handleCallbacks(unknownCallbacks, name, password);
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,25 +20,45 @@
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hdfs.client.HdfsClientConfigKeys;
import org.apache.hadoop.hdfs.protocol.datatransfer.sasl.SaslDataTransferServer.SaslServerCallbackHandler;
import org.apache.hadoop.test.LambdaTestUtils;
import org.junit.Assert;
import org.junit.Test;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import javax.security.auth.callback.Callback;
import javax.security.auth.callback.UnsupportedCallbackException;
import java.util.Arrays;
import java.io.IOException;
import java.util.List;
import java.util.concurrent.atomic.AtomicReference;

/** For testing {@link CustomizedCallbackHandler}. */
public class TestCustomizedCallbackHandler {
public static final Logger LOG = LoggerFactory.getLogger(TestCustomizedCallbackHandler.class);
static final Logger LOG = LoggerFactory.getLogger(TestCustomizedCallbackHandler.class);

static final AtomicReference<List<Callback>> LAST_CALLBACKS = new AtomicReference<>();

static void runHandleCallbacks(Object caller, List<Callback> callbacks, String name) {
LOG.info("{}: handling {} for {}", caller.getClass().getSimpleName(), callbacks, name);
LAST_CALLBACKS.set(callbacks);
}

/** Assert if the callbacks in {@link #LAST_CALLBACKS} are the same as the expected callbacks. */
static void assertCallbacks(Callback[] expected) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a javadoc explaining what it does

final List<Callback> computed = LAST_CALLBACKS.getAndSet(null);
Assert.assertNotNull(computed);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think you could do this in a single AssertJ assertion, especially lines 49-51, provides expected was actually a list.

AssertJ.assertThat(computed)
 .describedAs("computed callbacks")
 .isNotNull()
 .hasSameElementsAs(expected)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Tried it but AssertJ was not available for import. So I will keep the current code and avoid adding the dependency.

Assert.assertEquals(expected.length, computed.size());
for (int i = 0; i < expected.length; i++) {
Assert.assertSame(expected[i], computed.get(i));
}
}

static class MyCallback implements Callback { }

static class MyCallbackHandler implements CustomizedCallbackHandler {
@Override
public void handleCallback(List<Callback> callbacks, String name, char[] password) {
LOG.info("{}: handling {} for {}", getClass().getSimpleName(), callbacks, name);
public void handleCallbacks(List<Callback> callbacks, String name, char[] password) {
runHandleCallbacks(this, callbacks, name);
}
}

Expand All @@ -48,16 +68,52 @@ public void testCustomizedCallbackHandler() throws Exception {
final Callback[] callbacks = {new MyCallback()};

// without setting conf, expect UnsupportedCallbackException
try {
new SaslServerCallbackHandler(conf, String::toCharArray).handle(callbacks);
Assert.fail("Expected UnsupportedCallbackException for " + Arrays.asList(callbacks));
} catch (UnsupportedCallbackException e) {
LOG.info("The failure is expected", e);
}
LambdaTestUtils.intercept(UnsupportedCallbackException.class, () -> runTest(conf, callbacks));

// set conf and expect success
conf.setClass(HdfsClientConfigKeys.DFS_DATA_TRANSFER_SASL_CUSTOMIZEDCALLBACKHANDLER_CLASS_KEY,
MyCallbackHandler.class, CustomizedCallbackHandler.class);
new SaslServerCallbackHandler(conf, String::toCharArray).handle(callbacks);
assertCallbacks(callbacks);
}

static class MyCallbackMethod {
public void handleCallbacks(List<Callback> callbacks, String name, char[] password)
throws UnsupportedCallbackException {
runHandleCallbacks(this, callbacks, name);
}
}

static class MyExceptionMethod {
public void handleCallbacks(List<Callback> callbacks, String name, char[] password)
throws UnsupportedCallbackException {
runHandleCallbacks(this, callbacks, name);
throw new UnsupportedCallbackException(callbacks.get(0));
}
}

@Test
public void testCustomizedCallbackMethod() throws Exception {
final Configuration conf = new Configuration();
final Callback[] callbacks = {new MyCallback()};

// without setting conf, expect UnsupportedCallbackException
LambdaTestUtils.intercept(UnsupportedCallbackException.class, () -> runTest(conf, callbacks));

// set conf and expect success
conf.setClass(HdfsClientConfigKeys.DFS_DATA_TRANSFER_SASL_CUSTOMIZEDCALLBACKHANDLER_CLASS_KEY,
MyCallbackMethod.class, Object.class);
new SaslServerCallbackHandler(conf, String::toCharArray).handle(callbacks);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add a static counter somewhere to verify the callback was actually invoked.

add another test to raise an exception in the callback to verify it gets reported.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure, will add both.

assertCallbacks(callbacks);

// set conf and expect exception
conf.setClass(HdfsClientConfigKeys.DFS_DATA_TRANSFER_SASL_CUSTOMIZEDCALLBACKHANDLER_CLASS_KEY,
MyExceptionMethod.class, Object.class);
LambdaTestUtils.intercept(IOException.class, () -> runTest(conf, callbacks));
}

static void runTest(Configuration conf, Callback... callbacks)
throws IOException, UnsupportedCallbackException {
new SaslServerCallbackHandler(conf, String::toCharArray).handle(callbacks);
}
}