Skip to content

Commit

Permalink
Clean up CLI check-mem-arena code
Browse files Browse the repository at this point in the history
  • Loading branch information
aaronjonessonos committed Mar 7, 2025
1 parent 39a7797 commit ff5d27f
Showing 1 changed file with 74 additions and 49 deletions.
123 changes: 74 additions & 49 deletions cli/src/check_mem_arena.rs
Original file line number Diff line number Diff line change
@@ -1,16 +1,81 @@

use tract_hir::internal::*;
use std::collections::HashMap;
use serde::{Deserialize, Serialize};
use tract_metal::memory::MetalMemSchema;

pub fn verify_size_and_usage(model: &TypedModel, options: &PlanOptions, path: impl AsRef<std::path::Path>) -> TractResult<()> {
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
struct MemArenaUsage {
peak_memory_size: i64,
peak_memory_usage: f32,
}

impl MemArenaUsage {
pub fn eval_from_schema(
schema: &MetalMemSchema,
symbol_values: &SymbolValues
) -> TractResult<Self> {
Ok(Self {
peak_memory_size: schema.eval_peak_memory_size(&symbol_values)?,
peak_memory_usage: schema.eval_usage(&symbol_values)?,
})
}
}

#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
struct MemArenaMetrics {
memory_size: String,
size_by_partition: Vec<String>,
pp: HashMap<i64, MemArenaUsage>,
tg: HashMap<i64, MemArenaUsage>,
}

impl MemArenaMetrics {
pub fn from_schema(schema: &MetalMemSchema) -> TractResult<Self> {
log::info!("Analyzing Metal memory schema utilization...");
const MAX_GEN_TOKENS: i64 = 2048;
const MAX_PROMPT_TOKENS: i64 = 2048;

const STEP_TOKENS: i64 = 16;
let memory_size: String = schema.memory_size().to_string();
let size_by_partition: Vec<String> = schema
.size_by_partition()
.iter()
.map(|it| it.to_string())
.collect();
let symbol_scope = SymbolScope::default();
let sequence_length = symbol_scope.sym("S");
let past_sequence_length = symbol_scope.sym("P");

let mut pp = HashMap::new();
for s in (STEP_TOKENS..MAX_PROMPT_TOKENS+1).step_by(STEP_TOKENS as usize) {
log::info!("Prompt processing: P: 0, S: {}", s);
let symbol_values = SymbolValues::default()
.with(&sequence_length, s)
.with(&past_sequence_length, 0);
pp.insert(s, MemArenaUsage::eval_from_schema(&schema, &symbol_values)?);
}
let mut tg = HashMap::new();
for p in (0..MAX_GEN_TOKENS+1).step_by(STEP_TOKENS as usize) {
log::info!("Token generation: P: {}, S: 1", p);
let symbol_values = SymbolValues::default()
.with(&sequence_length, 1)
.with(&past_sequence_length, p);
tg.insert(p, MemArenaUsage::eval_from_schema(&schema, &symbol_values)?);
}
Ok(Self { memory_size, size_by_partition, pp, tg })
}
}

pub fn verify_size_and_usage(
model: &TypedModel,
options: &PlanOptions,
path: impl AsRef<std::path::Path>
) -> TractResult<()> {
log::info!("Analyzing Metal memory schema utilization...");
const SCHEMA_HINT_S: i64 = 1024;
const SCHEMA_HINT_P: i64 = 0;

const MAX_GEN_TOKENS: i64 = 2048;
const MAX_PROMPT_TOKENS: i64 = 2048;

const STEP_TOKENS: i64 = 16;

let plan = SimplePlan::new_with_options(model, options)?;
let order = plan.order_without_consts();
let mut symbol_values = SymbolValues::default();
Expand All @@ -32,49 +97,9 @@ pub fn verify_size_and_usage(model: &TypedModel, options: &PlanOptions, path: im
&symbol_values,
)?;

let size_by_partition: Vec<String> = schema
.size_by_partition()
.iter()
.map(|it| format!("\"{}\"", it))
.collect();
let mut result: String = format!(
"{{\n\"memory_size\": \"{}\",\n\"size_by_partition\": [{}],\n\"pp\": {{",
schema.memory_size(),
size_by_partition.join(",\n"),
);
for s in (STEP_TOKENS..MAX_PROMPT_TOKENS+1).step_by(STEP_TOKENS as usize) {
log::info!("Prompt processing: P: 0, S: {}", s);
symbol_values.set(&sequence_length, s);
symbol_values.set(&past_sequence_length, 0);
if s > STEP_TOKENS {
result += ",";
}
result += &format!(
"\n\"{}\": {{\n\"peak_memory_size\": {},\n\"peak_memory_usage\": {}\n}}",
s,
schema.eval_peak_memory_size(&symbol_values)?,
schema.eval_usage(&symbol_values)?,
);
}
result += "\n},\n\"tg\": {";
for p in (0..MAX_GEN_TOKENS+1).step_by(STEP_TOKENS as usize) {
if p % STEP_TOKENS == 0 {
log::info!("Token generation: P: {}, S: 1", p);
}
symbol_values.set(&sequence_length, 1);
symbol_values.set(&past_sequence_length, p);
if p > 0 {
result += ",";
}
result += &format!(
"\n\"{}\": {{\n\"peak_memory_size\": {},\n\"peak_memory_usage\": {}\n}}",
p,
schema.eval_peak_memory_size(&symbol_values)?,
schema.eval_usage(&symbol_values)?,
);
}
result += "\n}\n}\n";
std::fs::write(path.as_ref(), result).expect("Unable to write file");
let metrics = MemArenaMetrics::from_schema(&schema)?;

std::fs::write(path.as_ref(), serde_json::to_string(&metrics)?).expect("Unable to write file");

Ok(())
}

0 comments on commit ff5d27f

Please sign in to comment.