Skip to content

Commit a51aea9

Browse files
committed
Style
1 parent ba3d736 commit a51aea9

File tree

2 files changed

+11
-8
lines changed

2 files changed

+11
-8
lines changed

examples/pytorch/continuous_batching.py

Lines changed: 8 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -55,15 +55,18 @@ def setup_metrics():
5555
from opentelemetry.sdk.trace import TracerProvider
5656
from opentelemetry.sdk.trace.export import BatchSpanProcessor
5757

58-
5958
resource = Resource.create({"service.name": "transformers"})
6059
metrics_exporter = PeriodicExportingMetricReader(
61-
OTLPMetricExporter(endpoint="http://localhost:9090/api/v1/otlp/v1/metrics"), # Uses OTEL_EXPORTER_OTLP_METRICS_ENDPOINT env var
62-
export_interval_millis=1000
60+
OTLPMetricExporter(
61+
endpoint="http://localhost:9090/api/v1/otlp/v1/metrics"
62+
), # Uses OTEL_EXPORTER_OTLP_METRICS_ENDPOINT env var
63+
export_interval_millis=1000,
6364
)
6465
meter_provider = MeterProvider(resource=resource, metric_readers=[metrics_exporter])
6566
metrics.set_meter_provider(meter_provider)
66-
trace_exporter = OTLPSpanExporter(endpoint="http://localhost:4318/v1/traces") # Uses OTEL_EXPORTER_OTLP_TRACES_ENDPOINT env var
67+
trace_exporter = OTLPSpanExporter(
68+
endpoint="http://localhost:4318/v1/traces"
69+
) # Uses OTEL_EXPORTER_OTLP_TRACES_ENDPOINT env var
6770
tracer_provider = TracerProvider(resource=resource)
6871
tracer_provider.add_span_processor(BatchSpanProcessor(trace_exporter))
6972
trace.set_tracer_provider(tracer_provider)
@@ -213,9 +216,7 @@ def batch_generate(
213216
# If no output file is provided, we pick a name based on the args
214217
if args.output_file is None:
215218
os.makedirs("runs/cb", exist_ok=True)
216-
args.output_file = (
217-
f"runs/cb/{args.num_blocks}_{args.max_batch_tokens}_{args.attn}_{args.matmul_precision}_{args.samples}.json"
218-
)
219+
args.output_file = f"runs/cb/{args.num_blocks}_{args.max_batch_tokens}_{args.attn}_{args.matmul_precision}_{args.samples}.json"
219220

220221
# Run warmup batch generation
221222
batch_generate(

src/transformers/generation/continuous_batching/cache.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -274,7 +274,9 @@ def compute_num_blocks_and_max_batch_tokens(
274274
logger.info(f"Cache memory: {cache_memory}")
275275

276276
# Compute memory footprints # TODO: check and explain better
277-
mem_per_activation_token = self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor
277+
mem_per_activation_token = (
278+
self._activation_dtype.itemsize * (self.hidden_size + self.vocab_size) * self._activation_safety_factor
279+
)
278280
mem_per_cache_token = 2 * self.num_heads * self.head_dim * self.num_layers * cache_dtype.itemsize
279281
mem_per_input_token = 8 * m * self._input_dtype.itemsize
280282
logger.info(f"Memory per activation token: {mem_per_activation_token}")

0 commit comments

Comments
 (0)