diff --git a/tm2/pkg/crypto/keys/client/add.go b/tm2/pkg/crypto/keys/client/add.go index a4089316e92..71dc6f03090 100644 --- a/tm2/pkg/crypto/keys/client/add.go +++ b/tm2/pkg/crypto/keys/client/add.go @@ -156,8 +156,7 @@ func execAdd(cfg *addCfg, args []string, io commands.IO) error { return err } - _, err = kb.GetByName(name) - if err == nil { + if has, err := kb.HasByName(name); err == nil && has { // account exists, ask for user confirmation response, err2 := io.GetConfirmation(fmt.Sprintf("Override the existing name %s", name)) if err2 != nil { diff --git a/tm2/pkg/crypto/keys/keybase.go b/tm2/pkg/crypto/keys/keybase.go index 06eaf54f503..31b0012c433 100644 --- a/tm2/pkg/crypto/keys/keybase.go +++ b/tm2/pkg/crypto/keys/keybase.go @@ -160,14 +160,32 @@ func (kb dbKeybase) List() ([]Info, error) { return res, nil } +// HasByNameOrAddress checks if a key with the name or bech32 string address is in the keybase. +func (kb dbKeybase) HasByNameOrAddress(nameOrBech32 string) (bool, error) { + address, err := crypto.AddressFromBech32(nameOrBech32) + if err != nil { + return kb.HasByName(nameOrBech32) + } + return kb.HasByAddress(address) +} + +// HasByName checks if a key with the name is in the keybase. +func (kb dbKeybase) HasByName(name string) (bool, error) { + return kb.db.Has(infoKey(name)), nil +} + +// HasByAddress checks if a key with the address is in the keybase. +func (kb dbKeybase) HasByAddress(address crypto.Address) (bool, error) { + return kb.db.Has(addrKey(address)), nil +} + // Get returns the public information about one key. func (kb dbKeybase) GetByNameOrAddress(nameOrBech32 string) (Info, error) { addr, err := crypto.AddressFromBech32(nameOrBech32) if err != nil { return kb.GetByName(nameOrBech32) - } else { - return kb.GetByAddress(addr) } + return kb.GetByAddress(addr) } func (kb dbKeybase) GetByName(name string) (Info, error) { diff --git a/tm2/pkg/crypto/keys/keybase_test.go b/tm2/pkg/crypto/keys/keybase_test.go index e80c6c3cd03..d7660ac38f1 100644 --- a/tm2/pkg/crypto/keys/keybase_test.go +++ b/tm2/pkg/crypto/keys/keybase_test.go @@ -89,8 +89,9 @@ func TestKeyManagement(t *testing.T) { assert.Empty(t, l) // create some keys - _, err = cstore.GetByName(n1) - require.Error(t, err) + has, err := cstore.HasByName(n1) + require.NoError(t, err) + require.False(t, has) i, err := cstore.CreateAccount(n1, mn1, bip39Passphrase, p1, 0, 0) require.NoError(t, err) require.Equal(t, n1, i.GetName()) @@ -100,10 +101,16 @@ func TestKeyManagement(t *testing.T) { // we can get these keys i2, err := cstore.GetByName(n2) require.NoError(t, err) - _, err = cstore.GetByName(n3) - require.NotNil(t, err) - _, err = cstore.GetByAddress(toAddr(i2)) + has, err = cstore.HasByName(n3) + require.NoError(t, err) + require.False(t, has) + has, err = cstore.HasByAddress(toAddr(i2)) + require.NoError(t, err) + require.True(t, has) + // Also check with HasByNameOrAddress + has, err = cstore.HasByNameOrAddress(crypto.AddressToBech32(toAddr(i2))) require.NoError(t, err) + require.True(t, has) addr, err := crypto.AddressFromBech32("g1frtkxv37nq7arvyz5p0mtjqq7hwuvd4dnt892p") require.NoError(t, err) _, err = cstore.GetByAddress(addr) @@ -127,8 +134,9 @@ func TestKeyManagement(t *testing.T) { keyS, err = cstore.List() require.NoError(t, err) require.Equal(t, 1, len(keyS)) - _, err = cstore.GetByName(n1) - require.Error(t, err) + has, err = cstore.HasByName(n1) + require.NoError(t, err) + require.False(t, has) // create an offline key o1 := "offline" @@ -388,8 +396,9 @@ func TestSeedPhrase(t *testing.T) { // now, let us delete this key err = cstore.Delete(n1, p1, false) require.Nil(t, err, "%+v", err) - _, err = cstore.GetByName(n1) - require.NotNil(t, err) + has, err := cstore.HasByName(n1) + require.NoError(t, err) + require.False(t, has) } func ExampleNew() { diff --git a/tm2/pkg/crypto/keys/lazy_keybase.go b/tm2/pkg/crypto/keys/lazy_keybase.go index f7f9e229980..62e88d9a8e2 100644 --- a/tm2/pkg/crypto/keys/lazy_keybase.go +++ b/tm2/pkg/crypto/keys/lazy_keybase.go @@ -37,6 +37,36 @@ func (lkb lazyKeybase) List() ([]Info, error) { return NewDBKeybase(db).List() } +func (lkb lazyKeybase) HasByNameOrAddress(nameOrBech32 string) (bool, error) { + db, err := db.NewDB(lkb.name, dbBackend, lkb.dir) + if err != nil { + return false, err + } + defer db.Close() + + return NewDBKeybase(db).HasByNameOrAddress(nameOrBech32) +} + +func (lkb lazyKeybase) HasByName(name string) (bool, error) { + db, err := db.NewDB(lkb.name, dbBackend, lkb.dir) + if err != nil { + return false, err + } + defer db.Close() + + return NewDBKeybase(db).HasByName(name) +} + +func (lkb lazyKeybase) HasByAddress(address crypto.Address) (bool, error) { + db, err := db.NewDB(lkb.name, dbBackend, lkb.dir) + if err != nil { + return false, err + } + defer db.Close() + + return NewDBKeybase(db).HasByAddress(address) +} + func (lkb lazyKeybase) GetByNameOrAddress(nameOrBech32 string) (Info, error) { db, err := db.NewDB(lkb.name, dbBackend, lkb.dir) if err != nil { diff --git a/tm2/pkg/crypto/keys/types.go b/tm2/pkg/crypto/keys/types.go index bba3a917b69..c5d33023a0a 100644 --- a/tm2/pkg/crypto/keys/types.go +++ b/tm2/pkg/crypto/keys/types.go @@ -13,6 +13,9 @@ import ( type Keybase interface { // CRUD on the keystore List() ([]Info, error) + HasByNameOrAddress(nameOrBech32 string) (bool, error) + HasByName(name string) (bool, error) + HasByAddress(address crypto.Address) (bool, error) GetByNameOrAddress(nameOrBech32 string) (Info, error) GetByName(name string) (Info, error) GetByAddress(address crypto.Address) (Info, error)