From 96a023905def84c493c0957d2c8dc0a9bb5e0b43 Mon Sep 17 00:00:00 2001 From: Lukas Vogel Date: Tue, 18 Dec 2018 08:12:52 +0100 Subject: [PATCH] TrustDB: Make it possible to use transactions (#2251) --- go/lib/infra/modules/trust/trustdb/trustdb.go | 53 ++- .../modules/trust/trustdb/trustdbsqlite/db.go | 391 +++++++++--------- .../trust/trustdb/trustdbtest/trustdbtest.go | 73 +++- 3 files changed, 305 insertions(+), 212 deletions(-) diff --git a/go/lib/infra/modules/trust/trustdb/trustdb.go b/go/lib/infra/modules/trust/trustdb/trustdb.go index 045eb7b8ec..04ec273e63 100644 --- a/go/lib/infra/modules/trust/trustdb/trustdb.go +++ b/go/lib/infra/modules/trust/trustdb/trustdb.go @@ -12,10 +12,13 @@ // See the License for the specific language governing permissions and // limitations under the License. +// Package trustdb provides wrappers for SQL calls for managing a database +// containing TRCs and Certificate Chains. package trustdb import ( "context" + "database/sql" "io" "github.com/scionproto/scion/go/lib/addr" @@ -23,22 +26,31 @@ import ( "github.com/scionproto/scion/go/lib/scrypto/trc" ) +// TrustDB is a database containing Certificates, Chains and TRCs, stored in JSON format. // TrustDB is the interface that all trust databases have to implement. +// Read and Write interactions with this interface have to happen in individual transactions +// (either explicit or implicit). type TrustDB interface { + Read + Write + BeginTransaction(ctx context.Context, opts *sql.TxOptions) (Transaction, error) + io.Closer +} + +// Read contains all read operation of the trust DB. +// On errors, GetXxx methods return nil and the error. If no error occurred, +// but the database query yielded 0 results, the first returned value is nil. +type Read interface { // GetIssCertVersion returns the specified version of the issuer certificate for // ia. If version is scrypto.LatestVer, this is equivalent to GetIssCertMaxVersion. GetIssCertVersion(ctx context.Context, ia addr.IA, version uint64) (*cert.Certificate, error) // GetIssCertMaxVersion returns the max version of the issuer certificate for ia. GetIssCertMaxVersion(ctx context.Context, ia addr.IA) (*cert.Certificate, error) - // InsertIssCert inserts the issuer certificate. - InsertIssCert(ctx context.Context, crt *cert.Certificate) (int64, error) // GetLeafCertVersion returns the specified version of the leaf certificate for // ia. If version is scrypto.LatestVer, this is equivalent to GetLeafCertMaxVersion. GetLeafCertVersion(ctx context.Context, ia addr.IA, version uint64) (*cert.Certificate, error) // GetLeafCertMaxVersion returns the max version of the leaf certificate for ia. GetLeafCertMaxVersion(ctx context.Context, ia addr.IA) (*cert.Certificate, error) - // InsertLeafCert inserts the leaf certificate. - InsertLeafCert(ctx context.Context, crt *cert.Certificate) (int64, error) // GetChainVersion returns the specified version of the certificate chain for // ia. If version is scrypto.LatestVer, this is equivalent to GetChainMaxVersion. GetChainVersion(ctx context.Context, ia addr.IA, version uint64) (*cert.Chain, error) @@ -46,18 +58,39 @@ type TrustDB interface { GetChainMaxVersion(ctx context.Context, ia addr.IA) (*cert.Chain, error) // GetAllChains returns all chains in the database. GetAllChains(ctx context.Context) ([]*cert.Chain, error) - // InsertChain inserts chain into the database. The first return value is the - // number of rows affected. - InsertChain(ctx context.Context, chain *cert.Chain) (int64, error) // GetTRCVersion returns the specified version of the TRC for // isd. If version is scrypto.LatestVer, this is equivalent to GetTRCMaxVersion. GetTRCVersion(ctx context.Context, isd addr.ISD, version uint64) (*trc.TRC, error) // GetTRCMaxVersion returns the max version of the TRC for ia. GetTRCMaxVersion(ctx context.Context, isd addr.ISD) (*trc.TRC, error) + // GetAllTRCs fetches all TRCs from the database. + GetAllTRCs(ctx context.Context) ([]*trc.TRC, error) +} + +// Write contains all write operations fo the trust DB. +type Write interface { + // InsertIssCert inserts the issuer certificate. + InsertIssCert(ctx context.Context, crt *cert.Certificate) (int64, error) + // InsertLeafCert inserts the leaf certificate. + InsertLeafCert(ctx context.Context, crt *cert.Certificate) (int64, error) + // InsertChain inserts chain into the database. The first return value is the + // number of rows affected. + InsertChain(ctx context.Context, chain *cert.Chain) (int64, error) // InsertTRC inserts trcobj into the database. The first return value is the // number of rows affected. InsertTRC(ctx context.Context, trcobj *trc.TRC) (int64, error) - // GetAllTRCs fetches all TRCs from the database. - GetAllTRCs(ctx context.Context) ([]*trc.TRC, error) - io.Closer +} + +// Transaction represents a trust DB transaction with an ongoing transaction. +// To end the transaction either Rollback or Commit should be called. Calling Commit or Rollback +// multiple times will result in an error. +type Transaction interface { + Read + Write + // Commit commits the transaction. + // Returns the underlying TrustDB connection. + Commit() error + // Rollback rollbacks the transaction. + // Returns the underlying TrustDB connection. + Rollback() error } diff --git a/go/lib/infra/modules/trust/trustdb/trustdbsqlite/db.go b/go/lib/infra/modules/trust/trustdb/trustdbsqlite/db.go index ace27095af..4b0220373a 100644 --- a/go/lib/infra/modules/trust/trustdb/trustdbsqlite/db.go +++ b/go/lib/infra/modules/trust/trustdb/trustdbsqlite/db.go @@ -12,8 +12,7 @@ // See the License for the specific language governing permissions and // limitations under the License. -// Package trustdb provides wrappers for SQL calls for managing a database -// containing TRCs and Certificate Chains. +// Package trustdbsqlite implements the trustdb interface with a sqlite backed DB. // // KNOWN ISSUE: DB methods serialize to/dezerialize from JSON on each call. // For performance penalty details, check the benchmarks in the test file. @@ -105,23 +104,23 @@ const ( INSERT OR IGNORE INTO LeafCerts (IsdID, AsID, Version, Data) VALUES (?, ?, ?, ?) ` getChainVersionStr = ` - SELECT Data, 0 FROM LeafCerts WHERE IsdID=? AND AsID=? AND Version=? + SELECT Data, 0 FROM LeafCerts WHERE IsdID=?1 AND AsID=?2 AND Version=?3 UNION SELECT ic.Data, ch.OrderKey FROM IssuerCerts ic, Chains ch WHERE ic.RowID IN ( - SELECT IssCertsRowID FROM Chains WHERE IsdID=? AND AsID=? AND Version=? + SELECT IssCertsRowID FROM Chains WHERE IsdID=?1 AND AsID=?2 AND Version=?3 ) ORDER BY ch.OrderKey ` getChainMaxVersionStr = ` - SELECT Data, 0 FROM LeafCerts WHERE IsdID=? AND AsID=? AND Version=( - SELECT MAX(Version) FROM Chains WHERE IsdID=? AND AsID=? + SELECT Data, 0 FROM LeafCerts WHERE IsdID=?1 AND AsID=?2 AND Version=( + SELECT MAX(Version) FROM Chains WHERE IsdID=?1 AND AsID=?2 ) UNION SELECT ic.Data, ch.OrderKey FROM IssuerCerts ic, Chains ch WHERE ic.RowID IN ( - SELECT IssCertsRowID FROM Chains WHERE IsdID=? AND AsID=? AND Version=( - SELECT MAX(Version) FROM Chains WHERE IsdID=? AND AsID=? + SELECT IssCertsRowID FROM Chains WHERE IsdID=?1 AND AsID=?2 AND Version=( + SELECT MAX(Version) FROM Chains WHERE IsdID=?1 AND AsID=?2 ) ) ORDER BY ch.OrderKey @@ -152,99 +151,59 @@ const ( ` ) -// DB is a database containing Certificates, Chains and TRCs, stored in JSON format. -// -// On errors, GetXxx methods return nil and the error. If no error occurred, -// but the database query yielded 0 results, the first returned value is nil. -// GetXxxCtx methods are the context equivalents of GetXxx. -type DB struct { - sync.RWMutex - db *sql.DB - getIssCertVersionStmt *sql.Stmt - getIssCertMaxVersionStmt *sql.Stmt - getIssCertRowIDStmt *sql.Stmt - insertIssCertStmt *sql.Stmt - getLeafCertVersionStmt *sql.Stmt - getLeafCertMaxVersionStmt *sql.Stmt - insertLeafCertStmt *sql.Stmt - getChainVersionStmt *sql.Stmt - getChainMaxVersionStmt *sql.Stmt - getAllChainsStmt *sql.Stmt - insertChainStmt *sql.Stmt - getTRCVersionStmt *sql.Stmt - getTRCMaxVersionStmt *sql.Stmt - insertTRCStmt *sql.Stmt - getAllTRCsStmt *sql.Stmt +type sqler interface { + ExecContext(context.Context, string, ...interface{}) (sql.Result, error) + QueryContext(context.Context, string, ...interface{}) (*sql.Rows, error) + QueryRowContext(context.Context, string, ...interface{}) *sql.Row +} + +type tdb struct { + *executor + db *sql.DB } func New(path string) (trustdb.TrustDB, error) { var err error - db := &DB{} + db := &tdb{} + db.executor = &executor{} if db.db, err = sqlite.New(path, Schema, SchemaVersion); err != nil { return nil, err } - // On future errors, close the sql database before exiting - defer func() { - if err != nil { - db.db.Close() - } - }() - if db.getIssCertVersionStmt, err = db.db.Prepare(getIssCertVersionStr); err != nil { - return nil, common.NewBasicError("Unable to prepare getIssCertVersion", err) - } - if db.getIssCertMaxVersionStmt, err = db.db.Prepare(getIssCertMaxVersionStr); err != nil { - return nil, common.NewBasicError("Unable to prepare getIssCertMaxVersion", err) - } - if db.getIssCertRowIDStmt, err = db.db.Prepare(getIssCertRowIDStr); err != nil { - return nil, common.NewBasicError("Unable to prepare getIssCertRowID", err) - } - if db.insertIssCertStmt, err = db.db.Prepare(insertIssCertStr); err != nil { - return nil, common.NewBasicError("Unable to prepare insertIssCert", err) - } - if db.getLeafCertVersionStmt, err = db.db.Prepare(getLeafCertVersionStr); err != nil { - return nil, common.NewBasicError("Unable to prepare getLeafCertVersion", err) - } - if db.getLeafCertMaxVersionStmt, err = db.db.Prepare(getLeafCertMaxVersionStr); err != nil { - return nil, common.NewBasicError("Unable to prepare getLeafCertMaxVersion", err) - } - if db.insertLeafCertStmt, err = db.db.Prepare(insertLeafCertStr); err != nil { - return nil, common.NewBasicError("Unable to prepare insertLeafCert", err) - } - if db.getChainVersionStmt, err = db.db.Prepare(getChainVersionStr); err != nil { - return nil, common.NewBasicError("Unable to prepare getChainVersion", err) - } - if db.getChainMaxVersionStmt, err = db.db.Prepare(getChainMaxVersionStr); err != nil { - return nil, common.NewBasicError("Unable to prepare getChainMaxVersion", err) - } - if db.getAllChainsStmt, err = db.db.Prepare(getAllChainsStr); err != nil { - return nil, common.NewBasicError("Unable to prepare getAllChains", err) - } - if db.insertChainStmt, err = db.db.Prepare(insertChainStr); err != nil { - return nil, common.NewBasicError("Unable to prepare insertChain", err) - } - if db.getTRCVersionStmt, err = db.db.Prepare(getTRCVersionStr); err != nil { - return nil, common.NewBasicError("Unable to prepare getTRCVersion", err) - } - if db.getTRCMaxVersionStmt, err = db.db.Prepare(getTRCMaxVersionStr); err != nil { - return nil, common.NewBasicError("Unable to prepare getTRCMaxVersion", err) - } - if db.insertTRCStmt, err = db.db.Prepare(insertTRCStr); err != nil { - return nil, common.NewBasicError("Unable to prepare insertTRC", err) - } - if db.getAllTRCsStmt, err = db.db.Prepare(getAllTRCsStr); err != nil { - return nil, common.NewBasicError("Unable to prepare getAllTRCs", err) - } + db.executor.db = db.db return db, nil } // Close closes the database connection. -func (db *DB) Close() error { +func (db *tdb) Close() error { return db.db.Close() } +// BeginTransaction starts a new transaction. +func (db *tdb) BeginTransaction(ctx context.Context, + opts *sql.TxOptions) (trustdb.Transaction, error) { + + db.Lock() + defer db.Unlock() + tx, err := db.db.BeginTx(ctx, opts) + if err != nil { + return nil, common.NewBasicError("Failed to create transaction", err) + } + return &transaction{ + executor: &executor{ + db: tx, + }, + tx: tx, + }, nil +} + +type executor struct { + sync.RWMutex + db sqler +} + // GetIssCertVersion returns the specified version of the issuer certificate for // ia. If version is scrypto.LatestVer, this is equivalent to GetIssCertMaxVersion. -func (db *DB) GetIssCertVersion(ctx context.Context, ia addr.IA, +func (db *executor) GetIssCertVersion(ctx context.Context, ia addr.IA, version uint64) (*cert.Certificate, error) { if version == scrypto.LatestVer { @@ -253,38 +212,31 @@ func (db *DB) GetIssCertVersion(ctx context.Context, ia addr.IA, db.RLock() defer db.RUnlock() var raw common.RawBytes - err := db.getIssCertVersionStmt.QueryRowContext(ctx, ia.I, ia.A, version).Scan(&raw) + err := db.db.QueryRowContext(ctx, getIssCertVersionStr, ia.I, ia.A, version).Scan(&raw) return parseCert(raw, ia, version, err) } // GetIssCertMaxVersion returns the max version of the issuer certificate for ia. -func (db *DB) GetIssCertMaxVersion(ctx context.Context, ia addr.IA) (*cert.Certificate, error) { +func (db *executor) GetIssCertMaxVersion(ctx context.Context, + ia addr.IA) (*cert.Certificate, error) { + db.RLock() defer db.RUnlock() var raw common.RawBytes - err := db.getIssCertMaxVersionStmt.QueryRowContext(ctx, ia.I, ia.A).Scan(&raw) + err := db.db.QueryRowContext(ctx, getIssCertMaxVersionStr, ia.I, ia.A).Scan(&raw) return parseCert(raw, ia, scrypto.LatestVer, err) } // InsertIssCert inserts the issuer certificate. -func (db *DB) InsertIssCert(ctx context.Context, crt *cert.Certificate) (int64, error) { - raw, err := crt.JSON(false) - if err != nil { - return 0, common.NewBasicError("Unable to convert to JSON", err) - } +func (db *executor) InsertIssCert(ctx context.Context, crt *cert.Certificate) (int64, error) { db.Lock() defer db.Unlock() - res, err := db.insertIssCertStmt.ExecContext(ctx, - crt.Subject.I, crt.Subject.A, crt.Version, raw) - if err != nil { - return 0, err - } - return res.RowsAffected() + return insertIssCert(ctx, db.db, crt) } // GetLeafCertVersion returns the specified version of the leaf certificate for // ia. If version is scrypto.LatestVer, this is equivalent to GetLeafCertMaxVersion. -func (db *DB) GetLeafCertVersion(ctx context.Context, ia addr.IA, +func (db *executor) GetLeafCertVersion(ctx context.Context, ia addr.IA, version uint64) (*cert.Certificate, error) { if version == scrypto.LatestVer { @@ -293,56 +245,31 @@ func (db *DB) GetLeafCertVersion(ctx context.Context, ia addr.IA, db.RLock() defer db.RUnlock() var raw common.RawBytes - err := db.getLeafCertVersionStmt.QueryRowContext(ctx, ia.I, ia.A, version).Scan(&raw) + err := db.db.QueryRowContext(ctx, getLeafCertVersionStr, ia.I, ia.A, version).Scan(&raw) return parseCert(raw, ia, version, err) } // GetLeafCertMaxVersion returns the max version of the leaf certificate for ia. -func (db *DB) GetLeafCertMaxVersion(ctx context.Context, ia addr.IA) (*cert.Certificate, error) { +func (db *executor) GetLeafCertMaxVersion(ctx context.Context, + ia addr.IA) (*cert.Certificate, error) { + db.RLock() defer db.RUnlock() var raw common.RawBytes - err := db.getLeafCertMaxVersionStmt.QueryRowContext(ctx, ia.I, ia.A).Scan(&raw) + err := db.db.QueryRowContext(ctx, getLeafCertMaxVersionStr, ia.I, ia.A).Scan(&raw) return parseCert(raw, ia, scrypto.LatestVer, err) } -func parseCert(raw common.RawBytes, ia addr.IA, v uint64, err error) (*cert.Certificate, error) { - if err == sql.ErrNoRows { - return nil, nil - } - if err != nil { - return nil, common.NewBasicError("Database access error", err) - } - crt, err := cert.CertificateFromRaw(raw) - if err != nil { - if v == scrypto.LatestVer { - return nil, common.NewBasicError("Cert parse error", err, "ia", ia, "version", "max") - } else { - return nil, common.NewBasicError("Cert parse error", err, "ia", ia, "version", v) - } - } - return crt, nil -} - // InsertLeafCert inserts the leaf certificate. -func (db *DB) InsertLeafCert(ctx context.Context, crt *cert.Certificate) (int64, error) { - raw, err := crt.JSON(false) - if err != nil { - return 0, common.NewBasicError("Unable to convert to JSON", err) - } +func (db *executor) InsertLeafCert(ctx context.Context, crt *cert.Certificate) (int64, error) { db.Lock() defer db.Unlock() - res, err := db.insertLeafCertStmt.ExecContext(ctx, - crt.Subject.I, crt.Subject.A, crt.Version, raw) - if err != nil { - return 0, err - } - return res.RowsAffected() + return insertLeafCert(ctx, db.db, crt) } // GetChainVersion returns the specified version of the certificate chain for // ia. If version is scrypto.LatestVer, this is equivalent to GetChainMaxVersion. -func (db *DB) GetChainVersion(ctx context.Context, ia addr.IA, +func (db *executor) GetChainVersion(ctx context.Context, ia addr.IA, version uint64) (*cert.Chain, error) { if version == scrypto.LatestVer { @@ -350,7 +277,7 @@ func (db *DB) GetChainVersion(ctx context.Context, ia addr.IA, } db.RLock() defer db.RUnlock() - rows, err := db.getChainVersionStmt.QueryContext(ctx, ia.I, ia.A, version, ia.I, ia.A, version) + rows, err := db.db.QueryContext(ctx, getChainVersionStr, ia.I, ia.A, version) if err != nil { return nil, err } @@ -359,11 +286,10 @@ func (db *DB) GetChainVersion(ctx context.Context, ia addr.IA, } // GetChainMaxVersion returns the max version of the chain for ia. -func (db *DB) GetChainMaxVersion(ctx context.Context, ia addr.IA) (*cert.Chain, error) { +func (db *executor) GetChainMaxVersion(ctx context.Context, ia addr.IA) (*cert.Chain, error) { db.RLock() defer db.RUnlock() - rows, err := db.getChainMaxVersionStmt.QueryContext(ctx, ia.I, ia.A, ia.I, ia.A, ia.I, ia.A, - ia.I, ia.A) + rows, err := db.db.QueryContext(ctx, getChainMaxVersionStr, ia.I, ia.A) if err != nil { return nil, err } @@ -371,36 +297,10 @@ func (db *DB) GetChainMaxVersion(ctx context.Context, ia addr.IA) (*cert.Chain, return parseChain(rows, err) } -func parseChain(rows *sql.Rows, err error) (*cert.Chain, error) { - if err != nil { - return nil, common.NewBasicError("Database access error", err) - } - certs := make([]*cert.Certificate, 0, 2) - var raw common.RawBytes - var pos int64 - for i := 0; rows.Next(); i++ { - if err = rows.Scan(&raw, &pos); err != nil { - return nil, err - } - crt, err := cert.CertificateFromRaw(raw) - if err != nil { - return nil, err - } - certs = append(certs, crt) - } - if err = rows.Err(); err != nil { - return nil, err - } - if len(certs) == 0 { - return nil, nil - } - return cert.ChainFromSlice(certs) -} - -func (db *DB) GetAllChains(ctx context.Context) ([]*cert.Chain, error) { +func (db *executor) GetAllChains(ctx context.Context) ([]*cert.Chain, error) { db.RLock() defer db.RUnlock() - rows, err := db.getAllChainsStmt.QueryContext(ctx) + rows, err := db.db.QueryContext(ctx, getAllChainsStr) if err != nil { return nil, common.NewBasicError("Database access error", err) } @@ -453,44 +353,31 @@ func (db *DB) GetAllChains(ctx context.Context) ([]*cert.Chain, error) { // InsertChain inserts chain into the database. The first return value is the // number of rows affected. -func (db *DB) InsertChain(ctx context.Context, chain *cert.Chain) (int64, error) { - if _, err := db.InsertLeafCert(ctx, chain.Leaf); err != nil { +func (db *executor) InsertChain(ctx context.Context, chain *cert.Chain) (int64, error) { + db.Lock() + defer db.Unlock() + if _, err := insertLeafCert(ctx, db.db, chain.Leaf); err != nil { return 0, err } - if _, err := db.InsertIssCert(ctx, chain.Issuer); err != nil { + if _, err := insertIssCert(ctx, db.db, chain.Issuer); err != nil { return 0, err } - db.Lock() - defer db.Unlock() ia, ver := chain.IAVer() - rowId, err := db.getIssCertRowIDCtx(ctx, chain.Issuer.Subject, chain.Issuer.Version) + rowId, err := getIssCertRowIDCtx(ctx, db.db, chain.Issuer.Subject, chain.Issuer.Version) if err != nil { return 0, err } // NOTE(roosd): Adding multiple rows to Chains table has to be done in a transaction. - res, err := db.insertChainStmt.ExecContext(ctx, ia.I, ia.A, ver, 1, rowId) + res, err := db.db.ExecContext(ctx, insertChainStr, ia.I, ia.A, ver, 1, rowId) if err != nil { return 0, err } return res.RowsAffected() } -func (db *DB) getIssCertRowIDCtx(ctx context.Context, ia addr.IA, ver uint64) (int64, error) { - var rowId int64 - err := db.getIssCertRowIDStmt.QueryRowContext(ctx, ia.I, ia.A, ver).Scan(&rowId) - if err == sql.ErrNoRows { - return 0, common.NewBasicError("Unable to get RowID of issuer certificate", nil, - "ia", ia, "ver", ver) - } - if err != nil { - return 0, common.NewBasicError("Database access error", err) - } - return rowId, nil -} - // GetTRCVersion returns the specified version of the TRC for // isd. If version is scrypto.LatestVer, this is equivalent to GetTRCMaxVersion. -func (db *DB) GetTRCVersion(ctx context.Context, +func (db *executor) GetTRCVersion(ctx context.Context, isd addr.ISD, version uint64) (*trc.TRC, error) { if version == scrypto.LatestVer { @@ -499,7 +386,7 @@ func (db *DB) GetTRCVersion(ctx context.Context, db.RLock() defer db.RUnlock() var raw common.RawBytes - err := db.getTRCVersionStmt.QueryRowContext(ctx, isd, version).Scan(&raw) + err := db.db.QueryRowContext(ctx, getTRCVersionStr, isd, version).Scan(&raw) if err == sql.ErrNoRows { return nil, nil } @@ -514,11 +401,11 @@ func (db *DB) GetTRCVersion(ctx context.Context, } // GetTRCMaxVersion returns the max version of the TRC for ia. -func (db *DB) GetTRCMaxVersion(ctx context.Context, isd addr.ISD) (*trc.TRC, error) { +func (db *executor) GetTRCMaxVersion(ctx context.Context, isd addr.ISD) (*trc.TRC, error) { db.RLock() defer db.RUnlock() var raw common.RawBytes - err := db.getTRCMaxVersionStmt.QueryRowContext(ctx, isd).Scan(&raw) + err := db.db.QueryRowContext(ctx, getTRCMaxVersionStr, isd).Scan(&raw) if err == sql.ErrNoRows { return nil, nil } @@ -534,14 +421,14 @@ func (db *DB) GetTRCMaxVersion(ctx context.Context, isd addr.ISD) (*trc.TRC, err // InsertTRC inserts trcobj into the database. The first return value is the // number of rows affected. -func (db *DB) InsertTRC(ctx context.Context, trcobj *trc.TRC) (int64, error) { +func (db *executor) InsertTRC(ctx context.Context, trcobj *trc.TRC) (int64, error) { raw, err := trcobj.JSON(false) if err != nil { return 0, common.NewBasicError("Unable to convert to JSON", err) } db.Lock() defer db.Unlock() - res, err := db.insertTRCStmt.ExecContext(ctx, trcobj.ISD, trcobj.Version, raw) + res, err := db.db.ExecContext(ctx, insertTRCStr, trcobj.ISD, trcobj.Version, raw) if err != nil { return 0, err } @@ -549,10 +436,10 @@ func (db *DB) InsertTRC(ctx context.Context, trcobj *trc.TRC) (int64, error) { } // GetAllTRCs fetches all TRCs from the database. -func (db *DB) GetAllTRCs(ctx context.Context) ([]*trc.TRC, error) { +func (db *executor) GetAllTRCs(ctx context.Context) ([]*trc.TRC, error) { db.RLock() defer db.RUnlock() - rows, err := db.getAllTRCsStmt.QueryContext(ctx) + rows, err := db.db.QueryContext(ctx, getAllTRCsStr) if err != nil { return nil, common.NewBasicError("Database access error", err) } @@ -572,3 +459,121 @@ func (db *DB) GetAllTRCs(ctx context.Context) ([]*trc.TRC, error) { } return trcs, nil } + +type transaction struct { + *executor + tx *sql.Tx +} + +func (db *transaction) Commit() error { + db.Lock() + defer db.Unlock() + if db.tx == nil { + return common.NewBasicError("Transaction already done", nil) + } + err := db.tx.Commit() + if err != nil { + return common.NewBasicError("Failed to commit transaction", err) + } + db.tx = nil + return nil +} + +func (db *transaction) Rollback() error { + db.Lock() + defer db.Unlock() + if db.tx == nil { + return common.NewBasicError("Transaction already done", nil) + } + err := db.tx.Rollback() + db.tx = nil + if err != nil { + return common.NewBasicError("Failed to rollback transaction", err) + } + return nil +} + +func insertIssCert(ctx context.Context, db sqler, crt *cert.Certificate) (int64, error) { + raw, err := crt.JSON(false) + if err != nil { + return 0, common.NewBasicError("Unable to convert to JSON", err) + } + res, err := db.ExecContext(ctx, insertIssCertStr, + crt.Subject.I, crt.Subject.A, crt.Version, raw) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func parseCert(raw common.RawBytes, ia addr.IA, v uint64, err error) (*cert.Certificate, error) { + if err == sql.ErrNoRows { + return nil, nil + } + if err != nil { + return nil, common.NewBasicError("Database access error", err) + } + crt, err := cert.CertificateFromRaw(raw) + if err != nil { + if v == scrypto.LatestVer { + return nil, common.NewBasicError("Cert parse error", err, "ia", ia, "version", "max") + } else { + return nil, common.NewBasicError("Cert parse error", err, "ia", ia, "version", v) + } + } + return crt, nil +} + +func insertLeafCert(ctx context.Context, db sqler, crt *cert.Certificate) (int64, error) { + raw, err := crt.JSON(false) + if err != nil { + return 0, common.NewBasicError("Unable to convert to JSON", err) + } + res, err := db.ExecContext(ctx, insertLeafCertStr, + crt.Subject.I, crt.Subject.A, crt.Version, raw) + if err != nil { + return 0, err + } + return res.RowsAffected() +} + +func parseChain(rows *sql.Rows, err error) (*cert.Chain, error) { + if err != nil { + return nil, common.NewBasicError("Database access error", err) + } + certs := make([]*cert.Certificate, 0, 2) + var raw common.RawBytes + var pos int64 + for i := 0; rows.Next(); i++ { + if err = rows.Scan(&raw, &pos); err != nil { + return nil, err + } + crt, err := cert.CertificateFromRaw(raw) + if err != nil { + return nil, err + } + certs = append(certs, crt) + } + if err = rows.Err(); err != nil { + return nil, err + } + if len(certs) == 0 { + return nil, nil + } + return cert.ChainFromSlice(certs) +} + +func getIssCertRowIDCtx(ctx context.Context, db sqler, + ia addr.IA, ver uint64) (int64, error) { + + var rowId int64 + err := db.QueryRowContext(ctx, getIssCertRowIDStr, ia.I, ia.A, ver).Scan(&rowId) + if err == sql.ErrNoRows { + return 0, common.NewBasicError("Unable to get RowID of issuer certificate", nil, + "ia", ia, "ver", ver) + } + if err != nil { + return 0, common.NewBasicError("Database access error", err) + } + return rowId, nil +} diff --git a/go/lib/infra/modules/trust/trustdb/trustdbtest/trustdbtest.go b/go/lib/infra/modules/trust/trustdb/trustdbtest/trustdbtest.go index 3424520e9d..0fa41d3936 100644 --- a/go/lib/infra/modules/trust/trustdb/trustdbtest/trustdbtest.go +++ b/go/lib/infra/modules/trust/trustdb/trustdbtest/trustdbtest.go @@ -33,6 +33,11 @@ var ( Timeout = time.Second ) +type rwTrustDB interface { + trustdb.Read + trustdb.Write +} + // TestTrustDB should be used to test any implementation of the TrustDB interface. // An implementation of the TrustDB interface should at least have on test method that calls // this test-suite. The calling test code should have a top level Convey block. @@ -40,7 +45,7 @@ var ( // setup should return a TrustDB in a clean state, i.e. no entries in the DB. // cleanup can be used to release any resources that have been allocated during setup. func TestTrustDB(t *testing.T, setup func() trustdb.TrustDB, cleanup func(trustdb.TrustDB)) { - testWrapper := func(test func(*testing.T, trustdb.TrustDB)) func() { + testWrapper := func(test func(*testing.T, rwTrustDB)) func() { return func() { db := setup() test(t, db) @@ -53,9 +58,39 @@ func TestTrustDB(t *testing.T, setup func() trustdb.TrustDB, cleanup func(trustd Convey("TestLeafCert", testWrapper(testLeafCert)) Convey("TestChain", testWrapper(testChain)) Convey("TestChainGetAll", testWrapper(testChainGetAll)) + // Now test everything with a transaction as well. + txTestWrapper := func(test func(*testing.T, rwTrustDB)) func() { + return func() { + ctx, cancelF := context.WithTimeout(context.Background(), Timeout) + defer cancelF() + db := setup() + tx, err := db.BeginTransaction(ctx, nil) + xtest.FailOnErr(t, err) + test(t, tx) + err = tx.Commit() + xtest.FailOnErr(t, err) + cleanup(db) + } + } + trustDbTestWrapper := func(test func(*testing.T, trustdb.TrustDB)) func() { + return func() { + db := setup() + test(t, db) + cleanup(db) + } + } + Convey("WithTransaction", func() { + Convey("TestTRC", txTestWrapper(testTRC)) + Convey("TestTRCGetAll", txTestWrapper(testTRCGetAll)) + Convey("TestIssCert", txTestWrapper(testIssCert)) + Convey("TestLeafCert", txTestWrapper(testLeafCert)) + Convey("TestChain", txTestWrapper(testChain)) + Convey("TestChainGetAll", txTestWrapper(testChainGetAll)) + Convey("TransactionRollback", trustDbTestWrapper(testRollback)) + }) } -func testTRC(t *testing.T, db trustdb.TrustDB) { +func testTRC(t *testing.T, db rwTrustDB) { Convey("Initialize DB and load TRC", func() { ctx, cancelF := context.WithTimeout(context.Background(), Timeout) defer cancelF() @@ -100,7 +135,7 @@ func testTRC(t *testing.T, db trustdb.TrustDB) { }) } -func testTRCGetAll(t *testing.T, db trustdb.TrustDB) { +func testTRCGetAll(t *testing.T, db rwTrustDB) { Convey("Test get all TRCs", func() { ctx, cancelF := context.WithTimeout(context.Background(), time.Second) defer cancelF() @@ -126,7 +161,7 @@ func testTRCGetAll(t *testing.T, db trustdb.TrustDB) { } func insertTRCFromFile(t *testing.T, ctx context.Context, - fName string, db trustdb.TrustDB) *trc.TRC { + fName string, db rwTrustDB) *trc.TRC { trcobj, err := trc.TRCFromFile("../trustdbtest/"+fName, false) xtest.FailOnErr(t, err) @@ -135,7 +170,7 @@ func insertTRCFromFile(t *testing.T, ctx context.Context, return trcobj } -func testIssCert(t *testing.T, db trustdb.TrustDB) { +func testIssCert(t *testing.T, db rwTrustDB) { Convey("Initialize DB and load issuer Cert", func() { ctx, cancelF := context.WithTimeout(context.Background(), Timeout) defer cancelF() @@ -184,7 +219,7 @@ func testIssCert(t *testing.T, db trustdb.TrustDB) { }) } -func testLeafCert(t *testing.T, db trustdb.TrustDB) { +func testLeafCert(t *testing.T, db rwTrustDB) { Convey("Initialize DB and load leaf Cert", func() { ctx, cancelF := context.WithTimeout(context.Background(), Timeout) defer cancelF() @@ -233,7 +268,7 @@ func testLeafCert(t *testing.T, db trustdb.TrustDB) { }) } -func testChain(t *testing.T, db trustdb.TrustDB) { +func testChain(t *testing.T, db rwTrustDB) { Convey("Initialize DB and load Chain", func() { ctx, cancelF := context.WithTimeout(context.Background(), Timeout) defer cancelF() @@ -277,7 +312,7 @@ func testChain(t *testing.T, db trustdb.TrustDB) { }) } -func testChainGetAll(t *testing.T, db trustdb.TrustDB) { +func testChainGetAll(t *testing.T, db rwTrustDB) { Convey("Test get all chains", func() { ctx, cancelF := context.WithTimeout(context.Background(), time.Second) defer cancelF() @@ -302,8 +337,28 @@ func testChainGetAll(t *testing.T, db trustdb.TrustDB) { }) } +func testRollback(t *testing.T, db trustdb.TrustDB) { + Convey("Test transaction rollback", func() { + ctx, cancelF := context.WithTimeout(context.Background(), time.Second) + defer cancelF() + tx, err := db.BeginTransaction(ctx, nil) + SoMsg("Transaction begin should not fail", err, ShouldBeNil) + trcobj, err := trc.TRCFromFile("../trustdbtest/testdata/ISD1-V1.trc", false) + SoMsg("err trc", err, ShouldBeNil) + SoMsg("trc", trcobj, ShouldNotBeNil) + cnt, err := tx.InsertTRC(ctx, trcobj) + SoMsg("TRC insert should not fail", err, ShouldBeNil) + SoMsg("Insert count", cnt, ShouldEqual, 1) + err = tx.Rollback() + SoMsg("Rollback should not fail", err, ShouldBeNil) + trcs, err := db.GetAllTRCs(ctx) + SoMsg("GetAllTRCs should work", err, ShouldBeNil) + SoMsg("No TRCs expected", len(trcs), ShouldEqual, 0) + }) +} + func insertChainFromFile(t *testing.T, ctx context.Context, - fName string, db trustdb.TrustDB) *cert.Chain { + fName string, db rwTrustDB) *cert.Chain { chain, err := cert.ChainFromFile("../trustdbtest/"+fName, false) xtest.FailOnErr(t, err)