@@ -98,6 +98,21 @@ static const char *db_sqlite3_fmt_error(struct db_stmt *stmt)
98
98
sqlite3_errmsg (conn2sql (stmt -> db -> conn )));
99
99
}
100
100
101
+ static bool is_strict_constraint_error (struct db_stmt * stmt )
102
+ {
103
+ sqlite3 * sql = conn2sql (stmt -> db -> conn );
104
+ const char * errmsg = sqlite3_errmsg (sql );
105
+ int errcode = sqlite3_errcode (sql );
106
+
107
+ if (errcode != SQLITE_CONSTRAINT || !stmt -> db -> use_strict_tables )
108
+ return false;
109
+
110
+ return (strstr (errmsg , "CHECK constraint failed" ) ||
111
+ strstr (errmsg , "datatype mismatch" ) ||
112
+ strstr (errmsg , "cannot store" ) ||
113
+ strstr (errmsg , "NOT NULL constraint failed" ));
114
+ }
115
+
101
116
static bool db_sqlite3_setup (struct db * db , bool create )
102
117
{
103
118
char * filename ;
@@ -205,16 +220,183 @@ static bool db_sqlite3_setup(struct db *db, bool create)
205
220
"PRAGMA foreign_keys = ON;" , -1 , & stmt , NULL );
206
221
err = sqlite3_step (stmt );
207
222
sqlite3_finalize (stmt );
208
- return err == SQLITE_DONE ;
223
+
224
+ if (err != SQLITE_DONE )
225
+ return false;
226
+
227
+ bool is_testing = (getenv ("TEST_DB_PROVIDER" ) ||
228
+ getenv ("PYTEST_PAR" ) ||
229
+ getenv ("TEST_DEBUG" ) ||
230
+ getenv ("VALGRIND" ));
231
+
232
+ /* SQLite 3.37.0 introduced STRICT table support */
233
+ if ((db -> developer || is_testing ) && sqlite3_libversion_number () >= 3037000 )
234
+ db -> use_strict_tables = true;
235
+
236
+ {
237
+ static const char * security_pragmas [] = {
238
+ "PRAGMA trusted_schema = OFF;" ,
239
+ "PRAGMA cell_size_check = ON;" ,
240
+ "PRAGMA secure_delete = ON;" ,
241
+ NULL
242
+ };
243
+
244
+ for (int i = 0 ; security_pragmas [i ]; i ++ ) {
245
+ err = sqlite3_prepare_v2 (conn2sql (db -> conn ),
246
+ security_pragmas [i ], -1 , & stmt , NULL );
247
+ if (err == SQLITE_OK ) {
248
+ err = sqlite3_step (stmt );
249
+ sqlite3_finalize (stmt );
250
+ }
251
+ }
252
+ }
253
+
254
+ return true;
255
+ }
256
+
257
+ static bool is_standalone_type_keyword (const char * query , const char * pos ,
258
+ const char * keyword , size_t keyword_len ,
259
+ size_t query_len )
260
+ {
261
+ bool prefix_ok = (pos == query || (!isalnum (pos [-1 ]) && pos [-1 ] != '_' ));
262
+ const char * after = pos + keyword_len ;
263
+ bool suffix_ok = (after >= query + query_len ||
264
+ (!isalnum (after [0 ]) && after [0 ] != '_' ));
265
+
266
+ return prefix_ok && suffix_ok ;
267
+ }
268
+
269
+ static char * normalize_varchar_to_text (const tal_t * ctx , const char * query )
270
+ {
271
+ char * result ;
272
+ const char * src ;
273
+ char * dst ;
274
+ size_t query_len ;
275
+
276
+ if (!query )
277
+ return NULL ;
278
+
279
+ query_len = strlen (query );
280
+
281
+ #define MAX_SQL_STATEMENT_LENGTH 1048576 /* 1MB limit */
282
+ if (query_len > MAX_SQL_STATEMENT_LENGTH )
283
+ return NULL ;
284
+
285
+ /* INT(3) -> INTEGER(7) worst case: +4 bytes per conversion */
286
+ size_t max_expansions = (query_len / 3 ) * 4 ;
287
+ size_t buffer_size = query_len + max_expansions + 64 ;
288
+
289
+ if (buffer_size < query_len )
290
+ return NULL ;
291
+
292
+ result = tal_arr (ctx , char , buffer_size );
293
+ src = query ;
294
+ dst = result ;
295
+
296
+ while (* src ) {
297
+ if (strncasecmp (src , "BIGSERIAL" , 9 ) == 0 &&
298
+ is_standalone_type_keyword (query , src , "BIGSERIAL" , 9 , query_len )) {
299
+ strcpy (dst , "INTEGER" );
300
+ dst += 7 ;
301
+ src += 9 ;
302
+ } else if (strncasecmp (src , "VARCHAR" , 7 ) == 0 &&
303
+ is_standalone_type_keyword (query , src , "VARCHAR" , 7 , query_len )) {
304
+ strcpy (dst , "TEXT" );
305
+ dst += 4 ;
306
+ src += 7 ;
307
+
308
+ if (* src == '(' ) {
309
+ const char * paren_start = src ;
310
+ while (* src && * src != ')' ) {
311
+ src ++ ;
312
+ /* Prevent runaway on malformed SQL */
313
+ if (src - paren_start > 1000 )
314
+ return NULL ;
315
+ }
316
+ if (* src == ')' ) src ++ ;
317
+ }
318
+ } else if (strncasecmp (src , "BIGINT" , 6 ) == 0 &&
319
+ is_standalone_type_keyword (query , src , "BIGINT" , 6 , query_len )) {
320
+ strcpy (dst , "INTEGER" );
321
+ dst += 7 ;
322
+ src += 6 ;
323
+ } else if (strncasecmp (src , "INT" , 3 ) == 0 &&
324
+ is_standalone_type_keyword (query , src , "INT" , 3 , query_len )) {
325
+ strcpy (dst , "INTEGER" );
326
+ dst += 7 ;
327
+ src += 3 ;
328
+ } else {
329
+ * dst ++ = * src ++ ;
330
+ }
331
+ }
332
+
333
+ * dst = '\0' ;
334
+ return result ;
335
+ }
336
+
337
+ static char * add_strict_to_create_table (const tal_t * ctx , const char * query )
338
+ {
339
+ char * semicolon_pos ;
340
+ ptrdiff_t prefix_len ;
341
+
342
+ if (!strcasestr (query , "CREATE TABLE" ))
343
+ return tal_strdup (ctx , query );
344
+
345
+ if (strcasestr (query , "STRICT" ))
346
+ return tal_strdup (ctx , query );
347
+
348
+ semicolon_pos = strrchr (query , ';' );
349
+ if (!semicolon_pos )
350
+ semicolon_pos = (char * )query + strlen (query );
351
+
352
+ prefix_len = semicolon_pos - query ;
353
+ return tal_fmt (ctx , "%.*s STRICT%s" , (int )prefix_len ,
354
+ query , semicolon_pos );
355
+ }
356
+
357
+ static char * prepare_query_for_execution (const tal_t * ctx , struct db * db ,
358
+ const char * query )
359
+ {
360
+ char * normalized_query ;
361
+
362
+ normalized_query = normalize_varchar_to_text (ctx , query );
363
+ if (!normalized_query )
364
+ return NULL ;
365
+
366
+ if (db -> use_strict_tables )
367
+ return add_strict_to_create_table (ctx , normalized_query );
368
+ else
369
+ return normalized_query ;
209
370
}
210
371
211
372
static bool db_sqlite3_query (struct db_stmt * stmt )
212
373
{
213
374
sqlite3_stmt * s ;
214
375
sqlite3 * conn = conn2sql (stmt -> db -> conn );
215
376
int err ;
377
+ char * query_to_execute ;
216
378
217
- err = sqlite3_prepare_v2 (conn , stmt -> query -> query , -1 , & s , NULL );
379
+ query_to_execute = prepare_query_for_execution (stmt , stmt -> db ,
380
+ stmt -> query -> query );
381
+ bool should_free_query = (query_to_execute != stmt -> query -> query );
382
+
383
+ err = sqlite3_prepare_v2 (conn , query_to_execute , -1 , & s , NULL );
384
+
385
+ if (err != SQLITE_OK ) {
386
+ if (should_free_query )
387
+ tal_free (query_to_execute );
388
+ tal_free (stmt -> error );
389
+ if (is_strict_constraint_error (stmt )) {
390
+ stmt -> error = tal_fmt (stmt , "%s (Note: STRICT tables are enabled)" ,
391
+ db_sqlite3_fmt_error (stmt ));
392
+ } else {
393
+ stmt -> error = db_sqlite3_fmt_error (stmt );
394
+ }
395
+ return false;
396
+ }
397
+
398
+ if (should_free_query )
399
+ tal_free (query_to_execute );
218
400
219
401
for (size_t i = 0 ; i < stmt -> query -> placeholders ; i ++ ) {
220
402
struct db_binding * b = & stmt -> bindings [i ];
@@ -246,12 +428,6 @@ static bool db_sqlite3_query(struct db_stmt *stmt)
246
428
}
247
429
}
248
430
249
- if (err != SQLITE_OK ) {
250
- tal_free (stmt -> error );
251
- stmt -> error = db_sqlite3_fmt_error (stmt );
252
- return false;
253
- }
254
-
255
431
stmt -> inner_stmt = s ;
256
432
return true;
257
433
}
@@ -270,7 +446,12 @@ static bool db_sqlite3_exec(struct db_stmt *stmt)
270
446
err = sqlite3_step (stmt -> inner_stmt );
271
447
if (err != SQLITE_DONE ) {
272
448
tal_free (stmt -> error );
273
- stmt -> error = db_sqlite3_fmt_error (stmt );
449
+ if (is_strict_constraint_error (stmt )) {
450
+ stmt -> error = tal_fmt (stmt , "%s (Note: STRICT tables are enabled)" ,
451
+ db_sqlite3_fmt_error (stmt ));
452
+ } else {
453
+ stmt -> error = db_sqlite3_fmt_error (stmt );
454
+ }
274
455
return false;
275
456
}
276
457
0 commit comments