diff --git a/lwm/llama.py b/lwm/llama.py index 8a8871c..0e957f2 100644 --- a/lwm/llama.py +++ b/lwm/llama.py @@ -571,10 +571,10 @@ def __call__( platform = xla_bridge.get_backend().platform if platform == "tpu": - logging.info(f"Using fused attention for {platform}") + logger.info(f"Using fused attention for {platform}") ring_attention_fn = ring_flash_attention_tpu else: - logging.info(f"Fused attention is not yet supported for {platform}, using non-fused version") + logger.info(f"Fused attention is not yet supported for {platform}, using non-fused version") ring_attention_fn = ring_attention # uses BPT attention ring_attention_sharded = shard_map( partial(