Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add Pagination Support for Retrieving All CVEs API #357

Closed
wants to merge 2 commits into from
Closed
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions db/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ type DB interface {
UpsertFetchMeta(*models.FetchMeta) error

Get(string) (*models.CveDetail, error)
GetAll(limit, pageNum int) ([]models.CveDetail, int, error)
GetMulti([]string) (map[string]models.CveDetail, error)
GetCveIDsByCpeURI(string) ([]string, []string, []string, error)
GetByCpeURI(string) ([]models.CveDetail, error)
Expand Down
74 changes: 74 additions & 0 deletions db/rdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -308,6 +308,80 @@ func (r *RDBDriver) Get(cveID string) (*models.CveDetail, error) {
return &detail, nil
}

func (r *RDBDriver) GetAll(limit, pageNum int) ([]models.CveDetail, int, error) {
var details []models.CveDetail

// Calculate offset based on page number and limit
offset := (pageNum - 1) * limit

var totalCount int64
if err := r.conn.Model(&models.Nvd{}).Count(&totalCount).Error; err != nil {
return nil, 0, xerrors.Errorf("Failed to count Nvd records: %w", err)
}

// Get all Nvd records
var nvds []models.Nvd
if err := r.conn.Limit(limit).Offset(offset).
Preload("Descriptions").
Preload("Cvss2").
Preload("Cvss3").
Preload("Cwes").
Preload("Cpes").
Preload("References").
Preload("Certs").
Find(&nvds).Error; err != nil {
return nil, 0, xerrors.Errorf("Failed to get Nvd records: %w", err)
}

// Map Nvd records to CveDetail
for _, nvd := range nvds {
detail := models.CveDetail{
CveID: nvd.CveID,
Nvds: []models.Nvd{nvd},
}

// Populate EnvCpes for each Cpe in Nvd
for i := range nvd.Cpes {
if err := r.conn.
Where(&models.NvdEnvCpe{NvdCpeID: uint(nvd.Cpes[i].ID)}).
Find(&nvd.Cpes[i].EnvCpes).Error; err != nil {
return nil, 0, xerrors.Errorf("Failed to fill Nvd EnvCpes for Cpe ID: %d, err: %w", uint(nvd.Cpes[i].ID), err)
}
}

// Get corresponding JVN records
var jvns []models.Jvn
if err := r.conn.
Where(&models.Jvn{CveID: nvd.CveID}).
Preload("Cvss2").
Preload("Cvss3").
Preload("Cpes").
Preload("References").
Preload("Certs").
Find(&jvns).Error; err != nil {
return nil, 0, xerrors.Errorf("Failed to get Jvn records for CVE ID: %s, err: %w", nvd.CveID, err)
}
detail.Jvns = jvns

// Get corresponding Fortinet records
var fortinets []models.Fortinet
if err := r.conn.
Where(&models.Fortinet{CveID: nvd.CveID}).
Preload("Cvss3").
Preload("Cwes").
Preload("Cpes").
Preload("References").
Find(&fortinets).Error; err != nil {
return nil, 0, xerrors.Errorf("Failed to get Fortinet records for CVE ID: %s, err: %w", nvd.CveID, err)
}
detail.Fortinets = fortinets

details = append(details, detail)
}

return details, int(totalCount), nil
}

func (r *RDBDriver) getCveIDsByPartVendorProduct(uri string) ([]string, error) {
specified, err := naming.UnbindURI(uri)
if err != nil {
Expand Down
41 changes: 41 additions & 0 deletions db/redis.go
Original file line number Diff line number Diff line change
Expand Up @@ -191,6 +191,47 @@ func (r *RedisDriver) Get(cveID string) (*models.CveDetail, error) {
return &detail, nil
}

// GetAll retrieves all CVE details from the database
func (r *RedisDriver) GetAll(limit, pageNum int) ([]models.CveDetail, int, error) {
ctx := context.Background()

// Get all keys for CVE details
keys, err := r.conn.Keys(ctx, fmt.Sprintf(cveKeyFormat, "*")).Result()
Copy link
Collaborator

@MaineK00n MaineK00n Jan 9, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if err != nil {
return nil, 0, xerrors.Errorf("Failed to get keys from Redis: %w", err)
}

// Total count is the length of keys
totalCount := len(keys)

// Calculate offset based on page number and limit
offset := (pageNum - 1) * limit
end := offset + limit
if end > totalCount {
end = totalCount
}

// Select the keys for the current page
paginatedKeys := keys[offset:end]

// Extract CVE IDs from keys and fetch details
var cveDetails []models.CveDetail
for _, key := range paginatedKeys {
cveID := strings.TrimPrefix(key, fmt.Sprintf(cveKeyFormat, ""))
results, err := r.conn.HGetAll(ctx, key).Result()
if err != nil {
return nil, 0, xerrors.Errorf("Failed to HGetAll for key %s: %w", key, err)
}
detail, err := convertToCveDetail(cveID, results)
if err != nil {
return nil, 0, xerrors.Errorf("Failed to convertToCveDetail for CVE ID %s: %w", cveID, err)
}
cveDetails = append(cveDetails, detail)
}

return cveDetails, totalCount, nil
}

func convertToCveDetail(cveID string, results map[string]string) (models.CveDetail, error) {
detail := models.CveDetail{
CveID: cveID,
Expand Down
22 changes: 22 additions & 0 deletions server/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import (
"os"
"path/filepath"
"sort"
"strconv"

"github.com/labstack/echo/v4"
"github.com/labstack/echo/v4/middleware"
Expand Down Expand Up @@ -38,6 +39,7 @@ func Start(logToFile bool, logDir string, driver db.DB) error {
// Routes
e.GET("/health", health())
e.GET("/cves/:id", getCve(driver))
e.GET("/cves", getAllCves(driver))
e.POST("/cpes", getCveByCpeName(driver))
e.POST("/cpes/ids", getCveIDsByCpeName(driver))

Expand Down Expand Up @@ -67,6 +69,26 @@ func getCve(driver db.DB) echo.HandlerFunc {
}
}

func getAllCves(driver db.DB) echo.HandlerFunc {
return func(c echo.Context) error {
cveDetails, err := driver.GetAll()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
cveDetails, err := driver.GetAll()

limit, _ := strconv.Atoi(c.QueryParam("limit"))
pageNum, _ := strconv.Atoi(c.QueryParam("page_num"))
Comment on lines +75 to +76
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Please handle the errors.


// Fetch paginated data
cveDetails, totalCount, err := driver.GetAll(limit, pageNum)
if err != nil {
log.Errorf("Failed to get all CVEs: %s", err)
return c.JSON(http.StatusInternalServerError, map[string]string{"error": "failed to retrieve CVEs"})
}

return c.JSON(http.StatusOK, echo.Map{
"total": totalCount,
"data": cveDetails,
})
}
}

type cpeName struct {
Name string `form:"name"`
}
Expand Down