diff --git a/db/db.go b/db/db.go index f8a51f5..c7ed2ac 100644 --- a/db/db.go +++ b/db/db.go @@ -28,6 +28,7 @@ type DB interface { UpsertFetchMeta(*models.FetchMeta) error Get(string) (*models.CveDetail, error) + GetAll(limit, pageNum int) ([]models.CveDetail, int, error) GetMulti([]string) (map[string]models.CveDetail, error) GetCveIDsByCpeURI(string) ([]string, []string, []string, error) GetByCpeURI(string) ([]models.CveDetail, error) diff --git a/db/rdb.go b/db/rdb.go index 5e7e779..05c1625 100644 --- a/db/rdb.go +++ b/db/rdb.go @@ -308,6 +308,75 @@ func (r *RDBDriver) Get(cveID string) (*models.CveDetail, error) { return &detail, nil } +func (r *RDBDriver) GetAll(limit, pageNum int) ([]models.CveDetail, int, error) { + var details []models.CveDetail + + // Aggregate unique CVE IDs from all sources + var allCveIDs []string + r.conn.Table("nvds").Select("DISTINCT cve_id").Pluck("cve_id", &allCveIDs) + r.conn.Table("jvns").Select("DISTINCT cve_id").Where("cve_id NOT IN (?)", allCveIDs).Pluck("cve_id", &allCveIDs) + r.conn.Table("fortinets").Select("DISTINCT cve_id").Where("cve_id NOT IN (?)", allCveIDs).Pluck("cve_id", &allCveIDs) + + totalCount := len(allCveIDs) + offset := (pageNum - 1) * limit + end := offset + limit + if end > totalCount { + end = totalCount + } + + // Apply pagination on the list of CVE IDs + paginatedCveIDs := allCveIDs[offset:end] + + // Fetch details for each CVE ID + for _, cveID := range paginatedCveIDs { + detail := models.CveDetail{CveID: cveID} + + // Fetch NVD data + var nvds []models.Nvd + if err := r.conn.Where("cve_id = ?", cveID). + Preload("Descriptions"). + Preload("Cvss2"). + Preload("Cvss3"). + Preload("Cwes"). + Preload("Cpes"). + Preload("References"). + Preload("Certs"). + Find(&nvds).Error; err != nil { + return nil, 0, xerrors.Errorf("Failed to get Nvd records for CVE ID: %s, err: %w", cveID, err) + } + detail.Nvds = nvds + + // Fetch JVN data + var jvns []models.Jvn + if err := r.conn.Where("cve_id = ?", cveID). + Preload("Cvss2"). + Preload("Cvss3"). + Preload("Cpes"). + Preload("References"). + Preload("Certs"). + Find(&jvns).Error; err != nil { + return nil, 0, xerrors.Errorf("Failed to get Jvn records for CVE ID: %s, err: %w", cveID, err) + } + detail.Jvns = jvns + + // Fetch Fortinet data + var fortinets []models.Fortinet + if err := r.conn.Where("cve_id = ?", cveID). + Preload("Cvss3"). + Preload("Cwes"). + Preload("Cpes"). + Preload("References"). + Find(&fortinets).Error; err != nil { + return nil, 0, xerrors.Errorf("Failed to get Fortinet records for CVE ID: %s, err: %w", cveID, err) + } + detail.Fortinets = fortinets + + details = append(details, detail) + } + + return details, totalCount, nil +} + func (r *RDBDriver) getCveIDsByPartVendorProduct(uri string) ([]string, error) { specified, err := naming.UnbindURI(uri) if err != nil { diff --git a/db/redis.go b/db/redis.go index 70c31dc..6016ae1 100644 --- a/db/redis.go +++ b/db/redis.go @@ -191,6 +191,47 @@ func (r *RedisDriver) Get(cveID string) (*models.CveDetail, error) { return &detail, nil } +// GetAll retrieves all CVE details from the database +func (r *RedisDriver) GetAll(limit, pageNum int) ([]models.CveDetail, int, error) { + ctx := context.Background() + + // Get all keys for CVE details + keys, err := r.conn.Keys(ctx, fmt.Sprintf(cveKeyFormat, "*")).Result() + if err != nil { + return nil, 0, xerrors.Errorf("Failed to get keys from Redis: %w", err) + } + + // Total count is the length of keys + totalCount := len(keys) + + // Calculate offset based on page number and limit + offset := (pageNum - 1) * limit + end := offset + limit + if end > totalCount { + end = totalCount + } + + // Select the keys for the current page + paginatedKeys := keys[offset:end] + + // Extract CVE IDs from keys and fetch details + var cveDetails []models.CveDetail + for _, key := range paginatedKeys { + cveID := strings.TrimPrefix(key, fmt.Sprintf(cveKeyFormat, "")) + results, err := r.conn.HGetAll(ctx, key).Result() + if err != nil { + return nil, 0, xerrors.Errorf("Failed to HGetAll for key %s: %w", key, err) + } + detail, err := convertToCveDetail(cveID, results) + if err != nil { + return nil, 0, xerrors.Errorf("Failed to convertToCveDetail for CVE ID %s: %w", cveID, err) + } + cveDetails = append(cveDetails, detail) + } + + return cveDetails, totalCount, nil +} + func convertToCveDetail(cveID string, results map[string]string) (models.CveDetail, error) { detail := models.CveDetail{ CveID: cveID, diff --git a/server/server.go b/server/server.go index 2339c52..351a06a 100644 --- a/server/server.go +++ b/server/server.go @@ -6,6 +6,7 @@ import ( "os" "path/filepath" "sort" + "strconv" "github.com/labstack/echo/v4" "github.com/labstack/echo/v4/middleware" @@ -38,6 +39,7 @@ func Start(logToFile bool, logDir string, driver db.DB) error { // Routes e.GET("/health", health()) e.GET("/cves/:id", getCve(driver)) + e.GET("/cves", getAllCves(driver)) e.POST("/cpes", getCveByCpeName(driver)) e.POST("/cpes/ids", getCveIDsByCpeName(driver)) @@ -67,6 +69,26 @@ func getCve(driver db.DB) echo.HandlerFunc { } } +func getAllCves(driver db.DB) echo.HandlerFunc { + return func(c echo.Context) error { + cveDetails, err := driver.GetAll() + limit, _ := strconv.Atoi(c.QueryParam("limit")) + pageNum, _ := strconv.Atoi(c.QueryParam("page_num")) + + // Fetch paginated data + cveDetails, totalCount, err := driver.GetAll(limit, pageNum) + if err != nil { + log.Errorf("Failed to get all CVEs: %s", err) + return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to retrieve CVEs"}) + } + + return c.JSON(http.StatusOK, echo.Map{ + "total": totalCount, + "data": cveDetails, + }) + } +} + type cpeName struct { Name string `form:"name"` }