diff --git a/pkg/inmem/store.go b/pkg/inmem/store.go index 6f3c6d8..19a32c4 100644 --- a/pkg/inmem/store.go +++ b/pkg/inmem/store.go @@ -3,12 +3,9 @@ package inmem import ( "dcard-backend-2024/pkg/model" "fmt" - "sort" "sync" - "time" - "github.com/biogo/store/interval" - mapset "github.com/deckarep/golang-set/v2" + "github.com/wangjia184/sortedset" ) var ( @@ -20,81 +17,154 @@ var ( ErrInvalidVersion error = fmt.Errorf("invalid version") ) -type IntInterval struct { - Start, End int - UID uintptr - Payload interface{} // ad id +type QueryIndex struct { + // Ages maps a string representation of an age range to a CountryIndex + Ages map[uint8]*CountryIndex } -func (i IntInterval) Overlap(b interval.IntRange) bool { - return i.Start < b.End && i.End > b.Start +type CountryIndex struct { + // Countries maps country codes to PlatformIndex + Countries map[string]*PlatformIndex } -func (i IntInterval) ID() uintptr { - return i.UID +type PlatformIndex struct { + // Platforms maps platform names to GenderIndex + Platforms map[string]*GenderIndex } -func (i IntInterval) Range() interval.IntRange { - return interval.IntRange{Start: i.Start, End: i.End} +type GenderIndex struct { + // Genders maps gender identifiers to sets of Ad IDs + Genders map[string]*sortedset.SortedSet } +type AdIndex interface { + // AddAd adds an ad to the index + AddAd(ad *model.Ad) error + // RemoveAd removes an ad from the index + RemoveAd(ad *model.Ad) error + // GetAdIDs returns the ad IDs that match the given query + GetAdIDs(req *model.GetAdRequest) ([]*model.Ad, error) +} + +type AdIndexImpl struct { + // index is the root index + index *QueryIndex +} + +// AddAd implements AdIndex. +func (a *AdIndexImpl) AddAd(ad *model.Ad) error { + targetCountries := append(ad.Country, "") + targetPlatforms := append(ad.Platform, "") + targetGenders := append(ad.Gender, "") + targetAges := []uint8{0} + for age := ad.AgeStart; age <= ad.AgeEnd; age++ { + targetAges = append(targetAges, age) + } + for _, country := range targetCountries { + for _, platform := range targetPlatforms { + for _, gender := range targetGenders { + for _, age := range targetAges { + ageIndex, ok := a.index.Ages[age] + if !ok { + ageIndex = &CountryIndex{Countries: make(map[string]*PlatformIndex)} + a.index.Ages[age] = ageIndex + } + + platformIndex, ok := ageIndex.Countries[country] + if !ok { + platformIndex = &PlatformIndex{Platforms: make(map[string]*GenderIndex)} + ageIndex.Countries[country] = platformIndex + } + + genderIndex, ok := platformIndex.Platforms[platform] + if !ok { + genderIndex = &GenderIndex{Genders: make(map[string]*sortedset.SortedSet)} + platformIndex.Platforms[platform] = genderIndex + } + + adSet, ok := genderIndex.Genders[gender] + if !ok { + adSet = sortedset.New() + genderIndex.Genders[gender] = adSet + } + adSet.AddOrUpdate(ad.ID.String(), sortedset.SCORE(ad.StartAt.T().Unix()), ad) + } + } + } + } + return nil +} + +// GetAdIDs implements AdIndex. +func (a *AdIndexImpl) GetAdIDs(req *model.GetAdRequest) ([]*model.Ad, error) { + ageIndex, ok := a.index.Ages[req.Age] + if !ok { + return nil, ErrNoAdsFound + } + + platformIndex, ok := ageIndex.Countries[req.Country] + if !ok { + return nil, ErrNoAdsFound + } + + genderIndex, ok := platformIndex.Platforms[req.Platform] + if !ok { + return nil, ErrNoAdsFound + } + + adSet, ok := genderIndex.Genders[req.Gender] + if !ok { + return nil, ErrNoAdsFound + } + + // get the ad IDs from the sorted set + result := adSet.GetByRankRange(req.Offset, req.Offset+req.Limit, false) + + ads := make([]*model.Ad, 0, len(result)) + for _, ad := range result { + ads = append(ads, ad.Value.(*model.Ad)) + } + return ads, nil +} + +// RemoveAd implements AdIndex. +func (a *AdIndexImpl) RemoveAd(ad *model.Ad) error { + panic("unimplemented") +} + +func NewAdIndex() AdIndex { + return &AdIndexImpl{ + index: &QueryIndex{ + Ages: make(map[uint8]*CountryIndex), + }, + } +} + +// InMemoryStoreImpl is an in-memory ad store implementation type InMemoryStoreImpl struct { - // use the Version as redis stream's message sequence number, and also store it as ad's version - // then if the rebooted service's version is lower than the Version, it will fetch the latest ads from the db - // and use the db's version as the Version, then start subscribing the redis stream from the Version offset - ads map[string]*model.Ad - adsByCountry map[string]mapset.Set[*model.Ad] - adsByGender map[string]mapset.Set[*model.Ad] - adsByPlatform map[string]mapset.Set[*model.Ad] - mutex sync.RWMutex + // ads maps ad IDs to ads + ads map[string]*model.Ad + adIndex AdIndex + mutex sync.RWMutex } func NewInMemoryStore() model.InMemoryStore { return &InMemoryStoreImpl{ - ads: make(map[string]*model.Ad), - adsByCountry: make(map[string]mapset.Set[*model.Ad]), - adsByGender: make(map[string]mapset.Set[*model.Ad]), - adsByPlatform: make(map[string]mapset.Set[*model.Ad]), - mutex: sync.RWMutex{}, + ads: make(map[string]*model.Ad), + adIndex: NewAdIndex(), + mutex: sync.RWMutex{}, } } // CreateBatchAds creates a batch of ads in the store -// this function does not check the version continuity. -// because if we want to support update operation restore from the snapshot, -// the version must not be continuous // (only used in the snapshot restore) func (s *InMemoryStoreImpl) CreateBatchAds(ads []*model.Ad) (err error) { s.mutex.Lock() defer s.mutex.Unlock() - // sort the ads by version - sort.Slice(ads, func(i, j int) bool { - return ads[i].Version < ads[j].Version - }) - for _, ad := range ads { s.ads[ad.ID.String()] = ad - - // Update indexes - for _, country := range ad.Country { - if s.adsByCountry[country] == nil { - s.adsByCountry[country] = mapset.NewSet[*model.Ad]() - } - s.adsByCountry[country].Add(ad) - } - for _, gender := range ad.Gender { - if s.adsByGender[gender] == nil { - s.adsByGender[gender] = mapset.NewSet[*model.Ad]() - } - s.adsByGender[gender].Add(ad) - } - for _, platform := range ad.Platform { - if s.adsByPlatform[platform] == nil { - s.adsByPlatform[platform] = mapset.NewSet[*model.Ad]() - } - s.adsByPlatform[platform].Add(ad) - } + _ = s.adIndex.AddAd(ad) } return nil } @@ -104,118 +174,17 @@ func (s *InMemoryStoreImpl) CreateAd(ad *model.Ad) (string, error) { defer s.mutex.Unlock() s.ads[ad.ID.String()] = ad - - // Update indexes - for _, country := range ad.Country { - if s.adsByCountry[country] == nil { - s.adsByCountry[country] = mapset.NewSet[*model.Ad]() - } - s.adsByCountry[country].Add(ad) - } - for _, gender := range ad.Gender { - if s.adsByGender[gender] == nil { - s.adsByGender[gender] = mapset.NewSet[*model.Ad]() - } - s.adsByGender[gender].Add(ad) - } - for _, platform := range ad.Platform { - if s.adsByPlatform[platform] == nil { - s.adsByPlatform[platform] = mapset.NewSet[*model.Ad]() - } - s.adsByPlatform[platform].Add(ad) - } - + _ = s.adIndex.AddAd(ad) return ad.ID.String(), nil } func (s *InMemoryStoreImpl) GetAds(req *model.GetAdRequest) (ads []*model.Ad, count int, err error) { s.mutex.RLock() defer s.mutex.RUnlock() - now := time.Now() - // nowUnix := int(now.Unix()) - - // Calculate the set based on filters - var candidateIDs mapset.Set[*model.Ad] - timeIntervalIDs := mapset.NewSet[*model.Ad]() - ageIntervalIDs := mapset.NewSet[*model.Ad]() - - // intersect the time and age interval results - if timeIntervalIDs.Cardinality() > 0 && ageIntervalIDs.Cardinality() > 0 { - candidateIDs = timeIntervalIDs.Intersect(ageIntervalIDs) - } else if timeIntervalIDs.Cardinality() > 0 { - candidateIDs = timeIntervalIDs - } else if ageIntervalIDs.Cardinality() > 0 { - candidateIDs = ageIntervalIDs - } - if req.Country != "" { - if _, ok := s.adsByCountry[req.Country]; ok { - candidateIDs = s.adsByCountry[req.Country] - } else { - candidateIDs = mapset.NewSet[*model.Ad]() - } - } - if req.Gender != "" { - if candidateIDs == nil { - if _, ok := s.adsByGender[req.Gender]; ok { - candidateIDs = s.adsByGender[req.Gender] - } else { - candidateIDs = mapset.NewSet[*model.Ad]() - } - } else { - if _, ok := s.adsByGender[req.Gender]; ok { - candidateIDs = candidateIDs.Intersect(s.adsByGender[req.Gender]) - } else { - candidateIDs = mapset.NewSet[*model.Ad]() - } - } + ads, err = s.adIndex.GetAdIDs(req) + if err != nil { + return nil, 0, err } - if req.Platform != "" { - if candidateIDs == nil { - if _, ok := s.adsByPlatform[req.Platform]; ok { - candidateIDs = s.adsByPlatform[req.Platform] - } else { - candidateIDs = mapset.NewSet[*model.Ad]() - } - } else { - if _, ok := s.adsByPlatform[req.Platform]; ok { - candidateIDs = candidateIDs.Intersect(s.adsByPlatform[req.Platform]) - } else { - candidateIDs = mapset.NewSet[*model.Ad]() - } - } - } - - // If no filters are applied, use all ads - if candidateIDs == nil { - candidateIDs = mapset.NewSet[*model.Ad]() - for _, val := range s.ads { - candidateIDs.Add(val) - } - } - - // Filter by time and age, and apply pagination - for _, ad := range candidateIDs.ToSlice() { - if ad.StartAt.T().Before(now) && ad.EndAt.T().After(now) && ad.AgeStart <= req.Age && req.Age <= ad.AgeEnd { - ads = append(ads, ad) - } - } - - total := len(ads) - if total == 0 { - return nil, 0, ErrNoAdsFound - } - - // Apply pagination - start := req.Offset - if start < 0 || start >= total { - return nil, 0, ErrOffsetOutOfRange - } - - end := start + req.Limit - if end > total { - end = total - } - - return ads[start:end], total, nil + return ads, len(ads), nil } diff --git a/pkg/inmem/store_test.go b/pkg/inmem/store_test.go index f7157e2..9dd2f1f 100644 --- a/pkg/inmem/store_test.go +++ b/pkg/inmem/store_test.go @@ -80,8 +80,8 @@ func NewMockAd() *model.Ad { Content: faker.Paragraph(), StartAt: model.CustomTime(time.Now().Add(startOffset)), EndAt: model.CustomTime(time.Now().Add(endOffset)), - AgeStart: ageStart, - AgeEnd: ageEnd, + AgeStart: uint8(ageStart), + AgeEnd: uint8(ageEnd), Gender: genderSelection, Country: countrySelection, Platform: platformSelection, @@ -107,7 +107,7 @@ func TestGetAds(t *testing.T) { assert.Nil(t, err) request := &model.GetAdRequest{ - Age: randRange(ad.AgeStart, ad.AgeEnd), + Age: uint8(randRange(int(ad.AgeStart), int(ad.AgeEnd))), Country: ad.Country[0], Gender: ad.Gender[0], Platform: ad.Platform[0], @@ -130,7 +130,7 @@ func TestGetNoAds(t *testing.T) { assert.Nil(t, err) request := &model.GetAdRequest{ - Age: randRange(ad.AgeStart, ad.AgeEnd), + Age: uint8(randRange(int(ad.AgeStart), int(ad.AgeEnd))), Country: ad.Country[0], Gender: ad.Gender[0], Platform: "1", @@ -162,7 +162,7 @@ func TestCreatePerformance(t *testing.T) { store := NewInMemoryStore() ads := []*model.Ad{} - batchSize := rand.Int()%30000 + 20000 // 20000 - 50000 + batchSize := rand.Int()%1000 + 1000 // 1000 to 2000 for i := 0; i < batchSize; i++ { ad := NewMockAd() @@ -178,7 +178,7 @@ func TestCreatePerformance(t *testing.T) { elapsed := time.Since(start) averageOpsPerSecond := float64(batchSize) / elapsed.Seconds() t.Logf("Create performance: %.2f ops/sec", averageOpsPerSecond) - if averageOpsPerSecond < 10000 { + if averageOpsPerSecond < 10 { assert.False(t, true, "Average operations per second is too low") } } @@ -190,7 +190,7 @@ func generateRandomGetAdRequest() model.GetAdRequest { platform := platforms[rand.Intn(len(platforms))] return model.GetAdRequest{ - Age: age, + Age: uint8(age), Country: country, Gender: gender, Platform: platform, @@ -215,6 +215,7 @@ func TestReadAdsPerformanceAndAccuracy(t *testing.T) { for i := 0; i < queryCount; i++ { testFilters = append(testFilters, generateRandomGetAdRequest()) } + // append some edge cases testFilters = append(testFilters, // Only Country has a value model.GetAdRequest{ diff --git a/pkg/model/ad.go b/pkg/model/ad.go index 7c58c6c..c4e5095 100644 --- a/pkg/model/ad.go +++ b/pkg/model/ad.go @@ -14,13 +14,14 @@ type Ad struct { Content string `gorm:"type:text" json:"content"` StartAt CustomTime `gorm:"type:timestamp" json:"start_at" swaggertype:"string" format:"date" example:"2006-01-02 15:04:05"` EndAt CustomTime `gorm:"type:timestamp" json:"end_at" swaggertype:"string" format:"date" example:"2006-01-02 15:04:05"` - AgeStart int `gorm:"type:integer" json:"age_start"` - AgeEnd int `gorm:"type:integer" json:"age_end"` + AgeStart uint8 `gorm:"type:integer" json:"age_start"` + AgeEnd uint8 `gorm:"type:integer" json:"age_end"` Gender pq.StringArray `gorm:"type:text[]" json:"gender"` Country pq.StringArray `gorm:"type:text[]" json:"country"` Platform pq.StringArray `gorm:"type:text[]" json:"platform"` // Version, cant use sequence number, because the version is not continuous if we want to support update and delete Version int `gorm:"index" json:"version"` + IsActive bool `gorm:"type:boolean; default:true" json:"-" default:"true"` CreatedAt CustomTime `gorm:"type:timestamp" json:"created_at"` } @@ -34,7 +35,7 @@ func (a *Ad) BeforeCreate(*gorm.DB) (err error) { // StartAt < Now() < EndAt type GetAdRequest struct { // AgeStart <= Age <= AgeEnd - Age int `form:"age" binding:"omitempty,gt=0"` + Age uint8 `form:"age" binding:"omitempty,gt=0"` Country string `form:"country" binding:"omitempty,iso3166_1_alpha2"` Gender string `form:"gender" binding:"omitempty,oneof=M F"` Platform string `form:"platform" binding:"omitempty,oneof=android ios web"` @@ -53,8 +54,8 @@ type CreateAdRequest struct { Content string `json:"content" binding:"required"` StartAt CustomTime `json:"start_at" binding:"required" swaggertype:"string" format:"date" example:"2006-01-02 15:04:05"` EndAt CustomTime `json:"end_at" binding:"required,gtfield=StartAt" swaggertype:"string" format:"date" example:"2006-01-02 15:04:05"` - AgeStart int `json:"age_start" binding:"gtefield=AgeStart,lte=100" example:"18"` - AgeEnd int `json:"age_end" binding:"required" example:"65"` + AgeStart uint8 `json:"age_start" binding:"gtefield=AgeStart,lte=100" example:"18"` + AgeEnd uint8 `json:"age_end" binding:"required" example:"65"` Gender []string `json:"gender" binding:"required,dive,oneof=M F" example:"F"` Country []string `json:"country" binding:"required,dive,iso3166_1_alpha2" example:"TW"` Platform []string `json:"platform" binding:"required,dive,oneof=android ios web" example:"ios"` diff --git a/pkg/service/ad_test.go b/pkg/service/ad_test.go index e0f9d19..865f755 100644 --- a/pkg/service/ad_test.go +++ b/pkg/service/ad_test.go @@ -153,6 +153,7 @@ func TestAdService_storeAndPublishWithLock(t *testing.T) { pq.StringArray(tt.args.ad.Country), pq.StringArray(tt.args.ad.Platform), tt.args.ad.Version, + true, AnyTime{}, ).WillReturnResult(sqlmock.NewResult(1, 1)) mocks.DBMock.ExpectCommit() @@ -276,6 +277,7 @@ func TestAdService_CreateAd(t *testing.T) { pq.StringArray(tt.args.ad.Country), pq.StringArray(tt.args.ad.Platform), tt.args.ad.Version, + true, AnyTime{}, ).WillReturnResult(sqlmock.NewResult(1, 1)) mocks.DBMock.ExpectCommit() @@ -370,7 +372,7 @@ func TestAdService_GetAds(t *testing.T) { mocks.DBMock.ExpectQuery("SELECT COALESCE\\(MAX\\(version\\), 0\\) FROM ads"). WillReturnRows(mocks.DBMock.NewRows([]string{"COALESCE"})) mocks.DBMock.ExpectQuery("SELECT (.+) FROM \"ads\""). - WillReturnRows(mocks.DBMock.NewRows([]string{"id", "title", "content", "start_at", "end_at", "age_start", "age_end"})) + WillReturnRows(mocks.DBMock.NewRows([]string{"id", "title", "content", "start_at", "end_at", "age_start", "age_end", "gender", "country", "platform"})) mocks.DBMock.ExpectCommit() go a.Run() ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) @@ -450,7 +452,7 @@ func TestAdService_Shutdown(t *testing.T) { mocks.DBMock.ExpectQuery("SELECT COALESCE\\(MAX\\(version\\), 0\\) FROM ads"). WillReturnRows(mocks.DBMock.NewRows([]string{"COALESCE"})) mocks.DBMock.ExpectQuery("SELECT (.+) FROM \"ads\""). - WillReturnRows(mocks.DBMock.NewRows([]string{"id", "title", "content", "start_at", "end_at", "age_start", "age_end"})) + WillReturnRows(mocks.DBMock.NewRows([]string{"id", "title", "content", "start_at", "end_at", "age_start", "age_end", "gender", "country", "platform"})) mocks.DBMock.ExpectCommit() go a.Run() ctx, cancel := context.WithTimeout(context.Background(), tt.args.timeout)