Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PROTON] Add proton.state utility #5110

Merged
merged 17 commits into from
Nov 20, 2024
23 changes: 23 additions & 0 deletions third_party/proton/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,29 @@ proton-viewer -h

## Advanced features

### State annotation

In addition to `proton.scope`, we can also customize the call path of each GPU operation using `proton.state`.

`state` is different from `scope` in several ways:

1. State is not recursive; each operation can have only a single state. Inner most state will overwrite the outer most state.
2. A states is a suffix, meaning that the original call path will append a state above the name of each kernel.
3. State is compatible with both Python and shadow contexts.

The following example demonstrates a basic use of state:

```python
with proton.scope("test"):
with proton.state("state0"):
with proton.scope("test0"):
foo0[1,](x, y)
with proton.scope("test1"):
foo1[1,](x, y)
```

The call path of `foo1` will be `test->test1->state0`.

### Instrumentation (experimental)

In addition to profiling, Proton also incorporates MLIR/LLVM based compiler instrumentation passes to get Triton level analysis
Expand Down
7 changes: 7 additions & 0 deletions third_party/proton/csrc/Proton.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,13 @@ void initProton(pybind11::module &&m) {
SessionManager::instance().exitOp(Scope(scopeId, name));
});

m.def("enter_state", [](const std::string &state) {
SessionManager::instance().setState(state);
});

m.def("exit_state",
[]() { SessionManager::instance().setState(std::nullopt); });

