Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Apple Metal support in tract #1432

Merged
merged 77 commits into from
Sep 9, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
77 commits
Select commit Hold shift + click to select a range
b52a512
Initial work on metal backend for tract using metal flash attention
hubertdelajonquieresonos Jun 7, 2024
28cd74a
add metal in cli
hubertdelajonquieresonos Jun 11, 2024
5749958
Element wise, Binary ops, integration in tract cli
hubertdelajonquieresonos Jun 14, 2024
7741033
Fix clippy
hubertdelajonquieresonos Jun 14, 2024
b378c55
Fix windows CI for tract-metal
hubertdelajonquieresonos Jun 14, 2024
18bf278
Plug runtime test suite and fix issues
hubertdelajonquieresonos Jun 17, 2024
8678338
Fix clippy
hubertdelajonquieresonos Jun 17, 2024
4acaf4b
Some work on broadcast and cast ops
hubertdelajonquieresonos Jun 18, 2024
62a3a1e
WIp on Metal op support
hubertdelajonquieresonos Jun 24, 2024
3aa95a7
Axis OP implementation WIP
hubertdelajonquieresonos Jun 24, 2024
0c394ef
Add tracing to completion of command buffer
hubertdelajonquieresonos Jun 25, 2024
3240167
Fix warning
hubertdelajonquieresonos Jun 25, 2024
13191b4
Try to fix CI
hubertdelajonquieresonos Jun 25, 2024
669d86a
Rewrite MetalTensor and add copy kernel for IntoShape op
hubertdelajonquieresonos Jun 26, 2024
7e70b3d
Extend MatMulOp support, add MultiBroadcastTo and MoveAxis to support…
hubertdelajonquieresonos Jun 28, 2024
fad64b6
Some Reduce ops, fixes and improvemvents
hubertdelajonquieresonos Jul 9, 2024
e939c64
Some code re-organisation
hubertdelajonquieresonos Jul 9, 2024
d1c663a
Add Softmax metal implementation
hubertdelajonquieresonos Jul 10, 2024
c5df01b
Fix warning
hubertdelajonquieresonos Jul 10, 2024
5d829da
Improve softmax kernel
hubertdelajonquieresonos Jul 10, 2024
3271a56
Add Slice implementation on metal
hubertdelajonquieresonos Jul 10, 2024
12d9096
Remove declutter implementation of MetalSlice
hubertdelajonquieresonos Jul 10, 2024
102f44e
Remove debug statements
hubertdelajonquieresonos Jul 10, 2024
8fa3cc5
Fix ios flash attention metal lib path
hubertdelajonquieresonos Jul 10, 2024
85764cc
Concat metal implementation, code improvement, preload array and nnop…
hubertdelajonquieresonos Jul 11, 2024
91a7219
Clippy fixes and code improvement
hubertdelajonquieresonos Jul 11, 2024
436292b
Improve GEMM ops
hubertdelajonquieresonos Jul 11, 2024
c479fe9
Remove commented code
hubertdelajonquieresonos Jul 11, 2024
ce39063
desactivate check in concat
hubertdelajonquieresonos Jul 11, 2024
5f0abd8
Improve concat
hubertdelajonquieresonos Jul 11, 2024
28f7933
Improve Metal tensor api
hubertdelajonquieresonos Jul 12, 2024
2280cb5
Retain Metal tensor inside metal kernels instead of no plan flush
hubertdelajonquieresonos Jul 12, 2024
f50c58b
Fixes following rebase
hubertdelajonquieresonos Jul 12, 2024
96c064b
Add BasicMatMul implementation on metal for iOS
hubertdelajonquieresonos Jul 16, 2024
835723c
exclude test-metal from linux test
hubertdelajonquieresonos Jul 16, 2024
858692c
Fix for CI
hubertdelajonquieresonos Jul 16, 2024
d6e7a78
Fix for CI
hubertdelajonquieresonos Jul 16, 2024
837fd1a
Fix for CI
hubertdelajonquieresonos Jul 16, 2024
fd92716
Fix for CI
hubertdelajonquieresonos Jul 16, 2024
4388090
Fix for CI
hubertdelajonquieresonos Jul 16, 2024
5c13ce0
Fix CI
hubertdelajonquieresonos Jul 16, 2024
daf8734
Fix clippy
hubertdelajonquieresonos Jul 16, 2024
9d6ae05
Fix CI
hubertdelajonquieresonos Jul 17, 2024
48de038
Rewrite Model with dedicated RMS norm op
hubertdelajonquieresonos Jul 17, 2024
20fb51c
Improve CI for macos
hubertdelajonquieresonos Jul 17, 2024
a40dd7a
Fix clippy
hubertdelajonquieresonos Jul 17, 2024
b1b08e1
Fix clippy
hubertdelajonquieresonos Jul 17, 2024
845da13
Improve test
hubertdelajonquieresonos Jul 17, 2024
eb9d387
Change macos instance for github action
hubertdelajonquieresonos Jul 18, 2024
84a9643
Integrate MPS MatMul and enable it for iOS
hubertdelajonquieresonos Jul 18, 2024
71c28c3
Fixes following rebase
hubertdelajonquieresonos Jul 18, 2024
c172388
Silu op rewrite and metal implementation
hubertdelajonquieresonos Jul 19, 2024
a608f63
Add Metal GPU trace capture in tract cli
hubertdelajonquieresonos Jul 19, 2024
bb21dfe
Add op reduce Min/Max
hubertdelajonquieresonos Jul 19, 2024
53584ff
Fix Min/Max reduce
hubertdelajonquieresonos Jul 19, 2024
64b2f6c
Fix CI
hubertdelajonquieresonos Jul 19, 2024
af2e223
Proptest Silu and RMS norm and fixes
hubertdelajonquieresonos Jul 19, 2024
4127f11
rewrite reduce ops
hubertdelajonquieresonos Jul 19, 2024
be775fd
Some reformatting and clippy
hubertdelajonquieresonos Jul 19, 2024
8c0fad4
Some code re-organisation
hubertdelajonquieresonos Jul 19, 2024
3d7062b
Improve CI
hubertdelajonquieresonos Jul 19, 2024
095543a
Refactor code for matmul implementations
hubertdelajonquieresonos Jul 19, 2024
3e92859
Fix clippy
hubertdelajonquieresonos Jul 19, 2024
8c08e33
Improve code
hubertdelajonquieresonos Jul 19, 2024
c80d1cb
Format metal code
hubertdelajonquieresonos Jul 19, 2024
ad6b3d1
Improve code base
hubertdelajonquieresonos Jul 25, 2024
6638445
Remove command buffer capacity
hubertdelajonquieresonos Jul 25, 2024
1c25905
Fix CI following release
hubertdelajonquieresonos Jul 25, 2024
987031c
Improve MetalSync operator for model output
hubertdelajonquieresonos Jul 25, 2024
1b119c6
Fix version
hubertdelajonquieresonos Jul 25, 2024
95c4f1e
Fix name of variable
hubertdelajonquieresonos Jul 26, 2024
30c7d9e
Add more context to metal capture error
hubertdelajonquieresonos Jul 30, 2024
7703cf7
Improve code
hubertdelajonquieresonos Aug 2, 2024
197bc4b
Format code
hubertdelajonquieresonos Aug 2, 2024
99abd5a
Add new gelu rewrite rule
hubertdelajonquieresonos Aug 2, 2024
7f429f0
Integrate MPSVectoreMatrixMultiplication
hubertdelajonquieresonos Sep 6, 2024
ac648eb
Fixes following rebase
hubertdelajonquieresonos Sep 6, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions .github/workflows/beta-and-nightly.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,32 +52,37 @@ jobs:
strategy:
matrix:
rust: [ 1.75.0, beta ]
os: [ ubuntu-latest, macOS-latest ]
fail-fast: false

runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- name: Rustup update
run: rustup update
- name: LLVM setup
run: sudo apt-get install -y llvm
- name: cargo check, -D warnings
if: runner.os == 'Linux'
env:
RUSTFLAGS: "-D warnings"
RUST_VERSION: ${{matrix.rust}}
run: |
rustc --version
cargo check --exclude tract-metal --exclude test-metal --workspace
- name: cargo check, -D warnings
if: runner.os == 'macOS'
env:
RUSTFLAGS: "-D warnings"
RUST_VERSION: ${{matrix.rust}}
run: |
rustc --version
cargo check
cargo check

cargo-clippy:
strategy:
fail-fast: false
matrix:
os: [ ubuntu-latest, macOS-latest ]
rust: [ stable, beta, nightly ]

runs-on: ubuntu-latest

fail-fast: false
steps:
- uses: actions/checkout@v3
- name: Rustup update, install clippy
Expand All @@ -87,7 +92,14 @@ jobs:
env:
RUST_VERSION: ${{matrix.rust}}
- name: Run cargo-clippy
run: cargo clippy
if: runner.os == 'Linux'
run: cargo clippy --exclude tract-metal --exclude test-metal --workspace
env:
RUSTFLAGS: "-D warnings --force-warn unknown_lints"
RUST_VERSION: ${{matrix.rust}}
- name: Run cargo-clippy
if: runner.os == 'macOS'
run: cargo clippy
env:
RUSTFLAGS: "-D warnings --force-warn unknown_lints"
RUST_VERSION: ${{matrix.rust}}
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/full.yml
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ jobs:
- name: Check all targets
run: |
ROOT=$(pwd) ./.travis/ci-system-setup.sh
cargo check --all-targets --workspace
cargo check --all-targets --exclude tract-metal --exclude test-metal --workspace

