Skip to content

Commit

Permalink
Merge pull request #18 from cake4everyone/feat/per-server-birthdays
Browse files Browse the repository at this point in the history
Added per server announcements for a birthday
  • Loading branch information
Kesuaheli authored Jan 1, 2025
2 parents bea388e + e23f17f commit b02090c
Show file tree
Hide file tree
Showing 6 changed files with 151 additions and 50 deletions.
133 changes: 103 additions & 30 deletions modules/birthday/birthdaybase.go
Original file line number Diff line number Diff line change
Expand Up @@ -43,12 +43,14 @@ type birthdayBase struct {
}

type birthdayEntry struct {
ID uint64 `database:"id"`
Day int `database:"day"`
Month int `database:"month"`
Year int `database:"year"`
Visible bool `database:"visible"`
time time.Time
ID uint64 `database:"id"`
Day int `database:"day"`
Month int `database:"month"`
Year int `database:"year"`
Visible bool `database:"visible"`
time time.Time
GuildIDsRaw string `database:"guilds"`
GuildIDs []string
}

// Returns a readable Form of the date
Expand Down Expand Up @@ -108,15 +110,75 @@ func (b birthdayEntry) Age() int {
return b.Next().Year() - b.Year - 1
}

// ParseGuildIDs splits the guild IDs into a slice and stores them in b.GuildIDs.
func (b *birthdayEntry) ParseGuildIDs() {
b.GuildIDs = strings.Split(b.GuildIDsRaw, ",")
}

// IsInGuild returns true if the guildID is in b.GuildIDs.
// If guildID is empty, IsInGuild returns true.
func (b birthdayEntry) IsInGuild(guildID string) bool {
if guildID == "" {
return true
}
return util.ContainsString(b.GuildIDs, guildID)
}

// SetGuild sets the guildID in the birthday entry.
func (b *birthdayEntry) SetGuild(guildID string) {
b.GuildIDsRaw += guildID
b.ParseGuildIDs()
}

// AddGuild adds the guildID to the birthday entry.
func (b *birthdayEntry) AddGuild(guildID string) error {
if util.ContainsString(b.GuildIDs, guildID) {
return nil
} else if len(b.GuildIDs) >= 3 {
return fmt.Errorf("this entry already has %d guilds", len(b.GuildIDs))
}
b.GuildIDsRaw += "," + guildID
b.GuildIDsRaw = strings.Trim(b.GuildIDsRaw, ", ")
b.ParseGuildIDs()
return nil
}

// IsEqual returns true if b and b2 are equal.
//
// That is, if all of the following are true
// 1. They have the same user ID.
// 2. They are on the same date.
// 3. They have the same visibility.
// 4. They have the same guilds in (any order).
func (b birthdayEntry) IsEqual(b2 birthdayEntry) bool {
if b.ID != b2.ID || b.Day != b2.Day || b.Month != b2.Month || b.Year != b2.Year || b.Visible != b2.Visible {
return false
}

// check for same guilds in any order
for _, guildID := range b.GuildIDs {
if !util.ContainsString(b2.GuildIDs, guildID) {
return false
}
}
for _, guildID := range b2.GuildIDs {
if !util.ContainsString(b.GuildIDs, guildID) {
return false
}
}
return true
}

// getBirthday copies all birthday fields into the struct pointed at by b.
//
// If the user from b.ID is not found it returns sql.ErrNoRows.
func (cmd birthdayBase) getBirthday(b *birthdayEntry) (err error) {
row := database.QueryRow("SELECT day,month,year,visible FROM birthdays WHERE id=?", b.ID)
err = row.Scan(&b.Day, &b.Month, &b.Year, &b.Visible)
row := database.QueryRow("SELECT day,month,year,visible,guilds FROM birthdays WHERE id=?", b.ID)
err = row.Scan(&b.Day, &b.Month, &b.Year, &b.Visible, &b.GuildIDsRaw)
if err != nil {
return err
}
b.ParseGuildIDs()
return b.ParseTime()
}

Expand All @@ -127,27 +189,36 @@ func (cmd birthdayBase) hasBirthday(id uint64) (hasBirthday bool, err error) {
}

