Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pip prod(deps): bump marimo from 0.9.34 to 0.10.0 #3634

Merged
merged 8 commits into from
Jan 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
60 changes: 30 additions & 30 deletions docs/marimo/linalg_snitch.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import marimo

__generated_with = "0.8.20"
__generated_with = "0.10.0"
app = marimo.App(width="medium")


@app.cell
def __():
def _():
import marimo as mo
return (mo,)


@app.cell
def __(mo):
def _(mo):
mo.md(
"""
\
Expand All @@ -26,7 +26,7 @@ def __(mo):


@app.cell
def __():
def _():
# Import all the necessary functionality from xDSL for this notebook
# If you see an error about xdsl not being defined run this cell manually

Expand Down Expand Up @@ -118,7 +118,7 @@ def __():


@app.cell
def __(
def _(
AffineMap,
AffineMapAttr,
Block,
Expand Down Expand Up @@ -201,7 +201,7 @@ def __(


@app.cell
def __(mo):
def _(mo):
min_val = 1
max_val = 10
m = mo.ui.slider(min_val, max_val, value=2, label="M")
Expand All @@ -211,7 +211,7 @@ def __(mo):


@app.cell
def __(k, m, mo, n):
def _(k, m, mo, n):
mo.md(
f"""
We can parametrize the shapes of the matrices operated on:
Expand All @@ -227,7 +227,7 @@ def __(k, m, mo, n):


@app.cell
def __(k, m, mo, n):
def _(k, m, mo, n):
a_shape = (m.value, k.value)
b_shape = (k.value, n.value)
c_shape = (m.value, n.value)
Expand All @@ -244,13 +244,13 @@ def __(k, m, mo, n):


@app.cell
def __(mo):
def _(mo):
mo.md("""### Compiling to RISC-V""")
return


@app.cell
def __(MLContext, get_all_dialects):
def _(MLContext, get_all_dialects):
ctx = MLContext()

for dialect_name, dialect_factory in get_all_dialects().items():
Expand All @@ -259,13 +259,13 @@ def __(MLContext, get_all_dialects):


@app.cell
def __(mo):
def _(mo):
mo.md("""We can take this representation, and lower to RISC-V-specific dialects:""")
return


@app.cell
def __(
def _(
PipelinePass,
convert_arith_to_riscv,
convert_func_to_riscv_func,
Expand Down Expand Up @@ -296,7 +296,7 @@ def __(


@app.cell
def __(mo):
def _(mo):
mo.md(
"""
#### Register allocation
Expand All @@ -308,7 +308,7 @@ def __(mo):


@app.cell
def __(
def _(
CanonicalizePass,
PipelinePass,
RISCVRegisterAllocation,
Expand All @@ -331,7 +331,7 @@ def __(


@app.cell
def __(
def _(
CanonicalizePass,
ConvertRiscvScfToRiscvCfPass,
PipelinePass,
Expand All @@ -354,13 +354,13 @@ def __(


@app.cell
def __(mo):
def _(mo):
mo.md("""This representation of the program in xDSL corresponds ~1:1 to RISC-V assembly, and we can use a helper function to print that out.""")
return


@app.cell
def __(asm_html, mo, riscv_asm_module, riscv_code):
def _(asm_html, mo, riscv_asm_module, riscv_code):
riscv_asm = riscv_code(riscv_asm_module)

mo.md(f"""\
Expand All @@ -373,7 +373,7 @@ def __(asm_html, mo, riscv_asm_module, riscv_code):


@app.cell
def __(mo):
def _(mo):
mo.md(
"""
### Compiling to Snitch
Expand All @@ -385,7 +385,7 @@ def __(mo):


@app.cell
def __(
def _(
PipelinePass,
arith_add_fastmath,
convert_linalg_to_memref_stream,
Expand Down Expand Up @@ -420,13 +420,13 @@ def __(


@app.cell
def __(mo):
def _(mo):
mo.md("""We can then lower this to assembly that includes assembly instructions from the Snitch-extended ISA:""")
return


@app.cell
def __(pipeline_accordion, snitch_stream_module):
def _(pipeline_accordion, snitch_stream_module):
from xdsl.transforms.test_lower_linalg_to_snitch import LOWER_SNITCH_STREAM_TO_ASM_PASSES

snitch_asm_module, snitch_asm_accordion = pipeline_accordion(
Expand All @@ -442,7 +442,7 @@ def __(pipeline_accordion, snitch_stream_module):


@app.cell
def __(k, m, mo, n):
def _(k, m, mo, n):
mo.md(
f"""
We can see how changing our input sizes affects the assembly produced:
Expand All @@ -458,7 +458,7 @@ def __(k, m, mo, n):


@app.cell
def __(asm_html, mo, riscv_code, snitch_asm_module):
def _(asm_html, mo, riscv_code, snitch_asm_module):
snitch_asm = riscv_code(snitch_asm_module)

mo.md(f"""\
Expand All @@ -471,7 +471,7 @@ def __(asm_html, mo, riscv_code, snitch_asm_module):


@app.cell
def __(mo):
def _(mo):
mo.md(
"""
### Interpreting the assembly using xDSL
Expand All @@ -483,7 +483,7 @@ def __(mo):


@app.cell
def __(TypedPtr, a_shape, b_shape, c_shape, ctx, mo, riscv_module):
def _(TypedPtr, a_shape, b_shape, c_shape, ctx, mo, riscv_module):
from math import prod

from xdsl.interpreter import Interpreter, OpCounter
Expand Down Expand Up @@ -532,7 +532,7 @@ def __(TypedPtr, a_shape, b_shape, c_shape, ctx, mo, riscv_module):


@app.cell
def __(
def _(
Interpreter,
OpCounter,
ShapedArray,
Expand Down Expand Up @@ -579,7 +579,7 @@ def __(


@app.cell
def __(k, m, mo, n, riscv_op_counter, snitch_op_counter):
def _(k, m, mo, n, riscv_op_counter, snitch_op_counter):
rv_dict = dict(riscv_op_counter.ops)
sn_dict = dict(snitch_op_counter.ops)

Expand Down Expand Up @@ -666,7 +666,7 @@ def format_row(key: str, *values: str):


@app.cell
def __(ModuleOp, mo):
def _(ModuleOp, mo):
import html as htmllib

def module_html(module: ModuleOp) -> str:
Expand All @@ -684,13 +684,13 @@ def asm_html(asm: str) -> str:


@app.cell
def __():
def _():
from collections import Counter
return (Counter,)


@app.cell
def __(Counter, ModuleOp, ModulePass, PipelinePass, ctx, mo, module_html):
def _(Counter, ModuleOp, ModulePass, PipelinePass, ctx, mo, module_html):
def spec_str(p: ModulePass) -> str:
if isinstance(p, PipelinePass):
return ",".join(str(c.pipeline_pass_spec()) for c in p.passes)
Expand Down
Loading
Loading