Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Parallelise checkUsingKeys #231

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
91 changes: 61 additions & 30 deletions keyring.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"context"
"errors"
"fmt"
"runtime"
"strings"
"sync"
"time"
Expand Down Expand Up @@ -325,39 +326,69 @@ func (k *KeyRing) checkUsingKeys(
requests []VerifyJSONRequest, results []VerifyJSONResult, keyIDs [][]KeyID,
keys map[PublicKeyLookupRequest]PublicKeyLookupResult,
) {
procs := runtime.NumCPU() - 1
if procs < 1 {
procs = 1
}
type job struct {
index int // the original index in the requests/results array
request VerifyJSONRequest // the request itself
}
jobs := make(map[int][]job)
for i := range requests {
if results[i].Error == nil {
// We've already checked this message and it passed the signature checks.
// So we can skip to the next message.
continue
}
for _, keyID := range keyIDs[i] {
serverKey, ok := keys[PublicKeyLookupRequest{requests[i].ServerName, keyID}]
if !ok {
// No key for this key ID so we continue onto the next key ID.
continue
}
if !serverKey.WasValidAt(requests[i].AtTS, requests[i].StrictValidityChecking) {
// The key wasn't valid at the timestamp we needed it to be valid at.
// So skip onto the next key.
results[i].Error = fmt.Errorf(
"gomatrixserverlib: key with ID %q for %q not valid at %d",
keyID, requests[i].ServerName, requests[i].AtTS,
)
continue
}
if err := VerifyJSON(
string(requests[i].ServerName), keyID, ed25519.PublicKey(serverKey.Key), requests[i].Message,
); err != nil {
// The signature wasn't valid, record the error and try the next key ID.
results[i].Error = err
continue
jobs[i%procs] = append(jobs[i%procs], job{i, requests[i]})
}
var wg sync.WaitGroup // tracks the workers
var mu sync.RWMutex // protects results array
wg.Add(len(jobs))
for _, j := range jobs {
go func(jobs []job) {
for _, j := range jobs {
mu.RLock()
if results[j.index].Error == nil {
// We've already checked this message and it passed the signature checks.
// So we can skip to the next message.
mu.RUnlock()
continue
}
mu.RUnlock()
for _, keyID := range keyIDs[j.index] {
serverKey, ok := keys[PublicKeyLookupRequest{j.request.ServerName, keyID}]
if !ok {
// No key for this key ID so we continue onto the next key ID.
continue
}
if !serverKey.WasValidAt(j.request.AtTS, j.request.StrictValidityChecking) {
// The key wasn't valid at the timestamp we needed it to be valid at.
// So skip onto the next key.
mu.Lock()
results[j.index].Error = fmt.Errorf(
"gomatrixserverlib: key with ID %q for %q not valid at %d",
keyID, j.request.ServerName, j.request.AtTS,
)
mu.Unlock()
Copy link
Member

@kegsay kegsay Oct 5, 2020

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not a fan of the shared variables across goroutines, it feels like a code smell as opposed to sharing data via channels. I'm aware we do this to maintain the index in the results array.

I think a better solution would be to add the parallelisation around VerifyJSON only, meaning you just need to pass information that is directly used without needing to access shared slices/maps. Something like:

type verifyJSONReq struct {
    i int
    serverName ServerName
    keyID KeyID
    pubKey ed25519.PublicKey
    msg []byte
}
type verifyJSONRes struct {
    i int
    err error
}
reqCh := make(chan verifyJSONReq, 50)
resCh := make(chan verifyJSONRes, 50)
var wg sync.WaitGroup
wg.Add(procs)
for i := 0; i < procs; i++ {
    go func() {
        defer wg.Done()
        for item := range reqCh {
             err := VerifyJSON(item.serverName, item.keyID, item.pubKey, item.msg)
            resCh <- verifyJSONRes{item.i, err}
        }
    }
}
for i := range requests {
    // insert code that was there previously up to VerifyJSON
    reqCh <- verifyJSONReq{i, other, func, args}
}
close(reqCh) // kill the goroutines after they process everything
// kill the response channel when we've got all the response
go func() {
    wg.Wait()
    close(resCh)
}()
for res := range resCh {
    results[res.i].Error = res.err
}

continue
}
if err := VerifyJSON(
string(j.request.ServerName), keyID, ed25519.PublicKey(serverKey.Key), j.request.Message,
); err != nil {
// The signature wasn't valid, record the error and try the next key ID.
mu.Lock()
results[j.index].Error = err
mu.Unlock()
continue
}
// The signature is valid, set the result to nil.
mu.Lock()
results[j.index].Error = nil
mu.Unlock()
break
}
}
// The signature is valid, set the result to nil.
results[i].Error = nil
break
}
wg.Done()
}(j)
}
wg.Wait()
}

type KeyClient interface {
Expand Down