Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Multiple Entrypoints and Shared State implementation; Request for comments/feedback. #8030

Open
cptspacemanspiff opened this issue Jan 29, 2025 · 5 comments
Assignees
Labels
module: exir Issues related to Export IR and the code under exir/ rfc Request for comment and feedback on a post, proposal, etc. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@cptspacemanspiff
Copy link
Contributor

cptspacemanspiff commented Jan 29, 2025

🚀 The feature, motivation and pitch

Sorry this is kind-of long...

Overview

Hi, I am not really sure what the process for this is, but I made some changes/a wrapping to executorch to be able to have a shared state between multiple exported methods of a torch.module.

For now I implemented this in my own repo: cptspacemanspiff/execu-tools)

My initial motivation for this is that when exporting Encoder/Decoder models one would really like to use a precalculated cross attention cache, but this is hard to do without conditional branching (via torch.cond)/rewriting or extraneous copies/calculations.

  • Aside from this somewhat specific use-case, I think it can be more generally useful, as it allows the export of more of the whole model's generate graph without rewriting as much into C++, which causes more surface area to introduce bugs/etc.
    • for example you can just place the branching in C++, call methods in sequence, and do not have to manually rewrite C++ to pass data between methods, since things can get passed under the hood via shared state.
  • Additionally, there may be KV cache management strategies that might benefit from outside access to the cache, from outside of a particular models decode function.

My high-level thought process is that, for the most part, for a given model class (sequence to sequence, decoder only, etc), the location of forced graph breaks from the torch export process (due to either data-dependent branching or what not) are mostly in the same place. If you export the individual parts in python then glue these together with minimal code on the C++ side, The C++ runtime implementation should be able to generalize to other models of the same class. From what I can tell, this is already the case for decoder only models, but for other architectures that still have state this is less well defined.

Currently, I have 2 examples exported / working:

  1. A dummy module that just sets/reads from a cache.
  2. A hugging face BART model (OPUS translation),:
    • It uses Hugging face static cache, logit postprocessing, and stopping conditions from huggingface's generate implementation,
    • Uses a C++ wrapped version of hugging faces tokenization library, loading the configuration json blob from the model.
    • However, I am manually using argmax for the greedy search, and there were changes to the models attention to support non-fixed batch sizes and encoder sequence lengths that are still not merged into transformers.

What I am asking for:

I have been working on this on my own, but it seems like it is something that would be more generally useful. I just started really digging into the torch.export process/ executorch in the past month so may have some fundamental misunderstandings.

Aside from the changes required from executorch (of which there are surprisingly few), Most of my changes were implemented in my own repository, mainly because it was easier to setup, and I was able to avoid futzing around with the executorch build system. That being said, if this is a functionality that aligns with executorch's goals I would like to get it merged in. At least the exporter class that manages everything, alongside the c++ memory manager that handles things on the runtime side, though I would need advice on that.

Also, some specific things I have done to get this working are somewhere on the scale between somewhat janky and very janky. I would appreciate feedback/thoughts on alternatives/advice.

This is also providing a pointer/context issue for the executorch changes/pull requests (see bottom).

I am making a few pull requests that are not fully fleshed out with regards to testing/linting/etc, but illustrate what changes I needed to make to the executorch repository.
 

Alternatives

I looked at using torch.cond to explicitly implement the branching logic required. The issue I ran across was that each branch must be fully functional, with the same graph signature for each submodule on each branch. This forces the overall method signature must be the same in all cases, which leads to extra copy backs even when data is not modified in a branch.

The other alternative is writing a lot of C++ code, which adds complexity and bugs when inevitably it does not exactly match the python implementation.

RFC (Optional)

More details are here:

README: https://github.com/[cptspacemanspiff/execu-tools](https://github.com/cptspacemanspiff/execu-tools)
And a more in-depth rational/build log here: https://github.com/cptspacemanspiff/execu-tools/blob/main/DesignReasoning.md

What this looks like:

For the python export:

class StatefulModel(torch.nn.Module):
    def __init__(
        self,
        max_batch_size: int,
        max_seq_len: int,
    ):
        super().__init__()
        self.register_buffer(
            "cache",
            torch.zeros((max_batch_size, max_seq_len), dtype=torch.float32),
            persistent=True,
        )    

    # need slicing here:
    def set_cache(self, data: torch.Tensor):
        self.cache[0 : data.shape[0], 0 : data.shape[1]] = data
        return None

    # need narrow here:
    def get_cache(self, data: torch.Tensor):
        narrowed_cache = self.cache.narrow(0, 0, data.size(0)).narrow(1, 0, data.size(1))
        data.copy_(narrowed_cache)
        return None

