diff --git a/FTL.h b/FTL.h index 323015147..1eed9bf00 100644 --- a/FTL.h +++ b/FTL.h @@ -52,16 +52,6 @@ #include "routines.h" -// Next we define the step size in which the struct arrays are reallocated if they -// grow too large. This number should be large enough so that reallocation does not -// have to run very often, but should be as small as possible to avoid wasting memory -#define QUERIESALLOCSTEP 10000 -#define FORWARDEDALLOCSTEP 4 -#define CLIENTSALLOCSTEP 10 -#define DOMAINSALLOCSTEP 1000 -#define OVERTIMEALLOCSTEP 100 -#define WILDCARDALLOCSTEP 100 - #define SOCKETBUFFERLEN 1024 // How often do we garbage collect (to ensure we only have data fitting to the MAXLOGAGE defined above)? [seconds] @@ -166,7 +156,7 @@ typedef struct { int domainID; int clientID; int forwardID; - bool db; + sqlite3_int64 db; int id; // the ID is a (signed) int in dnsmasq, so no need for a long int here bool complete; unsigned char privacylevel; @@ -180,8 +170,8 @@ typedef struct { unsigned char magic; int count; int failed; - char *ip; - char *name; + unsigned long long ippos; + unsigned long long namepos; bool new; } forwardedDataStruct; @@ -189,8 +179,8 @@ typedef struct { unsigned char magic; int count; int blockedcount; - char *ip; - char *name; + unsigned long long ippos; + unsigned long long namepos; bool new; } clientsDataStruct; @@ -198,7 +188,7 @@ typedef struct { unsigned char magic; int count; int blockedcount; - char *domain; + unsigned long long domainpos; unsigned char regexmatch; } domainsDataStruct; @@ -209,9 +199,7 @@ typedef struct { int blocked; int cached; int forwarded; - int clientnum; - int *clientdata; - int querytypedata[7]; + int querytypedata[TYPE_MAX-1]; } overTimeDataStruct; typedef struct { @@ -227,7 +215,7 @@ typedef struct { extern logFileNamesStruct files; extern FTLFileNamesStruct FTLfiles; -extern countersStruct counters; +extern countersStruct *counters; extern ConfigStruct config; extern queriesDataStruct *queries; @@ -236,6 +224,10 @@ extern clientsDataStruct *clients; extern domainsDataStruct *domains; extern overTimeDataStruct *overTime; +/// Indexed by client ID, then time index (like `overTime`). +/// This gets automatically updated whenever a new client or overTime slot is added. +extern int **overTimeClientData; + extern FILE *logfile; extern volatile sig_atomic_t killed; diff --git a/Makefile b/Makefile index 43555948c..6e27a0e7e 100644 --- a/Makefile +++ b/Makefile @@ -13,8 +13,8 @@ DNSMASQOPTS = -DHAVE_DNSSEC -DHAVE_DNSSEC_STATIC # Flags for compiling with libidn : -DHAVE_IDN # Flags for compiling with libidn2: -DHAVE_LIBIDN2 -DIDN2_VERSION_NUMBER=0x02000003 -FTLDEPS = FTL.h routines.h version.h api.h dnsmasq_interface.h -FTLOBJ = main.o memory.o log.o daemon.o datastructure.o signals.o socket.o request.o grep.o setupVars.o args.o threads.o gc.o config.o database.o msgpack.o api.o dnsmasq_interface.o resolve.o regex.o +FTLDEPS = FTL.h routines.h version.h api.h dnsmasq_interface.h shmem.h +FTLOBJ = main.o memory.o log.o daemon.o datastructure.o signals.o socket.o request.o grep.o setupVars.o args.o gc.o config.o database.o msgpack.o api.o dnsmasq_interface.o resolve.o regex.o shmem.o DNSMASQDEPS = config.h dhcp-protocol.h dns-protocol.h radv-protocol.h dhcp6-protocol.h dnsmasq.h ip6addr.h metrics.h DNSMASQOBJ = arp.o dbus.o domain.o lease.o outpacket.o rrfilter.o auth.o dhcp6.o edns0.o log.o poll.o slaac.o blockdata.o dhcp.o forward.o loop.o radv.o tables.o bpf.o dhcp-common.o helper.o netlink.o rfc1035.o tftp.o cache.o dnsmasq.o inotify.o network.o rfc2131.o util.o conntrack.o dnssec.o ipset.o option.o rfc3315.o crypto.o dump.o ubus.o metrics.o @@ -51,7 +51,7 @@ CCFLAGS=-std=gnu11 -I$(IDIR) -Wall -Wextra -Wno-unused-parameter -D_FILE_OFFSET_ # for dnsmasq we need the nettle crypto library and the gmp maths library # We link the two libraries statically. Althougth this increases the binary file size by about 1 MB, it saves about 5 MB of shared libraries and makes deployment easier #LIBS=-pthread -lnettle -lgmp -lhogweed -LIBS=-pthread -Wl,-Bstatic -L/usr/local/lib -lhogweed -lgmp -lnettle -Wl,-Bdynamic +LIBS=-pthread -Wl,-Bstatic -L/usr/local/lib -lhogweed -lgmp -lnettle -Wl,-Bdynamic -lrt # Flags for compiling with libidn : -lidn # Flags for compiling with libidn2: -lidn2 diff --git a/api.c b/api.c index 615286d55..621e73f4f 100644 --- a/api.c +++ b/api.c @@ -44,8 +44,8 @@ int cmpdesc(const void *a, const void *b) void getStats(int *sock) { - int blocked = counters.blocked; - int total = counters.queries; + int blocked = counters->blocked; + int total = counters->queries; float percentage = 0.0f; // Avoid 1/0 condition @@ -54,14 +54,14 @@ void getStats(int *sock) // Send domains being blocked if(istelnet[*sock]) { - ssend(*sock, "domains_being_blocked %i\n", counters.gravity); + ssend(*sock, "domains_being_blocked %i\n", counters->gravity); } else - pack_int32(*sock, counters.gravity); + pack_int32(*sock, counters->gravity); // unique_clients: count only clients that have been active within the most recent 24 hours int i, activeclients = 0; - for(i=0; i < counters.clients; i++) + for(i=0; i < counters->clients; i++) { validate_access("clients", i, true, __LINE__, __FUNCTION__, __FILE__); if(clients[i].count > 0) @@ -72,21 +72,21 @@ void getStats(int *sock) ssend(*sock, "dns_queries_today %i\nads_blocked_today %i\nads_percentage_today %f\n", total, blocked, percentage); ssend(*sock, "unique_domains %i\nqueries_forwarded %i\nqueries_cached %i\n", - counters.domains, counters.forwardedqueries, counters.cached); - ssend(*sock, "clients_ever_seen %i\n", counters.clients); + counters->domains, counters->forwardedqueries, counters->cached); + ssend(*sock, "clients_ever_seen %i\n", counters->clients); ssend(*sock, "unique_clients %i\n", activeclients); // Sum up all query types (A, AAAA, ANY, SRV, SOA, ...) int sumalltypes = 0; for(i=0; i < TYPE_MAX-1; i++) { - sumalltypes += counters.querytype[i]; + sumalltypes += counters->querytype[i]; } ssend(*sock, "dns_queries_all_types %i\n", sumalltypes); // Send individual reply type counters ssend(*sock, "reply_NODATA %i\nreply_NXDOMAIN %i\nreply_CNAME %i\nreply_IP %i\n", - counters.reply_NODATA, counters.reply_NXDOMAIN, counters.reply_CNAME, counters.reply_IP); + counters->reply_NODATA, counters->reply_NXDOMAIN, counters->reply_CNAME, counters->reply_IP); ssend(*sock, "privacy_level %i\n", config.privacylevel); } else @@ -94,16 +94,16 @@ void getStats(int *sock) pack_int32(*sock, total); pack_int32(*sock, blocked); pack_float(*sock, percentage); - pack_int32(*sock, counters.domains); - pack_int32(*sock, counters.forwardedqueries); - pack_int32(*sock, counters.cached); - pack_int32(*sock, counters.clients); + pack_int32(*sock, counters->domains); + pack_int32(*sock, counters->forwardedqueries); + pack_int32(*sock, counters->cached); + pack_int32(*sock, counters->clients); pack_int32(*sock, activeclients); } // Send status if(istelnet[*sock]) { - ssend(*sock, "status %s\n", counters.gravity > 0 ? "enabled" : "disabled"); + ssend(*sock, "status %s\n", counters->gravity > 0 ? "enabled" : "disabled"); } else pack_uint8(*sock, blockingstatus); @@ -116,7 +116,7 @@ void getOverTime(int *sock) time_t mintime = time(NULL) - config.maxlogage; // Start with the first non-empty overTime slot - for(i=0; i < counters.overTime; i++) + for(i=0; i < counters->overTime; i++) { validate_access("overTime", i, true, __LINE__, __FUNCTION__, __FILE__); if((overTime[i].total > 0 || overTime[i].blocked > 0) && @@ -134,7 +134,7 @@ void getOverTime(int *sock) if(istelnet[*sock]) { - for(i = j; i < counters.overTime; i++) + for(i = j; i < counters->overTime; i++) { ssend(*sock,"%i %i %i\n",overTime[i].timestamp,overTime[i].total,overTime[i].blocked); } @@ -145,15 +145,15 @@ void getOverTime(int *sock) // and map16 can hold up to (2^16)-1 = 65535 pairs // Send domains over time - pack_map16_start(*sock, (uint16_t) (counters.overTime - j)); - for(i = j; i < counters.overTime; i++) { + pack_map16_start(*sock, (uint16_t) (counters->overTime - j)); + for(i = j; i < counters->overTime; i++) { pack_int32(*sock, overTime[i].timestamp); pack_int32(*sock, overTime[i].total); } // Send ads over time - pack_map16_start(*sock, (uint16_t) (counters.overTime - j)); - for(i = j; i < counters.overTime; i++) { + pack_map16_start(*sock, (uint16_t) (counters->overTime - j)); + for(i = j; i < counters->overTime; i++) { pack_int32(*sock, overTime[i].timestamp); pack_int32(*sock, overTime[i].blocked); } @@ -162,7 +162,7 @@ void getOverTime(int *sock) void getTopDomains(char *client_message, int *sock) { - int i, temparray[counters.domains][2], count=10, num; + int i, temparray[counters->domains][2], count=10, num; bool blocked, audit = false, asc = false; blocked = command(client_message, ">top-ads"); @@ -194,7 +194,7 @@ void getTopDomains(char *client_message, int *sock) if(command(client_message, " asc")) asc = true; - for(i=0; i < counters.domains; i++) + for(i=0; i < counters->domains; i++) { validate_access("domains", i, true, __LINE__, __FUNCTION__, __FILE__); temparray[i][0] = i; @@ -207,9 +207,9 @@ void getTopDomains(char *client_message, int *sock) // Sort temporary array if(asc) - qsort(temparray, counters.domains, sizeof(int[2]), cmpasc); + qsort(temparray, counters->domains, sizeof(int[2]), cmpasc); else - qsort(temparray, counters.domains, sizeof(int[2]), cmpdesc); + qsort(temparray, counters->domains, sizeof(int[2]), cmpdesc); // Get filter @@ -244,28 +244,28 @@ void getTopDomains(char *client_message, int *sock) { // Send the data required to get the percentage each domain has been blocked / queried if(blocked) - pack_int32(*sock, counters.blocked); + pack_int32(*sock, counters->blocked); else - pack_int32(*sock, counters.queries); + pack_int32(*sock, counters->queries); } int n = 0; - for(i=0; i < counters.domains; i++) + for(i=0; i < counters->domains; i++) { // Get sorted indices int j = temparray[i][0]; validate_access("domains", j, true, __LINE__, __FUNCTION__, __FILE__); // Skip this domain if there is a filter on it - if(excludedomains != NULL && insetupVarsArray(domains[j].domain)) + if(excludedomains != NULL && insetupVarsArray(getstr(domains[j].domainpos))) continue; // Skip this domain if already included in audit - if(audit && countlineswith(domains[j].domain, files.auditlist) > 0) + if(audit && countlineswith(getstr(domains[j].domainpos), files.auditlist) > 0) continue; // Hidden domain, probably due to privacy level. Skip this in the top lists - if(strcmp(domains[j].domain, HIDDEN_DOMAIN) == 0) + if(strcmp(getstr(domains[j].domainpos), HIDDEN_DOMAIN) == 0) continue; if(blocked && showblocked && domains[j].blockedcount > 0) @@ -273,11 +273,11 @@ void getTopDomains(char *client_message, int *sock) if(audit && domains[j].regexmatch == REGEX_BLOCKED) { if(istelnet[*sock]) - ssend(*sock, "%i %i %s wildcard\n", n, domains[j].blockedcount, domains[j].domain); + ssend(*sock, "%i %i %s wildcard\n", n, domains[j].blockedcount, getstr(domains[j].domainpos)); else { - char *fancyWildcard = calloc(3 + strlen(domains[j].domain), sizeof(char)); + char *fancyWildcard = calloc(3 + strlen(getstr(domains[j].domainpos)), sizeof(char)); if(fancyWildcard == NULL) return; - sprintf(fancyWildcard, "*.%s", domains[j].domain); + sprintf(fancyWildcard, "*.%s", getstr(domains[j].domainpos)); if(!pack_str32(*sock, fancyWildcard)) return; @@ -289,9 +289,9 @@ void getTopDomains(char *client_message, int *sock) else { if(istelnet[*sock]) - ssend(*sock, "%i %i %s\n", n, domains[j].blockedcount, domains[j].domain); + ssend(*sock, "%i %i %s\n", n, domains[j].blockedcount, getstr(domains[j].domainpos)); else { - if(!pack_str32(*sock, domains[j].domain)) + if(!pack_str32(*sock, getstr(domains[j].domainpos))) return; pack_int32(*sock, domains[j].blockedcount); @@ -302,10 +302,10 @@ void getTopDomains(char *client_message, int *sock) else if(!blocked && showpermitted && (domains[j].count - domains[j].blockedcount) > 0) { if(istelnet[*sock]) - ssend(*sock,"%i %i %s\n",n,(domains[j].count - domains[j].blockedcount),domains[j].domain); + ssend(*sock,"%i %i %s\n",n,(domains[j].count - domains[j].blockedcount),getstr(domains[j].domainpos)); else { - if(!pack_str32(*sock, domains[j].domain)) + if(!pack_str32(*sock, getstr(domains[j].domainpos))) return; pack_int32(*sock, domains[j].count - domains[j].blockedcount); @@ -324,7 +324,7 @@ void getTopDomains(char *client_message, int *sock) void getTopClients(char *client_message, int *sock) { - int i, temparray[counters.clients][2], count=10, num; + int i, temparray[counters->clients][2], count=10, num; // Exit before processing any data if requested via config setting get_privacy_level(NULL); @@ -357,7 +357,7 @@ void getTopClients(char *client_message, int *sock) if(command(client_message, " blocked")) blockedonly = true; - for(i=0; i < counters.clients; i++) + for(i=0; i < counters->clients; i++) { validate_access("clients", i, true, __LINE__, __FUNCTION__, __FILE__); temparray[i][0] = i; @@ -373,9 +373,9 @@ void getTopClients(char *client_message, int *sock) // Sort temporary array if(asc) - qsort(temparray, counters.clients, sizeof(int[2]), cmpasc); + qsort(temparray, counters->clients, sizeof(int[2]), cmpasc); else - qsort(temparray, counters.clients, sizeof(int[2]), cmpdesc); + qsort(temparray, counters->clients, sizeof(int[2]), cmpdesc); // Get clients which the user doesn't want to see char * excludeclients = read_setupVarsconf("API_EXCLUDE_CLIENTS"); @@ -387,11 +387,11 @@ void getTopClients(char *client_message, int *sock) if(!istelnet[*sock]) { // Send the total queries so they can make percentages from this data - pack_int32(*sock, counters.queries); + pack_int32(*sock, counters->queries); } int n = 0; - for(i=0; i < counters.clients; i++) + for(i=0; i < counters->clients; i++) { // Get sorted indices and counter values (may be either total or blocked count) int j = temparray[i][0]; @@ -400,19 +400,15 @@ void getTopClients(char *client_message, int *sock) // Skip this client if there is a filter on it if(excludeclients != NULL && - (insetupVarsArray(clients[j].ip) || insetupVarsArray(clients[j].name))) + (insetupVarsArray(getstr(clients[j].ippos)) || insetupVarsArray(getstr(clients[j].namepos)))) continue; // Hidden client, probably due to privacy level. Skip this in the top lists - if(strcmp(clients[j].ip, HIDDEN_CLIENT) == 0) + if(strcmp(getstr(clients[j].ippos), HIDDEN_CLIENT) == 0) continue; - // Only return name if available - char *name; - if(clients[j].name != NULL) - name = clients[j].name; - else - name = ""; + char *client_ip = getstr(clients[j].ippos); + char *client_name = getstr(clients[j].namepos); // Return this client if either // - "withzero" option is set, and/or @@ -420,10 +416,10 @@ void getTopClients(char *client_message, int *sock) if(includezeroclients || ccount > 0) { if(istelnet[*sock]) - ssend(*sock,"%i %i %s %s\n", n, ccount, clients[j].ip, name); + ssend(*sock,"%i %i %s %s\n", n, ccount, client_ip, client_name); else { - if(!pack_str32(*sock, "") || !pack_str32(*sock, clients[j].ip)) + if(!pack_str32(*sock, "") || !pack_str32(*sock, client_ip)) return; pack_int32(*sock, ccount); @@ -443,12 +439,12 @@ void getTopClients(char *client_message, int *sock) void getForwardDestinations(char *client_message, int *sock) { bool sort = true; - int i, temparray[counters.forwarded][2], totalqueries = 0; + int i, temparray[counters->forwarded][2], totalqueries = 0; if(command(client_message, "unsorted")) sort = false; - for(i=0; i < counters.forwarded; i++) { + for(i=0; i < counters->forwarded; i++) { validate_access("forwarded", i, true, __LINE__, __FUNCTION__, __FILE__); // If we want to print a sorted output, we fill the temporary array with // the values we will use for sorting afterwards @@ -461,13 +457,13 @@ void getForwardDestinations(char *client_message, int *sock) if(sort) { // Sort temporary array in descending order - qsort(temparray, counters.forwarded, sizeof(int[2]), cmpdesc); + qsort(temparray, counters->forwarded, sizeof(int[2]), cmpdesc); } - totalqueries = counters.forwardedqueries + counters.cached + counters.blocked; + totalqueries = counters->forwardedqueries + counters->cached + counters->blocked; // Loop over available forward destinations - for(i=-2; i < min(counters.forwarded, 8); i++) + for(i=-2; i < min(counters->forwarded, 8); i++) { char *ip, *name; float percentage = 0.0f; @@ -480,7 +476,7 @@ void getForwardDestinations(char *client_message, int *sock) if(totalqueries > 0) // Whats the percentage of locked queries on the total amount of queries? - percentage = 1e2f * counters.blocked / totalqueries; + percentage = 1e2f * counters->blocked / totalqueries; } else if(i == -1) { @@ -490,7 +486,7 @@ void getForwardDestinations(char *client_message, int *sock) if(totalqueries > 0) // Whats the percentage of cached queries on the total amount of queries? - percentage = 1e2f * counters.cached / totalqueries; + percentage = 1e2f * counters->cached / totalqueries; } else { @@ -502,13 +498,10 @@ void getForwardDestinations(char *client_message, int *sock) else j = i; validate_access("forwarded", j, true, __LINE__, __FUNCTION__, __FILE__); - ip = forwarded[j].ip; - // Only return name if available - if(forwarded[j].name != NULL) - name = forwarded[j].name; - else - name = ""; + // Get IP and host name of forward destination if available + ip = getstr(forwarded[j].ippos); + name = getstr(forwarded[j].namepos); // Get percentage if(totalqueries > 0) @@ -538,14 +531,14 @@ void getQueryTypes(int *sock) { int i,total = 0; for(i=0; i < TYPE_MAX-1; i++) - total += counters.querytype[i]; + total += counters->querytype[i]; float percentage[TYPE_MAX-1] = { 0.0 }; // Prevent floating point exceptions by checking if the divisor is != 0 if(total > 0) for(i=0; i < TYPE_MAX-1; i++) - percentage[i] = 1e2f*counters.querytype[i]/total; + percentage[i] = 1e2f*counters->querytype[i]/total; if(istelnet[*sock]) { ssend(*sock, "A (IPv4): %.2f\nAAAA (IPv6): %.2f\nANY: %.2f\nSRV: %.2f\nSOA: %.2f\nPTR: %.2f\nTXT: %.2f\n", @@ -628,15 +621,15 @@ void getAllQueries(char *client_message, int *sock) { // Iterate through all known forward destinations int i; - validate_access("forwards", MAX(0,counters.forwarded-1), true, __LINE__, __FUNCTION__, __FILE__); + validate_access("forwards", MAX(0,counters->forwarded-1), true, __LINE__, __FUNCTION__, __FILE__); forwarddestid = -3; - for(i = 0; i < counters.forwarded; i++) + for(i = 0; i < counters->forwarded; i++) { // Try to match the requested string against their IP addresses and // (if available) their host names - if(strcmp(forwarded[i].ip, forwarddest) == 0 || - (forwarded[i].name != NULL && - strcmp(forwarded[i].name, forwarddest) == 0)) + if(strcmp(getstr(forwarded[i].ippos), forwarddest) == 0 || + (forwarded[i].namepos != 0 && + strcmp(getstr(forwarded[i].namepos), forwarddest) == 0)) { forwarddestid = i; break; @@ -661,11 +654,11 @@ void getAllQueries(char *client_message, int *sock) filterdomainname = true; // Iterate through all known domains int i; - validate_access("domains", MAX(0,counters.domains-1), true, __LINE__, __FUNCTION__, __FILE__); - for(i = 0; i < counters.domains; i++) + validate_access("domains", MAX(0,counters->domains-1), true, __LINE__, __FUNCTION__, __FILE__); + for(i = 0; i < counters->domains; i++) { // Try to match the requested string - if(strcmp(domains[i].domain, domainname) == 0) + if(strcmp(getstr(domains[i].domainpos), domainname) == 0) { domainid = i; break; @@ -689,13 +682,13 @@ void getAllQueries(char *client_message, int *sock) filterclientname = true; // Iterate through all known clients int i; - validate_access("clients", MAX(0,counters.clients-1), true, __LINE__, __FUNCTION__, __FILE__); - for(i = 0; i < counters.clients; i++) + validate_access("clients", MAX(0,counters->clients-1), true, __LINE__, __FUNCTION__, __FILE__); + for(i = 0; i < counters->clients; i++) { // Try to match the requested string - if(strcmp(clients[i].ip, clientname) == 0 || - (clients[i].name != NULL && - strcmp(clients[i].name, clientname) == 0)) + if(strcmp(getstr(clients[i].ippos), clientname) == 0 || + (clients[i].namepos != 0 && + strcmp(getstr(clients[i].namepos), clientname) == 0)) { clientid = i; break; @@ -716,7 +709,7 @@ void getAllQueries(char *client_message, int *sock) { // User wants a different number of requests // Don't allow a start index that is smaller than zero - ibeg = counters.queries-num; + ibeg = counters->queries-num; if(ibeg < 0) ibeg = 0; } @@ -739,7 +732,7 @@ void getAllQueries(char *client_message, int *sock) clearSetupVarsArray(); int i; - for(i=ibeg; i < counters.queries; i++) + for(i=ibeg; i < counters->queries; i++) { validate_access("queries", i, true, __LINE__, __FUNCTION__, __FILE__); // Check if this query has been create while in maximum privacy mode @@ -796,10 +789,8 @@ void getAllQueries(char *client_message, int *sock) char *domain = getDomainString(i); // Similarly for the client char *client; - if(clients[queries[i].clientID].name != NULL && - strlen(clients[queries[i].clientID].name) > 0 && - queries[i].privacylevel < PRIVACY_HIDE_DOMAINS_CLIENTS) - client = clients[queries[i].clientID].name; + if(strlen(getstr(clients[queries[i].clientID].namepos)) > 0) + client = getClientNameString(i); else client = getClientIPString(i); @@ -847,13 +838,13 @@ void getRecentBlocked(char *client_message, int *sock) // Test for integer that specifies number of entries to be shown if(sscanf(client_message, "%*[^(](%i)", &num) > 0) { // User wants a different number of requests - if(num >= counters.queries) + if(num >= counters->queries) num = 0; } // Find most recently blocked query int found = 0; - for(i = counters.queries - 1; i > 0 ; i--) + for(i = counters->queries - 1; i > 0 ; i--) { validate_access("queries", i, true, __LINE__, __FUNCTION__, __FILE__); @@ -890,7 +881,7 @@ void getQueryTypesOverTime(int *sock) { int i, sendit = -1; time_t mintime = time(NULL) - config.maxlogage; - for(i = 0; i < counters.overTime; i++) + for(i = 0; i < counters->overTime; i++) { validate_access("overTime", i, true, __LINE__, __FUNCTION__, __FILE__); if((overTime[i].total > 0 || overTime[i].blocked > 0) && overTime[i].timestamp >= mintime) @@ -902,7 +893,7 @@ void getQueryTypesOverTime(int *sock) if(sendit > -1) { - for(i = sendit; i < counters.overTime; i++) + for(i = sendit; i < counters->overTime; i++) { validate_access("overTime", i, true, __LINE__, __FUNCTION__, __FILE__); @@ -1010,7 +1001,7 @@ void getClientsOverTime(int *sock) if(config.privacylevel >= PRIVACY_HIDE_DOMAINS_CLIENTS) return; - for(i = 0; i < counters.overTime; i++) + for(i = 0; i < counters->overTime; i++) { validate_access("overTime", i, true, __LINE__, __FUNCTION__, __FILE__); if((overTime[i].total > 0 || overTime[i].blocked > 0) && @@ -1028,25 +1019,25 @@ void getClientsOverTime(int *sock) // Array of clients to be skipped in the output // if skipclient[i] == true then this client should be hidden from // returned data. We initialize it with false - bool skipclient[counters.clients]; - memset(skipclient, false, counters.clients*sizeof(bool)); + bool skipclient[counters->clients]; + memset(skipclient, false, counters->clients*sizeof(bool)); if(excludeclients != NULL) { getSetupVarsArray(excludeclients); - for(i=0; i < counters.clients; i++) + for(i=0; i < counters->clients; i++) { validate_access("clients", i, true, __LINE__, __FUNCTION__, __FILE__); // Check if this client should be skipped - if(insetupVarsArray(clients[i].ip) || - insetupVarsArray(clients[i].name)) + if(insetupVarsArray(getstr(clients[i].ippos)) || + insetupVarsArray(getstr(clients[i].namepos))) skipclient[i] = true; } } // Main return loop - for(i = sendit; i < counters.overTime; i++) + for(i = sendit; i < counters->overTime; i++) { validate_access("overTime", i, true, __LINE__, __FUNCTION__, __FILE__); @@ -1057,19 +1048,12 @@ void getClientsOverTime(int *sock) // Loop over forward destinations to generate output to be sent to the client int j; - for(j = 0; j < counters.clients; j++) + for(j = 0; j < counters->clients; j++) { - int thisclient = 0; - if(skipclient[j]) continue; - if(j < overTime[i].clientnum) - { - // This client entry does already exist at this timestamp - // -> use counter of requests sent to this destination - thisclient = overTime[i].clientdata[j]; - } + int thisclient = overTimeClientData[j][i]; if(istelnet[*sock]) ssend(*sock, " %i", thisclient); @@ -1101,37 +1085,38 @@ void getClientNames(int *sock) // Array of clients to be skipped in the output // if skipclient[i] == true then this client should be hidden from // returned data. We initialize it with false - bool skipclient[counters.clients]; - memset(skipclient, false, counters.clients*sizeof(bool)); + bool skipclient[counters->clients]; + memset(skipclient, false, counters->clients*sizeof(bool)); if(excludeclients != NULL) { getSetupVarsArray(excludeclients); - for(i=0; i < counters.clients; i++) + for(i=0; i < counters->clients; i++) { validate_access("clients", i, true, __LINE__, __FUNCTION__, __FILE__); // Check if this client should be skipped - if(insetupVarsArray(clients[i].ip) || - insetupVarsArray(clients[i].name)) + if(insetupVarsArray(getstr(clients[i].ippos)) || + insetupVarsArray(getstr(clients[i].namepos))) skipclient[i] = true; } } // Loop over clients to generate output to be sent to the client - for(i = 0; i < counters.clients; i++) + for(i = 0; i < counters->clients; i++) { validate_access("clients", i, true, __LINE__, __FUNCTION__, __FILE__); if(skipclient[i]) continue; - char *client_name = clients[i].name != NULL ? clients[i].name : ""; + char *client_ip = getstr(clients[i].ippos); + char *client_name = getstr(clients[i].namepos); if(istelnet[*sock]) - ssend(*sock, "%s %s\n", client_name, clients[i].ip); + ssend(*sock, "%s %s\n", client_name, client_ip); else { pack_str32(*sock, client_name); - pack_str32(*sock, clients[i].ip); + pack_str32(*sock, client_ip); } } @@ -1147,7 +1132,7 @@ void getUnknownQueries(int *sock) return; int i; - for(i=0; i < counters.queries; i++) + for(i=0; i < counters->queries; i++) { validate_access("queries", i, true, __LINE__, __FUNCTION__, __FILE__); if(queries[i].status != QUERY_UNKNOWN && queries[i].complete) continue; @@ -1166,10 +1151,10 @@ void getUnknownQueries(int *sock) validate_access("clients", queries[i].clientID, true, __LINE__, __FUNCTION__, __FILE__); - char *client = clients[queries[i].clientID].ip; + char *client = getstr(clients[queries[i].clientID].ippos); if(istelnet[*sock]) - ssend(*sock, "%i %i %i %s %s %s %i %s\n", queries[i].timestamp, i, queries[i].id, type, domains[queries[i].domainID].domain, client, queries[i].status, queries[i].complete ? "true" : "false"); + ssend(*sock, "%i %i %i %s %s %s %i %s\n", queries[i].timestamp, i, queries[i].id, type, getstr(domains[queries[i].domainID].domainpos), client, queries[i].status, queries[i].complete ? "true" : "false"); else { pack_int32(*sock, queries[i].timestamp); pack_int32(*sock, queries[i].id); @@ -1179,7 +1164,7 @@ void getUnknownQueries(int *sock) return; // Use str32 for domain and client because we have no idea how long they will be (max is 4294967295 for str32) - if(!pack_str32(*sock, domains[queries[i].domainID].domain) || !pack_str32(*sock, client)) + if(!pack_str32(*sock, getstr(domains[queries[i].domainID].domainpos)) || !pack_str32(*sock, client)) return; pack_uint8(*sock, queries[i].status); @@ -1199,10 +1184,10 @@ void getDomainDetails(char *client_message, int *sock) } int i; - for(i = 0; i < counters.domains; i++) + for(i = 0; i < counters->domains; i++) { validate_access("domains", i, true, __LINE__, __FUNCTION__, __FILE__); - if(strcmp(domains[i].domain, domain) == 0) + if(strcmp(getstr(domains[i].domainpos), domain) == 0) { ssend(*sock,"Domain \"%s\", ID: %i\n", domain, i); ssend(*sock,"Total: %i\n", domains[i].count); diff --git a/database.c b/database.c index 64ea8f491..7dcb6cbaf 100644 --- a/database.c +++ b/database.c @@ -9,6 +9,7 @@ * Please see LICENSE file for your rights under this license. */ #include "FTL.h" +#include "shmem.h" sqlite3 *db; bool database = false; @@ -317,6 +318,33 @@ int number_of_queries_in_DB(void) return result; } +static sqlite3_int64 last_ID_in_DB(void) +{ + sqlite3_stmt* stmt; + + int rc = sqlite3_prepare_v2(db, "SELECT MAX(ID) FROM queries", -1, &stmt, NULL); + if( rc ){ + logg("last_ID_in_DB() - SQL error prepare (%i): %s", rc, sqlite3_errmsg(db)); + dbclose(); + check_database(rc); + return -1; + } + + rc = sqlite3_step(stmt); + if( rc != SQLITE_ROW ){ + logg("last_ID_in_DB() - SQL error step (%i): %s", rc, sqlite3_errmsg(db)); + dbclose(); + check_database(rc); + return -1; + } + + sqlite3_int64 result = sqlite3_column_int64(stmt, 0); + + sqlite3_finalize(stmt); + + return result; +} + int get_number_of_queries_in_DB(void) { int result = -1; @@ -355,6 +383,9 @@ void save_to_DB(void) long int i; sqlite3_stmt* stmt; + // Get last ID stored in the database + sqlite3_int64 lastID = last_ID_in_DB(); + bool ret = dbquery("BEGIN TRANSACTION"); if(!ret) { @@ -375,10 +406,10 @@ void save_to_DB(void) int total = 0, blocked = 0; time_t currenttimestamp = time(NULL); time_t newlasttimestamp = 0; - for(i = 0; i < counters.queries; i++) + for(i = 0; i < counters->queries; i++) { validate_access("queries", i, true, __LINE__, __FUNCTION__, __FILE__); - if(queries[i].db) + if(queries[i].db != 0) { // Skip, already saved in database continue; @@ -422,7 +453,7 @@ void save_to_DB(void) if(queries[i].status == QUERY_FORWARDED && queries[i].forwardID > -1) { validate_access("forwarded", queries[i].forwardID, true, __LINE__, __FUNCTION__, __FILE__); - sqlite3_bind_text(stmt, 6, forwarded[queries[i].forwardID].ip, -1, SQLITE_TRANSIENT); + sqlite3_bind_text(stmt, 6, getstr(forwarded[queries[i].forwardID].ippos), -1, SQLITE_TRANSIENT); } else { @@ -451,8 +482,8 @@ void save_to_DB(void) } saved++; - // Mark this query as saved in the database only if successful - queries[i].db = true; + // Mark this query as saved in the database by setting the corresponding ID + queries[i].db = ++lastID; // Total counter information (delta computation) total++; @@ -492,7 +523,7 @@ void save_to_DB(void) if(debug) { - logg("Notice: Queries stored in DB: %u (took %.1f ms)", saved, timer_elapsed_msec(DATABASE_WRITE_TIMER)); + logg("Notice: Queries stored in DB: %u (took %.1f ms, last SQLite ID %llu)", saved, timer_elapsed_msec(DATABASE_WRITE_TIMER), lastID); if(saved_error > 0) logg(" There are queries that have not been saved"); } @@ -548,15 +579,15 @@ void *DB_thread(void *val) // Update lastDBsave timer lastDBsave = time(NULL) - time(NULL)%config.DBinterval; - // Lock FTL's data structure, since it is - // likely that it will be changed here - enable_thread_lock(); + // Lock FTL's data structures, since it is + // likely that they will be changed here + lock_shm(); // Save data to database save_to_DB(); - // Release thread lock - disable_thread_lock(); + // Release data lock + unlock_shm(); // Check if GC should be done on the database if(DBdeleteoldqueries) @@ -613,12 +644,7 @@ void read_data_from_DB(void) // Loop through returned database rows while((rc = sqlite3_step(stmt)) == SQLITE_ROW) { - // Ensure we have enough space in the queries struct - memory_check(QUERIES); - - // Set ID for this query - int queryID = counters.queries; - + sqlite3_int64 dbid = sqlite3_column_int64(stmt, 0); int queryTimeStamp = sqlite3_column_int(stmt, 1); // 1483228800 = 01/01/2017 @ 12:00am (UTC) if(queryTimeStamp < 1483228800) @@ -673,10 +699,6 @@ void read_data_from_DB(void) continue; } - // Obtain IDs only after filtering which queries we want to keep - int domainID = findDomainID(domain); - int clientID = findClientID(client); - const char *forwarddest = (const char *)sqlite3_column_text(stmt, 6); int forwardID = 0; // Determine forwardID only when status == 2 (forwarded) as the @@ -691,31 +713,40 @@ void read_data_from_DB(void) forwardID = findForwardID(forwarddest, true); } + // Obtain IDs only after filtering which queries we want to keep int overTimeTimeStamp = queryTimeStamp - (queryTimeStamp % 600) + 300; int timeidx = findOverTimeID(overTimeTimeStamp); - validate_access("overTime", timeidx, true, __LINE__, __FUNCTION__, __FILE__); + int domainID = findDomainID(domain); + int clientID = findClientID(client); + + // Ensure we have enough space in the queries struct + memory_check(QUERIES); + + // Set index for this query + int queryIndex = counters->queries; // Store this query in memory - validate_access("queries", queryID, false, __LINE__, __FUNCTION__, __FILE__); - queries[queryID].magic = MAGICBYTE; - queries[queryID].timestamp = queryTimeStamp; - queries[queryID].type = type; - queries[queryID].status = status; - queries[queryID].domainID = domainID; - queries[queryID].clientID = clientID; - queries[queryID].forwardID = forwardID; - queries[queryID].timeidx = timeidx; - queries[queryID].db = true; // Mark this as already present in the database - queries[queryID].id = 0; // This is dnsmasq's internal ID. We don't store it in the database - queries[queryID].complete = true; // Mark as all information is avaiable - queries[queryID].response = 0; - queries[queryID].AD = false; + validate_access("overTime", timeidx, true, __LINE__, __FUNCTION__, __FILE__); + validate_access("queries", queryIndex, false, __LINE__, __FUNCTION__, __FILE__); + queries[queryIndex].magic = MAGICBYTE; + queries[queryIndex].timestamp = queryTimeStamp; + queries[queryIndex].type = type; + queries[queryIndex].status = status; + queries[queryIndex].domainID = domainID; + queries[queryIndex].clientID = clientID; + queries[queryIndex].forwardID = forwardID; + queries[queryIndex].timeidx = timeidx; + queries[queryIndex].db = dbid; + queries[queryIndex].id = 0; + queries[queryIndex].complete = true; // Mark as all information is avaiable + queries[queryIndex].response = 0; + queries[queryIndex].AD = false; lastDBimportedtimestamp = queryTimeStamp; // Handle type counters if(type >= TYPE_A && type < TYPE_MAX) { - counters.querytype[type-1]++; + counters->querytype[type-1]++; overTime[timeidx].querytypedata[type-1]++; } @@ -723,36 +754,35 @@ void read_data_from_DB(void) overTime[timeidx].total++; // Update overTime data structure with the new client - validate_access_oTcl(timeidx, clientID, __LINE__, __FUNCTION__, __FILE__); - overTime[timeidx].clientdata[clientID]++; + overTimeClientData[clientID][timeidx]++; // Increase DNS queries counter - counters.queries++; + counters->queries++; // Increment status counters switch(status) { case QUERY_UNKNOWN: // Unknown - counters.unknown++; + counters->unknown++; break; case QUERY_GRAVITY: // Blocked by gravity.list case QUERY_WILDCARD: // Blocked by regex filter case QUERY_BLACKLIST: // Blocked by black.list case QUERY_EXTERNAL_BLOCKED: // Blocked by external provider - counters.blocked++; + counters->blocked++; overTime[timeidx].blocked++; domains[domainID].blockedcount++; clients[clientID].blockedcount++; break; case QUERY_FORWARDED: // Forwarded - counters.forwardedqueries++; + counters->forwardedqueries++; // Update overTime data structure break; case QUERY_CACHE: // Cached or local config - counters.cached++; + counters->cached++; // Update overTime data structure overTime[timeidx].cached++; break; @@ -764,7 +794,7 @@ void read_data_from_DB(void) break; } } - logg("Imported %i queries from the long-term database", counters.queries); + logg("Imported %i queries from the long-term database", counters->queries); if( rc != SQLITE_DONE ){ logg("read_data_from_DB() - SQL error step (%i): %s", rc, sqlite3_errmsg(db)); diff --git a/datastructure.c b/datastructure.c index cdac2eedc..18cb871ac 100644 --- a/datastructure.c +++ b/datastructure.c @@ -32,19 +32,19 @@ int findOverTimeID(int overTimetimestamp) int timeidx = -1, i; // Check struct size memory_check(OVERTIME); - if(counters.overTime > 0) - validate_access("overTime", counters.overTime-1, true, __LINE__, __FUNCTION__, __FILE__); - for(i=0; i < counters.overTime; i++) + if(counters->overTime > 0) + validate_access("overTime", counters->overTime-1, true, __LINE__, __FUNCTION__, __FILE__); + for(i=0; i < counters->overTime; i++) { if(overTime[i].timestamp == overTimetimestamp) return i; } // We loop over this to fill potential data holes with zeros int nexttimestamp = 0; - if(counters.overTime != 0) + if(counters->overTime != 0) { - validate_access("overTime", counters.overTime-1, false, __LINE__, __FUNCTION__, __FILE__); - nexttimestamp = overTime[counters.overTime-1].timestamp + 600; + validate_access("overTime", counters->overTime-1, false, __LINE__, __FUNCTION__, __FILE__); + nexttimestamp = overTime[counters->overTime-1].timestamp + 600; } else { @@ -57,7 +57,7 @@ int findOverTimeID(int overTimetimestamp) { // Check struct size memory_check(OVERTIME); - timeidx = counters.overTime; + timeidx = counters->overTime; validate_access("overTime", timeidx, false, __LINE__, __FUNCTION__, __FILE__); // Set magic byte overTime[timeidx].magic = MAGICBYTE; @@ -66,15 +66,16 @@ int findOverTimeID(int overTimetimestamp) overTime[timeidx].blocked = 0; overTime[timeidx].cached = 0; // overTime[timeidx].querytypedata is static - overTime[timeidx].clientnum = 0; - overTime[timeidx].clientdata = NULL; - counters.overTime++; + counters->overTime++; - // Update time stamp for next loop interation - if(counters.overTime != 0) + // Create new overTime slot in client shared memory + addOverTimeClientSlot(); + + // Update time stamp for next loop interaction + if(counters->overTime != 0) { - validate_access("overTime", counters.overTime-1, false, __LINE__, __FUNCTION__, __FILE__); - nexttimestamp = overTime[counters.overTime-1].timestamp + 600; + validate_access("overTime", counters->overTime-1, false, __LINE__, __FUNCTION__, __FILE__); + nexttimestamp = overTime[counters->overTime-1].timestamp + 600; } } @@ -89,12 +90,12 @@ int findOverTimeID(int overTimetimestamp) int findForwardID(const char * forward, bool count) { int i, forwardID = -1; - if(counters.forwarded > 0) - validate_access("forwarded", counters.forwarded-1, true, __LINE__, __FUNCTION__, __FILE__); + if(counters->forwarded > 0) + validate_access("forwarded", counters->forwarded-1, true, __LINE__, __FUNCTION__, __FILE__); // Go through already knows forward servers and see if we used one of those - for(i=0; i < counters.forwarded; i++) + for(i=0; i < counters->forwarded; i++) { - if(strcmp(forwarded[i].ip, forward) == 0) + if(strcmp(getstr(forwarded[i].ippos), forward) == 0) { forwardID = i; if(count) forwarded[forwardID].count++; @@ -103,8 +104,8 @@ int findForwardID(const char * forward, bool count) } // This forward server is not known // Store ID - forwardID = counters.forwarded; - logg("New forward server: %s (%i/%u)", forward, forwardID, counters.forwarded_MAX); + forwardID = counters->forwarded; + logg("New forward server: %s (%i/%u)", forward, forwardID, counters->forwarded_MAX); // Check struct size memory_check(FORWARDED); @@ -118,16 +119,16 @@ int findForwardID(const char * forward, bool count) else forwarded[forwardID].count = 0; // Save forward destination IP address - forwarded[forwardID].ip = strdup(forward); + forwarded[forwardID].ippos = addstr(forward); forwarded[forwardID].failed = 0; // Initialize forward hostname // Due to the nature of us being the resolver, // the actual resolving of the host name has // to be done separately to be non-blocking forwarded[forwardID].new = true; - forwarded[forwardID].name = NULL; + forwarded[forwardID].namepos = 0; // 0 -> string with length zero // Increase counter by one - counters.forwarded++; + counters->forwarded++; return forwardID; } @@ -135,16 +136,16 @@ int findForwardID(const char * forward, bool count) int findDomainID(const char *domain) { int i; - if(counters.domains > 0) - validate_access("domains", counters.domains-1, true, __LINE__, __FUNCTION__, __FILE__); - for(i=0; i < counters.domains; i++) + if(counters->domains > 0) + validate_access("domains", counters->domains-1, true, __LINE__, __FUNCTION__, __FILE__); + for(i=0; i < counters->domains; i++) { // Quick test: Does the domain start with the same character? - if(domains[i].domain[0] != domain[0]) + if(getstr(domains[i].domainpos)[0] != domain[0]) continue; // If so, compare the full domain using strcmp - if(strcmp(domains[i].domain, domain) == 0) + if(strcmp(getstr(domains[i].domainpos), domain) == 0) { domains[i].count++; return i; @@ -153,7 +154,7 @@ int findDomainID(const char *domain) // If we did not return until here, then this domain is not known // Store ID - int domainID = counters.domains; + int domainID = counters->domains; // Check struct size memory_check(DOMAINS); @@ -166,11 +167,11 @@ int findDomainID(const char *domain) // Set blocked counter to zero domains[domainID].blockedcount = 0; // Store domain name - no need to check for NULL here as it doesn't harm - domains[domainID].domain = strdup(domain); + domains[domainID].domainpos = addstr(domain); // RegEx needs to be evaluated for this new domain domains[domainID].regexmatch = REGEX_UNKNOWN; // Increase counter by one - counters.domains++; + counters->domains++; return domainID; } @@ -179,16 +180,16 @@ int findClientID(const char *client) { int i; // Compare content of client against known client IP addresses - if(counters.clients > 0) - validate_access("clients", counters.clients-1, true, __LINE__, __FUNCTION__, __FILE__); - for(i=0; i < counters.clients; i++) + if(counters->clients > 0) + validate_access("clients", counters->clients-1, true, __LINE__, __FUNCTION__, __FILE__); + for(i=0; i < counters->clients; i++) { // Quick test: Does the clients IP start with the same character? - if(clients[i].ip[0] != client[0]) + if(getstr(clients[i].ippos)[0] != client[0]) continue; // If so, compare the full IP using strcmp - if(strcmp(clients[i].ip, client) == 0) + if(strcmp(getstr(clients[i].ippos), client) == 0) { clients[i].count++; return i; @@ -197,7 +198,7 @@ int findClientID(const char *client) // If we did not return until here, then this client is definitely new // Store ID - int clientID = counters.clients; + int clientID = counters->clients; // Check struct size memory_check(CLIENTS); @@ -210,15 +211,18 @@ int findClientID(const char *client) // Initialize blocked count to zero clients[clientID].blockedcount = 0; // Store client IP - no need to check for NULL here as it doesn't harm - clients[clientID].ip = strdup(client); + clients[clientID].ippos = addstr(client); // Initialize client hostname // Due to the nature of us being the resolver, // the actual resolving of the host name has // to be done separately to be non-blocking clients[clientID].new = true; - clients[clientID].name = NULL; + clients[clientID].namepos = 0; // Increase counter by one - counters.clients++; + counters->clients++; + + // Create new overTime client data + newOverTimeClient(); return clientID; } @@ -242,7 +246,7 @@ char *getDomainString(int queryID) if(queries[queryID].privacylevel < PRIVACY_HIDE_DOMAINS) { validate_access("domains", queries[queryID].domainID, true, __LINE__, __FUNCTION__, __FILE__); - return domains[queries[queryID].domainID].domain; + return getstr(domains[queries[queryID].domainID].domainpos); } else return HIDDEN_DOMAIN; @@ -255,7 +259,20 @@ char *getClientIPString(int queryID) if(queries[queryID].privacylevel < PRIVACY_HIDE_DOMAINS_CLIENTS) { validate_access("clients", queries[queryID].clientID, true, __LINE__, __FUNCTION__, __FILE__); - return clients[queries[queryID].clientID].ip; + return getstr(clients[queries[queryID].clientID].ippos); + } + else + return HIDDEN_CLIENT; +} + +// Privacy-level sensitive subroutine that returns the client host name +// only when appropriate for the requested query +char *getClientNameString(int queryID) +{ + if(queries[queryID].privacylevel < PRIVACY_HIDE_DOMAINS_CLIENTS) + { + validate_access("clients", queries[queryID].clientID, true, __LINE__, __FUNCTION__, __FILE__); + return getstr(clients[queries[queryID].clientID].namepos); } else return HIDDEN_CLIENT; diff --git a/dnsmasq_interface.c b/dnsmasq_interface.c index 5d4c9501c..e5f539386 100644 --- a/dnsmasq_interface.c +++ b/dnsmasq_interface.c @@ -12,6 +12,7 @@ #undef __USE_XOPEN #include "FTL.h" #include "dnsmasq_interface.h" +#include "shmem.h" void print_flags(unsigned int flags); void save_reply_type(unsigned int flags, int queryID, struct timeval response); @@ -31,7 +32,7 @@ void FTL_new_query(unsigned int flags, char *name, struct all_addr *addr, char * return; // Create new query in data structure - enable_thread_lock(); + lock_shm(); // Get timestamp int querytimestamp, overTimetimestamp; @@ -61,7 +62,7 @@ void FTL_new_query(unsigned int flags, char *name, struct all_addr *addr, char * { // Return early to avoid accessing querytypedata out of bounds if(debug) logg("Notice: Skipping unknown query type: %s (%i)", types, id); - disable_thread_lock(); + unlock_shm(); return; } @@ -69,13 +70,13 @@ void FTL_new_query(unsigned int flags, char *name, struct all_addr *addr, char * if(!config.analyze_AAAA && querytype == TYPE_AAAA) { if(debug) logg("Not analyzing AAAA query"); - disable_thread_lock(); + unlock_shm(); return; } // Ensure we have enough space in the queries struct memory_check(QUERIES); - int queryID = counters.queries; + int queryID = counters->queries; // Convert domain to lower case char *domain = strdup(name); @@ -86,7 +87,7 @@ void FTL_new_query(unsigned int flags, char *name, struct all_addr *addr, char * { // free memory already allocated here free(domain); - disable_thread_lock(); + unlock_shm(); return; } @@ -105,7 +106,7 @@ void FTL_new_query(unsigned int flags, char *name, struct all_addr *addr, char * { free(domain); free(client); - disable_thread_lock(); + unlock_shm(); return; } @@ -117,7 +118,7 @@ void FTL_new_query(unsigned int flags, char *name, struct all_addr *addr, char * int timeidx = findOverTimeID(overTimetimestamp); validate_access("overTime", timeidx, true, __LINE__, __FUNCTION__, __FILE__); overTime[timeidx].querytypedata[querytype-1]++; - counters.querytype[querytype-1]++; + counters->querytype[querytype-1]++; // Skip rest of the analysis if this query is not of type A or AAAA // but user wants to see only A and AAAA queries (pre-v4.1 behavior) @@ -128,7 +129,7 @@ void FTL_new_query(unsigned int flags, char *name, struct all_addr *addr, char * free(domain); free(domainbuffer); free(client); - disable_thread_lock(); + unlock_shm(); return; } @@ -147,7 +148,8 @@ void FTL_new_query(unsigned int flags, char *name, struct all_addr *addr, char * queries[queryID].domainID = domainID; queries[queryID].clientID = clientID; queries[queryID].timeidx = timeidx; - queries[queryID].db = false; + // Initialize database rowID with zero, will be set when the query is stored in the long-term DB + queries[queryID].db = 0; queries[queryID].id = id; queries[queryID].complete = false; queries[queryID].response = converttimeval(request); @@ -165,18 +167,17 @@ void FTL_new_query(unsigned int flags, char *name, struct all_addr *addr, char * queries[queryID].privacylevel = config.privacylevel; // Increase DNS queries counter - counters.queries++; + counters->queries++; // Count this query as unknown as long as no reply has // been found and analyzed - counters.unknown++; + counters->unknown++; // Update overTime data validate_access("overTime", timeidx, true, __LINE__, __FUNCTION__, __FILE__); overTime[timeidx].total++; // Update overTime data structure with the new client - validate_access_oTcl(timeidx, clientID, __LINE__, __FUNCTION__, __FILE__); - overTime[timeidx].clientdata[clientID]++; + overTimeClientData[clientID][timeidx]++; // Try blocking regex if configured validate_access("domains", domainID, false, __LINE__, __FUNCTION__, __FILE__); @@ -212,7 +213,7 @@ void FTL_new_query(unsigned int flags, char *name, struct all_addr *addr, char * free(domainbuffer); // Release thread lock - disable_thread_lock(); + unlock_shm(); } static int findQueryID(int id) @@ -226,11 +227,11 @@ static int findQueryID(int id) // MAX(0, a) is used to return 0 in case a is negative (negative array indices are harmful) // Validate access only once for the maximum index (all lower will work) - validate_access("queries", counters.queries-1, false, __LINE__, __FUNCTION__, __FILE__); - int until = MAX(0, counters.queries-MAXITER); + validate_access("queries", counters->queries-1, false, __LINE__, __FUNCTION__, __FILE__); + int until = MAX(0, counters->queries-MAXITER); int i; // Check UUIDs of queries - for(i = counters.queries-1; i >= until; i--) + for(i = counters->queries-1; i >= until; i--) if(queries[i].id == id) return i; @@ -245,7 +246,7 @@ void FTL_forwarded(unsigned int flags, char *name, struct all_addr *addr, int id return; // Save that this query got forwarded to an upstream server - enable_thread_lock(); + lock_shm(); // Get forward destination IP address char dest[ADDRSTRLEN]; @@ -265,7 +266,7 @@ void FTL_forwarded(unsigned int flags, char *name, struct all_addr *addr, int id // This may happen e.g. if the original query was a PTR query or "pi.hole" // as we ignore them altogether free(forward); - disable_thread_lock(); + unlock_shm(); return; } @@ -278,7 +279,7 @@ void FTL_forwarded(unsigned int flags, char *name, struct all_addr *addr, int id if(queries[i].complete && queries[i].status != QUERY_CACHE) { free(forward); - disable_thread_lock(); + unlock_shm(); return; } @@ -308,7 +309,7 @@ void FTL_forwarded(unsigned int flags, char *name, struct all_addr *addr, int id // This code section acknowledges this by removing one entry from // the cached counters as we will re-brand this query as having been // forwarded in the following. - counters.cached--; + counters->cached--; // Also correct overTime data overTime[j].cached--; @@ -323,7 +324,7 @@ void FTL_forwarded(unsigned int flags, char *name, struct all_addr *addr, int id { // Normal forwarded query (status is set below) // Query is no longer unknown - counters.unknown--; + counters->unknown--; // Hereby, this query is now fully determined queries[i].complete = true; } @@ -338,11 +339,11 @@ void FTL_forwarded(unsigned int flags, char *name, struct all_addr *addr, int id overTime[j].forwarded++; // Update counter for forwarded queries - counters.forwardedqueries++; + counters->forwardedqueries++; // Release allocated memory free(forward); - disable_thread_lock(); + unlock_shm(); } void FTL_dnsmasq_reload(void) @@ -353,7 +354,7 @@ void FTL_dnsmasq_reload(void) // Called when dnsmasq re-reads its config and hosts files // Reset number of blocked domains - counters.gravity = 0; + counters->gravity = 0; // Inspect 01-pihole.conf to see if Pi-hole blocking is enabled, // i.e. if /etc/pihole/gravity.list is sourced as addn-hosts file @@ -378,7 +379,7 @@ void FTL_reply(unsigned short flags, char *name, struct all_addr *addr, int id) return; // Interpret hosts files that have been read by dnsmasq - enable_thread_lock(); + lock_shm(); // Determine returned result if available char dest[ADDRSTRLEN]; dest[0] = '\0'; if(addr) @@ -412,14 +413,14 @@ void FTL_reply(unsigned short flags, char *name, struct all_addr *addr, int id) { // This may happen e.g. if the original query was "pi.hole" if(debug) logg("FTL_reply(): Query %i has not been found", id); - disable_thread_lock(); + unlock_shm(); return; } if(queries[i].reply != REPLY_UNKNOWN) { // Nothing to be done here - disable_thread_lock(); + unlock_shm(); return; } @@ -427,7 +428,7 @@ void FTL_reply(unsigned short flags, char *name, struct all_addr *addr, int id) { // Answered from local configuration, might be a wildcard or user-provided // This query is no longer unknown - counters.unknown--; + counters->unknown--; // Get time index int querytimestamp, overTimetimestamp; @@ -440,7 +441,7 @@ void FTL_reply(unsigned short flags, char *name, struct all_addr *addr, int id) strcmp(answer, "::") == 0) { // Answered from user-defined blocking rules (dnsmasq config files) - counters.blocked++; + counters->blocked++; overTime[timeidx].blocked++; validate_access("domains", queries[i].domainID, true, __LINE__, __FUNCTION__, __FILE__); @@ -454,7 +455,7 @@ void FTL_reply(unsigned short flags, char *name, struct all_addr *addr, int id) else { // Answered from a custom (user provided) cache file - counters.cached++; + counters->cached++; overTime[timeidx].cached++; queries[i].status = QUERY_CACHE; @@ -471,7 +472,7 @@ void FTL_reply(unsigned short flags, char *name, struct all_addr *addr, int id) int domainID = queries[i].domainID; validate_access("domains", domainID, true, __LINE__, __FUNCTION__, __FILE__); - if(strcmp(domains[domainID].domain, name) == 0) + if(strcmp(getstr(domains[domainID].domainpos), name) == 0) { // Save reply type and update individual reply counters save_reply_type(flags, i, response); @@ -499,7 +500,7 @@ void FTL_reply(unsigned short flags, char *name, struct all_addr *addr, int id) print_flags(flags); } - disable_thread_lock(); + unlock_shm(); } static void detect_blocked_IP(unsigned short flags, char* answer, int queryID) @@ -560,14 +561,14 @@ static void query_externally_blocked(int i) // Correct counters if necessary ... if(queries[i].status == QUERY_FORWARDED) { - counters.forwardedqueries--; + counters->forwardedqueries--; overTime[queries[i].timeidx].forwarded--; validate_access("forwarded", queries[i].forwardID, true, __LINE__, __FUNCTION__, __FILE__); forwarded[queries[i].forwardID].count--; } // ... but as blocked - counters.blocked++; + counters->blocked++; overTime[queries[i].timeidx].blocked++; validate_access("domains", queries[i].domainID, true, __LINE__, __FUNCTION__, __FILE__); domains[queries[i].domainID].blockedcount++; @@ -584,7 +585,7 @@ void FTL_cache(unsigned int flags, char *name, struct all_addr *addr, char *arg, return; // Save that this query got answered from cache - enable_thread_lock(); + lock_shm(); char dest[ADDRSTRLEN]; dest[0] = '\0'; if(addr) { @@ -600,7 +601,7 @@ void FTL_cache(unsigned int flags, char *name, struct all_addr *addr, char *arg, { // free memory already allocated here free(domain); - disable_thread_lock(); + unlock_shm(); return; } free(domain); @@ -657,14 +658,14 @@ void FTL_cache(unsigned int flags, char *name, struct all_addr *addr, char *arg, { // This may happen e.g. if the original query was a PTR query or "pi.hole" // as we ignore them altogether - disable_thread_lock(); + unlock_shm(); return; } if(!queries[i].complete) { // This query is no longer unknown - counters.unknown--; + counters->unknown--; // Get time index int querytimestamp, overTimetimestamp; @@ -696,13 +697,13 @@ void FTL_cache(unsigned int flags, char *name, struct all_addr *addr, char *arg, case QUERY_GRAVITY: // gravity.list case QUERY_BLACKLIST: // black.list case QUERY_WILDCARD: // regex blocked - counters.blocked++; + counters->blocked++; overTime[timeidx].blocked++; domains[domainID].blockedcount++; clients[clientID].blockedcount++; break; case QUERY_CACHE: // cached from one of the lists - counters.cached++; + counters->cached++; overTime[timeidx].cached++; break; case QUERY_EXTERNAL_BLOCKED: @@ -723,7 +724,7 @@ void FTL_cache(unsigned int flags, char *name, struct all_addr *addr, char *arg, logg("*************************** unknown CACHE reply (2) ***************************"); print_flags(flags); } - disable_thread_lock(); + unlock_shm(); } void FTL_dnssec(int status, int id) @@ -733,13 +734,13 @@ void FTL_dnssec(int status, int id) return; // Process DNSSEC result for a domain - enable_thread_lock(); + lock_shm(); // Search for corresponding query identified by ID int i = findQueryID(id); if(i < 0) { // This may happen e.g. if the original query was an unhandled query type - disable_thread_lock(); + unlock_shm(); return; } @@ -748,7 +749,7 @@ void FTL_dnssec(int status, int id) { int domainID = queries[i].domainID; validate_access("domains", domainID, true, __LINE__, __FUNCTION__, __FILE__); - logg("**** got DNSSEC details for %s: %i (ID %i)", domains[domainID].domain, status, id); + logg("**** got DNSSEC details for %s: %i (ID %i)", getstr(domains[domainID].domainpos), status, id); } // Iterate through possible values @@ -759,7 +760,7 @@ void FTL_dnssec(int status, int id) else queries[i].dnssec = DNSSEC_BOGUS; - disable_thread_lock(); + unlock_shm(); } void FTL_header_ADbit(unsigned char header4, int id) @@ -768,12 +769,12 @@ void FTL_header_ADbit(unsigned char header4, int id) if(config.privacylevel >= PRIVACY_NOSTATS) return; - enable_thread_lock(); + lock_shm(); // Check if AD bit is set in DNS header if(!(header4 & 0x20)) { // AD bit not set - disable_thread_lock(); + unlock_shm(); return; } @@ -782,14 +783,14 @@ void FTL_header_ADbit(unsigned char header4, int id) if(i < 0) { // This may happen e.g. if the original query was an unhandled query type - disable_thread_lock(); + unlock_shm(); return; } // Store AD bit in query data queries[i].AD = true; - disable_thread_lock(); + unlock_shm(); } void print_flags(unsigned int flags) @@ -815,26 +816,26 @@ void save_reply_type(unsigned int flags, int queryID, struct timeval response) { // NXDOMAIN queries[queryID].reply = REPLY_NXDOMAIN; - counters.reply_NXDOMAIN++; + counters->reply_NXDOMAIN++; } else { // NODATA(-IPv6) queries[queryID].reply = REPLY_NODATA; - counters.reply_NODATA++; + counters->reply_NODATA++; } } else if(flags & F_CNAME) { // queries[queryID].reply = REPLY_CNAME; - counters.reply_CNAME++; + counters->reply_CNAME++; } else if(flags & F_REVERSE) { // reserve lookup queries[queryID].reply = REPLY_DOMAIN; - counters.reply_domain++; + counters->reply_domain++; } else if(flags & F_RRNAME) { @@ -845,7 +846,7 @@ void save_reply_type(unsigned int flags, int queryID, struct timeval response) { // Valid IP queries[queryID].reply = REPLY_IP; - counters.reply_IP++; + counters->reply_IP++; } // Save response time (relative time) @@ -955,7 +956,7 @@ void FTL_forwarding_failed(struct server *server) return; // Save that this query got forwarded to an upstream server - enable_thread_lock(); + lock_shm(); char dest[ADDRSTRLEN]; if(server->addr.sa.sa_family == AF_INET) inet_ntop(AF_INET, &server->addr.in.sin_addr, dest, ADDRSTRLEN); @@ -972,7 +973,7 @@ void FTL_forwarding_failed(struct server *server) forwarded[forwardID].failed++; free(forward); - disable_thread_lock(); + unlock_shm(); return; } @@ -1169,6 +1170,6 @@ int FTL_listsfile(char* filename, unsigned int index, FILE *f, int cache_size, s } logg("%s: parsed %i domains (took %.1f ms)", filename, added, timer_elapsed_msec(LISTS_TIMER)); - counters.gravity += added; + counters->gravity += added; return name_count; } diff --git a/gc.c b/gc.c index 37f102b2f..bc93aa588 100644 --- a/gc.c +++ b/gc.c @@ -9,6 +9,8 @@ * Please see LICENSE file for your rights under this license. */ #include "FTL.h" +#include "shmem.h" + bool doGC = false; time_t lastGCrun = 0; @@ -30,7 +32,7 @@ void *GC_thread(void *val) // Lock FTL's data structure, since it is likely that it will be changed here // Requests should not be processed/answered when data is about to change - enable_thread_lock(); + lock_shm(); // Get minimum time stamp to keep time_t mintime = time(NULL) - config.maxlogage; @@ -42,7 +44,7 @@ void *GC_thread(void *val) if(debug) logg("GC starting, mintime: %u %s", mintime, ctime(&mintime)); // Process all queries - for(i=0; i < counters.queries; i++) + for(i=0; i < counters->queries; i++) { validate_access("queries", i, true, __LINE__, __FUNCTION__, __FILE__); // Test if this query is too new @@ -51,7 +53,7 @@ void *GC_thread(void *val) // Adjust total counters and total over time data - // We cannot edit counters.queries directly as it is used + // We cannot edit counters->queries directly as it is used // as max ID for the queries[] struct int timeidx = queries[i].timeidx; validate_access("overTime", timeidx, true, __LINE__, __FUNCTION__, __FILE__); @@ -63,8 +65,7 @@ void *GC_thread(void *val) clients[clientID].count--; // Adjust corresponding overTime counters - validate_access_oTcl(timeidx, clientID, __LINE__, __FUNCTION__, __FILE__); - overTime[timeidx].clientdata[clientID]--; + overTimeClientData[clientID][timeidx]--; // Adjust domain counter (no overTime information) int domainID = queries[i].domainID; @@ -76,25 +77,25 @@ void *GC_thread(void *val) { case QUERY_UNKNOWN: // Unknown (?) - counters.unknown--; + counters->unknown--; break; case QUERY_FORWARDED: // Forwarded to an upstream DNS server - counters.forwardedqueries--; + counters->forwardedqueries--; overTime[timeidx].forwarded--; validate_access("forwarded", queries[i].forwardID, true, __LINE__, __FUNCTION__, __FILE__); forwarded[queries[i].forwardID].count--; break; case QUERY_CACHE: // Answered from local cache _or_ local config - counters.cached--; + counters->cached--; overTime[timeidx].cached--; break; case QUERY_GRAVITY: // Blocked by Pi-hole's blocking lists (fall through) case QUERY_BLACKLIST: // Exact blocked (fall through) case QUERY_WILDCARD: // Regex blocked (fall through) case QUERY_EXTERNAL_BLOCKED: // Blocked by upstream provider (fall through) - counters.blocked--; + counters->blocked--; overTime[timeidx].blocked--; domains[domainID].blockedcount--; clients[clientID].blockedcount--; @@ -108,23 +109,23 @@ void *GC_thread(void *val) switch(queries[i].reply) { case REPLY_NODATA: // NODATA(-IPv6) - counters.reply_NODATA--; + counters->reply_NODATA--; break; case REPLY_NXDOMAIN: // NXDOMAIN - counters.reply_NXDOMAIN--; + counters->reply_NXDOMAIN--; break; case REPLY_CNAME: // - counters.reply_CNAME--; + counters->reply_CNAME--; break; case REPLY_IP: // valid IP - counters.reply_IP--; + counters->reply_IP--; break; case REPLY_DOMAIN: // reverse lookup - counters.reply_domain--; + counters->reply_domain--; break; default: // Incomplete query or TXT, do nothing @@ -134,7 +135,7 @@ void *GC_thread(void *val) // Update type counters if(queries[i].type >= TYPE_A && queries[i].type < TYPE_MAX) { - counters.querytype[queries[i].type-1]--; + counters->querytype[queries[i].type-1]--; validate_access("overTime", queries[i].timeidx, true, __LINE__, __FUNCTION__, __FILE__); overTime[queries[i].timeidx].querytypedata[queries[i].type-1]--; } @@ -149,18 +150,18 @@ void *GC_thread(void *val) // Example: (I = now invalid, X = still valid queries, F = free space) // Before: IIIIIIXXXXFF // After: XXXXFFFFFFFF - memmove(&queries[0], &queries[removed], (counters.queries - removed)*sizeof(*queries)); + memmove(&queries[0], &queries[removed], (counters->queries - removed)*sizeof(*queries)); // Update queries counter - counters.queries -= removed; + counters->queries -= removed; // Zero out remaining memory (marked as "F" in the above example) - memset(&queries[counters.queries], 0, (counters.queries_MAX - counters.queries)*sizeof(*queries)); + memset(&queries[counters->queries], 0, (counters->queries_MAX - counters->queries)*sizeof(*queries)); if(debug) logg("Notice: GC removed %i queries (took %.2f ms)", removed, timer_elapsed_msec(GC_TIMER)); // Release thread lock - disable_thread_lock(); + unlock_shm(); // After storing data in the database for the next time, // we should scan for old entries, which will then be deleted diff --git a/log.c b/log.c index 2a0eaa6b5..41c0b8105 100644 --- a/log.c +++ b/log.c @@ -127,14 +127,14 @@ void logg_struct_resize(const char* str, int to, int step) void log_counter_info(void) { - logg(" -> Total DNS queries: %i", counters.queries); - logg(" -> Cached DNS queries: %i", counters.cached); - logg(" -> Forwarded DNS queries: %i", counters.forwardedqueries); - logg(" -> Exactly blocked DNS queries: %i", counters.blocked); - logg(" -> Unknown DNS queries: %i", counters.unknown); - logg(" -> Unique domains: %i", counters.domains); - logg(" -> Unique clients: %i", counters.clients); - logg(" -> Known forward destinations: %i", counters.forwarded); + logg(" -> Total DNS queries: %i", counters->queries); + logg(" -> Cached DNS queries: %i", counters->cached); + logg(" -> Forwarded DNS queries: %i", counters->forwardedqueries); + logg(" -> Exactly blocked DNS queries: %i", counters->blocked); + logg(" -> Unknown DNS queries: %i", counters->unknown); + logg(" -> Unique domains: %i", counters->domains); + logg(" -> Unique clients: %i", counters->clients); + logg(" -> Known forward destinations: %i", counters->forwarded); } void log_FTL_version(void) diff --git a/main.c b/main.c index 963e3cfa6..a785a650f 100644 --- a/main.c +++ b/main.c @@ -35,7 +35,13 @@ int main (int argc, char* argv[]) timer_start(EXIT_TIMER); logg("########## FTL started! ##########"); log_FTL_version(); - init_thread_lock(); + + // Initialize shared memory + if(!init_shmem()) + { + logg("Initialization of shared memory failed."); + return EXIT_FAILURE; + } // pihole-FTL should really be run as user "pihole" to not mess up with file permissions // print warning otherwise @@ -87,6 +93,9 @@ int main (int argc, char* argv[]) // Invalidate blocking regex if compiled free_regex(); + // Remove shared memory objects + destroy_shmem(); + //Remove PID file removepid(); logg("########## FTL terminated after %.1f ms! ##########", timer_elapsed_msec(EXIT_TIMER)); diff --git a/memory.c b/memory.c index b66ce2286..a959a022d 100644 --- a/memory.c +++ b/memory.c @@ -9,6 +9,7 @@ * Please see LICENSE file for your rights under this license. */ #include "FTL.h" +#include "shmem.h" FTLFileNamesStruct FTLfiles = { // Default path for config file (regular installations) @@ -32,27 +33,26 @@ logFileNamesStruct files = { }; // Fixed size structs -countersStruct counters = { 0 }; +countersStruct *counters = NULL; ConfigStruct config; // Variable size array structs -queriesDataStruct *queries; -forwardedDataStruct *forwarded; -clientsDataStruct *clients; -domainsDataStruct *domains; -overTimeDataStruct *overTime; +queriesDataStruct *queries = NULL; +forwardedDataStruct *forwarded = NULL; +clientsDataStruct *clients = NULL; +domainsDataStruct *domains = NULL; +overTimeDataStruct *overTime = NULL; +int **overTimeClientData = NULL; void memory_check(int which) { switch(which) { case QUERIES: - if(counters.queries >= counters.queries_MAX) + if(counters->queries >= counters->queries_MAX-1) { - // Have to reallocate memory - counters.queries_MAX += QUERIESALLOCSTEP; - logg_struct_resize("queries",counters.queries_MAX,QUERIESALLOCSTEP); - queries = realloc(queries, counters.queries_MAX*sizeof(queriesDataStruct)); + // Have to reallocate shared memory + queries = enlarge_shmem_struct(QUERIES); if(queries == NULL) { logg("FATAL: Memory allocation failed! Exiting"); @@ -61,12 +61,10 @@ void memory_check(int which) } break; case FORWARDED: - if(counters.forwarded >= counters.forwarded_MAX) + if(counters->forwarded >= counters->forwarded_MAX-1) { - // Have to reallocate memory - counters.forwarded_MAX += FORWARDEDALLOCSTEP; - logg_struct_resize("forwarded",counters.forwarded_MAX,FORWARDEDALLOCSTEP); - forwarded = realloc(forwarded, counters.forwarded_MAX*sizeof(forwardedDataStruct)); + // Have to reallocate shared memory + forwarded = enlarge_shmem_struct(FORWARDED); if(forwarded == NULL) { logg("FATAL: Memory allocation failed! Exiting"); @@ -75,12 +73,10 @@ void memory_check(int which) } break; case CLIENTS: - if(counters.clients >= counters.clients_MAX) + if(counters->clients >= counters->clients_MAX-1) { - // Have to reallocate memory - counters.clients_MAX += CLIENTSALLOCSTEP; - logg_struct_resize("clients",counters.clients_MAX,CLIENTSALLOCSTEP); - clients = realloc(clients, counters.clients_MAX*sizeof(clientsDataStruct)); + // Have to reallocate shared memory + clients = enlarge_shmem_struct(CLIENTS); if(clients == NULL) { logg("FATAL: Memory allocation failed! Exiting"); @@ -89,12 +85,10 @@ void memory_check(int which) } break; case DOMAINS: - if(counters.domains >= counters.domains_MAX) + if(counters->domains >= counters->domains_MAX-1) { - // Have to reallocate memory - counters.domains_MAX += DOMAINSALLOCSTEP; - logg_struct_resize("domains",counters.domains_MAX,DOMAINSALLOCSTEP); - domains = realloc(domains, counters.domains_MAX*sizeof(domainsDataStruct)); + // Have to reallocate shared memory + domains = enlarge_shmem_struct(DOMAINS); if(domains == NULL) { logg("FATAL: Memory allocation failed! Exiting"); @@ -103,12 +97,10 @@ void memory_check(int which) } break; case OVERTIME: - if(counters.overTime >= counters.overTime_MAX) + if(counters->overTime >= counters->overTime_MAX-1) { - // Have to reallocate memory - counters.overTime_MAX += OVERTIMEALLOCSTEP; - logg_struct_resize("overTime",counters.overTime_MAX,OVERTIMEALLOCSTEP); - overTime = realloc(overTime, counters.overTime_MAX*sizeof(overTimeDataStruct)); + // Have to reallocate shared memory + overTime = enlarge_shmem_struct(OVERTIME); if(overTime == NULL) { logg("FATAL: Memory allocation failed! Exiting"); @@ -127,11 +119,11 @@ void memory_check(int which) void validate_access(const char * name, int pos, bool testmagic, int line, const char * function, const char * file) { int limit = 0; - if(name[0] == 'c') limit = counters.clients_MAX; - else if(name[0] == 'd') limit = counters.domains_MAX; - else if(name[0] == 'q') limit = counters.queries_MAX; - else if(name[0] == 'o') limit = counters.overTime_MAX; - else if(name[0] == 'f') limit = counters.forwarded_MAX; + if(name[0] == 'c') limit = counters->clients_MAX; + else if(name[0] == 'd') limit = counters->domains_MAX; + else if(name[0] == 'q') limit = counters->queries_MAX; + else if(name[0] == 'o') limit = counters->overTime_MAX; + else if(name[0] == 'f') limit = counters->forwarded_MAX; else { logg("Validator error (range)"); killed = 1; } if(pos >= limit || pos < 0) @@ -157,38 +149,6 @@ void validate_access(const char * name, int pos, bool testmagic, int line, const } } -void validate_access_oTcl(int timeidx, int clientID, int line, const char * function, const char * file) -{ - if(clientID < 0) - { - logg("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"); - logg("FATAL ERROR: Trying to access overTime.clientdata[%i]", clientID); - logg(" found in %s() (%s:%i)", function, file, line); - } - // Determine if there is enough space for saving the current - // clientID in the overTime data structure, allocate space otherwise - if(overTime[timeidx].clientnum <= clientID) - { - // Reallocate more space for clientdata - overTime[timeidx].clientdata = realloc(overTime[timeidx].clientdata, (clientID+1)*sizeof(*overTime[timeidx].clientdata)); - // Initialize new data fields with zeroes - int i; - for(i = overTime[timeidx].clientnum; i <= clientID; i++) - { - overTime[timeidx].clientdata[i] = 0; - } - // Update counter - overTime[timeidx].clientnum = clientID + 1; - } - int limit = overTime[timeidx].clientnum; - if(clientID >= limit) - { - logg("!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!!"); - logg("FATAL ERROR: Trying to access overTime.clientdata[%i], but maximum is %i", clientID, limit); - logg(" found in %s() (%s:%i)", function, file, line); - } -} - // The special memory handling routines have to be the last ones in this source file // as we restore the original definition of the strdup, free, calloc, and realloc // functions in here, i.e. if anything extra would come below these lines, it would diff --git a/regex.c b/regex.c index 2b6224ebb..5b62385ee 100644 --- a/regex.c +++ b/regex.c @@ -153,9 +153,9 @@ void free_regex(void) // Must reevaluate regex filters after having reread the regex filter // We reset all regex status to unknown to have them being reevaluated - if(counters.domains > 0) - validate_access("domains", counters.domains-1, false, __LINE__, __FUNCTION__, __FILE__); - for(int i=0; i < counters.domains; i++) + if(counters->domains > 0) + validate_access("domains", counters->domains-1, false, __LINE__, __FUNCTION__, __FILE__); + for(int i=0; i < counters->domains; i++) { domains[i].regexmatch = REGEX_UNKNOWN; } diff --git a/request.c b/request.c index d402559ae..82a791a21 100644 --- a/request.c +++ b/request.c @@ -10,6 +10,7 @@ #include "FTL.h" #include "api.h" +#include "shmem.h" bool command(char *client_message, const char* cmd) { return strstr(client_message, cmd) != NULL; @@ -118,7 +119,7 @@ void process_request(char *client_message, int *sock) logg("Received API request to re-resolve host names"); // Need to release the thread lock already here to allow // the resolver to process the incoming PTR requests - disable_thread_lock(); + unlock_shm(); // onlynew=false -> reresolve all host names resolveClients(false); resolveForwardDestinations(false); diff --git a/resolve.c b/resolve.c index 84a87155f..df86c690a 100644 --- a/resolve.c +++ b/resolve.c @@ -9,6 +9,7 @@ * Please see LICENSE file for your rights under this license. */ #include "FTL.h" +#include "shmem.h" // Resolve new client and upstream server host names // once every minute @@ -73,56 +74,46 @@ char *resolveHostname(const char *addr) // Resolve client host names void resolveClients(bool onlynew) { - int i; - for(i = 0; i < counters.clients; i++) + int clientID; + for(clientID = 0; clientID < counters->clients; clientID++) { // Memory validation - validate_access("clients", i, true, __LINE__, __FUNCTION__, __FILE__); + validate_access("clients", clientID, true, __LINE__, __FUNCTION__, __FILE__); // If onlynew flag is set, we will only resolve new clients // If not, we will try to re-resolve all known clients - if(onlynew && !clients[i].new) + if(onlynew && !clients[clientID].new) continue; - char *hostname = resolveHostname(clients[i].ip); + lock_shm(); - enable_thread_lock(); + clients[clientID].namepos = addstr(resolveHostname(getstr(clients[clientID].ippos))); + clients[clientID].new = false; - if(clients[i].name != NULL) - free(clients[i].name); - - clients[i].name = hostname; - clients[i].new = false; - - disable_thread_lock(); + unlock_shm(); } } // Resolve upstream destination host names void resolveForwardDestinations(bool onlynew) { - int i; - for(i = 0; i < counters.forwarded; i++) + int forwardID; + for(forwardID = 0; forwardID < counters->forwarded; forwardID++) { // Memory validation - validate_access("forwarded", i, true, __LINE__, __FUNCTION__, __FILE__); + validate_access("forwarded", forwardID, true, __LINE__, __FUNCTION__, __FILE__); // If onlynew flag is set, we will only resolve new upstream destinations // If not, we will try to re-resolve all known upstream destinations - if(onlynew && !forwarded[i].new) + if(onlynew && !forwarded[forwardID].new) continue; - char *hostname = resolveHostname(forwarded[i].ip); - - enable_thread_lock(); - - if(forwarded[i].name != NULL) - free(forwarded[i].name); + lock_shm(); - forwarded[i].name = hostname; - forwarded[i].new = false; + forwarded[forwardID].namepos = addstr(resolveHostname(getstr(forwarded[forwardID].ippos))); + forwarded[forwardID].new = false; - disable_thread_lock(); + unlock_shm(); } } diff --git a/routines.h b/routines.h index 2d27adb46..b4ad082b4 100644 --- a/routines.h +++ b/routines.h @@ -34,6 +34,7 @@ bool isValidIPv4(const char *addr); bool isValidIPv6(const char *addr); char *getDomainString(int queryID); char *getClientIPString(int queryID); +char *getClientNameString(int queryID); void close_telnet_socket(void); void close_unix_socket(void); @@ -67,11 +68,6 @@ void parse_args(int argc, char* argv[]); char* find_equals(const char* s); -// threads.c -void enable_thread_lock(void); -void disable_thread_lock(void); -void init_thread_lock(void); - // config.c void getLogFilePath(void); void read_FTLconf(void); @@ -95,7 +91,6 @@ void *FTLcalloc(size_t nmemb, size_t size, const char *file, const char *functio void *FTLrealloc(void *ptr_in, size_t size, const char *file, const char *function, int line); void FTLfree(void *ptr, const char* file, const char *function, int line); void validate_access(const char * name, int pos, bool testmagic, int line, const char * function, const char * file); -void validate_access_oTcl(int timeidx, int clientID, int line, const char * function, const char * file); int main_dnsmasq(int argc, char **argv); @@ -112,3 +107,22 @@ bool match_regex(char *input); void free_regex(void); void read_regex_from_file(void); bool in_whitelist(char *domain); + +// shmem.c +bool init_shmem(void); +void destroy_shmem(void); +unsigned long long addstr(const char *str); +char *getstr(unsigned long long pos); +void *enlarge_shmem_struct(char type); + +/** + * Create a new overTime client shared memory block. + * This also updates `overTimeClientData`. + */ +void newOverTimeClient(); + +/** + * Add a new overTime slot to each overTime client shared memory block. + * This also updates `overTimeClientData`. + */ +void addOverTimeClientSlot(); diff --git a/shmem.c b/shmem.c new file mode 100644 index 000000000..33572af75 --- /dev/null +++ b/shmem.c @@ -0,0 +1,446 @@ +/* Pi-hole: A black hole for Internet advertisements +* (c) 2018 Pi-hole, LLC (https://pi-hole.net) +* Network-wide ad blocking via your own hardware. +* +* FTL Engine +* Shared memory subroutines +* +* This file is copyright under the latest version of the EUPL. +* Please see LICENSE file for your rights under this license. */ + +#include "FTL.h" +#include "shmem.h" + +/// The name of the shared memory. Use this when connecting to the shared memory. +#define SHARED_LOCK_NAME "/FTL-lock" +#define SHARED_STRINGS_NAME "/FTL-strings" +#define SHARED_COUNTERS_NAME "/FTL-counters" +#define SHARED_DOMAINS_NAME "/FTL-domains" +#define SHARED_CLIENTS_NAME "/FTL-clients" +#define SHARED_QUERIES_NAME "/FTL-queries" +#define SHARED_FORWARDED_NAME "/FTL-forwarded" +#define SHARED_OVERTIME_NAME "/FTL-overTime" +#define SHARED_OVERTIMECLIENT_PREFIX "/FTL-client-" + +/// The pointer in shared memory to the shared string buffer +static SharedMemory shm_lock = { 0 }; +static SharedMemory shm_strings = { 0 }; +static SharedMemory shm_counters = { 0 }; +static SharedMemory shm_domains = { 0 }; +static SharedMemory shm_clients = { 0 }; +static SharedMemory shm_queries = { 0 }; +static SharedMemory shm_forwarded = { 0 }; +static SharedMemory shm_overTime = { 0 }; + +static SharedMemory *shm_overTimeClients = NULL; +static int overTimeClientCount = 0; + +typedef struct { + pthread_mutex_t lock; + bool waitingForLock; +} ShmLock; +static ShmLock *shmLock = NULL; + +static int pagesize; +static unsigned int next_pos = 0; + +unsigned long long addstr(const char *str) +{ + if(str == NULL) + { + logg("WARN: Called addstr() with NULL pointer"); + return 0; + } + + // Get string length + size_t len = strlen(str); + + if(debug) logg("Adding \"%s\" (len %i) to buffer. next_pos is %i", str, len, next_pos); + + // Reserve additional memory if necessary + size_t required_size = next_pos + len + 1; + // Need to cast to long long because size_t calculations cannot be negative + if((long long)required_size-(long long)shm_strings.size > 0 && + !realloc_shm(&shm_strings, shm_strings.size + pagesize)) + return 0; + + // Copy the C string pointed by str into the shared string buffer + strncpy(&((char*)shm_strings.ptr)[next_pos], str, len); + ((char*)shm_strings.ptr)[next_pos + len] = '\0'; + + // Increment string length counter + next_pos += len+1; + + // Return start of stored string + return (next_pos - (len + 1)); +} + +char *getstr(unsigned long long pos) +{ + return &((char*)shm_strings.ptr)[pos]; +} + +static char *clientShmName(int id) { + int name_len = 1 + snprintf(NULL, 0, "%s%d", SHARED_OVERTIMECLIENT_PREFIX, id); + char *name = malloc(sizeof(char) * name_len); + snprintf(name, (size_t) name_len, "%s%d", SHARED_OVERTIMECLIENT_PREFIX, id); + + return name; +} + +void newOverTimeClient() { + // Get the name of the new shared memory. + // This will be used in the struct, so it should not be immediately freed. + char *name = clientShmName(overTimeClientCount); + + // Create the shared memory with enough space for the current overTime slots + shm_unlink(name); + SharedMemory shm = create_shm(name, (counters->overTime/pagesize + 1)*pagesize*sizeof(int)); + if(shm.ptr == NULL) { + free(shm.name); + logg("Failed to initialize new overTime client %d", overTimeClientCount); + return; + } + + // Make space for the new shared memory + shm_overTimeClients = realloc(shm_overTimeClients, sizeof(SharedMemory) * (overTimeClientCount + 1)); + overTimeClientCount++; + shm_overTimeClients[overTimeClientCount-1] = shm; + + // Add to overTimeClientData + overTimeClientData = realloc(overTimeClientData, sizeof(int*) * (overTimeClientCount)); + overTimeClientData[overTimeClientCount-1] = shm.ptr; +} + +void addOverTimeClientSlot() { + // For each client slot, add pagesize overTime slots + for(int i = 0; i < overTimeClientCount; i++) + { + // Only increase the size of the shm object if needed + // shm_overTimeClients[i].size stores the size of the memory in bytes whereas + // counters->overTime (effectively) stores the number of slots each overTime + // client should have. Hence, counters->overTime needs to be multiplied by + // sizeof(int) to get the actual requested memory size + if(shm_overTimeClients[i].size > (size_t)counters->overTime*sizeof(int)) + continue; + + // Reallocate with one more slot + realloc_shm(&shm_overTimeClients[i], (counters->overTime + pagesize)*sizeof(int)); + + // Update overTimeClientData + overTimeClientData[i] = shm_overTimeClients[i].ptr; + } +} + +/// Create a mutex for shared memory +pthread_mutex_t create_mutex() { + pthread_mutexattr_t lock_attr = {}; + pthread_mutex_t lock = {}; + + // Initialize the lock attributes + pthread_mutexattr_init(&lock_attr); + + // Allow the lock to be used by other processes + pthread_mutexattr_setpshared(&lock_attr, PTHREAD_PROCESS_SHARED); + + // Make the lock robust against process death + pthread_mutexattr_setrobust(&lock_attr, PTHREAD_MUTEX_ROBUST); + + // Initialize the lock + pthread_mutex_init(&lock, &lock_attr); + + // Destroy the lock attributes since we're done with it + pthread_mutexattr_destroy(&lock_attr); + + return lock; +} + +void lock_shm() { + // Signal that FTL is waiting for a lock + shmLock->waitingForLock = true; + + int result = pthread_mutex_lock(&shmLock->lock); + + // Turn off the waiting for lock signal to notify everyone who was + // deferring to FTL that they can jump in the lock queue. + shmLock->waitingForLock = false; + + if(result == EOWNERDEAD) { + // Try to make the lock consistent if the other process died while + // holding the lock + result = pthread_mutex_consistent(&shmLock->lock); + } + + if(result != 0) + logg("Failed to obtain SHM lock: %s", strerror(result)); +} + +void unlock_shm() { + int result = pthread_mutex_unlock(&shmLock->lock); + + if(result != 0) + logg("Failed to unlock SHM lock: %s", strerror(result)); +} + +bool init_shmem(void) +{ + // Get kernel's page size + pagesize = getpagesize(); + + /****************************** shared memory lock ******************************/ + shm_unlink(SHARED_LOCK_NAME); + // Try to create shared memory object + shm_lock = create_shm(SHARED_LOCK_NAME, sizeof(ShmLock)); + if(shm_lock.ptr == NULL) + return false; + shmLock = (ShmLock*) shm_lock.ptr; + shmLock->lock = create_mutex(); + shmLock->waitingForLock = false; + + /****************************** shared strings buffer ******************************/ + // Try unlinking the shared memory object before creating a new one + // If the object is still existing, e.g., due to a past unclean exit + // of FTL, shm_open() would fail with error "File exists" + shm_unlink(SHARED_STRINGS_NAME); + // Try to create shared memory object + shm_strings = create_shm(SHARED_STRINGS_NAME, pagesize); + if(shm_strings.ptr == NULL) + return false; + + // Initialize shared string object with an empty string at position zero + ((char*)shm_strings.ptr)[0] = '\0'; + next_pos = 1; + + /****************************** shared counters struct ******************************/ + shm_unlink(SHARED_COUNTERS_NAME); + // Try to create shared memory object + shm_counters = create_shm(SHARED_COUNTERS_NAME, sizeof(countersStruct)); + if(shm_counters.ptr == NULL) + return false; + counters = (countersStruct*)shm_counters.ptr; + + /****************************** shared domains struct ******************************/ + shm_unlink(SHARED_DOMAINS_NAME); + // Try to create shared memory object + shm_domains = create_shm(SHARED_DOMAINS_NAME, pagesize*sizeof(domainsDataStruct)); + if(shm_domains.ptr == NULL) + return false; + domains = (domainsDataStruct*)shm_domains.ptr; + counters->domains_MAX = pagesize; + + /****************************** shared clients struct ******************************/ + shm_unlink(SHARED_CLIENTS_NAME); + // Try to create shared memory object + shm_clients = create_shm(SHARED_CLIENTS_NAME, pagesize*sizeof(clientsDataStruct)); + if(shm_clients.ptr == NULL) + return false; + clients = (clientsDataStruct*)shm_clients.ptr; + counters->clients_MAX = pagesize; + + /****************************** shared forwarded struct ******************************/ + shm_unlink(SHARED_FORWARDED_NAME); + // Try to create shared memory object + shm_forwarded = create_shm(SHARED_FORWARDED_NAME, pagesize*sizeof(forwardedDataStruct)); + if(shm_forwarded.ptr == NULL) + return false; + forwarded = (forwardedDataStruct*)shm_forwarded.ptr; + counters->forwarded_MAX = pagesize; + + /****************************** shared queries struct ******************************/ + shm_unlink(SHARED_QUERIES_NAME); + // Try to create shared memory object + shm_queries = create_shm(SHARED_QUERIES_NAME, pagesize*sizeof(queriesDataStruct)); + if(shm_queries.ptr == NULL) + return false; + queries = (queriesDataStruct*)shm_queries.ptr; + counters->queries_MAX = pagesize; + + /****************************** shared overTime struct ******************************/ + shm_unlink(SHARED_OVERTIME_NAME); + // Try to create shared memory object + shm_overTime = create_shm(SHARED_OVERTIME_NAME, pagesize*sizeof(overTimeDataStruct)); + if(shm_overTime.ptr == NULL) + return false; + overTime = (overTimeDataStruct*)shm_overTime.ptr; + counters->overTime_MAX = pagesize; + + return true; +} + +void destroy_shmem(void) +{ + pthread_mutex_destroy(&shmLock->lock); + shmLock = NULL; + + delete_shm(&shm_lock); + delete_shm(&shm_strings); + delete_shm(&shm_counters); + delete_shm(&shm_domains); + delete_shm(&shm_clients); + delete_shm(&shm_queries); + delete_shm(&shm_forwarded); + delete_shm(&shm_overTime); + + for(int i = 0; i < overTimeClientCount; i++) { + delete_shm(&shm_overTimeClients[i]); + free(shm_overTimeClients[i].name); + } +} + +SharedMemory create_shm(char *name, size_t size) +{ + if(debug) logg("Creating shared memory with name \"%s\" and size %zu", name, size); + + SharedMemory sharedMemory = { + .name = name, + .size = size, + .ptr = NULL + }; + + // Create the shared memory file in read/write mode with 600 permissions + int fd = shm_open(sharedMemory.name, O_CREAT | O_EXCL | O_TRUNC | O_RDWR, S_IRUSR | S_IWUSR); + + // Check for `shm_open` error + if(fd == -1) + { + logg("create_shm(): Failed to create_shm shared memory object \"%s\": %s", + name, strerror(errno)); + return sharedMemory; + } + + // Resize shared memory file + int result = ftruncate(fd, size); + + // Check for `ftruncate` error + if(result == -1) + { + logg("create_shm(): ftruncate(%i, %zu): Failed to resize shared memory object \"%s\": %s", + fd, size, sharedMemory.name, strerror(errno)); + return sharedMemory; + } + + // Create shared memory mapping + void *shm = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + + // Check for `mmap` error + if(shm == MAP_FAILED) + { + logg("create_shm(): Failed to map shared memory object \"%s\" (%i): %s", + sharedMemory.name, fd, strerror(errno)); + return sharedMemory; + } + + // Close shared memory object file descriptor as it is no longer + // needed after having called mmap() + close(fd); + + sharedMemory.ptr = shm; + return sharedMemory; +} + +void *enlarge_shmem_struct(char type) +{ + SharedMemory *sharedMemory; + size_t sizeofobj; + int *counter; + + // Select type of struct that should be enlarged + switch(type) + { + case QUERIES: + sharedMemory = &shm_queries; + sizeofobj = sizeof(queriesDataStruct); + counter = &counters->queries_MAX; + break; + case CLIENTS: + sharedMemory = &shm_clients; + sizeofobj = sizeof(clientsDataStruct); + counter = &counters->clients_MAX; + break; + case DOMAINS: + sharedMemory = &shm_domains; + sizeofobj = sizeof(domainsDataStruct); + counter = &counters->domains_MAX; + break; + case FORWARDED: + sharedMemory = &shm_forwarded; + sizeofobj = sizeof(forwardedDataStruct); + counter = &counters->forwarded_MAX; + break; + case OVERTIME: + sharedMemory = &shm_overTime; + sizeofobj = sizeof(overTimeDataStruct); + counter = &counters->overTime_MAX; + break; + default: + logg("Invalid argument in enlarge_shmem_struct(): %i", type); + return 0; + } + + // Reallocate enough space for 4096 instances of requested object + realloc_shm(sharedMemory, sharedMemory->size + pagesize*sizeofobj); + + // Add allocated memory to corresponding counter + *counter += pagesize; + + return sharedMemory->ptr; +} + +bool realloc_shm(SharedMemory *sharedMemory, size_t size) { + logg("Resizing \"%s\" from %zu to %zu", sharedMemory->name, sharedMemory->size, size); + + int result = munmap(sharedMemory->ptr, sharedMemory->size); + if(result != 0) + logg("realloc_shm(): munmap(%p, %zu) failed: %s", sharedMemory->ptr, sharedMemory->size, strerror(errno)); + + // Open shared memory object + int fd = shm_open(sharedMemory->name, O_RDWR, S_IRUSR | S_IWUSR); + if(fd == -1) + { + logg("realloc_shm(): Failed to open shared memory object \"%s\": %s", + sharedMemory->name, strerror(errno)); + return false; + } + + // Resize shard memory object to requested size + result = ftruncate(fd, size); + if(result == -1) { + logg("realloc_shm(): ftruncate(%i, %zu): Failed to resize \"%s\": %s", + fd, size, sharedMemory->name, strerror(errno)); + return false; + } + +// void *new_ptr = mremap(sharedMemory->ptr, sharedMemory->size, size, MREMAP_MAYMOVE); + void *new_ptr = mmap(NULL, size, PROT_READ | PROT_WRITE, MAP_SHARED, fd, 0); + if(new_ptr == MAP_FAILED) + { + logg("realloc_shm(): mremap(%p, %zu, %zu, MREMAP_MAYMOVE): Failed to reallocate \"%s\" (%i): %s", + sharedMemory->ptr, sharedMemory->size, size, sharedMemory->name, fd, + strerror(errno)); + return false; + } + + // Close shared memory object file descriptor as it is no longer + // needed after having called mmap() + close(fd); + + sharedMemory->ptr = new_ptr; + sharedMemory->size = size; + + return true; +} + +void delete_shm(SharedMemory *sharedMemory) +{ + // Unmap shared memory + int ret; + ret = munmap(sharedMemory->ptr, sharedMemory->size); + if(ret != 0) + logg("delete_shm(): munmap(%p, %zu) failed: %s", sharedMemory->ptr, sharedMemory->size, strerror(errno)); + + // Now you can no longer `shm_open` the memory, + // and once all others unlink, it will be destroyed. + ret = shm_unlink(sharedMemory->name); + if(ret != 0) + logg("delete_shm(): munmap(%s) failed: %s", sharedMemory->name, strerror(errno)); +} diff --git a/shmem.h b/shmem.h new file mode 100644 index 000000000..2bca284aa --- /dev/null +++ b/shmem.h @@ -0,0 +1,49 @@ +/* Pi-hole: A black hole for Internet advertisements +* (c) 2018 Pi-hole, LLC (https://pi-hole.net) +* Network-wide ad blocking via your own hardware. +* +* FTL Engine +* Shared memory header +* +* This file is copyright under the latest version of the EUPL. +* Please see LICENSE file for your rights under this license. */ + +#ifndef SHARED_MEMORY_SERVER_H +#define SHARED_MEMORY_SERVER_H +#include /* For shm_* functions */ +#include /* For mode constants */ +#include /* For O_* constants */ +#include + +typedef struct { + char *name; + size_t size; + void *ptr; +} SharedMemory; + +/// Create shared memory +/// +/// \param name the name of the shared memory +/// \param size the size to allocate +/// \return a structure with a pointer to the mounted shared memory. The pointer will be NULL if it failed +SharedMemory create_shm(char *name, size_t size); + +/// Reallocate shared memory +/// +/// \param sharedMemory the shared memory struct +/// \param size the new size +/// \return if reallocation was successful +bool realloc_shm(SharedMemory *sharedMemory, size_t size); + +/// Disconnect from shared memory. If there are no other connections to shared memory, it will be deleted. +/// +/// \param sharedMemory the shared memory struct +void delete_shm(SharedMemory *sharedMemory); + +/// Block until a lock can be obtained +void lock_shm(); + +/// Unlock the lock. Only call this if there is an active lock. +void unlock_shm(); + +#endif //SHARED_MEMORY_SERVER_H diff --git a/socket.c b/socket.c index 619441924..5214c5bdd 100644 --- a/socket.c +++ b/socket.c @@ -10,6 +10,7 @@ #include "FTL.h" #include "api.h" +#include "shmem.h" // The backlog argument defines the maximum length // to which the queue of pending connections for @@ -321,13 +322,13 @@ void *telnet_connection_handler_thread(void *socket_desc) // Lock FTL data structure, since it is likely that it will be changed here // Requests should not be processed/answered when data is about to change - enable_thread_lock(); + lock_shm(); process_request(message, &sock); free(message); // Release thread lock - disable_thread_lock(); + unlock_shm(); if(sock == 0) { @@ -378,13 +379,13 @@ void *socket_connection_handler_thread(void *socket_desc) // Lock FTL data structure, since it is likely that it will be changed here // Requests should not be processed/answered when data is about to change - enable_thread_lock(); + lock_shm(); process_request(message, &sock); free(message); // Release thread lock - disable_thread_lock(); + unlock_shm(); if(sock == 0) { diff --git a/threads.c b/threads.c deleted file mode 100644 index dd4f1d757..000000000 --- a/threads.c +++ /dev/null @@ -1,46 +0,0 @@ -/* Pi-hole: A black hole for Internet advertisements -* (c) 2017 Pi-hole, LLC (https://pi-hole.net) -* Network-wide ad blocking via your own hardware. -* -* FTL Engine -* Thread routines -* -* This file is copyright under the latest version of the EUPL. -* Please see LICENSE file for your rights under this license. */ - -#include "FTL.h" - -// Logic of the locks: -// Any of the various threads (logparser, GC, client threads) is accessing FTL's data structure. Hence, they should -// never run at the same time since the data can change half-way through, leading to unspecified behavior. -// threadlock: The threadlock ensures that only one thread can be active at any given time -pthread_mutex_t threadlock; - -void enable_thread_lock(void) -{ - // logg("At thread lock: waiting"); - int ret = pthread_mutex_lock(&threadlock); - // logg("At thread lock: passed"); - - if(ret != 0) - logg("Thread lock error: %i",ret); -} - -void disable_thread_lock(void) -{ - int ret = pthread_mutex_unlock(&threadlock); - // logg("At thread lock: unlocked"); - - if(ret != 0) - logg("Thread unlock error: %i",ret); -} - -void init_thread_lock(void) -{ - if (pthread_mutex_init(&threadlock, NULL) != 0) - { - logg("FATAL: Thread mutex init failed\n"); - // Return failure - exit(EXIT_FAILURE); - } -}