m.def("add_metrics",
[](size_t scopeId,
const std::map<std::string, MetricValueType> &metrics) {
Expand Down
16 changes: 15 additions & 1 deletion third_party/proton/csrc/include/Context/Context.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#include <limits>
#include <map>
#include <mutex>
#include <optional>
#include <string>
#include <vector>

Expand All @@ -31,7 +32,20 @@ class ContextSource {
public:
ContextSource() = default;
virtual ~ContextSource() = default;
virtual std::vector<Context> getContexts() = 0;

std::vector<Context> getContexts() {
auto contexts = getContextsImpl();
if (state.has_value()) {
contexts.push_back(state.value());
}
return contexts;
}

void setState(std::optional<Context> state) { ContextSource::state = state; }

protected:
virtual std::vector<Context> getContextsImpl() = 0;
static thread_local std::optional<Context> state;
};

/// A scope is a context with a unique identifier.
Expand Down
5 changes: 4 additions & 1 deletion third_party/proton/csrc/include/Context/Python.h
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,10 @@ namespace proton {
/// Unwind the Python stack and early return a list of contexts.
class PythonContextSource : public ContextSource {
public:
std::vector<Context> getContexts() override;
PythonContextSource() = default;

private:
std::vector<Context> getContextsImpl() override;
};

} // namespace proton
Expand Down
3 changes: 1 addition & 2 deletions third_party/proton/csrc/include/Context/Shadow.h
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@ class ShadowContextSource : public ContextSource, public ScopeInterface {
public:
ShadowContextSource() = default;

std::vector<Context> getContexts() override { return contextStack; }

void enterScope(const Scope &scope) override;

void exitScope(const Scope &scope) override;

private:
std::vector<Context> getContextsImpl() override { return contextStack; }
std::vector<Context> contextStack;
};

Expand Down
4 changes: 4 additions & 0 deletions third_party/proton/csrc/include/Session/Session.h
Original file line number Diff line number Diff line change
Expand Up @@ -95,6 +95,8 @@ class SessionManager : public Singleton<SessionManager> {
const std::map<std::string, MetricValueType> &metrics,
bool aggregable);

void setState(std::optional<Context> context);

private:
std::unique_ptr<Session> makeSession(size_t id, const std::string &path,
const std::string &profilerName,
Expand Down Expand Up @@ -146,6 +148,8 @@ class SessionManager : public Singleton<SessionManager> {
std::map<ScopeInterface *, size_t> scopeInterfaceCounts;
// op -> active count
std::map<OpInterface *, size_t> opInterfaceCounts;
// context source -> active count
std::map<ContextSource *, size_t> contextSourceCounts;
};

} // namespace proton
Expand Down
3 changes: 3 additions & 0 deletions third_party/proton/csrc/lib/Context/Context.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,9 @@

namespace proton {

/*static*/ thread_local std::optional<Context> ContextSource::state =
std::nullopt;

std::atomic<size_t> Scope::scopeIdCounter{1};

/*static*/ thread_local std::map<ThreadLocalOpInterface *, bool>
Expand Down
2 changes: 1 addition & 1 deletion third_party/proton/csrc/lib/Context/Python.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ std::string unpackPyobject(PyObject *pyObject) {

} // namespace

std::vector<Context> PythonContextSource::getContexts() {
std::vector<Context> PythonContextSource::getContextsImpl() {
pybind11::gil_scoped_acquire gil;

PyFrameObject *frame = PyEval_GetFrame();
Expand Down
12 changes: 12 additions & 0 deletions third_party/proton/csrc/lib/Session/Session.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -111,6 +111,7 @@ void SessionManager::activateSessionImpl(size_t sessionId) {
sessions[sessionId]->activate();
registerInterface<ScopeInterface>(sessionId, scopeInterfaceCounts);
registerInterface<OpInterface>(sessionId, opInterfaceCounts);
registerInterface<ContextSource>(sessionId, contextSourceCounts);
}

void SessionManager::deActivateSessionImpl(size_t sessionId) {
Expand All @@ -122,6 +123,7 @@ void SessionManager::deActivateSessionImpl(size_t sessionId) {
sessions[sessionId]->deactivate();
unregisterInterface<ScopeInterface>(sessionId, scopeInterfaceCounts);
unregisterInterface<OpInterface>(sessionId, opInterfaceCounts);
unregisterInterface<ContextSource>(sessionId, contextSourceCounts);
}

void SessionManager::removeSession(size_t sessionId) {
Expand Down Expand Up @@ -226,4 +228,14 @@ void SessionManager::addMetrics(
}
}

void SessionManager::setState(std::optional<Context> context) {
std::shared_lock<std::shared_mutex> lock(mutex);
for (auto iter : contextSourceCounts) {
auto [contextSource, count] = iter;
if (count > 0) {
contextSource->setState(context);
}
}
}

} // namespace proton
1 change: 1 addition & 0 deletions third_party/proton/proton/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# flake8: noqa
from .scope import scope, enter_scope, exit_scope
from .state import state, enter_state, exit_state
from .profile import (
start,
activate,
Expand Down
5 changes: 3 additions & 2 deletions third_party/proton/proton/hook.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from .state import enter_state, exit_state
from .scope import enter_scope, exit_scope
from triton.compiler import CompiledKernel, LazyDict

Expand All @@ -10,9 +11,9 @@ class TritonHook:

@staticmethod
def enter(lazy_dict: LazyDict) -> None:
enter_scope(COMPUTE_METADATA_SCOPE_NAME)
enter_state(COMPUTE_METADATA_SCOPE_NAME)
metadata = lazy_dict.get()
exit_scope()
exit_state()
fn_metrics = {k: metadata[k] for k in TritonHook.metrics if k in metadata}
enter_scope(metadata["name"], triton_op=True, metrics=fn_metrics)

Expand Down
61 changes: 61 additions & 0 deletions third_party/proton/proton/state.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from triton._C.libproton import proton as libproton
from .flags import get_profiling_on
from functools import wraps


class state:
"""
A context manager and decorator for entering and exiting a state.

Usage:
context manager:
```python
with proton.state("test0"):
foo[1,](x, y)
```

decorator:
```python
@proton.state("test0")
def foo(x, y):
...
```

Args:
name (str): The name of the state.
"""

def __init__(self, name: str) -> None:
self.name = name

def __enter__(self):
if not get_profiling_on():
return self
libproton.enter_state(self.name)
return self

def __exit__(self, exc_type, exc_value, traceback) -> None:
if not get_profiling_on():
return
libproton.exit_state()

def __call__(self, func):

@wraps(func)
def wrapper(*args, **kwargs):
if get_profiling_on():
libproton.enter_state(self.name)
ret = func(*args, **kwargs)
if get_profiling_on():
libproton.exit_state()
return ret

return wrapper


def enter_state(name: str) -> None:
libproton.enter_state(name)


def exit_state() -> None:
libproton.exit_state()
34 changes: 30 additions & 4 deletions third_party/proton/proton/viewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,8 +28,37 @@ def match_available_metrics(metrics, raw_metrics):
return ret


def remove_metadata(database: json):
# Find all frames with the name COMPUTE_METADATA_SCOPE_NAME, remove them and their children
# Then go up from the metadata node and remove the parent if all its children were
# metadata nodes
def remove_metadata_helper(node):
if "frame" not in node:
return node
if node["frame"]["name"] == COMPUTE_METADATA_SCOPE_NAME:
return None
children = node.get("children", [])
new_children = []
for child in children:
new_child = remove_metadata_helper(child)
if new_child is not None:
new_children.append(new_child)
if len(new_children) > 0 or len(children) == 0:
node["children"] = new_children
return node
return None

new_database = []
for node in database:
new_node = remove_metadata_helper(node)
if new_node is not None:
new_database.append(new_node)
return new_database


def get_raw_metrics(file):
database = json.load(file)
database = remove_metadata(database)
device_info = database.pop(1)
gf = ht.GraphFrame.from_literal(database)
return gf, gf.show_metric_columns(), device_info
Expand Down Expand Up @@ -180,9 +209,6 @@ def filter_frames(gf, include=None, exclude=None, threshold=None, metric=None):
"""
query = NegationQuery(inclusion_query)
gf = gf.filter(query, squash=True)
# filter out metadata computation
query = [{"name": f"^(?!{COMPUTE_METADATA_SCOPE_NAME}).*"}]
gf = gf.filter(query, squash=True)
if threshold:
query = ["*", {metric: f">= {threshold}"}]
gf = gf.filter(query, squash=True)
Expand Down Expand Up @@ -278,7 +304,7 @@ def main():
type=str,
default=None,
help="""Exclude frames that match the given regular expression and their children.
For example, the following command will exclude all paths starting from "test":
For example, the following command will exclude all paths starting from frames that contains "test":
```
proton-viewer -e ".*test.*" path/to/file.json
```
Expand Down
71 changes: 71 additions & 0 deletions third_party/proton/test/examples/triton.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
[
{
"children": [
{
"children": [
{
"children": [
{
"children": [],
"frame": {
"name": "cuda_kernel",
"type": "function"
},
"metrics": {
"count": 1,
"device_id": "0",
"device_type": "CUDA",
"time (ns)": 4064
}
}
],
"frame": {
"name": "__proton_launch_metadata",
"type": "function"
},
"metrics": {}
},
{
"children": [],
"frame": {
"name": "triton_kernel",
"type": "function"
},
"metrics": {
"bytes": 2.0,
"count": 1,
"device_id": "0",
"device_type": "CUDA",
"time (ns)": 1664
}
}
],
"frame": {
"name": "scope",
"type": "function"
},
"metrics": {}
}
],
"frame": {
"name": "ROOT",
"type": "function"
},
"metrics": {
"bytes": 0,
"count": 0,
"time (ns)": 0
}
},
{
"CUDA": {
"0": {
"arch": "86",
"bus_width": 128,
"clock_rate": 1140000,
"memory_clock_rate": 5501000,
"num_sms": 16
}
}
}
]
Loading
Loading