def test_stateful_export():
    max_batch_size = 10
    max_seq_len = 20

    # Wrap/Create a model to export. The model CANNOT have a forward method.
    model = StatefulModel(max_batch_size=max_batch_size, max_seq_len=max_seq_len)

    # Exporter class that manages the legwork
    exporter = MultiEntryPointExporter(model)

    # Register the buffer by fqn 
    # Alternatively pass a module fqn, which will register every registered buffer inside it.
    exporter.register_shared_buffer("cache")

    # Define dynamic dimensions as normal
    batch_size = Dim("batch_size_dim", min=1, max=max_batch_size)
    seq_len = Dim("seq_len_dim", min=1, max=max_seq_len)

    # Register methods for export, with examples for tracing.
    exporter.register(
        model.set_cache,
        data=MethodArg(
            torch.ones(max_batch_size-1, max_seq_len-1),
            dynamic_dims={0: batch_size, 1: seq_len},
        ),
    )

    exporter.register(
        model.get_cache,
        data=MethodArg(
            torch.ones(max_batch_size-1, max_seq_len-1),
            dynamic_dims={0: batch_size, 1: seq_len},
        ),
    )

    # Pass additional data to the runtime.
    constant_methods = {'my_const_function':torch.zeros(3,3)}

    # Export process 
    # I have not yet played with quantization or backends, there should not be issues.
    # I hope...
    exporter.export()
    exporter.to_edge(constant_methods=constant_methods)
    exporter.to_executorch()
    exporter.save(output_dir, "stateful_model") # Also saves a ton of diagnostic info.

C++

#include "ExecuTools/shared_memory_manager.h"
#include "ExecuToolsTestDirs.h"
#include <executorch/extension/module/module.h>
#include <executorch/extension/tensor/tensor_ptr_maker.h>
#include <executorch/runtime/platform/log.h>
using executorch::extension::Module;
using executorch::runtime::Error;
using executorch::runtime::Program;

ET_NODISCARD Error run_program() {
  // create a module:
  Module MultiEntryModule(EXECUTOOLS_PYTHON_ARTIFACT_DIR
                          "/StatefulModel/stateful_model.pte",
                          Module::LoadMode::MmapUseMlock, nullptr);
  // force load the program:
  ET_CHECK_OK_OR_RETURN_ERROR(MultiEntryModule.load(), "Failed to load module");
  auto program = MultiEntryModule.program();
  // validate that the program is loaded:
  ET_CHECK_OR_RETURN_ERROR(program != nullptr, InvalidProgram,
                           "Program is not loaded");

  // use the shared_ptr program to construct a shared memory manager:
  executools::SharedMemoryManager shared_memory_manager(program);

  ET_CHECK_OK_OR_RETURN_ERROR(
      MultiEntryModule.load_method(
          "set_cache", nullptr,
          shared_memory_manager.get_allocator("set_cache").get()),
      "Failed to load set_cache");
  ET_CHECK_OK_OR_RETURN_ERROR(
      MultiEntryModule.load_method(
          "get_cache", nullptr,
          shared_memory_manager.get_allocator("get_cache").get()),
      "Failed to load get_cache");

  const int batch_size = 10;
  const int seq_len = 20;

  // lambda function to check if the two tensors are the same:
  auto tensors_equal = [](const executorch::extension::TensorPtr &t1,
                          const executorch::extension::TensorPtr &t2,
                          size_t size) {
    auto ptr1 = t1->const_data_ptr<float>();
    auto ptr2 = t2->const_data_ptr<float>();
    for (size_t i = 0; i < size; i++) {
      if (ptr1[i] != ptr2[i]) {
        return false;
      }
    }
    return true;
  };

  auto set_input = executorch::extension::ones({batch_size, seq_len});
  auto get_input = executorch::extension::zeros({batch_size, seq_len});
  // the tensors are not equal.
  if (tensors_equal(set_input, get_input, batch_size * seq_len)) {
    return Error::InvalidState;
  }
  // Run the model, set the cache with the value from the input.
  auto none_result_1 =
      ET_UNWRAP(MultiEntryModule.execute("set_cache", set_input));
  // Run the model, get the cache, is returned into the 
  auto none_result_2 =
      ET_UNWRAP(MultiEntryModule.execute("get_cache", get_input));

  // Get input has now been filled with ones that were set into the cache.
  if (!tensors_equal(set_input, get_input, batch_size * seq_len)) {
    return Error::InvalidState;
  }
  return Error::Ok;
}

int main() {
  if (run_program() != Error::Ok) {
    ET_LOG(Error, "Test failed");
    return 1;
  }
  ET_LOG(Info, "Test passed");
  return 0;
}

