Skip to content

Commit

Permalink
refactor: pass mixing wallet to CoinJoin utils by reference
Browse files Browse the repository at this point in the history
  • Loading branch information
UdjinM6 committed Dec 1, 2024
1 parent 8a05f0c commit 0f29aaf
Show file tree
Hide file tree
Showing 4 changed files with 34 additions and 53 deletions.
25 changes: 3 additions & 22 deletions src/coinjoin/client.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -1395,12 +1395,7 @@ bool CCoinJoinClientSession::PrepareDenominate(int nMinRounds, int nMaxRounds, s
++nSteps;
continue;
}
const auto pwallet = GetWallet(m_wallet.GetName());
if (!pwallet) {
strErrorRet ="Couldn't get wallet pointer";
return false;
}
scriptDenom = keyHolderStorage.AddKey(pwallet.get());
scriptDenom = keyHolderStorage.AddKey(m_wallet);
}
vecPSInOutPairsRet.emplace_back(entry, CTxOut(nDenomAmount, scriptDenom));
// step is complete
Expand Down Expand Up @@ -1484,14 +1479,7 @@ bool CCoinJoinClientSession::MakeCollateralAmounts(const CompactTallyItem& tally
return false;
}

const auto pwallet = GetWallet(m_wallet.GetName());

if (!pwallet) {
WalletCJLogPrint(m_wallet, "CCoinJoinClientSession::%s -- Couldn't get wallet pointer\n", __func__);
return false;
}

CTransactionBuilder txBuilder(pwallet, tallyItem);
CTransactionBuilder txBuilder(m_wallet, tallyItem);

WalletCJLogPrint(m_wallet, "CCoinJoinClientSession::%s -- Start %s\n", __func__, txBuilder.ToString());

Expand Down Expand Up @@ -1654,14 +1642,7 @@ bool CCoinJoinClientSession::CreateDenominated(CAmount nBalanceToDenominate, con
return false;
}

const auto pwallet = GetWallet(m_wallet.GetName());

if (!pwallet) {
WalletCJLogPrint(m_wallet, "CCoinJoinClientSession::%s -- Couldn't get wallet pointer\n", __func__);
return false;
}

CTransactionBuilder txBuilder(pwallet, tallyItem);
CTransactionBuilder txBuilder(m_wallet, tallyItem);

WalletCJLogPrint(m_wallet, "CCoinJoinClientSession::%s -- Start %s\n", __func__, txBuilder.ToString());

Expand Down
48 changes: 24 additions & 24 deletions src/coinjoin/util.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,8 @@ inline unsigned int GetSizeOfCompactSizeDiff(uint64_t nSizePrev, uint64_t nSizeN
return ::GetSizeOfCompactSize(nSizeNew) - ::GetSizeOfCompactSize(nSizePrev);
}

CKeyHolder::CKeyHolder(CWallet* pwallet) :
reserveDestination(pwallet)
CKeyHolder::CKeyHolder(CWallet& wallet) :
reserveDestination(&wallet)
{
reserveDestination.GetReservedDestination(dest, false);
}
Expand All @@ -42,9 +42,9 @@ CScript CKeyHolder::GetScriptForDestination() const
}


CScript CKeyHolderStorage::AddKey(CWallet* pwallet)
CScript CKeyHolderStorage::AddKey(CWallet& wallet)
{
auto keyHolderPtr = std::make_unique<CKeyHolder>(pwallet);
auto keyHolderPtr = std::make_unique<CKeyHolder>(wallet);
auto script = keyHolderPtr->GetScriptForDestination();

LOCK(cs_storage);
Expand Down Expand Up @@ -87,14 +87,14 @@ void CKeyHolderStorage::ReturnAll()
}
}

CTransactionBuilderOutput::CTransactionBuilderOutput(CTransactionBuilder* pTxBuilderIn, std::shared_ptr<CWallet> pwalletIn, CAmount nAmountIn) :
CTransactionBuilderOutput::CTransactionBuilderOutput(CTransactionBuilder* pTxBuilderIn, CWallet& wallet, CAmount nAmountIn) :
pTxBuilder(pTxBuilderIn),
dest(pwalletIn.get()),
dest(&wallet),
nAmount(nAmountIn)
{
assert(pTxBuilder);
CTxDestination txdest;
LOCK(pwalletIn->cs_wallet);
LOCK(wallet.cs_wallet);
dest.GetReservedDestination(txdest, false);
script = ::GetScriptForDestination(txdest);
}
Expand All @@ -108,15 +108,15 @@ bool CTransactionBuilderOutput::UpdateAmount(const CAmount nNewAmount)
return true;
}