cli-tests:
runs-on: ubuntu-latest
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/windows.yml
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@ jobs:
- name: debug bin
run: dir "C:\\Program Files\\LLVM\\bin"
- name: top level cargo check
run: cargo check --workspace --exclude test-blas
run: cargo check --workspace --exclude test-blas --exclude tract-metal --exclude test-metal
env:
LIBCLANG_PATH: "C:\\Program Files\\LLVM\\bin"
- name: data / linalg / core / nnef / onnx / onnx-opl
Expand Down
9 changes: 9 additions & 0 deletions .travis/test-published-crates.sh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,15 @@ do
cargo -q test $CARGO_EXTRA -q -p tract-$c
done

if [ `uname` = "Darwin" ]
then
echo
echo "$WHITE ### metal ### $NC"
echo
cargo -q test $CARGO_EXTRA -q -p tract-metal
fi


# doc test are not finding libtensorflow.so
if ! cargo -q test $CARGO_EXTRA -q -p tract-tensorflow --lib $ALL_FEATURES
then
Expand Down
7 changes: 6 additions & 1 deletion .travis/test-rt.sh
Original file line number Diff line number Diff line change
Expand Up @@ -26,8 +26,13 @@ set +x
cd $ROOT
for c in test-rt/test*
do
if [ "$c" != "test-rt/test-tflite" ]
if [ "$c" = "test-rt/test-tflite" ]
then
echo "$WHITE ### $c ### IGNORED $NC"
elif [ "$c" = "test-rt/test-metal" -a `uname` != "Darwin" ]
then
echo "$WHITE ### $c ### IGNORED $NC"
else
echo
echo "$WHITE ### $c ### $NC"
echo
Expand Down
8 changes: 7 additions & 1 deletion Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ members = [
"onnx",
"libcli",
"cli",
"metal",
"extra",

"tflite",
Expand Down Expand Up @@ -47,10 +48,12 @@ members = [
"test-rt/suite-onnx",
"test-rt/test-f16",
"test-rt/test-blas",
"test-rt/test-metal",
"test-rt/test-unit-core",
"test-rt/test-onnx-core",
"test-rt/test-nnef-cycle",
"test-rt/test-tflite",
"test-rt/test-tflite",
"metal"
]

[workspace.dependencies]
Expand All @@ -63,6 +66,7 @@ anyhow = "1.0.43"
approx = "0.5"
bit-set= "0.5.2"
blis-src = { version = "0.2", features = ["static", "pthreads"] }
block = "0.1.6"
boow = "0.1.3"
box_drawing = "0.1.2"
byteorder = "1.4.3"
Expand All @@ -80,6 +84,7 @@ dyn-hash = "0.2"
env_logger = "0.10"
flatbuffers = "23.1.21"
flate2 = "1.0.20"
foreign-types = "0.5"
fs-err = "2"
fs2 = "0.4.3"
getrandom = "0.2"
Expand All @@ -93,6 +98,7 @@ liquid-core = "0.26"
log = "0.4.14"
maplit = "1.0.2"
memmap2 = "0.9"
metal = { version = "0.27.0", features = ["mps"] }
ndarray = "0.15.3"
ndarray-npy = { version = "0.8.0", features = [ "compressed_npz" ] }
nom = "7.0.0"
Expand Down
2 changes: 2 additions & 0 deletions api/rs/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ use tract_nnef::prelude::{
use tract_onnx::prelude::InferenceModelExt;
use tract_onnx_opl::WithOnnx;
use tract_pulse::model::{PulsedModel, PulsedModelExt};
use tract_pulse::internal::PlanOptions;
use tract_pulse::WithPulse;

use tract_api::*;
Expand Down Expand Up @@ -273,6 +274,7 @@ impl ModelInterface for Model {
&self.0,
&BenchLimits::default(),
&mut annotations,
&PlanOptions::default(),
&inputs,
None,
true,
Expand Down
4 changes: 4 additions & 0 deletions cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,10 @@ tract-onnx = { optional = true, version = "=0.21.7-pre", path = "../onnx" }
tract-tensorflow = { optional = true, version = "=0.21.7-pre", path = "../tensorflow" }
tract-tflite = { optional = true, version = "=0.21.7-pre", path = "../tflite" }

[target.'cfg(any(target_os = "macos", target_os = "ios"))'.dependencies]
tract-metal = { version = "=0.21.7-pre", path = "../metal" }


[features]
default = ["onnx", "tf", "pulse", "pulse-opl", "tflite", "extra"]
apple-amx-ios = [ "tract-linalg/apple-amx-ios" ]
Expand Down
4 changes: 4 additions & 0 deletions cli/src/dump.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use crate::plan_options::plan_options_from_subcommand;
use crate::params::SomeGraphDef;
use crate::tensor::run_params_from_subcommand;
use crate::Parameters;
Expand Down Expand Up @@ -113,6 +114,8 @@ pub fn handle(
tract_libcli::profile::extract_costs(&mut annotations, model, &run_params.symbols)?;
}
if options.profile {
let run_params = run_params_from_subcommand(params, sub_matches)?;
let plan_options = plan_options_from_subcommand(sub_matches)?;
let model = params
.tract_model
.downcast_ref::<TypedModel>()
Expand All @@ -122,6 +125,7 @@ pub fn handle(
model,
bench_limits,
&mut annotations,
&plan_options,
&inputs[0],
None,
options.folded,
Expand Down
33 changes: 32 additions & 1 deletion cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -29,6 +29,7 @@ mod params;
mod run;
#[cfg(feature = "pulse")]
mod stream_check;
mod plan_options;
mod tensor;
mod utils;

Expand Down Expand Up @@ -131,6 +132,8 @@ fn main() -> TractResult<()> {

.arg(Arg::new("f32-to-f16").long("f32-to-f16").alias("half-floats").long_help("Convert the decluttered network from f32 to f16"))
.arg(arg!(--"f16-to-f32" "Convert the decluttered network from f16 to f32"))
.arg(arg!(--"metal" "Convert metal compatible operator in the decluttered network. Only available on MacOS and iOS"))
.arg(Arg::new("metal-gpu-trace").long("metal-gpu-trace").takes_value(true).help("Capture Metal GPU trace at given path. Only available on MacOS and iOS"))
.arg(Arg::new("transform").short('t').long("transform").multiple_occurrences(true).takes_value(true).help("Apply a built-in transformation to the model"))
.arg(Arg::new("set").long("set").multiple_occurrences(true).takes_value(true)
.long_help("Set a symbol to a concrete value after decluttering"))
Expand Down Expand Up @@ -282,7 +285,31 @@ fn main() -> TractResult<()> {
env_logger::Builder::from_env(env).format_timestamp_nanos().init();
info_usage("init", probe.as_ref());

if let Err(e) = handle(matches, probe.as_ref()) {
let res = if matches.is_present("metal-gpu-trace") {
#[cfg(any(target_os = "macos", target_os = "ios"))]
{
let gpu_trace_path = std::path::Path::new(matches.value_of("metal-gpu-trace").unwrap()).to_path_buf();
ensure!(gpu_trace_path.is_absolute(), "Metal GPU trace file has to be absolute");
ensure!(!gpu_trace_path.exists(), format!("Given Metal GPU trace file {:?} already exists.", gpu_trace_path));
log::info!("Capturing Metal GPU trace at : {:?}", gpu_trace_path);
std::env::set_var("METAL_CAPTURE_ENABLED", "1");
std::env::set_var("METAL_DEVICE_WRAPPER_TYPE", "1");
let probe_ref = probe.as_ref();
tract_metal::METAL_CONTEXT.with_borrow(move |context| {
context.capture_trace(gpu_trace_path, move |_ctxt| {
handle(matches, probe_ref)
})
})
}
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
{
bail!("`--metal-gpu-trace` present but it is only available on MacOS and iOS")
}
} else {
handle(matches, probe.as_ref())
};

if let Err(e) = res {
error!("{:?}", e);
std::process::exit(1);
}
Expand Down Expand Up @@ -445,6 +472,10 @@ fn run_options(command: clap::Command) -> clap::Command {
"Path to a directory containing input tensors in NNEF format (.dat files). This sets tensor values.",
),
)
.arg(Arg::new("skip-order-opt-ram")
.long("skip-order-opt-ram")
.help("Plan node evaluation order without RAM optimisation"),
)
.arg(
Arg::new("allow-random-input")
.short('R')
Expand Down
18 changes: 17 additions & 1 deletion cli/src/params.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ use tract_core::ops::konst::Const;
use tract_itertools::Itertools;
use tract_libcli::profile::BenchLimits;
use tract_nnef::tensors::read_tensor;

use tract_core::transform::ModelTransform;
use tract_core::internal::*;
use tract_core::model::TypedModel;
use tract_hir::internal::*;
Expand Down Expand Up @@ -741,6 +741,22 @@ impl Parameters {
tract_core::floats::FloatPrecisionTranslator::<f16, f32>::default().translate_model(&m)
});
}
{
if matches.is_present("metal") {
#[cfg(any(target_os = "macos", target_os = "ios"))]
{
stage!("metal", typed_model -> typed_model, |m:TypedModel| {
tract_metal::transform::MetalTransform
.transform_into(&m)
});
}
#[cfg(not(any(target_os = "macos", target_os = "ios")))]
{
bail!("`--metal` present but it is only available on MacOS and iOS")
}
}
}

if let Some(transform) = matches.values_of("transform") {
for transform in transform {
stage!(transform, typed_model -> typed_model, |m:TypedModel| {
Expand Down
12 changes: 12 additions & 0 deletions cli/src/plan_options.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
use tract_core::internal::*;


pub fn plan_options_from_subcommand(
sub_matches: &clap::ArgMatches,
) -> TractResult<PlanOptions> {
let skip_order_opt_ram: bool = sub_matches.is_present("skip-order-opt-ram");

let mut options = PlanOptions::default();
options.skip_order_opt_ram = skip_order_opt_ram;
Ok(options)
}
5 changes: 3 additions & 2 deletions cli/src/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -142,9 +142,10 @@ pub fn handle(
fn run_regular(
tract: &dyn Model,
run_params: &RunParams,
_matches: &clap::ArgMatches,
matches: &clap::ArgMatches,
sub_matches: &clap::ArgMatches,
) -> TractResult<TVec<Vec<TValue>>> {
let plan_options = crate::plan_options::plan_options_from_subcommand(matches)?;
let steps = sub_matches.is_present("steps");
let check_f16_overflow = sub_matches.is_present("check-f16-overflow");
let assert_sane_floats = sub_matches.is_present("assert-sane-floats");
Expand All @@ -155,7 +156,7 @@ fn run_regular(
None
};
dispatch_model!(tract, |m| {
let plan = SimplePlan::new(m)?;
let plan = SimplePlan::new_with_options(m, &plan_options)?;
let mut state = SimpleState::new(plan)?;
let inputs = tract_libcli::tensor::retrieve_or_make_inputs(tract, run_params)?;
let mut results = tvec!(vec!(); state.model().outputs.len());
Expand Down
37 changes: 21 additions & 16 deletions core/src/model/fact.rs
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,17 @@ impl TypedFact {
}

pub fn without_value(&self) -> Self {
Self::dt_shape(self.datum_type, self.shape.clone())
TypedFact {
datum_type: self.datum_type,
shape: self.shape.clone(),
konst: None,
uniform: None,
opaque_fact: self.opaque_fact.clone(),
}
}

pub fn with_opaque_metadata<O: Into<Box<dyn OpaqueFact>>>(
mut self,
opaque_metadata: O,
) -> Self {
self.opaque_fact = Some(opaque_metadata.into());
pub fn with_opaque_fact<O: Into<Box<dyn OpaqueFact>>>(mut self, opaque_fact: O) -> Self {
self.opaque_fact = Some(opaque_fact.into());
self
}
}
Expand Down Expand Up @@ -395,16 +398,18 @@ impl<'a> From<&'a Arc<Tensor>> for TypedFact {

impl fmt::Debug for TypedFact {
fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
if let Some(k) = &self.konst {
write!(fmt, "{k:?}")?
} else if self.rank() > 0 {
write!(fmt, "{:?},{:?}", self.shape, self.datum_type)?
} else {
write!(fmt, "{:?}", self.datum_type)?
};
if let Some(of) = &self.opaque_fact {
write!(fmt, " {:?}", of)?;
}
match (self.konst.as_ref(), self.opaque_fact.as_ref()) {
(Some(ref k), None) => write!(fmt, "{k:?}"),
(Some(ref k), Some(meta)) => write!(fmt, "{meta:?} {k:?}"),
(None, None) if self.rank() > 0 => {
write!(fmt, "{:?},{:?}", self.shape, self.datum_type)
}
(None, Some(ref meta)) if self.rank() > 0 => {
write!(fmt, "{:?},{:?},{:?}", self.shape, self.datum_type, meta)
}
(None, Some(ref meta)) => write!(fmt, "{:?}, {:?}", self.datum_type, meta),
(None, None) => write!(fmt, "{:?}", self.datum_type),
}?;
Ok(())
}
}
Expand Down
Loading
Loading