refactor to use transaction parameters record object when creating a wallet transaction

This commit is contained in:
Craig Raw 2025-10-21 12:05:34 +02:00
parent ad90ea0d38
commit e975cbe6f8
2 changed files with 80 additions and 38 deletions

View file

@ -0,0 +1,41 @@
package com.sparrowwallet.drongo.wallet;
import com.sparrowwallet.drongo.address.Address;
import com.sparrowwallet.drongo.protocol.Transaction;
import java.util.List;
import java.util.Optional;
import java.util.Set;
public record TransactionParameters(List<UtxoSelector> utxoSelectors, List<TxoFilter> txoFilters, List<Payment> payments, List<byte[]> opReturns,
Set<WalletNode> excludedChangeNodes, double feeRate, double longTermFeeRate, double minRelayFeeRate, Long fee,
Integer currentBlockHeight, boolean groupByAddress, boolean includeMempoolOutputs, boolean allowRbf) {
public boolean containsSendMaxPayment() {
return payments.stream().anyMatch(Payment::isSendMax);
}
public Optional<Payment> getFirstSendMaxPayment() {
return payments.stream().filter(Payment::isSendMax).findFirst();
}
public List<Address> getPaymentAddresses() {
return payments.stream().map(Payment::getAddress).toList();
}
public long getTotalPaymentAmount() {
return payments.stream().mapToLong(Payment::getAmount).sum();
}
public long getTotalPaymentAmountLessExcluded(Payment excludedPayment) {
return payments.stream().filter(payment -> !excludedPayment.equals(payment)).mapToLong(Payment::getAmount).sum();
}
public boolean isMinRelayRate() {
return ((feeRate == minRelayFeeRate && minRelayFeeRate > 0d) || feeRate == Transaction.DEFAULT_MIN_RELAY_FEE) && fee == null;
}
public long getRequiredFeeAmount(double virtualSize) {
return fee == null ? (long)Math.floor(feeRate * virtualSize) : fee;
}
}

View file

