Skip to content
This repository has been archived by the owner on Aug 31, 2021. It is now read-only.

Commit

Permalink
Update address repo methods to not need a receiver
Browse files Browse the repository at this point in the history
  • Loading branch information
elizabethengelman committed Sep 11, 2019
1 parent 664ab66 commit 1119a7f
Show file tree
Hide file tree
Showing 4 changed files with 22 additions and 26 deletions.
8 changes: 3 additions & 5 deletions pkg/datastore/postgres/repositories/address_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -23,16 +23,14 @@ import (
"github.com/vulcanize/vulcanizedb/pkg/datastore/postgres"
)

type AddressRepository struct{}

const getOrCreateAddressQuery = `WITH addressId AS (
INSERT INTO addresses (address) VALUES ($1) ON CONFLICT DO NOTHING RETURNING id
)
SELECT id FROM addresses WHERE address = $1
UNION
SELECT id FROM addressId`

func (AddressRepository) GetOrCreateAddress(db *postgres.DB, address string) (int64, error) {
func GetOrCreateAddress(db *postgres.DB, address string) (int64, error) {
checksumAddress := getChecksumAddress(address)

var addressId int64
Expand All @@ -41,7 +39,7 @@ func (AddressRepository) GetOrCreateAddress(db *postgres.DB, address string) (in
return addressId, getOrCreateErr
}

func (AddressRepository) GetOrCreateAddressInTransaction(tx *sqlx.Tx, address string) (int64, error) {
func GetOrCreateAddressInTransaction(tx *sqlx.Tx, address string) (int64, error) {
checksumAddress := getChecksumAddress(address)

var addressId int64
Expand All @@ -50,7 +48,7 @@ func (AddressRepository) GetOrCreateAddressInTransaction(tx *sqlx.Tx, address st
return addressId, getOrCreateErr
}

func (AddressRepository) GetAddressById(db *postgres.DB, id int64) (string, error) {
func GetAddressById(db *postgres.DB, id int64) (string, error) {
var address string
getErr := db.Get(&address, `SELECT address FROM public.addresses WHERE id = $1`, id)
return address, getErr
Expand Down
36 changes: 17 additions & 19 deletions pkg/datastore/postgres/repositories/address_repository_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -32,13 +32,11 @@ import (
var _ = Describe("address lookup", func() {
var (
db *postgres.DB
repo repositories.AddressRepository
address = fakes.FakeAddress.Hex()
)
BeforeEach(func() {
db = test_config.NewTestDB(test_config.NewTestNode())
test_config.CleanTestDB(db)
repo = repositories.AddressRepository{}
})

AfterEach(func() {
Expand All @@ -52,7 +50,7 @@ var _ = Describe("address lookup", func() {

Describe("GetOrCreateAddress", func() {
It("creates an address record", func() {
addressId, createErr := repo.GetOrCreateAddress(db, address)
addressId, createErr := repositories.GetOrCreateAddress(db, address)
Expect(createErr).NotTo(HaveOccurred())

var actualAddress dbAddress
Expand All @@ -63,10 +61,10 @@ var _ = Describe("address lookup", func() {
})

It("returns the existing record id if the address already exists", func() {
createId, createErr := repo.GetOrCreateAddress(db, address)
createId, createErr := repositories.GetOrCreateAddress(db, address)
Expect(createErr).NotTo(HaveOccurred())

getId, getErr := repo.GetOrCreateAddress(db, address)
getId, getErr := repositories.GetOrCreateAddress(db, address)
Expect(getErr).NotTo(HaveOccurred())

var addressCount int
Expand All @@ -78,20 +76,20 @@ var _ = Describe("address lookup", func() {

It("gets upper-cased addresses", func() {
upperAddress := strings.ToUpper(address)
upperAddressId, createErr := repo.GetOrCreateAddress(db, upperAddress)
upperAddressId, createErr := repositories.GetOrCreateAddress(db, upperAddress)
Expect(createErr).NotTo(HaveOccurred())

mixedCaseAddressId, getErr := repo.GetOrCreateAddress(db, address)
mixedCaseAddressId, getErr := repositories.GetOrCreateAddress(db, address)
Expect(getErr).NotTo(HaveOccurred())
Expect(upperAddressId).To(Equal(mixedCaseAddressId))
})

It("gets lower-cased addresses", func() {
lowerAddress := strings.ToLower(address)
upperAddressId, createErr := repo.GetOrCreateAddress(db, lowerAddress)
upperAddressId, createErr := repositories.GetOrCreateAddress(db, lowerAddress)
Expect(createErr).NotTo(HaveOccurred())

mixedCaseAddressId, getErr := repo.GetOrCreateAddress(db, address)
mixedCaseAddressId, getErr := repositories.GetOrCreateAddress(db, address)
Expect(getErr).NotTo(HaveOccurred())
Expect(upperAddressId).To(Equal(mixedCaseAddressId))
})
Expand All @@ -112,7 +110,7 @@ var _ = Describe("address lookup", func() {
})

It("creates an address record", func() {
addressId, createErr := repo.GetOrCreateAddressInTransaction(tx, address)
addressId, createErr := repositories.GetOrCreateAddressInTransaction(tx, address)
Expect(createErr).NotTo(HaveOccurred())
commitErr := tx.Commit()
Expect(commitErr).NotTo(HaveOccurred())
Expand All @@ -125,10 +123,10 @@ var _ = Describe("address lookup", func() {
})

It("returns the existing record id if the address already exists", func() {
_, createErr := repo.GetOrCreateAddressInTransaction(tx, address)
_, createErr := repositories.GetOrCreateAddressInTransaction(tx, address)
Expect(createErr).NotTo(HaveOccurred())

_, getErr := repo.GetOrCreateAddressInTransaction(tx, address)
_, getErr := repositories.GetOrCreateAddressInTransaction(tx, address)
Expect(getErr).NotTo(HaveOccurred())
tx.Commit()

Expand All @@ -139,10 +137,10 @@ var _ = Describe("address lookup", func() {

It("gets upper-cased addresses", func() {
upperAddress := strings.ToUpper(address)
upperAddressId, createErr := repo.GetOrCreateAddressInTransaction(tx, upperAddress)
upperAddressId, createErr := repositories.GetOrCreateAddressInTransaction(tx, upperAddress)
Expect(createErr).NotTo(HaveOccurred())

mixedCaseAddressId, getErr := repo.GetOrCreateAddressInTransaction(tx, address)
mixedCaseAddressId, getErr := repositories.GetOrCreateAddressInTransaction(tx, address)
Expect(getErr).NotTo(HaveOccurred())
tx.Commit()

Expand All @@ -151,10 +149,10 @@ var _ = Describe("address lookup", func() {

It("gets lower-cased addresses", func() {
lowerAddress := strings.ToLower(address)
upperAddressId, createErr := repo.GetOrCreateAddressInTransaction(tx, lowerAddress)
upperAddressId, createErr := repositories.GetOrCreateAddressInTransaction(tx, lowerAddress)
Expect(createErr).NotTo(HaveOccurred())

mixedCaseAddressId, getErr := repo.GetOrCreateAddressInTransaction(tx, address)
mixedCaseAddressId, getErr := repositories.GetOrCreateAddressInTransaction(tx, address)
Expect(getErr).NotTo(HaveOccurred())
tx.Commit()

Expand All @@ -164,16 +162,16 @@ var _ = Describe("address lookup", func() {

Describe("GetAddressById", func() {
It("gets and address by it's id", func() {
addressId, createErr := repo.GetOrCreateAddress(db, address)
addressId, createErr := repositories.GetOrCreateAddress(db, address)
Expect(createErr).NotTo(HaveOccurred())

actualAddress, getErr := repo.GetAddressById(db, addressId)
actualAddress, getErr := repositories.GetAddressById(db, addressId)
Expect(getErr).NotTo(HaveOccurred())
Expect(actualAddress).To(Equal(address))
})

It("returns an error if the id doesn't exist", func() {
_, getErr := repo.GetAddressById(db, 0)
_, getErr := repositories.GetAddressById(db, 0)
Expect(getErr).To(HaveOccurred())
Expect(getErr).To(MatchError("sql: no rows in result set"))
})
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ func createLogs(logs []core.Log, receiptId int64, tx *sqlx.Tx) error {

func (FullSyncReceiptRepository) CreateFullSyncReceiptInTx(blockId int64, receipt core.Receipt, tx *sqlx.Tx) (int64, error) {
var receiptId int64
addressId, getAddressErr := AddressRepository{}.GetOrCreateAddressInTransaction(tx, receipt.ContractAddress)
addressId, getAddressErr := GetOrCreateAddressInTransaction(tx, receipt.ContractAddress)
if getAddressErr != nil {
logrus.Error("createReceipt: Error getting address id: ", getAddressErr)
return receiptId, getAddressErr
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ type HeaderSyncReceiptRepository struct{}

func (HeaderSyncReceiptRepository) CreateHeaderSyncReceiptInTx(headerID, transactionID int64, receipt core.Receipt, tx *sqlx.Tx) (int64, error) {
var receiptId int64
addressId, getAddressErr := AddressRepository{}.GetOrCreateAddressInTransaction(tx, receipt.ContractAddress)
addressId, getAddressErr := GetOrCreateAddressInTransaction(tx, receipt.ContractAddress)
if getAddressErr != nil {
log.Error("createReceipt: Error getting address id: ", getAddressErr)
return receiptId, getAddressErr
Expand Down

0 comments on commit 1119a7f

Please sign in to comment.