1
+ /* ***********************************************************************************
2
+ * Copyright (c) 2023, xeus-cpp contributors *
3
+ * Copyright (c) 2023, Johan Mabille, Loic Gouarin, Sylvain Corlay, Wolf Vollprecht *
4
+ * *
5
+ * Distributed under the terms of the BSD 3-Clause License. *
6
+ * *
7
+ * The full license is in the file LICENSE, distributed with this software. *
8
+ ************************************************************************************/
9
+ #include " xassist.hpp"
10
+
11
+ #define CURL_STATICLIB
12
+ #include < curl/curl.h>
13
+ #include < fstream>
14
+ #include < iostream>
15
+ #include < nlohmann/json.hpp>
16
+ #include < string>
17
+ #include < unordered_set>
18
+
19
+ using json = nlohmann::json;
20
+
21
+ namespace xcpp
22
+ {
23
+ class APIKeyManager
24
+ {
25
+ public:
26
+
27
+ static void saveApiKey (const std::string& model, const std::string& apiKey)
28
+ {
29
+ std::string apiKeyFilePath = model + " _api_key.txt" ;
30
+ std::ofstream out (apiKeyFilePath);
31
+ if (out)
32
+ {
33
+ out << apiKey;
34
+ out.close ();
35
+ std::cout << " API key saved for model " << model << std::endl;
36
+ }
37
+ else
38
+ {
39
+ std::cerr << " Failed to open file for writing API key for model " << model << std::endl;
40
+ }
41
+ }
42
+
43
+ // Method to load the API key for a specific model
44
+ static std::string loadApiKey (const std::string& model)
45
+ {
46
+ std::string apiKeyFilePath = model + " _api_key.txt" ;
47
+ std::ifstream in (apiKeyFilePath);
48
+ std::string apiKey;
49
+ if (in)
50
+ {
51
+ std::getline (in, apiKey);
52
+ in.close ();
53
+ return apiKey;
54
+ }
55
+
56
+ std::cerr << " Failed to open file for reading API key for model " << model << std::endl;
57
+ return " " ;
58
+ }
59
+ };
60
+
61
+ class CurlHelper
62
+ {
63
+ private:
64
+
65
+ CURL* m_curl;
66
+ curl_slist* m_headers;
67
+
68
+ public:
69
+
70
+ CurlHelper ()
71
+ : m_curl(curl_easy_init())
72
+ , m_headers(curl_slist_append(nullptr , " Content-Type: application/json" ))
73
+ {
74
+ }
75
+
76
+ ~CurlHelper ()
77
+ {
78
+ if (m_curl)
79
+ {
80
+ curl_easy_cleanup (m_curl);
81
+ }
82
+ if (m_headers)
83
+ {
84
+ curl_slist_free_all (m_headers);
85
+ }
86
+ }
87
+
88
+ // Delete copy constructor and copy assignment operator
89
+ CurlHelper (const CurlHelper&) = delete ;
90
+ CurlHelper& operator =(const CurlHelper&) = delete ;
91
+
92
+ // Delete move constructor and move assignment operator
93
+ CurlHelper (CurlHelper&&) = delete ;
94
+ CurlHelper& operator =(CurlHelper&&) = delete ;
95
+
96
+ std::string
97
+ performRequest (const std::string& url, const std::string& postData, const std::string& authHeader = " " )
98
+ {
99
+ if (!authHeader.empty ())
100
+ {
101
+ m_headers = curl_slist_append (m_headers, authHeader.c_str ());
102
+ }
103
+
104
+ curl_easy_setopt (m_curl, CURLOPT_URL, url.c_str ());
105
+ curl_easy_setopt (m_curl, CURLOPT_HTTPHEADER, m_headers);
106
+ curl_easy_setopt (m_curl, CURLOPT_POSTFIELDS, postData.c_str ());
107
+
108
+ std::string response;
109
+ curl_easy_setopt (
110
+ m_curl,
111
+ CURLOPT_WRITEFUNCTION,
112
+ +[](const char * in, size_t size, size_t num, std::string* out)
113
+ {
114
+ const size_t totalBytes (size * num);
115
+ out->append (in, totalBytes);
116
+ return totalBytes;
117
+ }
118
+ );
119
+ curl_easy_setopt (m_curl, CURLOPT_WRITEDATA, &response);
120
+
121
+ CURLcode res = curl_easy_perform (m_curl);
122
+ if (res != CURLE_OK)
123
+ {
124
+ std::cerr << " CURL request failed: " << curl_easy_strerror (res) << std::endl;
125
+ return " " ;
126
+ }
127
+
128
+ return response;
129
+ }
130
+ };
131
+
132
+ std::string gemini (const std::string& cell, const std::string& key)
133
+ {
134
+ CurlHelper curlHelper;
135
+ const std::string url = " https://generativelanguage.googleapis.com/v1beta/models/gemini-1.5-flash:generateContent?key="
136
+ + key;
137
+ const std::string postData = R"( {"contents": [{"parts":[{"text": ")" + cell + R"( "}]}]})" ;
138
+
139
+ std::string response = curlHelper.performRequest (url, postData);
140
+
141
+ json j = json::parse (response);
142
+ if (j.find (" error" ) != j.end ())
143
+ {
144
+ std::cerr << " Error: " << j[" error" ][" message" ] << std::endl;
145
+ return " " ;
146
+ }
147
+
148
+ return j[" candidates" ][0 ][" content" ][" parts" ][0 ][" text" ];
149
+ }
150
+
151
+ std::string openai (const std::string& cell, const std::string& key)
152
+ {
153
+ CurlHelper curlHelper;
154
+ const std::string url = " https://api.openai.com/v1/chat/completions" ;
155
+ const std::string postData = R"( {
156
+ "model": "gpt-3.5-turbo-16k",
157
+ "messages": [{"role": "user", "content": ")"
158
+ + cell + R"( "}],
159
+ "temperature": 0.7
160
+ })" ;
161
+ std::string authHeader = " Authorization: Bearer " + key;
162
+
163
+ std::string response = curlHelper.performRequest (url, postData, authHeader);
164
+
165
+ json j = json::parse (response);
166
+
167
+ if (j.find (" error" ) != j.end ())
168
+ {
169
+ std::cerr << " Error: " << j[" error" ][" message" ] << std::endl;
170
+ return " " ;
171
+ }
172
+
173
+ return j[" choices" ][0 ][" message" ][" content" ];
174
+ }
175
+
176
+ void xassist::operator ()(const std::string& line, const std::string& cell)
177
+ {
178
+ try
179
+ {
180
+ std::istringstream iss (line);
181
+ std::vector<std::string> tokens (
182
+ std::istream_iterator<std::string>{iss},
183
+ std::istream_iterator<std::string>()
184
+ );
185
+
186
+ std::vector<std::string> models = {" gemini" , " openai" };
187
+ std::string model = tokens[1 ];
188
+
189
+ if (std::find (models.begin (), models.end (), model) == models.end ())
190
+ {
191
+ std::cerr << " Model not found." << std::endl;
192
+ return ;
193
+ }
194
+
195
+ APIKeyManager api;
196
+ if (tokens[2 ] == " --save-key" )
197
+ {
198
+ xcpp::APIKeyManager::saveApiKey (model, cell);
199
+ return ;
200
+ }
201
+
202
+ std::string key = xcpp::APIKeyManager::loadApiKey (model);
203
+ if (key.empty ())
204
+ {
205
+ std::cerr << " API key for model " << model << " is not available." << std::endl;
206
+ return ;
207
+ }
208
+
209
+ std::string response;
210
+ if (model == " gemini" )
211
+ {
212
+ response = gemini (cell, key);
213
+ }
214
+ else if (model == " openai" )
215
+ {
216
+ response = openai (cell, key);
217
+ }
218
+
219
+ std::cout << response;
220
+ }
221
+ catch (const std::runtime_error& e)
222
+ {
223
+ std::cerr << " Caught an exception: " << e.what () << std::endl;
224
+ }
225
+ catch (...)
226
+ {
227
+ std::cerr << " Caught an unknown exception" << std::endl;
228
+ }
229
+ }
230
+ } // namespace xcpp
0 commit comments