// setBirthday inserts a new database entry with the values from b.
func (cmd birthdayBase) setBirthday(b birthdayEntry) error {
_, err := database.Exec("INSERT INTO birthdays(id,day,month,year,visible) VALUES(?,?,?,?,?);", b.ID, b.Day, b.Month, b.Year, b.Visible)
func (cmd birthdayBase) setBirthday(b *birthdayEntry) (err error) {
b.SetGuild(cmd.Interaction.GuildID)
_, err = database.Exec("INSERT INTO birthdays(id,day,month,year,visible,guilds) VALUES(?,?,?,?,?);", b.ID, b.Day, b.Month, b.Year, b.Visible, b.GuildIDsRaw)
return err
}

// updateBirthday updates an existing database entry with the values from b.
func (cmd birthdayBase) updateBirthday(b birthdayEntry) (before birthdayEntry, err error) {
err = b.ParseTime()
if err != nil {
return birthdayEntry{}, err
}
func (cmd birthdayBase) updateBirthday(b *birthdayEntry) (before birthdayEntry, err error) {
before.ID = b.ID
if err = cmd.getBirthday(&before); err != nil {
return birthdayEntry{}, fmt.Errorf("trying to get old birthday: %v", err)
}
b.GuildIDsRaw = before.GuildIDsRaw
b.ParseGuildIDs()

err = b.AddGuild(cmd.Interaction.GuildID)
if err != nil {
return birthdayEntry{}, fmt.Errorf("adding guild '%s' to birthday entry: %v", cmd.Interaction.GuildID, err)
}

// early return if nothing changed
if b.IsEqual(before) {
return before, nil
}

var (
updateNames []string
updateVars []any
oldV reflect.Value = reflect.ValueOf(before)
v reflect.Value = reflect.ValueOf(b)
v reflect.Value = reflect.ValueOf(*b)
)
for i := 0; i < v.NumField(); i++ {
var (
Expand All @@ -160,11 +231,11 @@ func (cmd birthdayBase) updateBirthday(b birthdayEntry) (before birthdayEntry, e
continue
}

tag := v.Type().Field(i).Tag.Get("database")
if tag == "" {
continue
}
if f.Interface() != oldF.Interface() {
tag := v.Type().Field(i).Tag.Get("database")
if tag == "" {
continue
}
updateNames = append(updateNames, tag)
updateVars = append(updateVars, f.Interface())
}
Expand Down Expand Up @@ -193,8 +264,8 @@ func (cmd birthdayBase) removeBirthday(id uint64) (birthdayEntry, error) {
return b, err
}

// getBirthdaysMonth return a sorted slice of birthday entries that matches the given month.
func (cmd birthdayBase) getBirthdaysMonth(month int) (birthdays []birthdayEntry, err error) {
// getBirthdaysMonth return a sorted slice of birthday entries that matches the given guildID and month.
func (cmd birthdayBase) getBirthdaysMonth(guildID string, month int) (birthdays []birthdayEntry, err error) {
var numOfEntries int64
err = database.QueryRow("SELECT COUNT(*) FROM birthdays WHERE month=?", month).Scan(&numOfEntries)
if err != nil {
Expand All @@ -206,20 +277,21 @@ func (cmd birthdayBase) getBirthdaysMonth(month int) (birthdays []birthdayEntry,
return birthdays, nil
}

rows, err := database.Query("SELECT id,day,year,visible FROM birthdays WHERE month=?", month)
rows, err := database.Query("SELECT id,day,year,visible,guilds FROM birthdays WHERE month=?", month)
if err != nil {
return birthdays, err
}
defer rows.Close()

for rows.Next() {
b := birthdayEntry{Month: month}
err = rows.Scan(&b.ID, &b.Day, &b.Year, &b.Visible)
err = rows.Scan(&b.ID, &b.Day, &b.Year, &b.Visible, &b.GuildIDsRaw)
if err != nil {
return birthdays, err
}
b.ParseGuildIDs()

if !b.Visible {
if !b.Visible || !b.IsInGuild(guildID) {
continue
}

Expand All @@ -238,8 +310,8 @@ func (cmd birthdayBase) getBirthdaysMonth(month int) (birthdays []birthdayEntry,
return birthdays, nil
}

// getBirthdaysDate return a slice of birthday entries that matches the given date.
func getBirthdaysDate(day int, month int) (birthdays []birthdayEntry, err error) {
// getBirthdaysDate return a slice of birthday entries that matches the given guildID and date.
func getBirthdaysDate(guildID string, day int, month int) (birthdays []birthdayEntry, err error) {
var numOfEntries int64
err = database.QueryRow("SELECT COUNT(*) FROM birthdays WHERE day=? AND month=?", day, month).Scan(&numOfEntries)
if err != nil {
Expand All @@ -251,20 +323,21 @@ func getBirthdaysDate(day int, month int) (birthdays []birthdayEntry, err error)
return birthdays, nil
}

rows, err := database.Query("SELECT id,year,visible FROM birthdays WHERE day=? AND month=?", day, month)
rows, err := database.Query("SELECT id,year,visible,guilds FROM birthdays WHERE day=? AND month=?", day, month)
if err != nil {
return birthdays, err
}
defer rows.Close()

for rows.Next() {
b := birthdayEntry{Day: day, Month: month}
err = rows.Scan(&b.ID, &b.Year, &b.Visible)
err = rows.Scan(&b.ID, &b.Year, &b.Visible, &b.GuildIDsRaw)
if err != nil {
return birthdays, err
}
b.ParseGuildIDs()

if !b.Visible {
if !b.Visible || !b.IsInGuild(guildID) {
continue
}

Expand Down
31 changes: 19 additions & 12 deletions modules/birthday/handleCheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,21 +35,16 @@ func Check(s *discordgo.Session) {
defer rows.Close()

now := time.Now()
birthdays, err := getBirthdaysDate(now.Day(), int(now.Month()))
if err != nil {
log.Printf("Error on getting todays birthdays from database: %v\n", err)
}
e, n := birthdayAnnounceEmbed(s, birthdays)
if n <= 0 {
return
}

for rows.Next() {
err = rows.Scan(&guildID, &channelID)
if err != nil {
log.Printf("Error on scanning birthday channel ID from database %v\n", err)
continue
}
if channelID == 0 {
continue
}

channel, err := s.Channel(fmt.Sprint(channelID))
if err != nil {
Expand All @@ -61,6 +56,15 @@ func Check(s *discordgo.Session) {
return
}

birthdays, err := getBirthdaysDate(fmt.Sprint(guildID), now.Day(), int(now.Month()))
if err != nil {
log.Printf("Error on getting todays birthdays from guild %s from database: %v\n", fmt.Sprint(guildID), err)
}
e, n := birthdayAnnounceEmbed(s, fmt.Sprint(guildID), birthdays)
if n <= 0 {
return
}

// announce
_, err = s.ChannelMessageSendEmbed(channel.ID, e)
if err != nil {
Expand All @@ -71,7 +75,7 @@ func Check(s *discordgo.Session) {

// birthdayAnnounceEmbed returns the embed, that contains all birthdays and 'n' as the number of
// birthdays, which is always len(b)
func birthdayAnnounceEmbed(s *discordgo.Session, b []birthdayEntry) (e *discordgo.MessageEmbed, n int) {
func birthdayAnnounceEmbed(s *discordgo.Session, guildID string, b []birthdayEntry) (e *discordgo.MessageEmbed, n int) {
var title, fValue string

switch len(b) {
Expand All @@ -85,14 +89,17 @@ func birthdayAnnounceEmbed(s *discordgo.Session, b []birthdayEntry) (e *discordg
}

for _, b := range b {
mention := fmt.Sprintf("<@%d>", b.ID)
member := util.IsGuildMember(s, guildID, fmt.Sprint(b.ID))
if member == nil {
continue
}

if b.Year == 0 {
fValue += fmt.Sprintf("%s\n", mention)
fValue += fmt.Sprintf("%s\n", member.Mention())
} else {
format := lang.Get(tp+"msg.announce.with_age", lang.FallbackLang())
format += "\n"
fValue += fmt.Sprintf(format, mention, fmt.Sprint(b.Age()))
fValue += fmt.Sprintf(format, member.Mention(), fmt.Sprint(b.Age()))
}
}

Expand Down
6 changes: 3 additions & 3 deletions modules/birthday/handlerSubcommandAnnounce.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ func (cmd Chat) subcommandAnnounce() subcommandAnnounce {

func (cmd subcommandAnnounce) handler() {
now := time.Now()
b, err := getBirthdaysDate(now.Day(), int(now.Month()))
b, err := getBirthdaysDate(cmd.Interaction.GuildID, now.Day(), int(now.Month()))
if err != nil {
log.Printf("Error on announce birthday: %v\n", err)
log.Printf("Error on announce birthday in guild %s: %v\n", cmd.Interaction.GuildID, err)
cmd.ReplyError()
return
}

e, n := birthdayAnnounceEmbed(cmd.Session, b)
e, n := birthdayAnnounceEmbed(cmd.Session, cmd.Interaction.GuildID, b)

if n <= 0 {
cmd.ReplyHiddenEmbed(e)
Expand Down
4 changes: 2 additions & 2 deletions modules/birthday/handlerSubcommandList.go
Original file line number Diff line number Diff line change
Expand Up @@ -49,9 +49,9 @@ func (cmd subcommandList) handler() {
}
month := int(cmd.month.IntValue())

birthdays, err := cmd.getBirthdaysMonth(month)
birthdays, err := cmd.getBirthdaysMonth(cmd.Interaction.GuildID, month)
if err != nil {
log.Printf("Error on get birthdays by month: %v\n", err)
log.Printf("Error on get birthdays by month from guild %s: %v\n", cmd.Interaction.GuildID, err)
cmd.ReplyError()
return
}
Expand Down
6 changes: 3 additions & 3 deletions modules/birthday/handlerSubcommandSet.go
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,7 @@ func (cmd subcommandSet) interactionHandler() {
return
}

b := birthdayEntry{
b := &birthdayEntry{
ID: authorID,
Day: int(cmd.day.IntValue()),
Month: int(cmd.month.IntValue()),
Expand Down Expand Up @@ -194,13 +194,13 @@ func (cmd subcommandSet) interactionHandler() {
}

// seperate handler for an update of the birthday
func (cmd subcommandSet) handleUpdate(b birthdayEntry, e *discordgo.MessageEmbed) error {
func (cmd subcommandSet) handleUpdate(b *birthdayEntry, e *discordgo.MessageEmbed) (err error) {
before, err := cmd.updateBirthday(b)
if err != nil {
return err
}

if b == before {
if b.IsEqual(before) {
var age string
if b.Year > 0 {
age = fmt.Sprintf(" (%d)", b.Age()+1)
Expand Down
21 changes: 21 additions & 0 deletions util/discord.go
Original file line number Diff line number Diff line change
Expand Up @@ -437,3 +437,24 @@ func MessageComplexWebhookEdit(src any) *discordgo.WebhookEdit {
panic("Given source type is not supported: " + fmt.Sprintf("%T", src))
}
}

// IsGuildMember returns the given user as a member of the given guild. If the
// user is not a member of the guild IsGuildMember returns nil.
func IsGuildMember(s *discordgo.Session, guildID, userID string) (member *discordgo.Member) {
member, err := s.State.Member(guildID, userID)
if err == nil {
return member
} else if err != discordgo.ErrStateNotFound {
log.Printf("ERROR: Failed to get guild member from cache (G: %s, U: %s): %v\n", guildID, userID, err)
}
member, err = s.GuildMember(guildID, userID)
if err == nil {
return member
}

var restErr *discordgo.RESTError
if !errors.As(err, &restErr) || restErr.Response.StatusCode != http.StatusNotFound {
log.Printf("ERROR: Failed to get guild member from API (G: %s, U: %s): %v\n", guildID, userID, err)
}
return nil
}

0 comments on commit b02090c

Please sign in to comment.