Skip to content

Commit

Permalink
add jar registration for spi
Browse files Browse the repository at this point in the history
Signed-off-by: zane-neo <zaniu@amazon.com>
  • Loading branch information
zane-neo committed Nov 29, 2023
1 parent 06c0380 commit bb55524
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ private Tool createTool(MLToolSpec toolSpec) {
if (!toolFactories.containsKey(toolSpec.getType())) {
throw new IllegalArgumentException("Tool not found: " + toolSpec.getType());
}
Tool tool = toolFactories.get(toolSpec.getType()).create(toolParams);
Tool.Factory factory = toolFactories.get(toolSpec.getType());
factory.initClient(client);
Tool tool = factory.create(toolParams);
if (toolSpec.getName() != null) {
tool.setName(toolSpec.getName());
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,16 +10,24 @@
import static org.opensearch.ml.common.CommonValue.ML_MODEL_INDEX;
import static org.opensearch.ml.common.CommonValue.ML_TASK_INDEX;

import java.io.File;
import java.net.URL;
import java.net.URLClassLoader;
import java.nio.file.Path;
import java.security.AccessController;
import java.security.PrivilegedExceptionAction;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.ServiceLoader;
import java.util.Set;
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.Supplier;

import org.apache.lucene.spatial3d.geom.Tools;
import org.opensearch.action.ActionRequest;
import org.opensearch.client.Client;
import org.opensearch.cluster.metadata.IndexNameExpressionResolver;
Expand Down Expand Up @@ -871,6 +879,12 @@ public Map<String, Processor.Factory<SearchResponseProcessor>> getResponseProces
@Override
public void loadExtensions(ExtensionLoader loader) {
externalToolFactories = new HashMap<>();
ServiceLoader<Tool> serviceLoader = ServiceLoader.load(Tool.class, Tool.class.getClassLoader());
for (Tool tool : serviceLoader) {
Tool.Factory<? extends Tool> factory = tool.getFactory();
externalToolFactories.put(tool.getType(), factory);
}

for (MLCommonsExtension extension : loader.loadExtensions(MLCommonsExtension.class)) {
List<Tool.Factory<? extends Tool>> toolFactories = extension.getToolFactories();
for (Tool.Factory<? extends Tool> toolFactory : toolFactories) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

package org.opensearch.ml.common.spi.tools;

import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import java.util.Map;

Expand Down Expand Up @@ -107,6 +108,10 @@ default boolean useOriginalInput() {
return false;
}

default Factory<? extends Tool> getFactory() {
return null;
}

/**
* Tool factory which can create instance of {@link Tool}.
* @param <T> The subclass this factory produces
Expand All @@ -120,6 +125,10 @@ interface Factory<T extends Tool> {
*/
T create(Map<String, Object> params);

default void initClient(Client client) {

}

/**
* Get the default description of this tool.
* @return the default description
Expand Down

0 comments on commit bb55524

Please sign in to comment.