Skip to content

Commit

Permalink
Use transaction
Browse files Browse the repository at this point in the history
  • Loading branch information
lukedirtwalker committed Dec 18, 2018
1 parent f1e4771 commit 4285db0
Show file tree
Hide file tree
Showing 5 changed files with 116 additions and 61 deletions.
45 changes: 18 additions & 27 deletions go/cert_srv/internal/csconfig/customers.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,11 @@
package csconfig

import (
"bytes"
"context"
"fmt"
"path/filepath"
"regexp"
"strconv"
"sync"
"time"

"github.com/scionproto/scion/go/lib/addr"
Expand All @@ -43,7 +41,6 @@ var reCustVerKey = regexp.MustCompile(`^(ISD\S+-AS\S+)-V(\d+)\.key$`)
// Customers is a mapping from non-core ASes assigned to this core AS to their public
// verifying key.
type Customers struct {
m sync.RWMutex
// trustDB is the trust database.
trustDB trustdb.TrustDB
}
Expand Down Expand Up @@ -86,7 +83,15 @@ func (c *Customers) loadCustomers(stateDir string) error {
if err != nil {
return common.NewBasicError("Unable to load key", err, "file", file)
}
err = c.trustDB.InsertCustKey(ctx, ia, activeVers[ia], key)
_, dbV, err := c.trustDB.GetCustKey(ctx, ia)
if err != nil {
return common.NewBasicError("Failed to check DB cust key", err, "ia", ia)
}
if dbV >= activeVers[ia] {
// db already contains a newer key.
continue
}
err = c.trustDB.InsertCustKey(ctx, ia, activeVers[ia], key, dbV)
if err != nil {
return common.NewBasicError("Failed to save customer key", err, "file", file)
}
Expand All @@ -96,36 +101,22 @@ func (c *Customers) loadCustomers(stateDir string) error {

// GetVerifyingKey returns the verifying key from the requested AS and nil if it is in the mapping.
// Otherwise, nil and an error.
func (c *Customers) GetVerifyingKey(ctx context.Context, ia addr.IA) (common.RawBytes, error) {
c.m.RLock()
defer c.m.RUnlock()
return c.getVerifyingKey(ctx, ia)
}
func (c *Customers) GetVerifyingKey(ctx context.Context,
ia addr.IA) (common.RawBytes, uint64, error) {

func (c *Customers) getVerifyingKey(ctx context.Context, ia addr.IA) (common.RawBytes, error) {
k, err := c.trustDB.GetCustKey(ctx, ia)
k, v, err := c.trustDB.GetCustKey(ctx, ia)
if err != nil {
return nil, err
return nil, 0, err
}
if k == nil {
return nil, common.NewBasicError(NotACustomer, nil, "ISD-AS", ia)
return nil, 0, common.NewBasicError(NotACustomer, nil, "ISD-AS", ia)
}
return k, nil
return k, v, nil
}

// SetVerifyingKey sets the verifying key for a specified AS. The key is written to the file system.
func (c *Customers) SetVerifyingKey(ctx context.Context, ia addr.IA, ver uint64,
newKey, oldKey common.RawBytes) error {
func (c *Customers) SetVerifyingKey(ctx context.Context, tx trustdb.Transaction,
ia addr.IA, newVer, oldVer uint64, newKey, oldKey common.RawBytes) error {

c.m.Lock()
defer c.m.Unlock()
currKey, err := c.getVerifyingKey(ctx, ia)
if err != nil {
return err
}
// Check that the key has not changed in the mean time
if !bytes.Equal(currKey, oldKey) {
return common.NewBasicError(KeyChanged, nil, "ISD-AS", ia)
}
return c.trustDB.InsertCustKey(ctx, ia, ver, newKey)
return tx.InsertCustKey(ctx, ia, newVer, newKey, oldVer)
}
23 changes: 17 additions & 6 deletions go/cert_srv/internal/reiss/handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ func (h *Handler) handle(r *infra.Request, addr *snet.Addr, req *cert_mgmt.Chain
return h.sendRep(ctx, addr, maxChain, r.ID)
}
// Get the verifying key from the customer mapping
verKey, err := h.State.Customers.GetVerifyingKey(ctx, addr.IA)
verKey, verVersion, err := h.State.Customers.GetVerifyingKey(ctx, addr.IA)
if err != nil {
return common.NewBasicError("Unable to get verifying key", err)
}
Expand All @@ -94,7 +94,7 @@ func (h *Handler) handle(r *infra.Request, addr *snet.Addr, req *cert_mgmt.Chain
return common.NewBasicError("Unable to verify request", err)
}
// Issue certificate chain
newChain, err := h.issueChain(ctx, crt, verKey)
newChain, err := h.issueChain(ctx, crt, verKey, verVersion)
if err != nil {
return common.NewBasicError("Unable to reissue certificate chain", err)
}
Expand Down Expand Up @@ -143,7 +143,7 @@ func (h *Handler) validateReq(c *cert.Certificate, vKey common.RawBytes,
vChain.Leaf.Subject, "sub", c.Subject)
}
if maxChain.Leaf.Version+1 != c.Version {
return common.NewBasicError("Invalid version", nil, "expected", maxChain.Leaf.Version,
return common.NewBasicError("Invalid version", nil, "expected", maxChain.Leaf.Version+1,
"actual", c.Version)
}
if !c.Issuer.Eq(h.IA) {
Expand All @@ -162,7 +162,7 @@ func (h *Handler) validateReq(c *cert.Certificate, vKey common.RawBytes,
// issueChain creates a certificate chain for the certificate and adds it to the
// trust store.
func (h *Handler) issueChain(ctx context.Context, c *cert.Certificate,
vKey common.RawBytes) (*cert.Chain, error) {
vKey common.RawBytes, verVersion uint64) (*cert.Chain, error) {

issCert, err := h.getIssuerCert(ctx)
if err != nil {
Expand All @@ -184,15 +184,26 @@ func (h *Handler) issueChain(ctx context.Context, c *cert.Certificate,
if err != nil {
return nil, err
}
tx, err := h.State.TrustDB.BeginTransaction(ctx, nil)
if err != nil {
return nil, common.NewBasicError("Failed to create transaction", err)
}
// Set verifying key.
err = h.State.Customers.SetVerifyingKey(ctx, c.Subject, c.Version, c.SubjectSignKey, vKey)
err = h.State.Customers.SetVerifyingKey(ctx, tx, c.Subject, c.Version, verVersion,
c.SubjectSignKey, vKey)
if err != nil {
tx.Rollback()
return nil, err
}
if _, err = h.State.TrustDB.InsertChain(ctx, chain); err != nil {
if _, err = tx.InsertChain(ctx, chain); err != nil {
tx.Rollback()
log.Error("[ReissHandler] Unable to write reissued certificate chain to disk", "err", err)
return nil, err
}
err = tx.Commit()
if err != nil {
return nil, common.NewBasicError("Failed to commit transaction", err)
}
return chain, nil
}

Expand Down
16 changes: 10 additions & 6 deletions go/lib/infra/modules/trust/trustdb/trustdb.go
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,9 @@ type Read interface {
GetTRCMaxVersion(ctx context.Context, isd addr.ISD) (*trc.TRC, error)
// GetAllTRCs fetches all TRCs from the database.
GetAllTRCs(ctx context.Context) ([]*trc.TRC, error)
// GetCustKey gets the customer signing key for the given AS in the latest version.
GetCustKey(ctx context.Context, ia addr.IA) (common.RawBytes, error)
// GetCustKey gets the customer signing key and the version
// for the given AS in the latest version.
GetCustKey(ctx context.Context, ia addr.IA) (common.RawBytes, uint64, error)
}

// Write contains all write operations fo the trust DB.
Expand All @@ -82,10 +83,13 @@ type Write interface {
// InsertTRC inserts trcobj into the database. The first return value is the
// number of rows affected.
InsertTRC(ctx context.Context, trcobj *trc.TRC) (int64, error)
// InsertCustKey inserts the given customer key.
// If a key with same ia and version is already stored this is a no-op,
// i.e. it does not change the contents.
InsertCustKey(ctx context.Context, ia addr.IA, version uint64, key common.RawBytes) error
// InsertCustKey inserts or updates the given customer key.
// If there has been a concurrent insert, i.e. the version in the DB is no longer oldVersion
// this operation should return an error.
// If there is no previous version 0 should be passed for the oldVersion argument.
// If oldVersion == version an error is returned.
InsertCustKey(ctx context.Context, ia addr.IA, version uint64,
key common.RawBytes, oldVersion uint64) error
}

// Transaction represents a trust DB transaction with an ongoing transaction.
Expand Down
58 changes: 48 additions & 10 deletions go/lib/infra/modules/trust/trustdb/trustdbsqlite/db.go
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,14 @@ const (
);
CREATE TABLE CustKeys (
IsdID INTEGER NOT NULL,
AsID INTEGER NOT NULL,
Version INTEGER NOT NULL,
Key DATA NOT NULL,
PRIMARY KEY (IsdID, AsID)
);
CREATE TABLE CustKeysLog (
IsdID INTEGER NOT NULL,
AsID INTEGER NOT NULL,
Version INTEGER NOT NULL,
Expand Down Expand Up @@ -159,10 +167,18 @@ const (
SELECT Data FROM TRCs
`
getCustKeyStr = `
SELECT Key, MAX(Version) FROM CustKeys WHERE IsdID=? AND AsID=? GROUP BY IsdID, AsID
SELECT Key, Version FROM CustKeys WHERE IsdID=? AND AsID=?
`
insertCustKeyStr = `
INSERT OR IGNORE INTO CustKeys (IsdID, AsID, Version, Key) VALUES (?, ?, ?, ?)
INSERT INTO CustKeys (IsdID, AsID, Version, Key) VALUES (?, ?, ?, ?)
`

insertCustKeyLogStr = `
INSERT OR IGNORE INTO CustKeysLog (IsdID, AsID, Version, Key) VALUES (?, ?, ?, ?)
`

updateCustKeyStr = `
UPDATE CustKeys SET Version = ?, Key = ? WHERE IsdID = ? AND AsID = ? AND Version = ?
`
)

Expand Down Expand Up @@ -476,31 +492,53 @@ func (db *executor) GetAllTRCs(ctx context.Context) ([]*trc.TRC, error) {
}

// GetCustKey gets the customer signing key for the given AS in the latest version.
func (db *executor) GetCustKey(ctx context.Context, ia addr.IA) (common.RawBytes, error) {
func (db *executor) GetCustKey(ctx context.Context, ia addr.IA) (common.RawBytes, uint64, error) {
db.RLock()
defer db.RUnlock()
var key common.RawBytes
var version uint64
err := db.db.QueryRowContext(ctx, getCustKeyStr, ia.I, ia.A).Scan(&key, &version)
if err == sql.ErrNoRows {
return nil, nil
return nil, 0, nil
}
if err != nil {
return nil, common.NewBasicError("Failed to look up cust key", err)
return nil, 0, common.NewBasicError("Failed to look up cust key", err)
}
return key, nil
return key, version, nil
}

// InsertCustKey inserts the given customer key.
func (db *executor) InsertCustKey(ctx context.Context, ia addr.IA,
version uint64, key common.RawBytes) error {
version uint64, key common.RawBytes, oldVersion uint64) error {

if version == oldVersion {
return common.NewBasicError("Same version as oldVersion not allowed",
nil, "version", version)
}
db.Lock()
defer db.Unlock()
if _, err := db.db.ExecContext(ctx, insertCustKeyStr, ia.I, ia.A, version, key); err != nil {
return common.NewBasicError("Failed to insert cust key", err, "ia", ia, "ver", version)
if oldVersion == 0 {
_, err := db.db.ExecContext(ctx, insertCustKeyStr, ia.I, ia.A, version, key)
if err != nil {
return common.NewBasicError("Failed to insert cust key", err, "ia", ia, "ver", version)
}
} else {
res, err := db.db.ExecContext(ctx, updateCustKeyStr, version, key, ia.I, ia.A, oldVersion)
if err != nil {
return common.NewBasicError("Failed to update cust key", err, "ia", ia, "ver", version)
}
n, err := res.RowsAffected()
if err != nil {
return common.NewBasicError("Unable to determine affected rows", err)
}
if n == 0 {
return common.NewBasicError("Cust keys has been modified", nil, "ia", ia,
"newVersion", version, "oldVersion", oldVersion)
}
}
return nil
// Insert in the log table.
_, err := db.db.ExecContext(ctx, insertCustKeyLogStr, ia.I, ia.A, version, key)
return err
}

type transaction struct {
Expand Down
35 changes: 23 additions & 12 deletions go/lib/infra/modules/trust/trustdb/trustdbtest/trustdbtest.go
Original file line number Diff line number Diff line change
Expand Up @@ -350,35 +350,46 @@ func testCustKey(t *testing.T, db rwTrustDB) {
key_110_2 := common.RawBytes("ddddddaa")

Convey("GetCustKey should return nil and no error", func() {
key, err := db.GetCustKey(ctx, ia1_110)
key, ver, err := db.GetCustKey(ctx, ia1_110)
SoMsg("No error expected", err, ShouldBeNil)
SoMsg("Empty result expected", key, ShouldBeNil)
SoMsg("0 version expected", ver, ShouldEqual, 0)
})
Convey("Insertion should work without error", func() {
err := db.InsertCustKey(ctx, ia1_110, 1, key_110_1)
var ver uint64 = 1
err := db.InsertCustKey(ctx, ia1_110, ver, key_110_1, 0)
SoMsg("No error expected", err, ShouldBeNil)
Convey("Inserted entry should be returned", func() {
key, err := db.GetCustKey(ctx, ia1_110)
key, dbVer, err := db.GetCustKey(ctx, ia1_110)
SoMsg("No error expected", err, ShouldBeNil)
SoMsg("Inserted key expected", key, ShouldResemble, key_110_1)
SoMsg("Inserted version expected", dbVer, ShouldEqual, ver)
})
Convey("Inserting a newer version should work", func() {
err := db.InsertCustKey(ctx, ia1_110, 2, key_110_2)
var newVer uint64 = 2
err := db.InsertCustKey(ctx, ia1_110, newVer, key_110_2, ver)
SoMsg("No error expected", err, ShouldBeNil)
Convey("New version should be returned", func() {
key, err := db.GetCustKey(ctx, ia1_110)
key, dbVer, err := db.GetCustKey(ctx, ia1_110)
SoMsg("No error expected", err, ShouldBeNil)
SoMsg("Inserted key expected", key, ShouldResemble, key_110_2)
SoMsg("Inserted version expected", dbVer, ShouldEqual, newVer)
})
})
Convey("Inserting the same version again should be ignored", func() {
err := db.InsertCustKey(ctx, ia1_110, 1, key_110_2)
Convey("Inserting the same version again should error", func() {
err := db.InsertCustKey(ctx, ia1_110, ver, key_110_2, ver)
SoMsg("Error expected", err, ShouldNotBeNil)
})
Convey("Inserting with 0 version should fail if there is an entry", func() {
err := db.InsertCustKey(ctx, ia1_110, ver, key_110_1, 0)
SoMsg("Error expected", err, ShouldNotBeNil)
})
Convey("Updating with outdated old version should fail", func() {
var newVer uint64 = 2
err := db.InsertCustKey(ctx, ia1_110, newVer, key_110_2, ver)
SoMsg("No error expected", err, ShouldBeNil)
Convey("The existing version should not be overridden", func() {
key, err := db.GetCustKey(ctx, ia1_110)
SoMsg("No error expected", err, ShouldBeNil)
SoMsg("Inserted key expected", key, ShouldResemble, key_110_1)
})
err = db.InsertCustKey(ctx, ia1_110, newVer, key_110_2, ver)
SoMsg("Error expected", err, ShouldNotBeNil)
})
})
})
Expand Down

0 comments on commit 4285db0

Please sign in to comment.