CTransactionBuilder::CTransactionBuilder(std::shared_ptr<CWallet> pwalletIn, const CompactTallyItem& tallyItemIn) :
pwallet(pwalletIn),
dummyReserveDestination(pwalletIn.get()),
CTransactionBuilder::CTransactionBuilder(CWallet& wallet, const CompactTallyItem& tallyItemIn) :
m_wallet(wallet),
dummyReserveDestination(&m_wallet),
tallyItem(tallyItemIn)
{
// Generate a feerate which will be used to consider if the remainder is dust and will go into fees or not
coinControl.m_discard_feerate = ::GetDiscardRate(*pwallet);
coinControl.m_discard_feerate = ::GetDiscardRate(m_wallet);
// Generate a feerate which will be used by calculations of this class and also by CWallet::CreateTransaction
coinControl.m_feerate = std::max(GetRequiredFeeRate(*pwallet), pwallet->m_pay_tx_fee);
coinControl.m_feerate = std::max(GetRequiredFeeRate(m_wallet), m_wallet.m_pay_tx_fee);
// Change always goes back to origin
coinControl.destChange = tallyItemIn.txdest;
// Only allow tallyItems inputs for tx creation
Expand All @@ -131,16 +131,16 @@ CTransactionBuilder::CTransactionBuilder(std::shared_ptr<CWallet> pwalletIn, con
// Get a comparable dummy scriptPubKey, avoid writing/flushing to the actual wallet db
CScript dummyScript;
{
LOCK(pwallet->cs_wallet);
WalletBatch dummyBatch(pwallet->GetDatabase(), false);
LOCK(m_wallet.cs_wallet);
WalletBatch dummyBatch(m_wallet.GetDatabase(), false);
dummyBatch.TxnBegin();
CKey secret;
secret.MakeNewKey(pwallet->CanSupportFeature(FEATURE_COMPRPUBKEY));
secret.MakeNewKey(m_wallet.CanSupportFeature(FEATURE_COMPRPUBKEY));
CPubKey dummyPubkey = secret.GetPubKey();
dummyBatch.TxnAbort();
dummyScript = ::GetScriptForDestination(PKHash(dummyPubkey));
// Calculate required bytes for the dummy signed tx with tallyItem's inputs only
nBytesBase = CalculateMaximumSignedTxSize(CTransaction(dummyTx), pwallet.get(), false);
nBytesBase = CalculateMaximumSignedTxSize(CTransaction(dummyTx), &m_wallet, false);
}
// Calculate the output size
nBytesOutput = ::GetSerializeSize(CTxOut(0, dummyScript), PROTOCOL_VERSION);
Expand Down Expand Up @@ -204,7 +204,7 @@ CTransactionBuilderOutput* CTransactionBuilder::AddOutput(CAmount nAmountOutput)
{
if (CouldAddOutput(nAmountOutput)) {
LOCK(cs_outputs);
vecOutputs.push_back(std::make_unique<CTransactionBuilderOutput>(this, pwallet, nAmountOutput));
vecOutputs.push_back(std::make_unique<CTransactionBuilderOutput>(this, m_wallet, nAmountOutput));
return vecOutputs.back().get();
}
return nullptr;
Expand Down Expand Up @@ -233,12 +233,12 @@ CAmount CTransactionBuilder::GetAmountUsed() const
CAmount CTransactionBuilder::GetFee(unsigned int nBytes) const
{
CAmount nFeeCalc = coinControl.m_feerate->GetFee(nBytes);
CAmount nRequiredFee = GetRequiredFee(*pwallet, nBytes);
CAmount nRequiredFee = GetRequiredFee(m_wallet, nBytes);
if (nRequiredFee > nFeeCalc) {
nFeeCalc = nRequiredFee;
}
if (nFeeCalc > pwallet->m_default_max_tx_fee) {
nFeeCalc = pwallet->m_default_max_tx_fee;
if (nFeeCalc > m_wallet.m_default_max_tx_fee) {
nFeeCalc = m_wallet.m_default_max_tx_fee;
}
return nFeeCalc;
}
Expand Down Expand Up @@ -273,9 +273,9 @@ bool CTransactionBuilder::Commit(bilingual_str& strResult)

CTransactionRef tx;
{
LOCK2(pwallet->cs_wallet, cs_main);
LOCK2(m_wallet.cs_wallet, cs_main);
FeeCalculation fee_calc_out;
if (!pwallet->CreateTransaction(vecSend, tx, nFeeRet, nChangePosRet, strResult, coinControl, fee_calc_out)) {
if (!m_wallet.CreateTransaction(vecSend, tx, nFeeRet, nChangePosRet, strResult, coinControl, fee_calc_out)) {
return false;
}
}
Expand Down Expand Up @@ -312,8 +312,8 @@ bool CTransactionBuilder::Commit(bilingual_str& strResult)
}

{
LOCK2(pwallet->cs_wallet, cs_main);
pwallet->CommitTransaction(tx, {}, {});
LOCK2(m_wallet.cs_wallet, cs_main);
m_wallet.CommitTransaction(tx, {}, {});
}

fKeepKeys = true;
Expand Down
10 changes: 5 additions & 5 deletions src/coinjoin/util.h
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ class CKeyHolder
CTxDestination dest;

public:
explicit CKeyHolder(CWallet* pwalletIn);
explicit CKeyHolder(CWallet& wallet);
CKeyHolder(CKeyHolder&&) = delete;
CKeyHolder& operator=(CKeyHolder&&) = delete;
void KeepKey();
Expand All @@ -33,7 +33,7 @@ class CKeyHolderStorage
std::vector<std::unique_ptr<CKeyHolder> > storage GUARDED_BY(cs_storage);

public:
CScript AddKey(CWallet* pwalletIn) EXCLUSIVE_LOCKS_REQUIRED(!cs_storage);
CScript AddKey(CWallet& wallet) EXCLUSIVE_LOCKS_REQUIRED(!cs_storage);
void KeepAll() EXCLUSIVE_LOCKS_REQUIRED(!cs_storage);
void ReturnAll() EXCLUSIVE_LOCKS_REQUIRED(!cs_storage);
};
Expand All @@ -54,7 +54,7 @@ class CTransactionBuilderOutput
CScript script;

public:
CTransactionBuilderOutput(CTransactionBuilder* pTxBuilderIn, std::shared_ptr<CWallet> pwalletIn, CAmount nAmountIn);
CTransactionBuilderOutput(CTransactionBuilder* pTxBuilderIn, CWallet& wallet, CAmount nAmountIn);
CTransactionBuilderOutput(CTransactionBuilderOutput&&) = delete;
CTransactionBuilderOutput& operator=(CTransactionBuilderOutput&&) = delete;
/// Get the scriptPubKey of this output
Expand All @@ -77,7 +77,7 @@ class CTransactionBuilderOutput
class CTransactionBuilder
{
/// Wallet the transaction will be build for
std::shared_ptr<CWallet> pwallet;
CWallet& m_wallet;
/// See CTransactionBuilder() for initialization
CCoinControl coinControl;
/// Dummy since we anyway use tallyItem's destination as change destination in coincontrol.
Expand All @@ -100,7 +100,7 @@ class CTransactionBuilder
friend class CTransactionBuilderOutput;

public:
CTransactionBuilder(std::shared_ptr<CWallet> pwalletIn, const CompactTallyItem& tallyItemIn);
CTransactionBuilder(CWallet& wallet, const CompactTallyItem& tallyItemIn);
~CTransactionBuilder();
/// Check it would be possible to add a single output with the amount nAmount. Returns true if its possible and false if not.
bool CouldAddOutput(CAmount nAmountOutput) const EXCLUSIVE_LOCKS_REQUIRED(!cs_outputs);
Expand Down
4 changes: 2 additions & 2 deletions src/wallet/test/coinjoin_tests.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -231,7 +231,7 @@ BOOST_FIXTURE_TEST_CASE(CTransactionBuilderTest, CTransactionBuilderTestSetup)
// Tests with single outpoint tallyItem
{
CompactTallyItem tallyItem = GetTallyItem({4999});
CTransactionBuilder txBuilder(wallet, tallyItem);
CTransactionBuilder txBuilder(*wallet, tallyItem);

BOOST_CHECK_EQUAL(txBuilder.CountOutputs(), 0);
BOOST_CHECK_EQUAL(txBuilder.GetAmountInitial(), tallyItem.nAmount);
Expand Down Expand Up @@ -268,7 +268,7 @@ BOOST_FIXTURE_TEST_CASE(CTransactionBuilderTest, CTransactionBuilderTestSetup)
// Tests with multiple outpoint tallyItem
{
CompactTallyItem tallyItem = GetTallyItem({10000, 20000, 30000, 40000, 50000});
CTransactionBuilder txBuilder(wallet, tallyItem);
CTransactionBuilder txBuilder(*wallet, tallyItem);
std::vector<CTransactionBuilderOutput*> vecOutputs;
bilingual_str strResult;

Expand Down

0 comments on commit 0f29aaf

Please sign in to comment.