Skip to content

Commit 96f523d

Browse files
committed
pp notes
Signed-off-by: Chenyaaang <chenyangli@google.com>
1 parent 922f4aa commit 96f523d

File tree

3 files changed

+13
-6
lines changed

3 files changed

+13
-6
lines changed

tpu_commons/executors/ray_distributed_executor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -120,6 +120,8 @@ def _initialize_ray_cluster(self) -> None:
120120
f"current platform {current_platform.device_name} does not "
121121
"support ray.")
122122

123+
# each node (host) serves as a unit, if 2 hosts, ray only knows 2 hosts
124+
# ray doesn't divide the TPUs inside each host.
123125
placement_group_specs: List[Dict[str, float]] = [{
124126
device_str:
125127
node['Resources'][device_str]
@@ -155,6 +157,8 @@ def _initialize_ray_cluster(self) -> None:
155157

156158
def _init_workers_ray(self, placement_group: "PlacementGroup",
157159
**ray_remote_kwargs):
160+
# placementgroup: "need 2 hosts, 4 chips on each"
161+
# bundle: workers on same host is a bundle.
158162
# The workers are the actual ray actors.
159163
self.workers: List[RayWorkerWrapper] = []
160164

@@ -190,6 +194,9 @@ def _init_workers_ray(self, placement_group: "PlacementGroup",
190194
driver_ip = get_ip()
191195
num_tpu_per_worker = placement_group.bundle_specs[0].get(
192196
current_platform.ray_device_key, 0)
197+
198+
# create a worker per bundle, a bundle is a dict, in my example
199+
# jax has 4 chips inside each bundle (v7x-4).
193200
for rank, bundle_id in enumerate(bundle_indices):
194201
scheduling_strategy = PlacementGroupSchedulingStrategy(
195202
placement_group=placement_group,

tpu_commons/models/jax/model_loader.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -285,7 +285,7 @@ def get_vllm_model(
285285
rng=rng,
286286
mesh=mesh,
287287
)
288-
params, lora_manager = model.load_weights()
288+
params, lora_manager = model.load_weights() #jax
289289

290290
jit_model = model.jit_step_func()
291291
compute_logits_fn = model.jit_compute_logits_func()

tpu_commons/models/vllm/vllm_model_wrapper.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def compute_hidden_state(
6767
inputs_embeds: Optional[torch.Tensor],
6868
) -> torch.Tensor:
6969
hidden_state = self.vllm_model(input_ids, positions,
70-
intermediate_tensors, inputs_embeds)
70+
intermediate_tensors, inputs_embeds) # 这里的output可能是hidden state 也可能是intermediate tensor, 只是因为jax没有pp, 所以都是hidden state了
7171
return hidden_state
7272

7373
def compute_logits(self, hidden_state: torch.Tensor) -> torch.Tensor:
@@ -136,7 +136,7 @@ def load_weights(self):
136136
# Returning to the jax land, so we need to wrap it into a JaxValue.
137137
return jax_view(params_and_buffers), lora_manager
138138

139-
def jit_step_func(self):
139+
def jit_step_func(self): # should also takes in intermediate_tensor.
140140

141141
@functools.partial(
142142
jax.jit,
@@ -175,9 +175,9 @@ def step_fun(
175175
self.model,
176176
torch_view(params_and_buffers),
177177
kwargs={
178-
"input_ids": torch_view(input_ids),
179-
"positions": torch_view(attn_metadata.input_positions),
180-
"intermediate_tensors": None,
178+
"input_ids": torch_view(input_ids), # torch_view(jax.array) -> torchax.tensor
179+
"positions": torch_view(attn_metadata.input_positions), # torch_view(jax.array) -> torchax.tensor
180+
"intermediate_tensors": None, # 应该用这个
181181
"inputs_embeds": None,
182182
},
183183
tie_weights=False,

0 commit comments

Comments
 (0)