Skip to content

Commit

Permalink
Signing identities for FederationClient
Browse files Browse the repository at this point in the history
  • Loading branch information
neilalexander committed Nov 15, 2022
1 parent 715dc88 commit 900369e
Show file tree
Hide file tree
Showing 7 changed files with 117 additions and 91 deletions.
9 changes: 5 additions & 4 deletions authstate.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,17 +17,18 @@ type StateProvider interface {

type FederatedStateClient interface {
LookupState(
ctx context.Context, s ServerName, roomID, eventID string, roomVersion RoomVersion,
ctx context.Context, origin, s ServerName, roomID, eventID string, roomVersion RoomVersion,
) (res RespState, err error)
LookupStateIDs(
ctx context.Context, s ServerName, roomID, eventID string,
ctx context.Context, origin, s ServerName, roomID, eventID string,
) (res RespStateIDs, err error)
}

// FederatedStateProvider is an implementation of StateProvider which solely uses federation requests to retrieve events.
type FederatedStateProvider struct {
FedClient FederatedStateClient
// The remote server to ask.
Origin ServerName
Server ServerName
// Set to true to remember the auth event IDs for the room at various states
RememberAuthEvents bool
Expand All @@ -38,7 +39,7 @@ type FederatedStateProvider struct {

// StateIDsBeforeEvent implements StateProvider
func (p *FederatedStateProvider) StateIDsBeforeEvent(ctx context.Context, event *HeaderedEvent) ([]string, error) {
res, err := p.FedClient.LookupStateIDs(ctx, p.Server, event.RoomID(), event.EventID())
res, err := p.FedClient.LookupStateIDs(ctx, p.Origin, p.Server, event.RoomID(), event.EventID())
if err != nil {
return nil, err
}
Expand All @@ -50,7 +51,7 @@ func (p *FederatedStateProvider) StateIDsBeforeEvent(ctx context.Context, event

// StateBeforeEvent implements StateProvider
func (p *FederatedStateProvider) StateBeforeEvent(ctx context.Context, roomVer RoomVersion, event *HeaderedEvent, eventIDs []string) (map[string]*Event, error) {
res, err := p.FedClient.LookupState(ctx, p.Server, event.RoomID(), event.EventID(), roomVer)
res, err := p.FedClient.LookupState(ctx, p.Origin, p.Server, event.RoomID(), event.EventID(), roomVer)
if err != nil {
return nil, err
}
Expand Down
6 changes: 3 additions & 3 deletions backfill.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ import (
type BackfillClient interface {
// Backfill performs a backfill request to the given server.
// https://matrix.org/docs/spec/server_server/latest#get-matrix-federation-v1-backfill-roomid
Backfill(ctx context.Context, server ServerName, roomID string, limit int, fromEventIDs []string) (Transaction, error)
Backfill(ctx context.Context, origin, server ServerName, roomID string, limit int, fromEventIDs []string) (Transaction, error)
}

// BackfillRequester contains the necessary functions to perform backfill requests from one server to another.
Expand Down Expand Up @@ -45,7 +45,7 @@ type BackfillRequester interface {
// but to verify it we need to know the prev_events of fromEventIDs.
//
// TODO: When does it make sense to return errors?
func RequestBackfill(ctx context.Context, b BackfillRequester, keyRing JSONVerifier,
func RequestBackfill(ctx context.Context, origin ServerName, b BackfillRequester, keyRing JSONVerifier,
roomID string, ver RoomVersion, fromEventIDs []string, limit int) ([]*HeaderedEvent, error) {

if len(fromEventIDs) == 0 {
Expand All @@ -67,7 +67,7 @@ func RequestBackfill(ctx context.Context, b BackfillRequester, keyRing JSONVerif
return nil, fmt.Errorf("gomatrixserverlib: RequestBackfill context cancelled %w", ctx.Err())
}
// fetch some events, and try a different server if it fails
txn, err := b.Backfill(ctx, s, roomID, limit, fromEventIDs)
txn, err := b.Backfill(ctx, origin, s, roomID, limit, fromEventIDs)
if err != nil {
continue // try the next server
}
Expand Down
20 changes: 10 additions & 10 deletions backfill_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ import (

type testBackfillRequester struct {
servers []ServerName
backfillFn func(server ServerName, roomID string, fromEventIDs []string, limit int) (*Transaction, error)
backfillFn func(origin, server ServerName, roomID string, fromEventIDs []string, limit int) (*Transaction, error)
authEventsToProvide [][]byte
stateIDsAtEvent map[string][]string
callOrderForStateIDsBeforeEvent []string // event IDs called
Expand All @@ -28,8 +28,8 @@ func (t *testBackfillRequester) StateBeforeEvent(ctx context.Context, roomVer Ro
func (t *testBackfillRequester) ServersAtEvent(ctx context.Context, roomID, eventID string) []ServerName {
return t.servers
}
func (t *testBackfillRequester) Backfill(ctx context.Context, server ServerName, roomID string, limit int, fromEventIDs []string) (Transaction, error) {
txn, err := t.backfillFn(server, roomID, fromEventIDs, limit)
func (t *testBackfillRequester) Backfill(ctx context.Context, origin, server ServerName, roomID string, limit int, fromEventIDs []string) (Transaction, error) {
txn, err := t.backfillFn(origin, server, roomID, fromEventIDs, limit)
if err != nil {
return Transaction{}, err
}
Expand Down Expand Up @@ -92,14 +92,14 @@ func TestRequestBackfillMultipleServers(t *testing.T) {
"$fnwGrQEpiOIUoDU2:baba.is.you": {"$WCraVpPZe5TtHAqs:baba.is.you"},
"$WCraVpPZe5TtHAqs:baba.is.you": nil,
},
backfillFn: func(server ServerName, roomID string, fromEventIDs []string, limit int) (*Transaction, error) {
backfillFn: func(origin, server ServerName, roomID string, fromEventIDs []string, limit int) (*Transaction, error) {
if roomID != testRoomID {
return nil, fmt.Errorf("bad room id: %s", roomID)
}
if server == serverA {
// server A returns events 1 and 3.
return &Transaction{
Origin: serverA,
Origin: origin,
OriginServerTS: AsTimestamp(time.Now()),
PDUs: []json.RawMessage{
testBackfillEvents[1], testBackfillEvents[3],
Expand All @@ -108,7 +108,7 @@ func TestRequestBackfillMultipleServers(t *testing.T) {
} else if server == serverB {
// server B returns events 0 and 2 and 3.
return &Transaction{
Origin: serverB,
Origin: origin,
OriginServerTS: AsTimestamp(time.Now()),
PDUs: []json.RawMessage{
testBackfillEvents[0], testBackfillEvents[2], testBackfillEvents[3],
Expand All @@ -118,7 +118,7 @@ func TestRequestBackfillMultipleServers(t *testing.T) {
return nil, fmt.Errorf("bad server name: %s", server)
},
}
result, err := RequestBackfill(ctx, tbr, keyRing, testRoomID, RoomVersionV1, testFromEventIDs, testLimit)
result, err := RequestBackfill(ctx, serverA, tbr, keyRing, testRoomID, RoomVersionV1, testFromEventIDs, testLimit)
if err != nil {
t.Fatalf("RequestBackfill got error: %s", err)
}
Expand Down Expand Up @@ -157,13 +157,13 @@ func TestRequestBackfillTopologicalSort(t *testing.T) {
"$fnwGrQEpiOIUoDU2:baba.is.you": {"$WCraVpPZe5TtHAqs:baba.is.you"},
"$WCraVpPZe5TtHAqs:baba.is.you": nil,
},
backfillFn: func(server ServerName, roomID string, fromEventIDs []string, limit int) (*Transaction, error) {
backfillFn: func(origin, server ServerName, roomID string, fromEventIDs []string, limit int) (*Transaction, error) {
if roomID != testRoomID {
return nil, fmt.Errorf("bad room id: %s", roomID)
}
if server == serverA {
return &Transaction{
Origin: serverA,
Origin: origin,
OriginServerTS: AsTimestamp(time.Now()),
PDUs: []json.RawMessage{
testBackfillEvents[0], testBackfillEvents[1], testBackfillEvents[2], testBackfillEvents[3],
Expand All @@ -173,7 +173,7 @@ func TestRequestBackfillTopologicalSort(t *testing.T) {
return nil, fmt.Errorf("bad server name: %s", server)
},
}
result, err := RequestBackfill(ctx, tbr, keyRing, testRoomID, RoomVersionV1, testFromEventIDs, testLimit)
result, err := RequestBackfill(ctx, serverA, tbr, keyRing, testRoomID, RoomVersionV1, testFromEventIDs, testLimit)
if err != nil {
t.Fatalf("RequestBackfill got error: %s", err)
}
Expand Down
Loading

0 comments on commit 900369e

Please sign in to comment.