Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

refactor: pass mixing wallet to CoinJoin utils by reference #6440

Closed
wants to merge 1 commit into from
Closed
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
25 changes: 3 additions & 22 deletions src/coinjoin/client.cpp
Original file line number Diff line number Diff line change
@@ -1395,12 +1395,7 @@ bool CCoinJoinClientSession::PrepareDenominate(int nMinRounds, int nMaxRounds, s
++nSteps;
continue;
}
const auto pwallet = GetWallet(m_wallet.GetName());
if (!pwallet) {
strErrorRet ="Couldn't get wallet pointer";
return false;
}
scriptDenom = keyHolderStorage.AddKey(pwallet.get());
scriptDenom = keyHolderStorage.AddKey(m_wallet);
}
vecPSInOutPairsRet.emplace_back(entry, CTxOut(nDenomAmount, scriptDenom));
// step is complete
@@ -1484,14 +1479,7 @@ bool CCoinJoinClientSession::MakeCollateralAmounts(const CompactTallyItem& tally
return false;
}

const auto pwallet = GetWallet(m_wallet.GetName());

if (!pwallet) {
WalletCJLogPrint(m_wallet, "CCoinJoinClientSession::%s -- Couldn't get wallet pointer\n", __func__);
return false;
}

CTransactionBuilder txBuilder(pwallet, tallyItem);
CTransactionBuilder txBuilder(m_wallet, tallyItem);

WalletCJLogPrint(m_wallet, "CCoinJoinClientSession::%s -- Start %s\n", __func__, txBuilder.ToString());

@@ -1654,14 +1642,7 @@ bool CCoinJoinClientSession::CreateDenominated(CAmount nBalanceToDenominate, con
return false;
}

const auto pwallet = GetWallet(m_wallet.GetName());

if (!pwallet) {
WalletCJLogPrint(m_wallet, "CCoinJoinClientSession::%s -- Couldn't get wallet pointer\n", __func__);
return false;
}

CTransactionBuilder txBuilder(pwallet, tallyItem);
CTransactionBuilder txBuilder(m_wallet, tallyItem);

WalletCJLogPrint(m_wallet, "CCoinJoinClientSession::%s -- Start %s\n", __func__, txBuilder.ToString());

48 changes: 24 additions & 24 deletions src/coinjoin/util.cpp
Original file line number Diff line number Diff line change
@@ -20,8 +20,8 @@ inline unsigned int GetSizeOfCompactSizeDiff(uint64_t nSizePrev, uint64_t nSizeN
return ::GetSizeOfCompactSize(nSizeNew) - ::GetSizeOfCompactSize(nSizePrev);
}

CKeyHolder::CKeyHolder(CWallet* pwallet) :
reserveDestination(pwallet)
CKeyHolder::CKeyHolder(CWallet& wallet) :
reserveDestination(&wallet)
{
reserveDestination.GetReservedDestination(dest, false);
}
@@ -42,9 +42,9 @@ CScript CKeyHolder::GetScriptForDestination() const
}


CScript CKeyHolderStorage::AddKey(CWallet* pwallet)
CScript CKeyHolderStorage::AddKey(CWallet& wallet)
{
auto keyHolderPtr = std::make_unique<CKeyHolder>(pwallet);
auto keyHolderPtr = std::make_unique<CKeyHolder>(wallet);
auto script = keyHolderPtr->GetScriptForDestination();

LOCK(cs_storage);
@@ -87,14 +87,14 @@ void CKeyHolderStorage::ReturnAll()
}
}

CTransactionBuilderOutput::CTransactionBuilderOutput(CTransactionBuilder* pTxBuilderIn, std::shared_ptr<CWallet> pwalletIn, CAmount nAmountIn) :
CTransactionBuilderOutput::CTransactionBuilderOutput(CTransactionBuilder* pTxBuilderIn, CWallet& wallet, CAmount nAmountIn) :
pTxBuilder(pTxBuilderIn),
dest(pwalletIn.get()),
dest(&wallet),
nAmount(nAmountIn)
{
assert(pTxBuilder);
CTxDestination txdest;
LOCK(pwalletIn->cs_wallet);
LOCK(wallet.cs_wallet);
dest.GetReservedDestination(txdest, false);
script = ::GetScriptForDestination(txdest);
}
@@ -108,15 +108,15 @@ bool CTransactionBuilderOutput::UpdateAmount(const CAmount nNewAmount)
return true;
}

CTransactionBuilder::CTransactionBuilder(std::shared_ptr<CWallet> pwalletIn, const CompactTallyItem& tallyItemIn) :
pwallet(pwalletIn),
dummyReserveDestination(pwalletIn.get()),
CTransactionBuilder::CTransactionBuilder(CWallet& wallet, const CompactTallyItem& tallyItemIn) :
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm, CTransactionBuilder doesn't actually need shared_ptr to wallet, because it doesn't own a copy of wallet in memory... That's a good finding.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we justify this? It doesn't seem correct to me. Are we sure that a CTransactionBuilder object won't outlive the wallet creating it? it seems this is the reason to use shared_ptr no? Please help clarify

