Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[CCI] [Extensions] Added the ExtensionActionUtil class #6969

Merged
merged 17 commits into from
Jun 21, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.extensions.action;

import org.opensearch.action.ActionRequest;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.Writeable;

import java.io.IOException;
import java.lang.reflect.Constructor;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;
import java.util.Arrays;

/**
* ExtensionActionUtil - a class for creating and processing remote requests using byte arrays.
*/
public class ExtensionActionUtil {

/**
* The Unicode UNIT SEPARATOR used to separate the Request class name and parameter bytes
*/
public static final byte UNIT_SEPARATOR = (byte) '\u001F';

/**
* @param request an instance of a request extending {@link ActionRequest}, containing information about the
* request being sent to the remote server. It is used to create a byte array containing the request data,
* which will be sent to the remote server.
* @return An Extension ActionRequest object that represents the deserialized data.
* If an error occurred during the deserialization process, the method will return {@code null}.
* @throws RuntimeException If a RuntimeException occurs while creating the proxy request bytes.
*/
public static byte[] createProxyRequestBytes(ActionRequest request) throws RuntimeException {
byte[] requestClassBytes = request.getClass().getName().getBytes(StandardCharsets.UTF_8);
byte[] requestBytes;

try {
requestBytes = convertParamsToBytes(request);
assert requestBytes != null;
return ByteBuffer.allocate(requestClassBytes.length + 1 + requestBytes.length)
.put(requestClassBytes)
.put(UNIT_SEPARATOR)
.put(requestBytes)
.array();
} catch (RuntimeException e) {
throw new RuntimeException("RuntimeException occurred while creating proxyRequestBytes");
}
}

/**
* @param requestBytes is a byte array containing the request data, used by the "createActionRequest"
* method to create an "ActionRequest" object, which represents the request model to be processed on the server.
* @return an "Action Request" object representing the request model for processing on the server,
* or {@code null} if the request data is invalid or null.
* @throws ReflectiveOperationException if an exception occurs during the reflective operation, such as when
* resolving the request class, accessing the constructor, or creating an instance using reflection
* @throws NullPointerException if a null pointer exception occurs during the creation of the ActionRequest object
*/
public static ActionRequest createActionRequest(byte[] requestBytes) throws ReflectiveOperationException {
int delimPos = delimPos(requestBytes);
String requestClassName = new String(Arrays.copyOfRange(requestBytes, 0, delimPos + 1), StandardCharsets.UTF_8).stripTrailing();
try {
Class<?> clazz = Class.forName(requestClassName);
Constructor<?> constructor = clazz.getConstructor(StreamInput.class);
StreamInput requestByteStream = StreamInput.wrap(Arrays.copyOfRange(requestBytes, delimPos + 1, requestBytes.length));
return (ActionRequest) constructor.newInstance(requestByteStream);
} catch (ReflectiveOperationException e) {
throw new ReflectiveOperationException(
"ReflectiveOperationException occurred while creating extensionAction request from bytes",
e
);
} catch (NullPointerException e) {
throw new NullPointerException(
"NullPointerException occurred while creating extensionAction request from bytes" + e.getMessage()
);
}
}

/**
* Converts the given object of type T, which implements the {@link Writeable} interface, to a byte array.
* @param <T> the type of the object to be converted to bytes, which must implement the {@link Writeable} interface.
* @param writeableObject the object of type T to be converted to bytes.
* @return a byte array containing the serialized bytes of the given object, or {@code null} if the input is invalid or null.
* @throws IllegalStateException if a failure occurs while writing the data
*/
public static <T extends Writeable> byte[] convertParamsToBytes(T writeableObject) throws IllegalStateException {
try (BytesStreamOutput out = new BytesStreamOutput()) {
writeableObject.writeTo(out);
return BytesReference.toBytes(out.bytes());
} catch (IOException ieo) {
throw new IllegalStateException("Failure writing bytes", ieo);
}
}

/**
* Searches for the position of the unit separator byte in the given byte array.
*
* @param bytes the byte array to search for the unit separator byte.
* @return the index of the unit separator byte in the byte array, or -1 if it was not found.
*/
public static int delimPos(byte[] bytes) {
for (int offset = 0; offset < bytes.length; ++offset) {
if (bytes[offset] == ExtensionActionUtil.UNIT_SEPARATOR) {
return offset;
}
}
return -1;
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,106 @@
/*
* SPDX-License-Identifier: Apache-2.0
*
* The OpenSearch Contributors require contributions made to
* this file be licensed under the Apache-2.0 license or a
* compatible open source license.
*/

package org.opensearch.extensions.action;

import org.junit.Before;
import org.mockito.Mockito;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.io.stream.BytesStreamOutput;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.common.io.stream.Writeable;
import org.opensearch.test.OpenSearchTestCase;

import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.charset.StandardCharsets;

import static org.opensearch.extensions.action.ExtensionActionUtil.UNIT_SEPARATOR;
import static org.opensearch.extensions.action.ExtensionActionUtil.createProxyRequestBytes;

public class ExtensionActionUtilTests extends OpenSearchTestCase {
private byte[] myBytes;
private final String actionName = "org.opensearch.action.MyExampleRequest";
private final byte[] actionNameBytes = MyExampleRequest.class.getName().getBytes(StandardCharsets.UTF_8);

@Before
public void setup() throws IOException {
BytesStreamOutput out = new BytesStreamOutput();
MyExampleRequest exampleRequest = new MyExampleRequest(actionName, actionNameBytes);
exampleRequest.writeTo(out);

byte[] requestBytes = BytesReference.toBytes(out.bytes());
byte[] requestClass = MyExampleRequest.class.getName().getBytes(StandardCharsets.UTF_8);
this.myBytes = ByteBuffer.allocate(requestClass.length + 1 + requestBytes.length)
.put(requestClass)
.put(UNIT_SEPARATOR)
.put(requestBytes)
.array();
}

public void testCreateProxyRequestBytes() throws IOException {
BytesStreamOutput out = new BytesStreamOutput();
MyExampleRequest exampleRequest = new MyExampleRequest(actionName, actionNameBytes);
exampleRequest.writeTo(out);

byte[] result = createProxyRequestBytes(exampleRequest);
assertArrayEquals(this.myBytes, result);
assertThrows(RuntimeException.class, () -> ExtensionActionUtil.createProxyRequestBytes(new MyExampleRequest(null, null)));
}

public void testCreateActionRequest() throws ReflectiveOperationException {
ActionRequest actionRequest = ExtensionActionUtil.createActionRequest(myBytes);
assertThrows(NullPointerException.class, () -> ExtensionActionUtil.createActionRequest(null));
assertThrows(ReflectiveOperationException.class, () -> ExtensionActionUtil.createActionRequest(actionNameBytes));
assertNotNull(actionRequest);
assertFalse(actionRequest.getShouldStoreResult());
}

public void testConvertParamsToBytes() throws IOException {
Writeable mockWriteableObject = Mockito.mock(Writeable.class);
Mockito.doThrow(new IOException("Test IOException")).when(mockWriteableObject).writeTo(Mockito.any());
assertThrows(IllegalStateException.class, () -> ExtensionActionUtil.convertParamsToBytes(mockWriteableObject));
}

public void testDelimPos() {
assertTrue(ExtensionActionUtil.delimPos(myBytes) > 0);
assertTrue(ExtensionActionUtil.delimPos(actionNameBytes) < 0);
assertEquals(-1, ExtensionActionUtil.delimPos(actionNameBytes));
}

private static class MyExampleRequest extends ActionRequest {
private final String param1;
private final byte[] param2;

public MyExampleRequest(String param1, byte[] param2) {
this.param1 = param1;
this.param2 = param2;
}

public MyExampleRequest(StreamInput in) throws IOException {
super(in);
param1 = in.readString();
param2 = in.readByteArray();
}

@Override
public void writeTo(StreamOutput out) throws IOException {
super.writeTo(out);
out.writeString(param1);
out.writeByteArray(param2);
}

@Override
public ActionRequestValidationException validate() {
return null;
}
}
}