Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[ESI][BSP] Simple ToHost DMA engine #8207

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
87 changes: 77 additions & 10 deletions frontends/PyCDE/integration_test/esitester.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,14 @@
# RUN: %PYTHON% %s %t cosim 2>&1
# RUN: esi-cosim.py --source %t -- esitester cosim env wait | FileCheck %s
# RUN: ESI_COSIM_MANIFEST_MMIO=1 esi-cosim.py --source %t -- esiquery cosim env info
# RUN: esi-cosim.py --source %t -- esitester cosim env dmatest
# RUN: esi-cosim.py --source %t -- esitester cosim env hostmemtest
# RUN: esi-cosim.py --source %t -- esitester cosim env dmawritetest

import pycde
from pycde import AppID, Clock, Module, Reset, generator, modparams
from pycde.bsp import get_bsp
from pycde.constructs import Counter, Reg, Wire
from pycde.esi import CallService
from pycde.constructs import ControlReg, Counter, Reg, Wire
from pycde.esi import CallService, ChannelService
import pycde.esi as esi
from pycde.types import Bits, Channel, UInt

Expand Down Expand Up @@ -181,19 +182,85 @@ def construct(ports):
return WriteMem


def ToHostDMATest(width: int):
"""Construct a module that sends a cycle count over a channel to the host the
specified number of times."""

class ToHostDMATest(Module):
clk = Clock()
rst = Reset()

@generator
def construct(ports):
count_reached = Wire(Bits(1))
count_valid = Wire(Bits(1))
out_xact = Wire(Bits(1))
cycle_counter = Counter(width)(clk=ports.clk,
rst=ports.rst,
clear=Bits(1)(0),
increment=Bits(1)(1))

write_cntr_incr = ~count_reached & count_valid & out_xact
write_counter = Counter(32)(clk=ports.clk,
rst=ports.rst,
clear=count_reached,
increment=write_cntr_incr)
num_writes = write_counter.out

# Get the MMIO space for commands.
cmd_chan_wire = Wire(Channel(esi.MMIOReadWriteCmdType))
resp_ready_wire = Wire(Bits(1))
cmd, cmd_valid = cmd_chan_wire.unwrap(resp_ready_wire)
mmio_xact = cmd_valid & resp_ready_wire
response_data = Bits(64)(0)
response_chan, response_ready = Channel(response_data.type).wrap(
response_data, cmd_valid)
resp_ready_wire.assign(response_ready)

# write_count is the specified number of times to send the cycle count.
write_count_ce = mmio_xact & cmd.write & (cmd.offset == UInt(32)(0))
write_count = cmd.data.as_uint().reg(clk=ports.clk,
rst=ports.rst,
rst_value=0,
ce=write_count_ce)
count_reached.assign(num_writes == write_count)
count_valid.assign(
ControlReg(clk=ports.clk,
rst=ports.rst,
asserts=[write_count_ce],
resets=[count_reached]))

mmio_rw = esi.MMIO.read_write(appid=AppID("ToHostDMATest"))
mmio_rw_cmd_chan = mmio_rw.unpack(data=response_chan)['cmd']
cmd_chan_wire.assign(mmio_rw_cmd_chan)

# Output channel.
out_channel, out_channel_ready = Channel(UInt(width)).wrap(
cycle_counter.out, count_valid)
out_xact.assign(out_channel_ready & count_valid)
ChannelService.to_host(name=AppID("out"), chan=out_channel)

return ToHostDMATest


class EsiTesterTop(Module):
clk = Clock()
rst = Reset()

@generator
def construct(ports):
PrintfExample(clk=ports.clk, rst=ports.rst)
ReadMem(32)(appid=esi.AppID("readmem", 32), clk=ports.clk, rst=ports.rst)
ReadMem(64)(appid=esi.AppID("readmem", 64), clk=ports.clk, rst=ports.rst)
ReadMem(96)(appid=esi.AppID("readmem", 96), clk=ports.clk, rst=ports.rst)
WriteMem(32)(appid=esi.AppID("writemem", 32), clk=ports.clk, rst=ports.rst)
WriteMem(64)(appid=esi.AppID("writemem", 64), clk=ports.clk, rst=ports.rst)
WriteMem(96)(appid=esi.AppID("writemem", 96), clk=ports.clk, rst=ports.rst)
# PrintfExample(clk=ports.clk, rst=ports.rst)
for width in [32, 64, 96, 128, 256, 384, 504, 512]:
# for width in [504, 512]:
ReadMem(width)(appid=esi.AppID("readmem", width),
clk=ports.clk,
rst=ports.rst)
WriteMem(width)(appid=esi.AppID("writemem", width),
clk=ports.clk,
rst=ports.rst)
ToHostDMATest(width)(appid=esi.AppID("tohostdmatest", width),
clk=ports.clk,
rst=ports.rst)


