diff --git a/coinjoin/coinjoin.go b/coinjoin/coinjoin.go index 13a8c07..486776c 100644 --- a/coinjoin/coinjoin.go +++ b/coinjoin/coinjoin.go @@ -205,6 +205,23 @@ func (t *Tx) UnmarshalBinary(b []byte) error { return t.Tx.Deserialize(bytes.NewReader(b)) } +var bogusSigScript = make([]byte, 108) // worst case size to redeem a P2PKH + +func (t *Tx) estimateSerializeSize(tx *wire.MsgTx, mcount int) int { + bogusMixedOut := &wire.TxOut{ + Value: t.mixValue, + Version: t.sc.version(), + PkScript: make([]byte, t.sc.scriptSize()), + } + for _, in := range tx.TxIn { + in.SignatureScript = bogusSigScript + } + for i := 0; i < mcount; i++ { + tx.AddTxOut(bogusMixedOut) + } + return tx.SerializeSize() +} + func feeForSerializeSize(relayFeePerKb int64, txSerializeSize int) int64 { fee := relayFeePerKb * int64(txSerializeSize) / 1000 @@ -220,6 +237,17 @@ func feeForSerializeSize(relayFeePerKb int64, txSerializeSize int) int64 { return fee } +// highFeeRate is the maximum multiplier of the standard fee rate before unmixed +// data is refused for paying too high of a fee. It should not be too low such +// that mixing outputs at the smallest common mixed value errors for too high +// fees, as these mixes are performed without any change outputs. +const highFeeRate = 150 + +func paysHighFees(fee, relayFeePerKb int64, txSerializeSize int) bool { + maxFee := feeForSerializeSize(highFeeRate*relayFeePerKb, txSerializeSize) + return fee > maxFee +} + func (t *Tx) ValidateUnmixed(unmixed []byte, mcount int) error { var fee int64 other := new(wire.MsgTx) @@ -253,18 +281,15 @@ func (t *Tx) ValidateUnmixed(unmixed []byte, mcount int) error { return err } fee -= int64(mcount) * t.mixValue - bogusMixedOut := &wire.TxOut{ - Value: t.mixValue, - Version: t.sc.version(), - PkScript: make([]byte, t.sc.scriptSize()), - } - for i := 0; i < mcount; i++ { - other.AddTxOut(bogusMixedOut) - } - requiredFee := feeForSerializeSize(t.feeRate, other.SerializeSize()) + size := t.estimateSerializeSize(other, mcount) + requiredFee := feeForSerializeSize(t.feeRate, size) if fee < requiredFee { return errors.New("coinjoin: unmixed transaction does not pay enough network fees") } + if paysHighFees(fee, t.feeRate, size) { + return errors.New("coinjoin: unmixed transaction pays insanely high fees") + } + return nil } diff --git a/integration/honest_test.go b/integration/honest_test.go index 37b7f64..27f7fdd 100644 --- a/integration/honest_test.go +++ b/integration/honest_test.go @@ -196,7 +196,7 @@ func TestHonest(t *testing.T) { i := i go func() { input := &wire.TxIn{ValueIn: inputValue * 1e8} - change := &wire.TxOut{Value: 1e8 - int64(1+i)*0.001e8, PkScript: change} + change := &wire.TxOut{Value: 1e8 - 0.0001e8 + int64(i+1), PkScript: change} con := newConfirmer(input, change) conn, err := tls.Dial("tcp", s.Addr, nettest.ClientTLS) if err != nil {