High level:

Currently most of the wrapping logic is sitting in my own custom repo here:

Somewhat based on comments from #7458. While there were some changes to executorch itself, it is mostly a wrapper which that has a custom memory pass to place the shared state buffers into a common memory address.

Specifically, the main steps in the process are:

  1. MultiEntryPointExporter gets set by the user with data on what to export, this includes shared buffers and module methods.
  2. MultiEntryPointExportermonkey patches the forward method, (this seems to be an undocumented capability, but works reliably so far...)
  3. Create a synthetic method that just mutates the values of all the shared buffers, this gives us a method to base our shared memory plan on.
  4. Run torch.export to export the model.
  5. On the resulting graph, for every shared buffer in the method, after the placeholders we inject a self.shared_buffer.copy_(self.shared_buffer). This forces the buffer to be treated as mutable, even if the method just reads it.
  6. We run to_edge.
  7. Since at this point all shared buffers are registered as mutable in the graph signature, we remove any copy operations where both the source and target are the same and pointing to a shared buffer. (This removes the op that we added earlier, during to_edge it may have gotten removed, but in the case where there is a later mutation, it would not have been.)
  8. We run our memory planning via to_executorch on our synthetic buffer init method, placing all shared buffers in mem_id 2. This mem_id 2 plan is saved for future reference.
  9. we rerun the memory planning on all methods, but after the planner has run, we overwrite the memory location for all objects in mem_id 2 to use the reference memory plan. This ensures that all objects are placed at the same offsets in the buffer.## related pull requests

Related issues/pull requests (I am still in the process of uploading these, will add issue numbers as I get them.):

#7515 -> #7810 -> Fix copy back buffers are not the same

Export of function with no user input fails: #8031
Allow for multi-method exports to use the ETRecord and the inspector-cli: #8336
Add ability to pass hierarchical allocator to runtime Module method: #8032
Allow to override default event tracing names: #8033

cc @JacobSzwejbka @angelayi

@cptspacemanspiff cptspacemanspiff changed the title Multi Entry points and Shared State implementation Request for comments/feedback. Multiple Entrypoints and Shared State implementation; Request for comments/feedback. Jan 29, 2025
@cptspacemanspiff
Copy link
Contributor Author

@JacobSzwejbka @kimishpatel Thoughts?

@manuelcandales manuelcandales added module: exir Issues related to Export IR and the code under exir/ triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jan 29, 2025
@JacobSzwejbka
Copy link
Contributor

MultiEntryPointExportermonkey patches the forward method, (this seems to be an undocumented capability, but works reliably so far...)

Ah yeah this is the recommended way to do this for now. I've brought it up to export before cc @angelayi because it has an obvious failing of what happens if one of the functions you are monkey patching internally calls forward.

@dbort dbort added the rfc Request for comment and feedback on a post, proposal, etc. label Jan 29, 2025
@cptspacemanspiff
Copy link
Contributor Author

MultiEntryPointExportermonkey patches the forward method, (this seems to be an undocumented capability, but works reliably so far...)

Ah yeah this is the recommended way to do this for now. I've brought it up to export before cc @angelayi because it has an obvious failing of what happens if one of the functions you are monkey patching internally calls forward.

Yeah, I was thinking about this from your comment earlier. For at least the use case trying to export hugging face models, there needs to be a wrapper anyway, in order to have the static-kv cache implementation be a part of the export model state. If you go with the assumption that we will be writing a 'lightweight wrapper' for the model class, we just make sure that that wrapper does not have a forward method, then I am pretty sure it should be safe. Alternatively, one could auto-generate a wrapper class, placing whatever model you want to self.model, and then have methods call into it.

@JacobSzwejbka
Copy link
Contributor

My initial impression is that the individual prs all seem good. The cpp pattern looks good (SharedMemoryManager). I'll take a deeper look at the python exporter tomorrow. I'm a little hesitant to introduce a new wrapper on top of things in python. I'd like to see if theres a way we can take the solutions in your exporter and just integrate them into EdgeProgramManager.

@angelayi
Copy link
Contributor

This is very cool, and thank you so much for writing up your process and the hurdles you had to get through export -- it was a great read!

Ah yeah this is the recommended way to do this for now. I've brought it up to export before cc @angelayi because it has an obvious failing of what happens if one of the functions you are monkey patching internally calls forward.

I agree that export does not have a good model for multiple entrypoints, but for now, having a wrapper class generally works for exporting methods.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: exir Issues related to Export IR and the code under exir/ rfc Request for comment and feedback on a post, proposal, etc. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Status: In progress
Development

No branches or pull requests

5 participants