if __name__ == "__main__":
Expand Down
1 change: 1 addition & 0 deletions frontends/PyCDE/src/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -43,6 +43,7 @@ declare_mlir_python_sources(PyCDESources
pycde/bsp/__init__.py
pycde/bsp/common.py
pycde/bsp/cosim.py
pycde/bsp/dma.py
pycde/bsp/xrt.py
pycde/bsp/Makefile.xrt.mk
pycde/bsp/xrt_package.tcl
Expand Down
3 changes: 0 additions & 3 deletions frontends/PyCDE/src/pycde/bsp/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,5 @@ def get_bsp(name: Optional[str] = None):
return CosimBSP_DMA
elif name == "xrt":
return XrtBSP
elif name == "xrt_cosim":
from .xrt import XrtCosimBSP
return XrtCosimBSP
else:
raise ValueError(f"Unknown bsp type: {name}")
8 changes: 6 additions & 2 deletions frontends/PyCDE/src/pycde/bsp/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,7 +208,8 @@ def build_read(ports, manifest_loc: int, table: Dict[int, AssignableSignal]):
client_cmd_channels = esi.ChannelDemux(
sel=sel_bits.pad_or_truncate(read_clients_clog2),
input=client_cmd_chan,
num_outs=len(table))
num_outs=len(table),
instance_name="client_cmd_demux")
client_data_channels = []
for (idx, offset) in enumerate(sorted(table.keys())):
bundle_wire = table[offset]
Expand Down Expand Up @@ -512,7 +513,10 @@ def build(ports):
# to complete the transmission.
num_chunks = TaggedWriteGearboxImpl.num_chunks
num_chunks_idx_bitwidth = clog2(num_chunks)
padding_numbits = output_bitwidth - (input_bitwidth % output_bitwidth)
if input_bitwidth % output_bitwidth == 0:
padding_numbits = 0
else:
padding_numbits = output_bitwidth - (input_bitwidth % output_bitwidth)
assert padding_numbits % 8 == 0, "Padding must be a multiple of 8."
client_data_padded = BitsSignal.concat(
[Bits(padding_numbits)(0), client_data])
Expand Down
4 changes: 2 additions & 2 deletions frontends/PyCDE/src/pycde/bsp/cosim.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from os import write
from typing import Dict, Tuple, Type

from ..signals import BundleSignal
Expand All @@ -16,6 +15,7 @@

from .common import (ChannelEngineService, ChannelHostMem, ChannelMMIO,
DummyFromHostEngine, DummyToHostEngine)
from .dma import OneItemBuffersToHost

from ..circt import ir
from ..circt.dialects import esi as raw_esi
Expand Down Expand Up @@ -52,7 +52,7 @@ class ESI_Cosim_UserTopWrapper(Module):
def build(ports):
user_module(clk=ports.clk, rst=ports.rst)
if emulate_dma:
ChannelEngineService(DummyToHostEngine, DummyFromHostEngine)(
ChannelEngineService(OneItemBuffersToHost, DummyFromHostEngine)(
None,
appid=esi.AppID("__channel_engines"),
clk=ports.clk,
Expand Down
90 changes: 90 additions & 0 deletions frontends/PyCDE/src/pycde/bsp/dma.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
# Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
# See https://llvm.org/LICENSE.txt for license information.
# SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception

from __future__ import annotations

from ..common import AppID, Clock, Input, InputChannel, Reset
from ..constructs import Mux, NamedWire, Reg, Wire
from ..module import modparams, generator
from ..types import Bits, Channel, StructType, Type, UInt
from ..support import clog2
from .. import esi


@modparams
def OneItemBuffersToHost(client_type: Type):

class OneItemBuffersToHost(esi.EngineModule):

@property
def TypeName(self):
return "OneItemBuffersToHost"

clk = Clock()
rst = Reset()
input_channel = InputChannel(client_type)

mmio = Input(esi.MMIO.read_write.type)

xfer_data_type = StructType([("valid", Bits(8)),
("client_data", client_type)])
hostmem = Input(esi.HostMem.write_req_bundle_type(xfer_data_type))

@generator
def build(ports):
clk = ports.clk
rst = ports.rst

mmio_resp_chan = Wire(Channel(Bits(64)))

mmio_rw = ports.mmio
mmio_cmd_chan_raw = mmio_rw.unpack(data=mmio_resp_chan)['cmd']
mmio_cmd_chan, mmio_cmd_fork_resp = mmio_cmd_chan_raw.fork(clk, rst)

mmio_resp_data = Wire(Bits(64), "mmio_resp_data")
# Always respond 0.
mmio_resp_data.assign(0)
mmio_resp_chan.assign(
mmio_cmd_fork_resp.transform(lambda _: mmio_resp_data))

_, _, mmio_cmd = mmio_cmd_chan.snoop()
num_sinks = 2
mmio_offset_words = NamedWire((mmio_cmd.offset.as_bits()[3:]).as_uint(),
"mmio_offset_words")
addr_above = mmio_offset_words >= UInt(32)(num_sinks)
addr_is_zero = mmio_offset_words == UInt(32)(0)
force_to_null = NamedWire(addr_above | ~addr_is_zero | mmio_cmd.write,
"force_to_null")
cmd_sink_sel = Mux(force_to_null,
Bits(clog2(num_sinks))(0),
mmio_offset_words.as_bits()[:clog2(num_sinks)])
mmio_data_only_chan = mmio_cmd_chan.transform(lambda m: m.data)
mailbox_names = ["null", "buffer_loc"]
demuxed = esi.ChannelDemux(mmio_data_only_chan, cmd_sink_sel, num_sinks)
mailbox_mod = esi.Mailbox(Bits(64))
mailboxes = [
mailbox_mod(clk=clk,
rst=rst,
input=c,
instance_name="mailbox_" + name)
for name, c in zip(mailbox_names, demuxed)
]
[_, buffer_loc] = mailboxes

next_buffer_loc_chan = buffer_loc.output

hostwr_type = esi.HostMem.write_req_channel_type(
OneItemBuffersToHost.xfer_data_type)
hostwr_joined = Channel.join(next_buffer_loc_chan, ports.input_channel)
hostwr = hostwr_joined.transform(lambda joined: hostwr_type({
"address": joined.a.as_uint(),
"tag": 0,
"data": {
"valid": 1,
"client_data": joined.b
},
}))
ports.hostmem.unpack(req=hostwr)

return OneItemBuffersToHost
Loading
Loading