@ -1034,45 +1034,42 @@ public class Wallet extends Persistable implements Comparable<Wallet> {
return getFee(changeOutput, feeRate, longTermFeeRate);
}
public WalletTransaction createWalletTransaction(List<UtxoSelector> utxoSelectors, List<TxoFilter> txoFilters, List<Payment> payments, List<byte[]> opReturns,
Set<WalletNode> excludedChangeNodes, double feeRate, double longTermFeeRate, double minRelayFeeRate, Long fee,
Integer currentBlockHeight, boolean groupByAddress, boolean includeMempoolOutputs, boolean allowRbf) throws InsufficientFundsException {
boolean sendMax = payments.stream().anyMatch(Payment::isSendMax);
long totalPaymentAmount = payments.stream().map(Payment::getAmount).mapToLong(v -> v).sum();
Map<BlockTransactionHashIndex, WalletNode> availableTxos = getWalletTxos(txoFilters);
public WalletTransaction createWalletTransaction(TransactionParameters params) throws InsufficientFundsException {
long totalPaymentAmount = params.getTotalPaymentAmount();
Map<BlockTransactionHashIndex, WalletNode> availableTxos = getWalletTxos(params.txoFilters());
long totalAvailableValue = availableTxos.keySet().stream().mapToLong(BlockTransactionHashIndex::getValue).sum();
if(fee != null && feeRate != minRelayFeeRate) {
if(params.fee() != null && params.feeRate() != params.minRelayFeeRate()) {
throw new IllegalArgumentException("Use an input fee rate equal to the min relay rate when using a defined fee amount so UTXO selectors overestimate effective value");
}
long maxSpendableAmt = getMaxSpendable(payments.stream().map(Payment::getAddress).collect(Collectors.toList()), feeRate, availableTxos);
long maxSpendableAmt = getMaxSpendable(params.getPaymentAddresses(), params.feeRate(), availableTxos);
if(maxSpendableAmt < 0) {
throw new InsufficientFundsException("Not enough combined value in all available UTXOs to send a transaction to the provided addresses at this fee rate");
}
//When a user fee is set, we can calculate the fees to spend all UTXOs because we assume all UTXOs are spendable at a fee rate of 1 sat/vB
//We can then add the user set fee less this amount as a "phantom payment amount" to the value required to find (which cannot include transaction fees)
long valueRequiredAmt = totalPaymentAmount + (fee != null ? fee - (totalAvailableValue - maxSpendableAmt) : 0);
long valueRequiredAmt = totalPaymentAmount + (params.fee() != null ? params.fee() - (totalAvailableValue - maxSpendableAmt) : 0);
if(maxSpendableAmt < valueRequiredAmt) {
throw new InsufficientFundsException("Not enough combined value in all available UTXOs to send a transaction to send the provided payments at the user set fee" + (fee == null ? " rate" : ""));
throw new InsufficientFundsException("Not enough combined value in all available UTXOs to send a transaction to send the provided payments at the user set fee" + (params.fee() == null ? " rate" : ""));
}
while(true) {
List<Map<BlockTransactionHashIndex, WalletNode>> selectedUtxoSets = selectInputSets(availableTxos, utxoSelectors, txoFilters, valueRequiredAmt, feeRate, longTermFeeRate, groupByAddress, includeMempoolOutputs, sendMax);
List<Map<BlockTransactionHashIndex, WalletNode>> selectedUtxoSets = selectInputSets(params, availableTxos, valueRequiredAmt);
Map<BlockTransactionHashIndex, WalletNode> selectedUtxos = new LinkedHashMap<>();
selectedUtxoSets.forEach(selectedUtxos::putAll);
long totalSelectedAmt = selectedUtxos.keySet().stream().mapToLong(BlockTransactionHashIndex::getValue).sum();
int numSets = selectedUtxoSets.size();
List<Payment> txPayments = new ArrayList<>(payments);
List<Payment> txPayments = new ArrayList<>(params.payments());
List<WalletTransaction.Output> outputs = new ArrayList<>();
Set<WalletNode> txExcludedChangeNodes = new HashSet<>(excludedChangeNodes);
long sequence = allowRbf ? TransactionInput.SEQUENCE_RBF_ENABLED : TransactionInput.SEQUENCE_RBF_DISABLED;
Set<WalletNode> txExcludedChangeNodes = new HashSet<>(params.excludedChangeNodes());
long sequence = params.allowRbf() ? TransactionInput.SEQUENCE_RBF_ENABLED : TransactionInput.SEQUENCE_RBF_DISABLED;
Transaction transaction = new Transaction();
transaction.setVersion(2);
if(currentBlockHeight != null) {
transaction.setLocktime(currentBlockHeight.longValue());
if(params.currentBlockHeight() != null) {
transaction.setLocktime(params.currentBlockHeight().longValue());
}
//Add inputs
@ -1083,8 +1080,8 @@ public class Wallet extends Persistable implements Comparable<Wallet> {
txInput.setSequenceNumber(sequence);
}
if(getScriptType() == P2TR && currentBlockHeight != null && sequence != TransactionInput.SEQUENCE_RBF_DISABLED) {
applySequenceAntiFeeSniping(transaction, selectedUtxos, currentBlockHeight);
if(getScriptType() == P2TR && params.currentBlockHeight() != null && sequence != TransactionInput.SEQUENCE_RBF_DISABLED) {
applySequenceAntiFeeSniping(transaction, selectedUtxos, params.currentBlockHeight());
}
for(int i = 1; i < numSets; i+=2) {
@ -1110,30 +1107,29 @@ public class Wallet extends Persistable implements Comparable<Wallet> {
}
//Add OP_RETURNs
for(byte[] opReturn : opReturns) {
for(byte[] opReturn : params.opReturns()) {
TransactionOutput output = transaction.addOutput(0L, new Script(List.of(ScriptChunk.fromOpcode(ScriptOpCodes.OP_RETURN), ScriptChunk.fromData(opReturn))));
outputs.add(new WalletTransaction.NonAddressOutput(output));
}
double noChangeVSize = transaction.getVirtualSize();
long noChangeFeeRequiredAmt = (fee == null ? (long)Math.floor(feeRate * noChangeVSize) : fee);
long noChangeFeeRequiredAmt = params.getRequiredFeeAmount(noChangeVSize);
//Add 1 satoshi to accommodate longer signatures when feeRate equals the current or common min relay fee to ensure fee is sufficient for maximum "relayability"
boolean isMinRelayRate = ((feeRate == minRelayFeeRate && minRelayFeeRate > 0d) || feeRate == Transaction.DEFAULT_MIN_RELAY_FEE) && fee == null;
if(isMinRelayRate) {
if(params.isMinRelayRate()) {
noChangeFeeRequiredAmt++;
}
//If sending all selected utxos, set the recipient amount to equal to total of those utxos less the no change fee
long maxSendAmt = totalSelectedAmt - noChangeFeeRequiredAmt;
Optional<Payment> optMaxPayment = payments.stream().filter(Payment::isSendMax).findFirst();
Optional<Payment> optMaxPayment = params.getFirstSendMaxPayment();
if(optMaxPayment.isPresent()) {
Payment maxPayment = optMaxPayment.get();
maxSendAmt = maxSendAmt - payments.stream().filter(payment -> !maxPayment.equals(payment)).map(Payment::getAmount).mapToLong(v -> v).sum();
maxSendAmt = maxSendAmt - params.getTotalPaymentAmountLessExcluded(maxPayment);
if(maxSendAmt > 0 && maxPayment.getAmount() != maxSendAmt) {
maxPayment.setAmount(maxSendAmt);
totalPaymentAmount = payments.stream().map(Payment::getAmount).mapToLong(v -> v).sum();
totalPaymentAmount = params.getTotalPaymentAmount();
continue;
}
}
@ -1154,18 +1150,18 @@ public class Wallet extends Persistable implements Comparable<Wallet> {
//Determine if a change output is required by checking if its value is greater than its dust threshold
List<Long> setChangeAmts = getSetChangeAmounts(selectedUtxoSets, totalPaymentAmount, noChangeFeeRequiredAmt);
double noChangeFeeRate = (fee == null ? feeRate : noChangeFeeRequiredAmt / transaction.getVirtualSize());
long costOfChangeAmt = getCostOfChange(noChangeFeeRate, longTermFeeRate);
double noChangeFeeRate = (params.fee() == null ? params.feeRate() : noChangeFeeRequiredAmt / transaction.getVirtualSize());
long costOfChangeAmt = getCostOfChange(noChangeFeeRate, params.longTermFeeRate());
if(setChangeAmts.stream().allMatch(amt -> amt > costOfChangeAmt) || (numSets > 1 && differenceAmt / transaction.getVirtualSize() > noChangeFeeRate * 2)) {
//Change output is required, determine new fee once change output has been added
WalletNode changeNode = getFreshNode(getChangeKeyPurpose());
while(txExcludedChangeNodes.contains(changeNode)) {
changeNode = getFreshNode(getChangeKeyPurpose(), changeNode);
}
TransactionOutput changeOutput = new TransactionOutput(transaction, setChangeAmts.iterator().next(), changeNode.getOutputScript());
TransactionOutput changeOutput = new TransactionOutput(transaction, setChangeAmts.getFirst(), changeNode.getOutputScript());
double changeVSize = noChangeVSize + changeOutput.getLength() * numSets;
long changeFeeRequiredAmt = (fee == null ? (long)Math.floor(feeRate * changeVSize) : fee);
if(isMinRelayRate) {
long changeFeeRequiredAmt = params.getRequiredFeeAmount(changeVSize);
if(params.isMinRelayRate()) {
changeFeeRequiredAmt++;
}
while(changeFeeRequiredAmt % numSets > 0) {
@ -1186,7 +1182,7 @@ public class Wallet extends Persistable implements Comparable<Wallet> {
//The new fee has meant that one of the change outputs is now dust. We pay too high a fee without change, but change is dust when added.
if(numSets > 1 && differenceAmt / transaction.getVirtualSize() < noChangeFeeRate * 2) {
//Maximize privacy. Pay a higher fee to keep multiple output sets.
return new WalletTransaction(this, transaction, utxoSelectors, selectedUtxoSets, txPayments, outputs, differenceAmt);
return new WalletTransaction(this, transaction, params.utxoSelectors(), selectedUtxoSets, txPayments, outputs, differenceAmt);
} else {
//Maxmize efficiency. Increase value required from inputs and try again.
valueRequiredAmt = totalSelectedAmt + 1;
@ -1194,10 +1190,10 @@ public class Wallet extends Persistable implements Comparable<Wallet> {
}
}
return new WalletTransaction(this, transaction, utxoSelectors, selectedUtxoSets, txPayments, outputs, changeMap, changeFeeRequiredAmt);
return new WalletTransaction(this, transaction, params.utxoSelectors(), selectedUtxoSets, txPayments, outputs, changeMap, changeFeeRequiredAmt);
}
return new WalletTransaction(this, transaction, utxoSelectors, selectedUtxoSets, txPayments, outputs, differenceAmt);
return new WalletTransaction(this, transaction, params.utxoSelectors(), selectedUtxoSets, txPayments, outputs, differenceAmt);
}
}
@ -1252,27 +1248,28 @@ public class Wallet extends Persistable implements Comparable<Wallet> {
}
}
private List<Map<BlockTransactionHashIndex, WalletNode>> selectInputSets(Map<BlockTransactionHashIndex, WalletNode> availableTxos, List<UtxoSelector> utxoSelectors, List<TxoFilter> txoFilters, Long targetValue, double feeRate, double longTermFeeRate, boolean groupByAddress, boolean includeMempoolOutputs, boolean sendMax) throws InsufficientFundsException {
List<OutputGroup> utxoPool = getGroupedUtxos(txoFilters, feeRate, longTermFeeRate, groupByAddress);
private List<Map<BlockTransactionHashIndex, WalletNode>> selectInputSets(TransactionParameters params, Map<BlockTransactionHashIndex, WalletNode> availableTxos,
Long targetValue) throws InsufficientFundsException {
List<OutputGroup> utxoPool = getGroupedUtxos(params);
List<OutputGroup.Filter> filters = new ArrayList<>();
filters.add(new OutputGroup.Filter(1, 6, false));
filters.add(new OutputGroup.Filter(1, 1, false));
if(includeMempoolOutputs) {
if(params.includeMempoolOutputs()) {
filters.add(new OutputGroup.Filter(0, 0, false));
filters.add(new OutputGroup.Filter(0, 0, true));
} else {
filters.add(new OutputGroup.Filter(1, 1, true));
}
if(sendMax) {
if(params.containsSendMaxPayment()) {
Collections.reverse(filters);
}
for(OutputGroup.Filter filter : filters) {
List<OutputGroup> filteredPool = utxoPool.stream().filter(filter::isEligible).collect(Collectors.toList());
for(UtxoSelector utxoSelector : utxoSelectors) {
for(UtxoSelector utxoSelector : params.utxoSelectors()) {
List<Collection<BlockTransactionHashIndex>> selectedInputSets = utxoSelector.selectSets(targetValue, filteredPool);
List<Map<BlockTransactionHashIndex, WalletNode>> selectedInputSetsList = new ArrayList<>();
long total = 0;
@ -1298,6 +1295,10 @@ public class Wallet extends Persistable implements Comparable<Wallet> {
throw new InsufficientFundsException("Not enough combined value in UTXOs for output value " + targetValue, targetValue);
}
public List<OutputGroup> getGroupedUtxos(TransactionParameters params) {
return getGroupedUtxos(params.txoFilters(), params.feeRate(), params.longTermFeeRate(), params.groupByAddress());
}
public List<OutputGroup> getGroupedUtxos(List<TxoFilter> txoFilters, double feeRate, double longTermFeeRate, boolean groupByAddress) {
List<OutputGroup> outputGroups = new ArrayList<>();
Map<Sha256Hash, BlockTransaction> walletTransactions = getWalletTransactions();