Copy link
Collaborator

@knst knst Dec 3, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

So far as I looked to code, instance of CTransactionBuilder is created in only 2 functions:

  • CCoinJoinClientSession::MakeCollateralAmounts
  • CCoinJoinClientSession::CreateDenominated

In both cases CCoinJoinClientSession works with CWallet object like with something that definitely exist, because both functions starts from AssertLockHeld(m_wallet.cs_wallet);

And in both cases instance of CTransactionBuilder is a local variable which is removed after usage, so, life of CWallet object is not extended long enough to provide any extra safety. Either CTransactionBuilder is misused and all CJ code is broken (because intensive usage of CWallet everywhere), either no need to keep shared_ptr to CWallet inside CTransactionBuilder at all.

Please, revive this PR, it seems as everything is fine with this changes.

UPD: #6441 seems as gives all good guarantees about CWallet memory ownership during client code

m_wallet(wallet),
dummyReserveDestination(&m_wallet),
tallyItem(tallyItemIn)
{
// Generate a feerate which will be used to consider if the remainder is dust and will go into fees or not
coinControl.m_discard_feerate = ::GetDiscardRate(*pwallet);
coinControl.m_discard_feerate = ::GetDiscardRate(m_wallet);
// Generate a feerate which will be used by calculations of this class and also by CWallet::CreateTransaction
coinControl.m_feerate = std::max(GetRequiredFeeRate(*pwallet), pwallet->m_pay_tx_fee);
coinControl.m_feerate = std::max(GetRequiredFeeRate(m_wallet), m_wallet.m_pay_tx_fee);
// Change always goes back to origin
coinControl.destChange = tallyItemIn.txdest;
// Only allow tallyItems inputs for tx creation
@@ -131,16 +131,16 @@ CTransactionBuilder::CTransactionBuilder(std::shared_ptr<CWallet> pwalletIn, con
// Get a comparable dummy scriptPubKey, avoid writing/flushing to the actual wallet db
CScript dummyScript;
{
LOCK(pwallet->cs_wallet);
WalletBatch dummyBatch(pwallet->GetDatabase(), false);
LOCK(m_wallet.cs_wallet);
WalletBatch dummyBatch(m_wallet.GetDatabase(), false);
dummyBatch.TxnBegin();
CKey secret;
secret.MakeNewKey(pwallet->CanSupportFeature(FEATURE_COMPRPUBKEY));
secret.MakeNewKey(m_wallet.CanSupportFeature(FEATURE_COMPRPUBKEY));
CPubKey dummyPubkey = secret.GetPubKey();
dummyBatch.TxnAbort();
dummyScript = ::GetScriptForDestination(PKHash(dummyPubkey));
// Calculate required bytes for the dummy signed tx with tallyItem's inputs only
nBytesBase = CalculateMaximumSignedTxSize(CTransaction(dummyTx), pwallet.get(), false);
nBytesBase = CalculateMaximumSignedTxSize(CTransaction(dummyTx), &m_wallet, false);
}
// Calculate the output size
nBytesOutput = ::GetSerializeSize(CTxOut(0, dummyScript), PROTOCOL_VERSION);
@@ -204,7 +204,7 @@ CTransactionBuilderOutput* CTransactionBuilder::AddOutput(CAmount nAmountOutput)
{
if (CouldAddOutput(nAmountOutput)) {
LOCK(cs_outputs);
vecOutputs.push_back(std::make_unique<CTransactionBuilderOutput>(this, pwallet, nAmountOutput));
vecOutputs.push_back(std::make_unique<CTransactionBuilderOutput>(this, m_wallet, nAmountOutput));
return vecOutputs.back().get();
}
return nullptr;
@@ -233,12 +233,12 @@ CAmount CTransactionBuilder::GetAmountUsed() const
CAmount CTransactionBuilder::GetFee(unsigned int nBytes) const
{
CAmount nFeeCalc = coinControl.m_feerate->GetFee(nBytes);
CAmount nRequiredFee = GetRequiredFee(*pwallet, nBytes);
CAmount nRequiredFee = GetRequiredFee(m_wallet, nBytes);
if (nRequiredFee > nFeeCalc) {
nFeeCalc = nRequiredFee;
}
if (nFeeCalc > pwallet->m_default_max_tx_fee) {
nFeeCalc = pwallet->m_default_max_tx_fee;
if (nFeeCalc > m_wallet.m_default_max_tx_fee) {
nFeeCalc = m_wallet.m_default_max_tx_fee;
}
return nFeeCalc;
}
@@ -273,9 +273,9 @@ bool CTransactionBuilder::Commit(bilingual_str& strResult)

CTransactionRef tx;
{
LOCK2(pwallet->cs_wallet, cs_main);
LOCK2(m_wallet.cs_wallet, cs_main);
FeeCalculation fee_calc_out;
if (!pwallet->CreateTransaction(vecSend, tx, nFeeRet, nChangePosRet, strResult, coinControl, fee_calc_out)) {
if (!m_wallet.CreateTransaction(vecSend, tx, nFeeRet, nChangePosRet, strResult, coinControl, fee_calc_out)) {
return false;
}
}
@@ -312,8 +312,8 @@ bool CTransactionBuilder::Commit(bilingual_str& strResult)
}

{
LOCK2(pwallet->cs_wallet, cs_main);
pwallet->CommitTransaction(tx, {}, {});
LOCK2(m_wallet.cs_wallet, cs_main);
m_wallet.CommitTransaction(tx, {}, {});
}

fKeepKeys = true;
10 changes: 5 additions & 5 deletions src/coinjoin/util.h
Original file line number Diff line number Diff line change
@@ -17,7 +17,7 @@ class CKeyHolder
CTxDestination dest;

public:
explicit CKeyHolder(CWallet* pwalletIn);
explicit CKeyHolder(CWallet& wallet);
CKeyHolder(CKeyHolder&&) = delete;
CKeyHolder& operator=(CKeyHolder&&) = delete;
void KeepKey();
@@ -33,7 +33,7 @@ class CKeyHolderStorage
std::vector<std::unique_ptr<CKeyHolder> > storage GUARDED_BY(cs_storage);

public:
CScript AddKey(CWallet* pwalletIn) EXCLUSIVE_LOCKS_REQUIRED(!cs_storage);
CScript AddKey(CWallet& wallet) EXCLUSIVE_LOCKS_REQUIRED(!cs_storage);
void KeepAll() EXCLUSIVE_LOCKS_REQUIRED(!cs_storage);
void ReturnAll() EXCLUSIVE_LOCKS_REQUIRED(!cs_storage);
};
@@ -54,7 +54,7 @@ class CTransactionBuilderOutput
CScript script;

public:
CTransactionBuilderOutput(CTransactionBuilder* pTxBuilderIn, std::shared_ptr<CWallet> pwalletIn, CAmount nAmountIn);
CTransactionBuilderOutput(CTransactionBuilder* pTxBuilderIn, CWallet& wallet, CAmount nAmountIn);
CTransactionBuilderOutput(CTransactionBuilderOutput&&) = delete;
CTransactionBuilderOutput& operator=(CTransactionBuilderOutput&&) = delete;
/// Get the scriptPubKey of this output
@@ -77,7 +77,7 @@ class CTransactionBuilderOutput
class CTransactionBuilder
{
/// Wallet the transaction will be build for
std::shared_ptr<CWallet> pwallet;
CWallet& m_wallet;
/// See CTransactionBuilder() for initialization
CCoinControl coinControl;
/// Dummy since we anyway use tallyItem's destination as change destination in coincontrol.
@@ -100,7 +100,7 @@ class CTransactionBuilder
friend class CTransactionBuilderOutput;

public:
CTransactionBuilder(std::shared_ptr<CWallet> pwalletIn, const CompactTallyItem& tallyItemIn);
CTransactionBuilder(CWallet& wallet, const CompactTallyItem& tallyItemIn);
~CTransactionBuilder();
/// Check it would be possible to add a single output with the amount nAmount. Returns true if its possible and false if not.
bool CouldAddOutput(CAmount nAmountOutput) const EXCLUSIVE_LOCKS_REQUIRED(!cs_outputs);
4 changes: 2 additions & 2 deletions src/wallet/test/coinjoin_tests.cpp
Original file line number Diff line number Diff line change
@@ -231,7 +231,7 @@ BOOST_FIXTURE_TEST_CASE(CTransactionBuilderTest, CTransactionBuilderTestSetup)
// Tests with single outpoint tallyItem
{
CompactTallyItem tallyItem = GetTallyItem({4999});
CTransactionBuilder txBuilder(wallet, tallyItem);
CTransactionBuilder txBuilder(*wallet, tallyItem);

BOOST_CHECK_EQUAL(txBuilder.CountOutputs(), 0);
BOOST_CHECK_EQUAL(txBuilder.GetAmountInitial(), tallyItem.nAmount);
@@ -268,7 +268,7 @@ BOOST_FIXTURE_TEST_CASE(CTransactionBuilderTest, CTransactionBuilderTestSetup)
// Tests with multiple outpoint tallyItem
{
CompactTallyItem tallyItem = GetTallyItem({10000, 20000, 30000, 40000, 50000});
CTransactionBuilder txBuilder(wallet, tallyItem);
CTransactionBuilder txBuilder(*wallet, tallyItem);
std::vector<CTransactionBuilderOutput*> vecOutputs;
bilingual_str strResult;