Skip to content

Commit

Permalink
login: add initial support for logging in as slack app
Browse files Browse the repository at this point in the history
  • Loading branch information
tulir committed Aug 3, 2024
1 parent 110dbcf commit 3a8d518
Show file tree
Hide file tree
Showing 12 changed files with 372 additions and 83 deletions.
4 changes: 2 additions & 2 deletions go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ require (
go.mau.fi/util v0.6.1-0.20240802175451-b430ebbffc98
golang.org/x/net v0.27.0
gopkg.in/yaml.v3 v3.0.1
maunium.net/go/mautrix v0.19.1-0.20240803150944-b71b32d0d6d6
maunium.net/go/mautrix v0.19.1-0.20240803190639-956c13761ebb
)

require (
Expand All @@ -38,4 +38,4 @@ require (
maunium.net/go/mauflag v1.0.0 // indirect
)

replace github.com/slack-go/slack => github.com/beeper/slackgo v0.0.0-20240803155237-c586cd47cbb5
replace github.com/slack-go/slack => github.com/beeper/slackgo v0.0.0-20240803190730-362662f1319d
8 changes: 4 additions & 4 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
github.com/DATA-DOG/go-sqlmock v1.5.2 h1:OcvFkGmslmlZibjAjaHm3L//6LiuBgolP7OputlJIzU=
github.com/DATA-DOG/go-sqlmock v1.5.2/go.mod h1:88MAG/4G7SMwSE3CeA0ZKzrT5CiOU3OJ+JlNzwDqpNU=
github.com/beeper/slackgo v0.0.0-20240803155237-c586cd47cbb5 h1:UgZR7wNGPl+kp54v6LktkRgThPFcrzIX3+Slpx5dy5E=
github.com/beeper/slackgo v0.0.0-20240803155237-c586cd47cbb5/go.mod h1:K+6JA6FP9/mILahVr6VH67l83p0sWkayPiDOBhzKWlo=
github.com/beeper/slackgo v0.0.0-20240803190730-362662f1319d h1:8wvWykAoc2uJX9MmYueb/vCSUcg3KYX3QvMknUed8xE=
github.com/beeper/slackgo v0.0.0-20240803190730-362662f1319d/go.mod h1:K+6JA6FP9/mILahVr6VH67l83p0sWkayPiDOBhzKWlo=
github.com/coreos/go-systemd/v22 v22.5.0 h1:RrqgGjYQKalulkV8NGVIfkXQf6YYmOyiJKk8iXXhfZs=
github.com/coreos/go-systemd/v22 v22.5.0/go.mod h1:Y58oyj3AT4RCenI/lSvhwexgC+NSVTIJ3seZv2GcEnc=
github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c=
Expand Down Expand Up @@ -71,5 +71,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
maunium.net/go/mauflag v1.0.0 h1:YiaRc0tEI3toYtJMRIfjP+jklH45uDHtT80nUamyD4M=
maunium.net/go/mauflag v1.0.0/go.mod h1:nLivPOpTpHnpzEh8jEdSL9UqO9+/KBJFmNRlwKfkPeA=
maunium.net/go/mautrix v0.19.1-0.20240803150944-b71b32d0d6d6 h1:ZXMvdnZ/oZk4kFACRaApV4BpbUqueJCgalHZXHKpu0I=
maunium.net/go/mautrix v0.19.1-0.20240803150944-b71b32d0d6d6/go.mod h1:ZWyxoQxRTBxzWIMs0kQCVogZIY0clTu33h102veCT/Q=
maunium.net/go/mautrix v0.19.1-0.20240803190639-956c13761ebb h1:Vk9NX4DjDXbBjxw/A9DGKaCsI1VcQbATl2XY0LQC/gw=
maunium.net/go/mautrix v0.19.1-0.20240803190639-956c13761ebb/go.mod h1:ZWyxoQxRTBxzWIMs0kQCVogZIY0clTu33h102veCT/Q=
2 changes: 1 addition & 1 deletion pkg/connector/backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ var _ bridgev2.BackfillingNetworkAPI = (*SlackClient)(nil)

func (s *SlackClient) FetchMessages(ctx context.Context, params bridgev2.FetchMessagesParams) (*bridgev2.FetchMessagesResponse, error) {
if s.Client == nil {
return nil, fmt.Errorf("not logged in")
return nil, bridgev2.ErrNotLoggedIn
}
_, channelID := slackid.ParsePortalID(params.Portal.ID)
if channelID == "" {
Expand Down
5 changes: 3 additions & 2 deletions pkg/connector/chatinfo.go
Original file line number Diff line number Diff line change
Expand Up @@ -355,8 +355,7 @@ func (s *SlackClient) fetchUserInfo(ctx context.Context, userID string, lastUpda
botInfo, err = s.Client.GetBotInfoContext(ctx, slack.GetBotInfoParameters{
Bot: userID,
})
} else {
//info, err = s.Client.GetUserInfoContext(ctx, userID)
} else if s.IsRealUser {
var infos map[string]*slack.User
infos, err = s.Client.GetUsersCacheContext(ctx, s.TeamID, slack.GetCachedUsersParameters{
CheckInteraction: true,
Expand All @@ -372,6 +371,8 @@ func (s *SlackClient) fetchUserInfo(ctx context.Context, userID string, lastUpda
return nil, nil
}
}
} else {
info, err = s.Client.GetUserInfoContext(ctx, userID)
}
if err != nil {
return nil, fmt.Errorf("failed to get user info for %q: %w", userID, err)
Expand Down
190 changes: 152 additions & 38 deletions pkg/connector/client.go
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ import (

"github.com/rs/zerolog"
"github.com/slack-go/slack"
"github.com/slack-go/slack/socketmode"
"maunium.net/go/mautrix/bridge/status"
"maunium.net/go/mautrix/bridgev2"
"maunium.net/go/mautrix/bridgev2/networkid"
Expand All @@ -44,13 +45,15 @@ func init() {
})
}

func makeSlackClient(log *zerolog.Logger, token, cookieToken string) *slack.Client {
func makeSlackClient(log *zerolog.Logger, token, cookieToken, appToken string) *slack.Client {
options := []slack.Option{
slack.OptionLog(slackgoZerolog{Logger: log.With().Str("component", "slackgo").Logger()}),
slack.OptionDebug(log.GetLevel() == zerolog.TraceLevel),
}
if cookieToken != "" {
options = append(options, slack.OptionCookie("d", cookieToken))
} else if appToken != "" {
options = append(options, slack.OptionAppLevelToken(appToken))
}
return slack.New(token, options...)
}
Expand All @@ -62,18 +65,28 @@ func (s *SlackConnector) LoadUserLogin(ctx context.Context, login *bridgev2.User
if meta.Token == "" {
sc = &SlackClient{Main: s, UserLogin: login, UserID: userID, TeamID: teamID}
} else {
client := makeSlackClient(&login.Log, meta.Token, meta.CookieToken)
client := makeSlackClient(&login.Log, meta.Token, meta.CookieToken, meta.AppToken)
sc = &SlackClient{
Main: s,
UserLogin: login,
Client: client,
RTM: client.NewRTM(),
UserID: userID,
TeamID: teamID,
Main: s,
UserLogin: login,
Client: client,
UserID: userID,
TeamID: teamID,
IsRealUser: strings.HasPrefix(meta.Token, "xoxs-") || strings.HasPrefix(meta.Token, "xoxc-"),

chatInfoCache: make(map[string]chatInfoCacheEntry),
lastReadCache: make(map[string]string),
}
if sc.IsRealUser {
sc.RTM = client.NewRTM()
} else {
log := login.Log.With().Str("component", "slackgo socketmode").Logger()
sc.SocketMode = socketmode.New(
sc.Client,
socketmode.OptionLog(slackgoZerolog{Logger: log}),
socketmode.OptionDebug(log.GetLevel() == zerolog.TraceLevel),
)
}
}
teamPortalKey := networkid.PortalKey{ID: slackid.MakeTeamPortalID(teamID)}
var err error
Expand All @@ -95,10 +108,14 @@ type SlackClient struct {
UserLogin *bridgev2.UserLogin
Client *slack.Client
RTM *slack.RTM
SocketMode *socketmode.Client
UserID string
TeamID string
BootResp *slack.ClientBootResponse
TeamPortal *bridgev2.Portal
IsRealUser bool

stopSocketMode context.CancelFunc

chatInfoCache map[string]chatInfoCacheEntry
chatInfoCacheLock sync.Mutex
Expand All @@ -116,6 +133,21 @@ func (s *SlackClient) GetClient() *slack.Client {
return s.Client
}

func (s *SlackClient) handleBootError(ctx context.Context, err error) {
if err.Error() == "user_removed_from_team" || err.Error() == "invalid_auth" {
s.invalidateSession(ctx, status.BridgeState{
StateEvent: status.StateBadCredentials,
Error: status.BridgeStateErrorCode(fmt.Sprintf("slack-%s", strings.ReplaceAll(err.Error(), "_", "-"))),
})
} else {
s.UserLogin.BridgeState.Send(status.BridgeState{
StateEvent: status.StateUnknownError,
Error: "slack-unknown-fetch-error",
Message: fmt.Sprintf("Unknown error from Slack: %s", err.Error()),
})
}
}

func (s *SlackClient) Connect(ctx context.Context) error {
if s.Client == nil {
s.UserLogin.BridgeState.Send(status.BridgeState{
Expand All @@ -124,21 +156,29 @@ func (s *SlackClient) Connect(ctx context.Context) error {
})
return nil
}
bootResp, err := s.Client.ClientBootContext(ctx)
if err != nil {
if err.Error() == "user_removed_from_team" || err.Error() == "invalid_auth" {
s.invalidateSession(ctx, status.BridgeState{
StateEvent: status.StateBadCredentials,
Error: status.BridgeStateErrorCode(fmt.Sprintf("slack-%s", strings.ReplaceAll(err.Error(), "_", "-"))),
})
} else {
s.UserLogin.BridgeState.Send(status.BridgeState{
StateEvent: status.StateUnknownError,
Error: "slack-unknown-fetch-error",
Message: fmt.Sprintf("Unknown error from Slack: %s", err.Error()),
})
var bootResp *slack.ClientBootResponse
if s.IsRealUser {
var err error
bootResp, err = s.Client.ClientBootContext(ctx)
if err != nil {
s.handleBootError(ctx, err)
return err
}
} else {
teamResp, err := s.Client.GetTeamInfoContext(ctx)
if err != nil {
s.handleBootError(ctx, err)
return fmt.Errorf("failed to fetch team info: %w", err)
}
userResp, err := s.Client.GetUserInfoContext(ctx, s.UserID)
if err != nil {
s.handleBootError(ctx, err)
return fmt.Errorf("failed to fetch user info: %w", err)
}
bootResp = &slack.ClientBootResponse{
Self: *userResp,
Team: *teamResp,
}
return err
}
return s.connect(ctx, bootResp)
}
Expand All @@ -149,13 +189,53 @@ func (s *SlackClient) connect(ctx context.Context, bootResp *slack.ClientBootRes
if err != nil {
return err
}
go s.consumeEvents()
go s.RTM.ManageConnection()
if s.IsRealUser {
go s.consumeRTMEvents()
go s.RTM.ManageConnection()
} else {
go s.consumeSocketModeEvents()
go s.runSocketMode(ctx)
}
go s.SyncEmojis(ctx)
go s.SyncChannels(ctx)
return nil
}

func (s *SlackClient) consumeRTMEvents() {
for evt := range s.RTM.IncomingEvents {
s.HandleSlackEvent(evt.Data)
}
}

func (s *SlackClient) consumeSocketModeEvents() {
for evt := range s.SocketMode.Events {
s.HandleSocketModeEvent(evt)
}
}

func (s *SlackClient) runSocketMode(ctx context.Context) {
var cancel context.CancelFunc
ctx, cancel = context.WithCancel(ctx)
defer cancel()
s.stopSocketMode = cancel
log := zerolog.Ctx(ctx)
for ctx.Err() == nil {
err := s.SocketMode.RunContext(ctx)
if err != nil {
log.Err(err).Msg("Error in socket mode connection")
s.UserLogin.BridgeState.Send(status.BridgeState{
StateEvent: status.StateTransientDisconnect,
Error: "slack-socketmode-error",
Message: err.Error(),
})
time.Sleep(10 * time.Second)
} else {
log.Info().Msg("Socket disconnected without error")
return
}
}
}

func (s *SlackClient) syncTeamPortal(ctx context.Context) error {
info := s.getTeamInfo()
if s.TeamPortal.MXID == "" {
Expand All @@ -181,7 +261,10 @@ func (s *SlackClient) getLastReadCache(channelID string) string {
return s.lastReadCache[channelID]
}

func (s *SlackClient) SyncChannels(ctx context.Context) {
func (s *SlackClient) getLatestMessageIDs(ctx context.Context) map[string]string {
if !s.IsRealUser {
return nil
}
log := zerolog.Ctx(ctx)
clientCounts, err := s.Client.ClientCountsContext(ctx, &slack.ClientCountsParams{
ThreadCountsByChannel: true,
Expand All @@ -190,7 +273,7 @@ func (s *SlackClient) SyncChannels(ctx context.Context) {
})
if err != nil {
log.Err(err).Msg("Failed to fetch client counts")
return
return nil
}
latestMessageIDs := make(map[string]string, len(clientCounts.Channels)+len(clientCounts.MpIMs)+len(clientCounts.IMs))
lastReadCache := make(map[string]string, len(clientCounts.Channels)+len(clientCounts.MpIMs)+len(clientCounts.IMs))
Expand All @@ -209,6 +292,12 @@ func (s *SlackClient) SyncChannels(ctx context.Context) {
s.lastReadCacheLock.Lock()
s.lastReadCache = lastReadCache
s.lastReadCacheLock.Unlock()
return latestMessageIDs
}

func (s *SlackClient) SyncChannels(ctx context.Context) {
log := zerolog.Ctx(ctx)
latestMessageIDs := s.getLatestMessageIDs(ctx)
userPortals, err := s.UserLogin.Bridge.DB.UserPortal.GetAllForLogin(ctx, s.UserLogin.UserLogin)
if err != nil {
log.Err(err).Msg("Failed to fetch user portals")
Expand All @@ -220,7 +309,7 @@ func (s *SlackClient) SyncChannels(ctx context.Context) {
}
var channels []*slack.Channel
token := s.UserLogin.Metadata.(*slackid.UserLoginMetadata).Token
if strings.HasPrefix(token, "xoxs") || s.Main.Config.Backfill.ConversationCount == -1 {
if s.IsRealUser && (strings.HasPrefix(token, "xoxs-") || s.Main.Config.Backfill.ConversationCount == -1) {
for _, ch := range s.BootResp.Channels {
ch.IsMember = true
channels = append(channels, &ch.Channel)
Expand All @@ -232,6 +321,9 @@ func (s *SlackClient) SyncChannels(ctx context.Context) {
log.Debug().Int("channel_count", len(channels)).Msg("Using channels from boot response for sync")
} else {
totalLimit := s.Main.Config.Backfill.ConversationCount
if totalLimit < 0 {
totalLimit = 50
}
var cursor string
log.Debug().Int("total_limit", totalLimit).Msg("Fetching conversation list for sync")
for totalLimit > 0 {
Expand Down Expand Up @@ -259,18 +351,39 @@ func (s *SlackClient) SyncChannels(ctx context.Context) {
cursor = nextCursor
}
}
slices.SortFunc(channels, func(a, b *slack.Channel) int {
return cmp.Compare(latestMessageIDs[a.ID], latestMessageIDs[b.ID])
})
if latestMessageIDs != nil {
slices.SortFunc(channels, func(a, b *slack.Channel) int {
return cmp.Compare(latestMessageIDs[a.ID], latestMessageIDs[b.ID])
})
}
for _, ch := range channels {
portalKey := s.makePortalKey(ch)
delete(existingPortals, portalKey)
latestMessageID, hasCounts := latestMessageIDs[ch.ID]
var latestMessageID string
var hasCounts bool
if !s.IsRealUser {
ch, err = s.Client.GetConversationInfoContext(ctx, &slack.GetConversationInfoInput{
ChannelID: ch.ID,
IncludeLocale: true,
IncludeNumMembers: true,
})
if err != nil {
log.Err(err).Str("channel_id", ch.ID).Msg("Failed to fetch channel info")
continue
}
hasCounts = ch.Latest != nil
if hasCounts {
latestMessageID = ch.Latest.Timestamp
}
} else {
latestMessageID, hasCounts = latestMessageIDs[ch.ID]
}
// TODO fetch latest message from channel info when using bot account?
s.Main.br.QueueRemoteEvent(s.UserLogin, &SlackChatResync{
SlackEventMeta: &SlackEventMeta{
Type: bridgev2.RemoteEventChatResync,
PortalKey: portalKey,
CreatePortal: hasCounts || !(ch.IsIM || ch.IsMpIM),
CreatePortal: hasCounts || (!ch.IsIM && !ch.IsMpIM),
LogContext: func(c zerolog.Context) zerolog.Context {
return c.
Object("portal_key", portalKey).
Expand Down Expand Up @@ -303,17 +416,18 @@ func (s *SlackClient) SyncChannels(ctx context.Context) {
}
}

func (s *SlackClient) consumeEvents() {
for evt := range s.RTM.IncomingEvents {
s.HandleSlackEvent(evt.Data)
}
}

func (s *SlackClient) Disconnect() {
s.disconnectRTM()
s.disconnectSocketMode()
s.Client = nil
}

func (s *SlackClient) disconnectSocketMode() {
if stop := s.stopSocketMode; stop != nil {
stop()
}
}

func (s *SlackClient) disconnectRTM() {
if rtm := s.RTM; rtm != nil {
err := rtm.Disconnect()
Expand Down
Loading

0 comments on commit 3a8d518

Please sign in to comment.