Skip to content

Commit

Permalink
Simplify AnyModule constructors info
Browse files Browse the repository at this point in the history
  • Loading branch information
HGuillemet committed Jan 9, 2024
1 parent 822d08b commit 291033f
Showing 1 changed file with 13 additions and 12 deletions.
25 changes: 13 additions & 12 deletions pytorch/src/main/java/org/bytedeco/pytorch/presets/torch.java
Original file line number Diff line number Diff line change
Expand Up @@ -198,8 +198,6 @@ public void mapModule(InfoMap infoMap, String name, String base, String baseBase
mapModule(infoMap, name, base, baseBase, true);
}

String anyModuleConstructors = "";

public void mapModule(InfoMap infoMap, String name, String base, String baseBase, boolean anyModuleCompatible) {
if (baseBase != null) {
infoMap.put(new Info(baseBase).pointerTypes(name + "ImplBaseBase"));
Expand All @@ -217,11 +215,18 @@ public void mapModule(InfoMap infoMap, String name, String base, String baseBase
;

if (anyModuleCompatible) {
anyModuleConstructors +=
"public AnyModule(" + name + "Impl module) { super((Pointer)null); allocate(module); }\n" +
// We need a @Cast because AnyModule constructor is explicit
"private native void allocate(@SharedPtr @Cast({\"\", \"std::shared_ptr<torch::nn::" + name + "Impl>\"}) " + name + "Impl module);\n";
infoMap.put(new Info("torch::nn::SequentialImpl::push_back<torch::nn::" + name + "Impl>").javaNames("push_back"));
infoMap
// Parser queries parameter as ModuleType* instead of std::shared_ptr<ModuleType>
// First cppName is to answer template query, second one to generate instance
.put(new Info(
"torch::nn::AnyModule::AnyModule<torch::nn::" + name + "Impl>(ModuleType*)",
"torch::nn::AnyModule::AnyModule<torch::nn::" + name + "Impl>(torch::nn::" + name + "Impl*)"
).define().javaText(
"public AnyModule(" + name + "Impl module) { super((Pointer)null); allocate(module); }\n" +
// We need a @Cast because AnyModule constructor is explicit
"private native void allocate(@SharedPtr @Cast({\"\", \"std::shared_ptr<torch::nn::" + name + "Impl>\"}) " + name + "Impl module);\n"))
.put(new Info("torch::nn::SequentialImpl::push_back<torch::nn::" + name + "Impl>").javaNames("push_back"))
;
}
}

Expand Down Expand Up @@ -1727,11 +1732,7 @@ public void map(InfoMap infoMap) {
"public native @ByVal @Name(\"forward<std::tuple<torch::Tensor,torch::Tensor>>\") T_TensorTensor_T forwardT_TensorTensor_T(@Const @ByRef Tensor query, @Const @ByRef Tensor key, @Const @ByRef Tensor value, @Const @ByRef(nullValue = \"torch::Tensor{}\") Tensor key_padding_mask, @Cast(\"bool\") boolean need_weights/*=true*/, @Const @ByRef(nullValue = \"torch::Tensor{}\") Tensor attn_mask, @Cast(\"bool\") boolean average_attn_weights/*=true*/);\n" +
"public native @ByVal @Name(\"forward<torch::nn::ASMoutput>\") ASMoutput forwardASMoutput(@Const @ByRef Tensor input, @Const @ByRef Tensor target);\n"
))
.put(new Info("torch::nn::AnyModule(ModuleType*)")
// We cannot use template instantiation mechanism in Parser with something like
// new Info("torch::nn::AnyModule<torch::nn::" + name + "Impl>(ModuleType*)")
// because it doesn't work with javaText. And we need javaText because of @Cast.
.javaText(anyModuleConstructors));
;

for (String[] outputType : new String[][]{
{"at::Tensor", "Tensor"},
Expand Down

0 comments on commit 291033f

Please sign in to comment.