-
Notifications
You must be signed in to change notification settings - Fork 10
/
Copy pathchain_strategy.py
47 lines (39 loc) · 1.39 KB
/
chain_strategy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
import torch
from typing import List
from .base_strategy import BaseStrategy
from functools import reduce
class ChainStrategy(BaseStrategy):
def __init__(self, strategies: List[BaseStrategy]):
self.strategies = strategies
def get_keep_index(self) -> int:
return min([st.get_keep_index() for st in self.strategies])
def on_logits(
self, logits: torch.FloatTensor, continuation_tokens: List[int]
) -> torch.FloatTensor:
return reduce(
lambda res, strategy: strategy.on_logits(res, continuation_tokens),
self.strategies,
logits,
)
def on_probs(
self, probs: torch.FloatTensor, continuation_tokens: List[int]
) -> torch.FloatTensor:
return reduce(
lambda res, strategy: strategy.on_probs(res, continuation_tokens),
self.strategies,
probs,
)
def on_next_token(
self, continuation_tokens: List[int], probs: torch.FloatTensor
) -> None:
for stg in self.strategies:
stg.on_next_token(continuation_tokens, probs)
def backtrack(self, continuation_tokens: List[int]) -> List[int]:
return reduce(
lambda res, strategy: strategy.backtrack(res),
self.strategies,
continuation_tokens,
)
def reset(self) -> None:
for stg in self.strategies:
stg.reset()