Skip to content

Commit

Permalink
Batch incoming DOCA raw packet data (#1731)
Browse files Browse the repository at this point in the history
* Fix bug in `DocaSourceStage` ring buffer initialization (#1820)
* Incoming packet data is buffered into a `mrc::BufferedChannel`, limiting the amount of times the convert stage needs to acquire the GIL.
* Launch the `_packet_gather_payload_kernel` kernel with a 2D grid, treating the byte-offset as the second axis.
* No longer convert 32bit int IP data into strings in the gather kernel, instead buffer 32bit int data, and let cuDF convert the data all in one big pass prior to constructing the cuDF DataFrame.
* No longer outputs fixed data sizes.
* Perform the sizes to offsets calculation with MatX rather than using a deprecated cuDF method.
* Add `PacketDataBuffer` struct to wrap the three pieces of data needed to be buffered: header, payload and payload sizes.
* Split `doca_stages.hpp` into `doca_source_stage.hpp` and `doca_convert_stage.hpp`
* Rename `doca_convert.cpp`->`doca_convert_stage.cpp` and `doca_source.cpp`->`doca_source_stage.cpp`
* Add command line flags to `examples/doca/run_udp_convert.py` and `examples/doca/vdb_realtime/sender/send.py`
* Move constants in `morpheus/_lib/doca/include/morpheus/doca/common.hpp` into the `morpheus::doca` namespace (was in global).
* Fix CI checks for DOCA based code
* Remove unused code

Closes #1820

This PR punts on fixing #1827

## By Submitting this PR I confirm:
- I am familiar with the [Contributing Guidelines](https://github.com/nv-morpheus/Morpheus/blob/main/docs/source/developer_guide/contributing.md).
- When the PR is ready for review, new or existing tests cover these changes.
- When the PR is ready for review, the documentation is up to date with these changes.

Authors:
  - David Gardner (https://github.com/dagardner-nv)

Approvers:
  - Michael Demoret (https://github.com/mdemoret-nv)

URL: #1731
  • Loading branch information
dagardner-nv authored Aug 1, 2024
1 parent 5e0d920 commit 68bbaa9
Show file tree
Hide file tree
Showing 22 changed files with 1,042 additions and 866 deletions.
8 changes: 4 additions & 4 deletions examples/doca/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ We can see the GPU's PCIe address is `cf:00.0`, and we can infer from the above
In case of UDP traffic, the sample will launch a simple pipeline with the DOCA Source Stage followed by a Monitor Stage to report number of received packets.

```
python3 ./examples/doca/run_udp_raw.py --nic_addr 17:00.1 --gpu_addr ca:00.0 --traffic_type udp
python ./examples/doca/run_udp_raw.py --nic_addr 17:00.1 --gpu_addr ca:00.0
```
UDP traffic can be easily sent with nping to the interface where Morpheus is listening:
```
Expand Down Expand Up @@ -118,15 +118,15 @@ DOCA GPUNetIO rate: 100000 pkts [00:12, 10963.39 pkts/s]
As the DOCA Source stage output packets in the new RawMessage format that not all the Morpheus stages may support, there is an additional stage named DOCA Convert Stage which transform the data RawMessage to the Messagemeta format.

```
python3 ./examples/doca/run_udp_convert.py --nic_addr 17:00.1 --gpu_addr ca:00.0 --traffic_type udp
python ./examples/doca/run_udp_convert.py --nic_addr 17:00.1 --gpu_addr ca:00.0
```

## Doca Sensitive Information Detection example for TCP traffic

The DOCA example is similar to the Sensitive Information Detection (SID) example in that it uses the `sid-minibert` model in conjunction with the `TritonInferenceStage` to detect sensitive information. The difference is that the sensitive information we will be detecting is obtained from a live TCP packet stream provided by a `DocaSourceStage`.
To run the example from the Morpheus root directory and capture all TCP network traffic from the given NIC, use the following command and replace the `nic_addr` and `gpu_addr` arguments with your NIC and GPU PCIe addresses.
```
# python examples/doca/run_tcp.py --nic_addr cc:00.1 --gpu_addr cf:00.0 --traffic_type tcp
# python examples/doca/run_tcp.py --nic_addr cc:00.1 --gpu_addr cf:00.0
```
```
====Registering Pipeline====
Expand All @@ -146,7 +146,7 @@ DOCA GPUNetIO rate: 0 pkts [00:03, ? pkts/s]====Registering Pipeline Complete!==
====Starting Pipeline====[00:02, ? pkts/s]
====Pipeline Started====0:02, ? pkts/s]
====Building Segment: linear_segment_0====
Added source: <from-doca-0; DocaSourceStage(nic_pci_address=cc:00.1, gpu_pci_address=cf:00.0)>
Added source: <from-doca-0; DocaSourceStage(nic_pci_address=cc:00.1, gpu_pci_address=cf:00.0, traffic_type=tcp)>
└─> morpheus.MessageMeta
Added stage: <monitor-1; MonitorStage(description=DOCA GPUNetIO rate, smoothing=0.05, unit=pkts, delayed_start=False, determine_count_fn=None, log_level=LogLevels.INFO)>
└─ morpheus.MessageMeta -> morpheus.MessageMeta
Expand Down
79 changes: 72 additions & 7 deletions examples/doca/run_udp_convert.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,21 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import logging
import os

import click

from morpheus.cli.utils import get_log_levels
from morpheus.cli.utils import parse_log_level
from morpheus.config import Config
from morpheus.config import CppConfig
from morpheus.config import PipelineModes
from morpheus.messages import RawPacketMessage
from morpheus.pipeline.linear_pipeline import LinearPipeline
from morpheus.stages.doca.doca_convert_stage import DocaConvertStage
from morpheus.stages.doca.doca_source_stage import DocaSourceStage
from morpheus.stages.general.monitor_stage import MonitorStage
from morpheus.stages.output.write_to_file_stage import WriteToFileStage
from morpheus.utils.logger import configure_logging


Expand All @@ -37,25 +41,86 @@
help="GPU PCI Address",
required=True,
)
def run_pipeline(nic_addr, gpu_addr):
@click.option(
"--num_threads",
default=os.cpu_count(),
type=click.IntRange(min=1),
show_default=True,
help="Number of internal pipeline threads to use.",
)
@click.option(
"--edge_buffer_size",
default=1024 * 16,
type=click.IntRange(min=1),
show_default=True,
help="Size of edge buffers.",
)
@click.option(
"--max_batch_delay_sec",
default=3.0,
type=float,
show_default=True,
help="Maximum amount of time in seconds to buffer incoming packets.",
)
@click.option(
"--buffer_channel_size",
default=None,
type=click.IntRange(min=2),
show_default=True,
help=("Size of the internal buffer channel used by the DocaConvertStage, if None, the value of `--edge_buffer_size`"
" will be used."),
)
@click.option("--log_level",
default="INFO",
type=click.Choice(get_log_levels(), case_sensitive=False),
callback=parse_log_level,
show_default=True,
help="Specify the logging level to use.")
@click.option("--output_file",
default=None,
help="File to output to, if not supplied, the to-file sink will be omitted.")
def run_pipeline(nic_addr: str,
gpu_addr: str,
num_threads: int,
edge_buffer_size: int,
max_batch_delay_sec: float,
buffer_channel_size: int,
log_level: int,
output_file: str | None):
# Enable the default logger
configure_logging(log_level=logging.DEBUG)
configure_logging(log_level=log_level)

CppConfig.set_should_use_cpp(True)

config = Config()
config.mode = PipelineModes.NLP

# Below properties are specified by the command line
config.num_threads = 10
config.edge_buffer_size = 1024
config.num_threads = num_threads
config.edge_buffer_size = edge_buffer_size

pipeline = LinearPipeline(config)

# add doca source stage
pipeline.set_source(DocaSourceStage(config, nic_addr, gpu_addr, 'udp'))
pipeline.add_stage(DocaConvertStage(config))
pipeline.add_stage(MonitorStage(config, description="DOCA GPUNetIO rate", unit='pkts'))

def count_raw_packets(message: RawPacketMessage):
return message.num

pipeline.add_stage(
MonitorStage(config,
description="DOCA GPUNetIO Raw rate",
unit='pkts',
determine_count_fn=count_raw_packets,
delayed_start=True))

pipeline.add_stage(
DocaConvertStage(config, max_batch_delay_sec=max_batch_delay_sec, buffer_channel_size=buffer_channel_size))

pipeline.add_stage(MonitorStage(config, description="Convert rate", unit='pkts', delayed_start=True))

if output_file is not None:
pipeline.add_stage(WriteToFileStage(config, filename=output_file, overwrite=True))

# Build the pipeline here to see types in the vizualization
pipeline.build()
Expand Down
6 changes: 5 additions & 1 deletion examples/doca/run_udp_raw.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,11 @@ def count_raw_packets(message: RawPacketMessage):
# add doca source stage
pipeline.set_source(DocaSourceStage(config, nic_addr, gpu_addr, 'udp'))
pipeline.add_stage(
MonitorStage(config, description="DOCA GPUNetIO rate", unit='pkts', determine_count_fn=count_raw_packets))
MonitorStage(config,
description="DOCA GPUNetIO rate",
unit='pkts',
determine_count_fn=count_raw_packets,
delayed_start=True))

# Build the pipeline here to see types in the vizualization
pipeline.build()
Expand Down
79 changes: 70 additions & 9 deletions examples/doca/vdb_realtime/sender/send.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,25 +17,86 @@
import glob
import os

import click
from scapy.all import IP # pylint: disable=no-name-in-module
from scapy.all import TCP
from scapy.all import UDP # pylint: disable=no-name-in-module
from scapy.all import RandShort
from scapy.all import Raw
from scapy.all import send

DEFAULT_DPORT = 5001
MORPHEUS_ROOT = os.environ['MORPHEUS_ROOT']

def main():
os.chdir("dataset")
for file in glob.glob("*.txt"):
with open(file, 'r', encoding='utf-8') as fp:

def get_data(input_glob: str) -> list[str]:
data = []
for file in glob.glob(input_glob):
with open(file, 'r', encoding='utf-8') as fh:
while True:
content = fp.read(1024)
content = fh.read(1024)
if not content:
break
pkt = IP(src="192.168.2.28", dst="192.168.2.27") / UDP(sport=RandShort(),
dport=5001) / Raw(load=content.encode('utf-8'))
print(pkt)
send(pkt, iface="enp202s0f0np0")

data.append(content)

return data


def send_data(data: list[str],
dst_ip: str,
dport: int = DEFAULT_DPORT,
iface: str | None = None,
src_ip: str | None = None,
sport: int | None = None,
net_type: str = 'UDP'):
if net_type == 'UDP':
net_type_cls = UDP
else:
net_type_cls = TCP

if sport is None:
sport = RandShort()

ip_kwargs = {"dst": dst_ip}
if src_ip is not None:
ip_kwargs["src"] = src_ip

packets = [
IP(**ip_kwargs) / net_type_cls(sport=sport, dport=dport) / Raw(load=content.encode('utf-8')) for content in data
]

send_kwargs = {}
if iface is not None:
send_kwargs["iface"] = iface

send(packets, **send_kwargs)


@click.command()
@click.option("--iface", help="Ethernet device to use, useful for systems with multiple NICs", required=False)
@click.option("--src_ip", help="Source IP to send from, useful for systems with multiple IPs", required=False)
@click.option("--dst_ip", help="Destination IP to send to", required=True)
@click.option("--dport", help="Destination port", type=int, default=DEFAULT_DPORT)
@click.option("--sport",
help="Source port, if undefined a random port will be used",
type=int,
default=None,
required=False)
@click.option("--net_type", type=click.Choice(['TCP', 'UDP'], case_sensitive=False), default='UDP')
@click.option("--input_data_glob",
type=str,
default=os.path.join(MORPHEUS_ROOT, 'examples/doca/vdb_realtime/sender/dataset/*.txt'),
help="Input filepath glob pattenr matching the data to send.")
def main(iface: str | None,
src_ip: str | None,
dst_ip: str,
dport: int,
sport: int | None,
net_type: str,
input_data_glob: str):
data = get_data(input_data_glob)
send_data(data=data, dst_ip=dst_ip, dport=dport, iface=iface, src_ip=src_ip, sport=sport, net_type=net_type)


if __name__ == "__main__":
Expand Down
43 changes: 38 additions & 5 deletions examples/doca/vdb_realtime/vdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,40 @@ def build_milvus_service(embedding_size):
help="GPU PCI Address",
required=True,
)
def run_pipeline(nic_addr, gpu_addr):
@click.option(
"--triton_server_url",
type=str,
default="localhost:8001",
show_default=True,
help="Triton server URL.",
)
@click.option(
"--embedding_model_name",
required=True,
default='all-MiniLM-L6-v2',
show_default=True,
help="The name of the model that is deployed on Triton server",
)
@click.option(
"--vector_db_uri",
type=str,
default="http://localhost:19530",
show_default=True,
help="URI for connecting to Vector Database server.",
)
@click.option(
"--vector_db_resource_name",
type=str,
default="vdb_doca",
show_default=True,
help="The identifier of the resource on which operations are to be performed in the vector database.",
)
def run_pipeline(nic_addr: str,
gpu_addr: str,
triton_server_url: str,
embedding_model_name: str,
vector_db_uri: str,
vector_db_resource_name: str):
# Enable the default logger
configure_logging(log_level=logging.DEBUG)

Expand Down Expand Up @@ -110,18 +143,18 @@ def run_pipeline(nic_addr, gpu_addr):
pipeline.add_stage(
TritonInferenceStage(config,
force_convert_inputs=True,
model_name="all-MiniLM-L6-v2",
server_url="localhost:8001",
model_name=embedding_model_name,
server_url=triton_server_url,
use_shared_memory=True))
pipeline.add_stage(MonitorStage(config, description="Embedding rate", unit='pkts'))

pipeline.add_stage(
WriteToVectorDBStage(config,
resource_name="vdb_doca",
resource_name=vector_db_resource_name,
batch_size=16896,
recreate=True,
service="milvus",
uri="http://localhost:19530",
uri=vector_db_uri,
resource_schemas={"vdb_doca": build_milvus_service(384)}))
pipeline.add_stage(MonitorStage(config, description="Upload rate", unit='docs'))

Expand Down
11 changes: 7 additions & 4 deletions morpheus/_lib/doca/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -20,15 +20,17 @@ set(doca_ROOT "/opt/mellanox/doca")
find_package(doca REQUIRED)

add_library(morpheus_doca

# Keep these sorted!
src/doca_context.cpp
src/doca_convert_kernel.cu
src/doca_convert.cpp
src/doca_convert_stage.cpp
src/doca_rx_pipe.cpp
src/doca_rx_queue.cpp
src/doca_semaphore.cpp
src/doca_source_kernel.cu
src/doca_source.cpp
src/doca_source_stage.cpp
src/packet_data_buffer.cpp
src/rte_context.cpp
)

Expand All @@ -38,14 +40,15 @@ target_include_directories(morpheus_doca
PUBLIC
$<BUILD_INTERFACE:${CMAKE_CURRENT_SOURCE_DIR}/include>
$<INSTALL_INTERFACE:include>
$<BUILD_INTERFACE:${doca_ROOT}/include>
)

target_link_libraries(morpheus_doca
PRIVATE
doca::doca
matx::matx
PUBLIC
${PROJECT_NAME}::morpheus

)

# Ideally, we dont use glob here. But there is no good way to guarantee you dont miss anything like *.cpp
Expand Down Expand Up @@ -79,7 +82,7 @@ set_target_properties(morpheus_doca
CUDA_SEPARABLE_COMPILATION ON
)

if (MORPHEUS_PYTHON_INPLACE_BUILD)
if(MORPHEUS_PYTHON_INPLACE_BUILD)
morpheus_utils_inplace_build_copy(morpheus_doca ${CMAKE_CURRENT_SOURCE_DIR})
endif()

Expand Down
3 changes: 2 additions & 1 deletion morpheus/_lib/doca/__init__.pyi
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from __future__ import annotations
import morpheus._lib.doca
import typing
import datetime
import morpheus._lib.messages
import mrc.core.segment

Expand All @@ -11,7 +12,7 @@ __all__ = [


class DocaConvertStage(mrc.core.segment.SegmentObject):
def __init__(self, builder: mrc.core.segment.Builder, name: str) -> None: ...
def __init__(self, builder: mrc.core.segment.Builder, name: str, max_batch_delay: datetime.timedelta = datetime.timedelta(microseconds=500000), max_batch_size: int = 40960, buffer_channel_size: int = 1024) -> None: ...
pass
class DocaSourceStage(mrc.core.segment.SegmentObject):
def __init__(self, builder: mrc.core.segment.Builder, name: str, nic_pci_address: str, gpu_pci_address: str, traffic_type: str) -> None: ...
Expand Down
Loading

0 comments on commit 68bbaa9

Please sign in to comment.