|
11 | 11 |
|
12 | 12 | # We centralize the logger here to coordinate between logging and progress bar |
13 | 13 | logger = logging.getLogger("ContinuousBatchingLogger") |
| 14 | +logger.setLevel(logging.INFO) |
14 | 15 |
|
15 | 16 |
|
16 | 17 | @staticmethod |
@@ -102,12 +103,35 @@ class RequestState: |
102 | 103 | static_outputs: list[int] = field(default_factory=list) # Generated tokens |
103 | 104 | allocated_blocks: list[int] = field(default_factory=list) # Block IDs allocated to the request |
104 | 105 | position_offset: int = 0 # Current position in the sequence for position_ids |
105 | | - status: RequestStatus = RequestStatus.PENDING # Status of the request |
| 106 | + _status: RequestStatus = RequestStatus.PENDING # Status of the request, hidden behind a property |
106 | 107 | max_new_tokens: int = 20 # Maximum number of new tokens to generate |
107 | 108 | eos_token_id: int = -1 # ID of the end-of-sequence token |
108 | 109 | created_time: float = field(default_factory=time.time) # Time the request was created |
109 | 110 | error: Optional[str] = None # Error message if the request failed |
110 | 111 | next_token: Optional[str] = None # Next token to be generated |
| 112 | + lifespan: tuple[float, float] = (-1, -1) # (time request was no longer pending, time request finished) |
| 113 | + |
| 114 | + @property |
| 115 | + def status(self) -> RequestStatus: |
| 116 | + return self._status |
| 117 | + |
| 118 | + @status.setter |
| 119 | + def status(self, value: RequestStatus): |
| 120 | + if self._status == RequestStatus.PENDING: |
| 121 | + self.lifespan = (time.time(), -1) |
| 122 | + elif value == RequestStatus.FINISHED: |
| 123 | + self.lifespan = (self.lifespan[0], time.time()) |
| 124 | + self.log_end_of_request() |
| 125 | + self._status = value |
| 126 | + |
| 127 | + def log_end_of_request(self): |
| 128 | + prefill_len = len(self.full_prompt_ids) |
| 129 | + decode_len = self.generated_len() |
| 130 | + start_time = self.lifespan[0] - self.created_time |
| 131 | + end_time = self.lifespan[1] - self.created_time |
| 132 | + logger.info( |
| 133 | + f"Request {self.request_id} finished: {prefill_len = } {decode_len = } {start_time = } {end_time = }" |
| 134 | + ) |
111 | 135 |
|
112 | 136 | def current_len(self) -> int: |
113 | 137 | """Get the current length of the sequence (prompt + generated tokens).""" |
@@ -148,7 +172,7 @@ def update_with_token(self, token_id: int) -> bool: |
148 | 172 | def __repr__(self): |
149 | 173 | msg = [ |
150 | 174 | f"request_id={self.request_id}", |
151 | | - f"status={self.status}", |
| 175 | + f"status={self._status}", |
152 | 176 | f"out_tokens={self.generated_len()}", |
153 | 177 | f"query_length={len(self.prompt_ids)}", |
154 | 178 | f"remaining_tokens={len(self.remaining_prompt_ids)}", |
|
0 commit comments