Skip to content

[Feature]: Way to using LLM's last hidden state embedding vector #5950

@junwoo-ctrl

Description

@junwoo-ctrl

🚀 The feature, motivation and pitch

My suggestion is simple, vllm must support for last hidden state embedding vector.
Here is some example.

Prompt Setting

I fine-tuned my local llm model, with following prompt setting.

##INPUT{"category": "Female Clothes", "brand": "LouisVuitton", "name": ""}\n
##RESULT{"category": "Female Clothes", "brand": "LouisVuitton", "name": "{{goods_name}}"

So, I can try to inference my model with this prompt input.

##INPUT{"category": "Female Perfume", "brand": "Channel": "name": ""}\n
##RESULT

Model Generate completed goods name.

##INPUT{"category": "Female Perfume", "brand": "Channel": "name": ""}\n
##RESULT{"category": "Female Perfume", "brand": "Channel": "name": "Channel CoCo Mademoiselle Queen Test 100ml"}

yeah, this model could generate unseen goods name, even in a plausible way!

Embedding Ideation

Let's re-think about inference prompt input.

##INPUT{"category": "Female Perfume", "brand": "Channel": "name": ""}\n
##RESULT

The hidden state for generating the next token of "##RESULT" will be semantically very closely related to the goods name.
Therefore, if we can assumes that good prompts are prepared, first hidden state vector is very very powerful and useful for Embedding Task.

Alternatives

I have some kind of suggestions for vLLM classes's method. This is just idea, so my code could be refactored!

LLM Class

generate_embedding
  • This method almost similar to generate() method.
def generate_embedding(
        self,
        prompts: Union[Union[PromptStrictInputs, Sequence[PromptStrictInputs]],
                       Optional[Union[str, List[str]]]] = None,
        sampling_params: Optional[Union[SamplingParams,
                                        Sequence[SamplingParams]]] = None,
        prompt_token_ids: Optional[Union[List[int], List[List[int]]]] = None,
        use_tqdm: bool = True,
        lora_request: Optional[Union[List[LoRARequest], LoRARequest]] = None,
    ) -> List[RequestOutput]:
        """
        comments
        """
        if self.llm_engine.model_config.embedding_mode:
            raise ValueError(
                "LLM.generate() is only supported for generation models "
                "(XForCausalLM).")
       ...

        outputs = self._run_engine_embed(use_tqdm=use_tqdm)
        return outputs
run_engine_embed()
  • A step_embed() method doesn't need to execute while in while.
def _run_engine_embed(
            self, *, use_tqdm: bool
    ) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
        step_outputs = self.llm_engine.step_embed()

        return step_outputs

LLMEngine Class

step_embed()
def step_embed(self) -> List[Union[RequestOutput, EmbeddingRequestOutput]]:
        
        seq_group_metadata_list, scheduler_outputs = self.scheduler.schedule()
        if not scheduler_outputs.is_empty():
            execute_model_req = ExecuteModelRequest(
                seq_group_metadata_list=seq_group_metadata_list,
                blocks_to_swap_in=scheduler_outputs.blocks_to_swap_in,
                blocks_to_swap_out=scheduler_outputs.blocks_to_swap_out,
                blocks_to_copy=scheduler_outputs.blocks_to_copy,
                num_lookahead_slots=scheduler_outputs.num_lookahead_slots,
                running_queue_size=scheduler_outputs.running_queue_size,
            )
            output = self.model_executor.execute_embedding(
                execute_model_req=execute_model_req)
        else:
            output = []

        # Log stats.
        self.do_log_stats(scheduler_outputs, [])

        if not request_outputs:
            self.model_executor.stop_remote_worker_execution_loop()
        return output

Scheduler Class

  • We have to change Scheduler class, because of garbage collect for kv-cache.
_schedule_prefills()
def _schedule_prefills(
        self,
        waiting_queue: deque,
        budget: SchedulingBudget,
        curr_loras: Optional[Set[int]],
        enable_chunking: bool = False,
    ) -> Tuple[deque, SchedulerPrefillOutputs]:
        ...
        can_allocate = self.block_manager.can_allocate(seq_group)
            if can_allocate == AllocStatus.LATER:
                self.block_manager.set_initialized_condition()

BlockSpacemanagerV2 Class

set_initialized_condition()
  • set_initialized_condition method initialize kv-cache block store.
def set_initialized_condition(self) -> None:
        # Free all allocated blocks
        for seq_id in list(self.block_tables.keys()):
            self.block_tables[seq_id].free()
            del self.block_tables[seq_id]

        for request_id in list(self.cross_block_tables.keys()):
            self.cross_block_tables[request_id].free()
            del self.cross_block_tables[request_id]

        # Reset the block allocator
        self.block_allocator = CpuGpuBlockAllocator.create(
            allocator_type="prefix_caching" if self.enable_caching else "naive",
            num_gpu_blocks=self.num_total_gpu_blocks,
            num_cpu_blocks=self.num_total_cpu_blocks,
            block_size=self.block_size,
        )

GPUExecutor Class

  • We have to change executor classes, because Executor classes(Interface) doesn't support execute_embedding.
execute_embedding
def execute_embedding(
        self, execute_model_req: ExecuteModelRequest
    ) -> List[Union[SamplerOutput, PoolerOutput]]:
        output = self.driver_worker.execute_embedding(execute_model_req)
        return output

Worker Class

  • We have to change worker classes, because Worker classes(Interface) doesn't support execute_embedding.
execute_embedding()
@torch.inference_mode()
    def execute_embedding(
        self,
        execute_model_req: Optional[ExecuteModelRequest] = None
    ) -> List[Union[SamplerOutput, PoolerOutput]]:
        if not self.is_driver_worker:
            self._execute_model_non_driver()
            return []
        .....
        output = self.model_runner.execute_embedding(seq_group_metadata_list,
                                                 self.gpu_cache)
        return output.detach().cpu().numpy()

ModelRunner Class

  • We have to change ModelRunner classes, because ModelRunner classes(Interface) doesn't support execute_embedding.
execute_embedding
@torch.inference_mode()
    def execute_embedding(
        self,
        seq_group_metadata_list: Optional[List[SequenceGroupMetadata]],
        kv_caches: List[torch.Tensor],
    ) -> Optional[SamplerOutput]:
        (input_tokens, input_positions, attn_metadata, sampling_metadata,
         lora_requests, lora_mapping, multi_modal_kwargs
         ) = self.prepare_input_tensors(seq_group_metadata_list)

        ...

        hidden_states = model_executable(
            input_ids=input_tokens,
            positions=input_positions,
            kv_caches=kv_caches,
            attn_metadata=attn_metadata,
            **multi_modal_kwargs,
        )
        # must calculate next token position.
        idx_of_the_last_non_padding_token = hidden_states.shape[0] - 1
        hidden_states = hidden_states.unsqueeze(0)
        embeddings = hidden_states[torch.arange(hidden_states.shape[0]), idx_of_the_last_non_padding_token]
        return embeddings

Additional context

No response

Metadata

Metadata

Assignees

No one assigned

    Labels

    feature requestNew feature or requeststaleOver 90 days of inactivity

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions