4
4
import litellm
5
5
import openai
6
6
from litellm import acompletion
7
- # from openai.error import APIError, RateLimitError, Timeout, TryAgain
8
- from openai import APIError , RateLimitError , Timeout
9
- from retry import retry
7
+ from tenacity import retry , retry_if_exception_type , stop_after_attempt
10
8
from pr_agent .algo .ai_handlers .base_ai_handler import BaseAiHandler
11
9
from pr_agent .config_loader import get_settings
12
10
from pr_agent .log import get_logger
@@ -28,7 +26,8 @@ def __init__(self):
28
26
"""
29
27
self .azure = False
30
28
self .aws_bedrock_client = None
31
-
29
+ self .api_base = None
30
+ self .repetition_penalty = None
32
31
if get_settings ().get ("OPENAI.KEY" , None ):
33
32
openai .api_key = get_settings ().openai .key
34
33
litellm .openai_key = get_settings ().openai .key
@@ -57,8 +56,11 @@ def __init__(self):
57
56
litellm .replicate_key = get_settings ().replicate .key
58
57
if get_settings ().get ("HUGGINGFACE.KEY" , None ):
59
58
litellm .huggingface_key = get_settings ().huggingface .key
60
- if get_settings ().get ("HUGGINGFACE.API_BASE" , None ):
61
- litellm .api_base = get_settings ().huggingface .api_base
59
+ if get_settings ().get ("HUGGINGFACE.API_BASE" , None ) and 'huggingface' in get_settings ().config .model :
60
+ litellm .api_base = get_settings ().huggingface .api_base
61
+ self .api_base = get_settings ().huggingface .api_base
62
+ if get_settings ().get ("HUGGINGFACE.REPITITION_PENALTY" , None ):
63
+ self .repetition_penalty = float (get_settings ().huggingface .repetition_penalty )
62
64
if get_settings ().get ("VERTEXAI.VERTEX_PROJECT" , None ):
63
65
litellm .vertex_project = get_settings ().vertexai .vertex_project
64
66
litellm .vertex_location = get_settings ().get (
@@ -78,8 +80,10 @@ def deployment_id(self):
78
80
"""
79
81
return get_settings ().get ("OPENAI.DEPLOYMENT_ID" , None )
80
82
81
- @retry (exceptions = (APIError , Timeout , AttributeError , RateLimitError ),
82
- tries = OPENAI_RETRIES , delay = 2 , backoff = 2 , jitter = (1 , 3 ))
83
+ @retry (
84
+ retry = retry_if_exception_type ((openai .APIError , openai .APIConnectionError , openai .Timeout )), # No retry on RateLimitError
85
+ stop = stop_after_attempt (OPENAI_RETRIES )
86
+ )
83
87
async def chat_completion (self , model : str , system : str , user : str , temperature : float = 0.2 ):
84
88
try :
85
89
resp , finish_reason = None , None
@@ -93,28 +97,39 @@ async def chat_completion(self, model: str, system: str, user: str, temperature:
93
97
"messages" : messages ,
94
98
"temperature" : temperature ,
95
99
"force_timeout" : get_settings ().config .ai_timeout ,
100
+ "api_base" : self .api_base ,
96
101
}
97
102
if self .aws_bedrock_client :
98
103
kwargs ["aws_bedrock_client" ] = self .aws_bedrock_client
104
+ if self .repetition_penalty :
105
+ kwargs ["repetition_penalty" ] = self .repetition_penalty
99
106
100
107
get_logger ().debug ("Prompts" , artifact = {"system" : system , "user" : user })
108
+
109
+ if get_settings ().config .verbosity_level >= 2 :
110
+ get_logger ().info (f"\n System prompt:\n { system } " )
111
+ get_logger ().info (f"\n User prompt:\n { user } " )
112
+
101
113
response = await acompletion (** kwargs )
102
- except (APIError , Timeout ) as e :
114
+ except (openai . APIError , openai . Timeout ) as e :
103
115
get_logger ().error ("Error during OpenAI inference: " , e )
104
116
raise
105
- except (RateLimitError ) as e :
117
+ except (openai . RateLimitError ) as e :
106
118
get_logger ().error ("Rate limit error during OpenAI inference: " , e )
107
119
raise
108
120
except (Exception ) as e :
109
121
get_logger ().error ("Unknown error during OpenAI inference: " , e )
110
- raise APIError from e
122
+ raise openai . APIError from e
111
123
if response is None or len (response ["choices" ]) == 0 :
112
- raise APIError
124
+ raise openai . APIError
113
125
else :
114
126
resp = response ["choices" ][0 ]['message' ]['content' ]
115
127
finish_reason = response ["choices" ][0 ]["finish_reason" ]
116
128
# usage = response.get("usage")
117
129
get_logger ().debug (f"\n AI response:\n { resp } " )
118
130
get_logger ().debug ("Full_response" , artifact = response )
119
131
132
+ if get_settings ().config .verbosity_level >= 2 :
133
+ get_logger ().info (f"\n AI response:\n { resp } " )
134
+
120
135
return resp , finish_reason
0 commit comments