Skip to content

Commit 1b2368f

Browse files
Clean up logs (#398)
Signed-off-by: Xiang Xu <xiangxu@google.com>
1 parent 314adcc commit 1b2368f

File tree

5 files changed

+6
-25
lines changed

5 files changed

+6
-25
lines changed

README.md

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,6 @@ Run `Llama 3.1 8B` offline inference on 4 TPU chips:
4040
HF_TOKEN=<huggingface_token> python tpu_commons/examples/offline_inference.py \
4141
--model=meta-llama/Llama-3.1-8B \
4242
--tensor_parallel_size=4 \
43-
--task=generate \
4443
--max_model_len=1024
4544
```
4645

@@ -51,7 +50,6 @@ Run `Llama 3.1 8B Instruct` offline inference on 4 TPU chips in disaggregated mo
5150
```
5251
PREFILL_SLICES=2 DECODE_SLICES=2 HF_TOKEN=<huggingface_token> \
5352
python tpu_commons/examples/offline_inference.py \
54-
--task=generate \
5553
--model=meta-llama/Meta-Llama-3-8B-Instruct \
5654
--max_model_len=1024 \
5755
--max_num_seqs=8
@@ -80,7 +78,6 @@ Run `Llama 3.1 70B Instruct` offline inference on 4 hosts (v6e-16) in interleave
8078
HF_TOKEN=<huggingface_token> python /workspace/tpu_commons/examples/offline_inference.py \
8179
--model=meta-llama/Llama-3.1-70B \
8280
--tensor_parallel_size=16 \
83-
--task=generate \
8481
--max_model_len=1024
8582
```
8683

@@ -94,7 +91,6 @@ export HF_TOKEN=<huggingface_token>
9491
python tpu_commons/examples/offline_inference.py \
9592
--model=meta-llama/Llama-3.1-8B \
9693
--tensor_parallel_size=4 \
97-
--task=generate \
9894
--max_model_len=1024
9995
```
10096

@@ -106,7 +102,6 @@ export HF_TOKEN=<huggingface_token>
106102
python vllm/examples/offline_inference/basic/generate.py \
107103
--model=Qwen/Qwen3-30B-A3B \
108104
--tensor_parallel_size=4 \
109-
--task=generate \
110105
--max_model_len=1024 \
111106
--enable-expert-parallel
112107
```
@@ -205,7 +200,6 @@ docker run \
205200
python /workspace/tpu_commons/examples/offline_inference.py \
206201
--model=meta-llama/Llama-3.1-8B \
207202
--tensor_parallel_size=4 \
208-
--task=generate \
209203
--max_model_len=1024 \
210204
```
211205

tpu_commons/models/jax/model_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -220,7 +220,7 @@ def get_model(
220220
mesh: Mesh,
221221
) -> Any:
222222
impl = os.getenv("MODEL_IMPL_TYPE", "flax_nnx").lower()
223-
logger.info(f"Loading model, implementation type={impl}")
223+
logger.info(f"Loading model with MODEL_IMPL_TYPE={impl}")
224224
if impl == "flax_nnx":
225225
return get_flax_model(vllm_config, rng, mesh)
226226
elif impl == "vllm":

tpu_commons/models/jax/utils/weight_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,7 +57,7 @@ def hf_model_weights_iterator(
5757
weights_files = []
5858
weights_location = "local"
5959
if os.path.isdir(model_name_or_path):
60-
logger.info(f"Loading weights locally from: {model_name_or_path}")
60+
logger.info(f"Found weights from local: {model_name_or_path}")
6161
weights_files = glob.glob(
6262
os.path.join(model_name_or_path, HF_WEIGHTS_FORMAT))
6363
elif file_utils.is_gcs_path(model_name_or_path):

tpu_commons/platforms/tpu_jax.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def get_attn_backend_cls(cls, selected_backend: _Backend, head_size: int,
6868

6969
@classmethod
7070
def get_device_name(cls, device_id: int = 0) -> str:
71+
logger.info(jax.lib.xla_bridge.get_backend().platform_version)
7172
try:
7273
if envs.VLLM_TPU_USING_PATHWAYS:
7374
return jax.local_devices()[0].device_kind
@@ -174,14 +175,14 @@ def check_and_update_config(cls, vllm_config: VllmConfig) -> None:
174175

175176
multihost_backend = os.environ.get("TPU_MULTIHOST_BACKEND", "").lower()
176177
if not multihost_backend: # Single host
177-
logger.warning(
178-
"JAX requires to use uniproc_executor for single host.")
178+
logger.info("Force using UniProcExecutor for JAX on single host.")
179179
parallel_config.distributed_executor_backend = "uni"
180180
elif multihost_backend == "ray":
181181
from tpu_commons.executors.ray_distributed_executor import \
182182
RayDistributedExecutor
183183
parallel_config.distributed_executor_backend = RayDistributedExecutor
184-
logger.info("Using Ray as the TPU multihost backend. ")
184+
logger.info(
185+
"Force using RayDistributedExecutor for JAX on single host.")
185186
else:
186187
logger.warning(
187188
f"Unknown TPU multihost backend: {multihost_backend}. "

tpu_commons/runner/jax/tpu_jax_runner.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,6 @@ def __init__(
8484

8585
self.maybe_forbid_compile = runner_utils.ForbidCompile(
8686
) if envs.VLLM_XLA_CHECK_RECOMPILATION else nullcontext()
87-
logger.info("TPUModelRunner created!")
8887

8988
def _verify_chunked_prefill_config(self):
9089
if (self.scheduler_config.max_num_batched_tokens
@@ -106,9 +105,6 @@ def _init_mesh(self) -> None:
106105
sharding_strategy = \
107106
self.vllm_config.additional_config["sharding"]["sharding_strategy"]
108107
except KeyError:
109-
logger.warning(
110-
f"No sharding strategy passed! Using default of full model parallelism={len(self.devices)}"
111-
)
112108
sharding_strategy = {"tensor_parallelism": len(self.devices)}
113109

114110
if os.getenv("NEW_MODEL_DESIGN", False):
@@ -120,20 +116,12 @@ def _init_mesh(self) -> None:
120116
try:
121117
dp = sharding_strategy["data_parallelism"]
122118
except KeyError:
123-
logger.warning(
124-
"No data parallelism passed! Using default value of 1")
125119
dp = 1
126-
127120
try:
128121
tp = sharding_strategy["tensor_parallelism"]
129122
except KeyError:
130-
logger.warning(
131-
f"No tensor parallelism passed! Using default value of {len(self.devices)}"
132-
)
133123
tp = len(self.devices)
134124

135-
tp = sharding_strategy["tensor_parallelism"]
136-
137125
axis_names = ("data", "model")
138126
mesh_shape = (dp, tp)
139127

@@ -277,8 +265,6 @@ def initialize_kv_cache(self, kv_cache_config: KVCacheConfig) -> None:
277265
if has_kv_transfer_group():
278266
get_kv_transfer_group().register_kv_caches(self.kv_caches)
279267

280-
logger.info(jax.lib.xla_bridge.get_backend().platform_version)
281-
282268
def _precompile_backbone(self) -> None:
283269
for num_tokens in self.num_tokens_paddings:
284270
input_ids = np.ones((num_tokens, ), dtype=np.int32)

0 commit comments

Comments
 (0)