@@ -53,29 +53,37 @@ def update_state(self, batch_update: Optional[BatchUpdate]):
5353 # Process added requests.
5454 for index , params , _ , _ in batch_update .added :
5555 min_p = params .min_p
56- if self .min_p_cpu [index ] != min_p :
56+ min_p_before = self .min_p_cpu [index ]
57+ if min_p_before != min_p :
5758 needs_update = True
5859 self .min_p_cpu [index ] = min_p
59- if min_p :
60- self .min_p_count += 1
60+ if min_p and not min_p_before :
61+ self .min_p_count += 1
62+ elif not min_p and min_p_before :
63+ self .min_p_count -= 1
6164
6265 if self .min_p_count :
6366 # Process removed requests.
64- needs_update |= bool (batch_update .removed )
65- for index in batch_update .removed :
66- if self .min_p_cpu [index ]:
67- self .min_p_count -= 1
67+ if batch_update .removed :
68+ needs_update = True
69+ for index in batch_update .removed :
70+ if self .min_p_cpu [index ]:
71+ self .min_p_cpu [index ] = 0
72+ self .min_p_count -= 1
6873
69- # Process moved requests, unidirectional (a->b) and swap (a<->b)
74+ # Process moved requests, unidirectional (a->b) and swap (a<->b).
7075 for adx , bdx , direct in batch_update .moved :
71- change = (min_p_a :=
72- self .min_p_cpu [adx ]) != (min_p_b :=
73- self .min_p_cpu [bdx ])
74- needs_update |= change
75- if change :
76+ min_p_a , min_p_b = self .min_p_cpu [adx ], self .min_p_cpu [bdx ]
77+ if min_p_a != min_p_b :
78+ needs_update = True
7679 self .min_p_cpu [bdx ] = min_p_a
7780 if direct == MoveDirectionality .SWAP :
7881 self .min_p_cpu [adx ] = min_p_b
82+ if direct == MoveDirectionality .UNIDIRECTIONAL :
83+ if min_p_a :
84+ self .min_p_cpu [adx ] = 0
85+ if min_p_b :
86+ self .min_p_count -= 1
7987
8088 # Update tensors if needed.
8189 size = batch_update .batch_size
0 commit comments