Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Adding new presets for pytorch c++ frontend? #623

Closed
wumo opened this issue Oct 13, 2018 · 80 comments
Closed

Adding new presets for pytorch c++ frontend? #623

wumo opened this issue Oct 13, 2018 · 80 comments

Comments

@wumo
Copy link
Contributor

wumo commented Oct 13, 2018

Recently pytorch release preview version 1.0, which features the experimental C++ frontend. The pytorch C++ frontend is easier and more complete than tensorflow.

The C++ frontend is a pure C++ interface to the PyTorch backend that follows the API and architecture of the established Python frontend. It is intended to enable research in high performance, low latency and bare metal C++ applications. It provides equivalents to torch.nn, torch.optim, torch.data and other components of the Python frontend. Here is a minimal side-by-side comparison of the two language frontends:

Python:

import torch

model = torch.nn.Linear(5, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
prediction = model.forward(torch.randn(3, 5))
loss = torch.nn.functional.mse_loss(prediction, torch.ones(3, 1))
loss.backward()
optimizer.step()

C++:

#include <torch/torch.h>

torch::nn::Linear model(5, 1);
torch::optim::SGD optimizer(model->parameters(), /*lr=*/0.1);
torch::Tensor prediction = model->forward(torch::randn({3, 5}));
auto loss = torch::mse_loss(prediction, torch::ones({3, 1}));
loss.backward();
optimizer.step();

Will you consider adding the pytorch c++ frontend presets?

@saudet
Copy link
Member

saudet commented Oct 13, 2018 via email

@saudet
Copy link
Member

saudet commented Oct 16, 2018

BTW, @Neiko2002, is this something you would be also interested in having?

@saudet
Copy link
Member

saudet commented Oct 16, 2018

/cc @cypof

@Neiko2002
Copy link
Member

@saudet To be honest I have never tried pytorch yet. Their python API looks similar to Java. A javacpp version of it might be easier to use than the official C++ API. Right now we deploy javacpp-tensorflow in production systems but train networks in python. Pytorch on the other hand has a simple C++ API to train and run a network.

Looking at the C++ API it does not feel it offers more than the Tensorflow API.

@saudet
Copy link
Member

saudet commented Oct 21, 2018

I think there is great potential in being able to (re)train on a platform that's more efficient than CPython, so I do hope this panes out. In any case, TensorFlow has obviously given up on that, let's see where PyTorch is willing to go! They used to promote Lua for this exact reason, although not Java for whatever political reason, so anyway IMO there is a chance they will redo something similar for C++11:
http://bytedeco.org/news/2015/03/14/java-meets-caffe/

@karllessard If you have any opinions about this, I would love to hear them!

@karllessard
Copy link

@karllessard If you have any opinions about this, I would love to hear them!

For the moment no I don't but this is very interesting, thanks for the tip @saudet !

@saudet
Copy link
Member

saudet commented Dec 27, 2018

I've been looking at the the C++ API over the past couple of days, and it looks like a mess. It mixes types from Caffe2, ATen, and C10 with no clear understanding of how they fit together: pytorch/pytorch#14850

Moreover, it's currently unable to train for MNIST using simple examples on either CPU or GPU:
https://discuss.pytorch.org/t/in-c-api-1-0-0-mnist-results-are-unstable/32823
pytorch/pytorch#15522
I've personally confirmed those results by building from source both PyTorch 1.0.0 and current master, both with and without CUDA/cuDNN.

We should probably wait until all this gets cleared up and fixed before starting to map anything to Java...

@iwroldtan
Copy link

Recently pytorch release preview version 1.0, which features the experimental C++ frontend. The pytorch C++ frontend is easier and more complete than tensorflow.

The C++ frontend is a pure C++ interface to the PyTorch backend that follows the API and architecture of the established Python frontend. It is intended to enable research in high performance, low latency and bare metal C++ applications. It provides equivalents to torch.nn, torch.optim, torch.data and other components of the Python frontend. Here is a minimal side-by-side comparison of the two language frontends:

Python:

import torch

model = torch.nn.Linear(5, 1)
optimizer = torch.optim.SGD(model.parameters(), lr=0.1)
prediction = model.forward(torch.randn(3, 5))
loss = torch.nn.functional.mse_loss(prediction, torch.ones(3, 1))
loss.backward()
optimizer.step()

C++:

#include <torch/torch.h>

torch::nn::Linear model(5, 1);
torch::optim::SGD optimizer(model->parameters(), /*lr=*/0.1);
torch::Tensor prediction = model->forward(torch::randn({3, 5}));
auto loss = torch::mse_loss(prediction, torch::ones({3, 1}));
loss.backward();
optimizer.step();

Will you consider adding the pytorch c++ frontend presets?

Could you give me a hand? I can not find the file, torch/torch.h.

@saudet
Copy link
Member

saudet commented Jan 4, 2019

@jxtps
Copy link
Contributor

jxtps commented Jul 26, 2019

It looks like that specific instability issue has been resolved as pytorch/pytorch#15522 is now closed.

Looking at their release log they seem to have been pretty busy: https://github.com/pytorch/pytorch/releases

Though the C++ API is still labeled as being in beta: https://pytorch.org/cppdocs/

Might it make sense to package something up?

@saudet
Copy link
Member

saudet commented Jul 26, 2019

@jxtps It still looks like a mess to me. There's still very little information about what C10 is supposed to be, apart from this: https://github.com/pytorch/pytorch/wiki/Software-Architecture-for-c10 But if you're willing to understand what we need to map and how, that would be a good first step I think.

@jxtps
Copy link
Contributor

jxtps commented Jul 27, 2019

Yeah, it seems like they've got a lot of history to deal with in that project.

Judging by https://pytorch.org/cppdocs/ it looks like the C++ frontend (= torch/torch.h) and its dependencies (which I presume include ATen and Autograd but maybe not TorchScript or the C++ extensions?) could make sense.

I would hope that wherever they end up with their c10 Tensor-library, it will be backwards compatible in the sense that 1. all relevant operations are preserved and 2. no operation subtly changes behavior but retains the same overall signature. So at some upgrade point there will be a massive number of compile errors as at::Tensor changes to c10::Tensor (or whatever), and the names of some methods change, but it's more of a "replace all", not "for every single call site figure out precisely what maps to what" type situation.

I recognize that the correct approach here may very well be "wait another year".

@doofin
Copy link

doofin commented Jan 22, 2021

Hi,as of 2021 the cpp frontend seems stable enough,I think we can just look files at libtorch/include/torch/csrc/api/include/torch/ and skip things like c10,Aten,etc.desperately want to have this bindings!

There are projects like https://github.com/nazarblch/torch-scala,but I haven't really understand this project.It uses cmake to generate sth which is confusing

@saudet
Copy link
Member

saudet commented Jan 22, 2021

There's been some work on this recently. @wmeddie Would you be able to make public what you have?

@wmeddie
Copy link
Contributor

wmeddie commented Jan 22, 2021

My fork is available master...wmeddie:master But need to give it a little bit of a clean up. If other people are interested in helping get it over the finish line I can create a PR over the weekend I think.

@doofin
Copy link

doofin commented Jan 22, 2021

@wmeddie Thanks for your nice job! I would be glad to help since I want to do this 2 years ago.. I have forked from your branch,will start a initial test build.

@saudet
Copy link
Member

saudet commented Mar 30, 2021

Hey everyone, I've finally taken some time to start working on this myself, and although there is still much to be done, the basics are there to get something like a simple training example for MNIST working on Linux, Mac, and Windows, with MKL and everything:

Now, some important parts of the C++ API are probably missing and/or not functional, yet, and since it has a pretty large surface it would be a good idea to prioritize needs. Please let me know what you think we should work on next:

  • Enable more build options? Which ones?
  • Add integration with the Python API as well?
  • Map/fix more of the C++ API? Which classes, functions, variables, etc?
  • Port more examples from C++ to Java?
  • Do something else entirely?

In any case, please give it a try with the snapshots (instructions at http://bytedeco.org/builds/) and let me know what you think!
Thank you for your interest

/cc @HGuillemet @stu1130

@HGuillemet
Copy link
Collaborator

I gave a first try. It seems quite promising.
I tried to port some CNN model from Python and I faced the following problem: many standard modules use an Options class that makes use of the TORCH_ARG macro. This macro takes a type, a name and a value and defines a private instance variable initialized with the value, a setter returning this and a getter.
This macro should be interpreted by the parser and generate a matching "builder-type" interface. Currently we cannot set these options.

Also it would be interesting to have some Java-ish replacement for the ExpandingArray template type. Can we use generics here ?

@saudet
Copy link
Member

saudet commented Apr 6, 2021

I gave a first try. It seems quite promising.

Thanks for trying it out!

I tried to port some CNN model from Python and I faced the following problem: many standard modules use an Options class that makes use of the TORCH_ARG macro. This macro takes a type, a name and a value and defines a private instance variable initialized with the value, a setter returning this and a getter.

TORCH_ARG() only seems to be used to define new options, which is not something users typically need to do. Could you show an example in C++ or in Python of what you're trying to do?

This macro should be interpreted by the parser and generate a matching "builder-type" interface. Currently we cannot set these options.

Also it would be interesting to have some Java-ish replacement for the ExpandingArray template type. Can we use generics here ?

We can add/generate as many "helper" methods and classes as we want, sure, but there's no way to automate this kind of translation process from C++ to Java in a generic fashion, so that kind of thing is probably not going to be part of JavaCPP itself or of the presets for PyTorch, but of some higher-level project, such as JavaCV in the case of OpenCV, FFmpeg, etc, or TF Java in the case of TensorFlow.

@HGuillemet
Copy link
Collaborator

HGuillemet commented Apr 6, 2021

Could you show an example in C++ or in Python of what you're trying to do?

The configuration of a Convolution module should look like this:

conv = new Conv2dImpl(
         new Conv2dOptions(numChannelsIn, numChannelsOut, new int[] { 3 })
             .stride(3)
             .padding(1)
             .bias(false)
           );

But the stride, padding, bias, etc... setters are not defined yet because declared in torch/nn/options/conv.h with the TORCH_ARG macro. Also the new int[] { 3 } is the kernel size and the Java presets only takes a LongPointer as a replacement for a ExpandingArray<2>.

@saudet
Copy link
Member

saudet commented Apr 6, 2021

I see, I'm not sure at the moment how we could make this work well since it uses the auto keyword making it pretty hard to parse, see bytedeco/javacpp#407, but something like this should work, for now:

convOptions = new Conv2dOptions(numChannelsIn, numChannelsOut, new LongPointer(3, 3));
convOptions.stride().put(new long[] {3, 3});
convOptions.padding().put(new long[] {1, 1});
convOptions.bias().put(false);
conv = new Conv2dImpl(convOptions);

Is this the only thing that's not working for you though? Instead of small usability issues like that, I would like to hear in priority about important issues preventing someone from getting anything working at all, and fix those first. However, if everything else is working for you, please let me know, and we can start looking at what we can do to polish that.

@HGuillemet
Copy link
Collaborator

Ok, I didn't realize we could set options this way.
Continuing my porting attempt...
There seems to be missing a direct way to create TensorArrayRef from a Java array of tensors, but I guess this could work:

new TensorArrayRef(new PointerPointer<Tensor>(x.length).put(x));

@HGuillemet
Copy link
Collaborator

There seems to be an issue with Sequential and AnyModule.

It's common in Pytorch to define a module as a sequence of other modules:

model = nn.Sequential(
          nn.Conv2d(1,20,5),
          nn.ReLU()
        )

which maps with the C++ API to:

torch::nn::Sequential model(
  torch::nn::Conv2d(1,20,5),
  torch::nn::ReLU()
);

A Sequential is basically a list of AnyModule, AnyModule being an abstract type representing any module. But because the forward method of modules may have any signature, usual inheritance is not an option. Calling the forward method of an AnyModule delegates the call to the underlying concrete module with dynamic checking of the signature.

In the preset we cannot rely on the C++ library and its implementation of AnyModule to dynamically call a forward method defined in Java (can we ?).
I believe there are other cases where the C++ lib must call the forward method of a module.
How to handle this ?

@saudet
Copy link
Member

saudet commented Apr 7, 2021

Ok, I didn't realize we could set options this way.
Continuing my porting attempt...
There seems to be missing a direct way to create TensorArrayRef from a Java array of tensors, but I guess this could work:

new TensorArrayRef(new PointerPointer<Tensor>(x.length).put(x));

It looks like that needs an actual array of tensors, and not an array of pointers to tensors, so if we don't have that, we need to first create one possibly like this:

Tensor array = new Tensor(x.length);
for (int i = 0; i < x.length; i++) {
    array.getPointer(i).put(x[i]);
}
new TensorArrayRef(array, x.length);

A Sequential is basically a list of AnyModule, AnyModule being an abstract type representing any module. But because the forward method of modules may have any signature, usual inheritance is not an option. Calling the forward method of an AnyModule delegates the call to the underlying concrete module with dynamic checking of the signature.

In the preset we cannot rely on the C++ library and its implementation of AnyModule to dynamically call a forward method defined in Java (can we ?).
I believe there are other cases where the C++ lib must call the forward method of a module.
How to handle this ?

Yeah, that looks really complicated. I don't see at the moment how we could make this work nicely. It doesn't sound like the Sequential module is required for anything though. According to the docs, it's only there for convenience:

/// Why should you use `Sequential` instead of a simple `std::vector`? The value
/// a `Sequential` provides over manually calling a sequence of modules is that
/// it allows treating the whole container *as a single module*, such that
/// performing a transformation on the `Sequential` applies to each of the
/// modules it stores (which are each a registered submodule of the
/// `Sequential`). For example, calling
/// `.to(torch::kCUDA)` on a `Sequential` will move each module in the list to
/// CUDA memory. For example:

@HGuillemet
Copy link
Collaborator

New issue: the C++ lib handles modules either as normal struct/object (the Module class) or as shared references wrapped in instances of a generic class: ModuleHolder<ModuleType>. See https://pytorch.org/tutorials/advanced/cpp_frontend.html#module-ownership.
I don't think the module holder can work with the preset, at least for user Java modules, for the same reason that Sequential. So you might as well strip all related classes from the parsing, if possible.

The documentation says the use of module holders is needed for some features like the serizalization API. I don't know why and haven't tested it (I haven't found the entrypoint in the preset).

Another issue : the register_module method has been declined for all concrete class of Module in the lib. But it should accept any shared_ptr to a Module subclass, so a single definition register_module(String, Module) should work and will allow to register a custom module. More generally, std::shared_ptr<ModuleType> could be parsed into Module everywhere.

@saudet
Copy link
Member

saudet commented May 31, 2021

register_module() is a function template, and looking at its definition, it's going to need to know the type of your Module at compile time to be able to "move" it correctly: https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/nn/module.h#L650

@jbaron
Copy link

jbaron commented May 31, 2021

Thanks, will have to look into this in a bit more detail I guess (especially also for own defined Modules and the implications on how to use/register them). Time to brush up on my C/C++ skills ;)

@HGuillemet
Copy link
Collaborator

In my case I haven't face any error when calling register_module passing an argument with a type being

  • the concrete class of a predefined library module,
  • the concrete class of a Java custom module extending Module
  • Module, when the module is a custom module.

@jbaron
Copy link

jbaron commented May 31, 2021

Indeed seems that if I play by those rules no issue, thanks.

P.S from my understanding the main reason to register a module is to make it easier to retrieve the trainable parameters. I guess some dictionary traversal is performed when you call net.parameters(). But if you could get the parameters in another way, it is not really required. But I could be very wrong here.

@saudet
Copy link
Member

saudet commented Jun 1, 2021

Actually, I think the problem is with the dynamic cast here below:
https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/nn/module.h#L651
It's not usually a problem what the input type is, but I didn't realize there was virtual inheritance involved, then it matters:
https://github.com/pytorch/pytorch/blob/master/torch/csrc/api/include/torch/nn/cloneable.h#L24
We can probably make this work by providing calls to dynamic_cast<Module>(LSTMImpl*), etc, which would return correct pointers to the base class Module. @jbaron Is this something you need?

@jbaron
Copy link

jbaron commented Jun 1, 2021

@saudet thanks for looking in to this. Right now only doing some simple prototyping to see what could be the best approach for my use-cases. So nothing is really vital at this moment, especially if there is a work around like in this case.

@saudet
Copy link
Member

saudet commented Jun 18, 2021

@jbaron I've updated to PyTorch 1.9.0 in commit f1ea76f, and I also added Module.asModule() methods to upcast correctly and do things like you wanted like this:

var lstm = new LSTMImpl(50, 20);
register_module("lstm1", lstm.asModule());

Let me know if there's anything else missing that you would like to have! And thanks for testing

@HGuillemet
Copy link
Collaborator

With the update, the data type for padding has changed from an expanding array of longs to some variant type that can be either an array, kSame or kValid. Code like:

convOpt.padding.put(new long[] { 3, 3 });

doesn't work anymore.

How can we handle this variant type ?

@jbaron
Copy link

jbaron commented Jun 21, 2021

Would something like this not work (Koltin sample):

val lp =  LongPointer()
lp.put(3L, 3L)
convOpt.padding().put<LongPointer>(lp)

@HGuillemet
Copy link
Collaborator

There must be some header value discriminating the actual type of the value stored in the variant.

Also beware of put(long, long): it's a a call to the method that puts a single value at a specific index. It will most probably give you a segmentation fault.

@jbaron
Copy link

jbaron commented Jun 21, 2021

Thanks, indeed misread the LongPointer API. The following code runs without a segmentation fault:


import org.bytedeco.javacpp.LongPointer
import org.bytedeco.pytorch.*
import org.bytedeco.pytorch.global.torch

class Net : Module() {

    private val conv = register_module("conv", getConv())

    private fun getConv(): Conv2dImpl {
        val kernel = LongPointer(2)
        kernel.put(0L, 3L)
        kernel.put(1L, 3L)

        val convOpt = Conv2dOptions(3L, 16L, kernel)
        val padding = LongPointer(2)
        padding.put(0L, 3L)
        padding.put(1L, 3L)
        convOpt.padding().put<LongPointer>(padding)
        return  Conv2dImpl(convOpt)

    }

    fun forward(t: Tensor): Tensor {
        return conv.forward(t)
    }

}

fun main() {
    val net = Net()
    val x = torch.rand(16,3, 28, 28)
    net.forward(x)
}


@saudet
Copy link
Member

saudet commented Jun 22, 2021

Right, we probably need to map those variants to set the options correctly, so I've done that in commit 7b9ccf3. Please give it a try!

@HGuillemet
Copy link
Collaborator

It now works as with version 1.8 for my cases with CNN and explicit paddings. Thanks !

@jbaron
Copy link

jbaron commented Jun 23, 2021

Short question, when a method expects a pointer to std::ostream, how to provide this? Are these C++ std lib classes also mapped to Java equivalents?

@saudet
Copy link
Member

saudet commented Jun 23, 2021

Short question, when a method expects a pointer to std::ostream, how to provide this? Are these C++ std lib classes also mapped to Java equivalents?

It's not currently mapped to anything, it's a pretty ugly API, so usually we get overloads using std::string, but PyTorch doesn't do that for some reason. It would require some serious considerations to get something usable in Java, but anyway we can get pointers to std::cout, std::cerr, and std::clog if that's all you need, see above #623 (comment).

@saudet saudet removed the help wanted label Aug 3, 2021
@saudet
Copy link
Member

saudet commented Aug 3, 2021

The presets for the C++ API of PyTorch 1.9.0 have been released with version 1.5.6! Thanks to everyone for helping out with this effort. I'll close this issue, but please do open new separate threads to continue the discussion about various remaining issues and potential enhancements.

/cc @stu1130

@saudet saudet closed this as completed Aug 3, 2021
@saudet
Copy link
Member

saudet commented Aug 27, 2021

@jbaron BTW, it looks like we can train TorchScript models using the C++ API without CPython, see pytorch/pytorch#17614. That's one way you could go about it with your models: Define them in Python, but train them using the C++ API. This is unlike SavedModel from TensorFlow, which AFAIK cannot be trained from outside CPython.

@jbaron
Copy link

jbaron commented Aug 27, 2021

Sounds great and would be very nice since not only we'll have a lot of (proven) models to pick from, I assume it will also be faster since fewer calls between Java and C++ during training.

@HGuillemet
Copy link
Collaborator

A couple of issues with the *Dict classes, eg StringTensorDict:

  • the get method mapping the [] operator should take a long, not a BytePointer. Isn't size_t always mapped to long ?
  • when the size is 0, begin() does return something that is not null, nor isNull() and throws a SIGSEGV when first() is called on it. Also probably increment() too when size if not 0 but we reach the end of the list.

@saudet
Copy link
Member

saudet commented Oct 9, 2021

AFAIK, torch::OrderedDict<std::string,at::Tensor> is pretty much the same as std::map<std::string,at::Tensor>, so the key is a string, that sounds alright, and it's possible that it can't be safely iterated, yes. Do you have counterexamples, in C++?

@HGuillemet
Copy link
Collaborator

Back to this issue.
I realized that ordering_dict.h is not parsed and OrderedDict is declared as a basic/container instead.
Thus OrderingDict specific features, for instance an overload of [], taking an integer, is not available.
I guess there is no mean to have ordereding_dict.h parsed while keeping the generic Dictionary features like get(key) and iterators ? Is it ?

About iterators: they seem to work as expected using this pattern :

for (Iterator it=dict.begin(); !it.equals(dict.end()); it = it.increment()) {
  ...
}

@saudet
Copy link
Member

saudet commented Nov 21, 2021

It should be possible to parse ordered_dict.h, it doesn't look like the kind of really complicated template that you'd find in STL that are essentially unparsable by anything else than a full C++ compiler. But is there anything that you're not able to do with the current mapping?

@HGuillemet
Copy link
Collaborator

Not at the moment, nothing prioritary.

@saudet
Copy link
Member

saudet commented Mar 30, 2022

@HGuillemet I've included ordered_dict.h in commit 78540be to map torch::OrderedDict<std::string,at::Tensor> like you wanted. Please give it a try with the snapshots: http://bytedeco.org/builds/

Let me know if there is anything else missing!

@HGuillemet
Copy link
Collaborator

I had to remove:

.valueTypes("@Cast({\"\", \"torch::OrderedDict<std::string,torch::nn::AnyModule>&&\"}) @StdMove StringAnyModuleDict")

line 1929 of torch/java to get it compiled.
Once compiled, it works, thank you ! The resulting code is much nicer.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

No branches or pull requests

10 participants