refactor outputgroup, add knapsackselector

This commit is contained in:
Craig Raw 2020-07-14 09:21:13 +02:00
parent 0a6e247163
commit 9d272c0eb2
8 changed files with 229 additions and 75 deletions

View file

@ -1,44 +1,33 @@
package com.sparrowwallet.drongo.wallet;
import com.sparrowwallet.drongo.KeyPurpose;
import com.sparrowwallet.drongo.protocol.Transaction;
import com.sparrowwallet.drongo.protocol.TransactionOutput;
import java.util.*;
import java.util.stream.Collectors;
import static com.sparrowwallet.drongo.protocol.Transaction.WITNESS_SCALE_FACTOR;
public class BnBUtxoSelector implements UtxoSelector {
private static final int TOTAL_TRIES = 100000;
private final Wallet wallet;
private final int noInputsWeightUnits;
private final Double feeRate;
private final Double longTermFeeRate;
private final int inputWeightUnits;
private final long costOfChangeValue;
public BnBUtxoSelector(Wallet wallet, int noInputsWeightUnits, Double feeRate, Double longTermFeeRate) {
this.wallet = wallet;
this.noInputsWeightUnits = noInputsWeightUnits;
this.feeRate = feeRate;
this.longTermFeeRate = longTermFeeRate;
this.inputWeightUnits = wallet.getInputWeightUnits();
this.costOfChangeValue = getCostOfChange();
this.costOfChangeValue = wallet.getCostOfChange(feeRate, longTermFeeRate);
}
@Override
public Collection<BlockTransactionHashIndex> select(long targetValue, Collection<BlockTransactionHashIndex> candidates) {
List<OutputGroup> utxoPool = candidates.stream().map(OutputGroup::new).collect(Collectors.toList());
public Collection<BlockTransactionHashIndex> select(long targetValue, Collection<OutputGroup> candidates) {
List<OutputGroup> utxoPool = new ArrayList<>(candidates);
long currentValue = 0;
ArrayDeque<Boolean> currentSelection = new ArrayDeque<>(utxoPool.size());
long actualTargetValue = targetValue + (long)(noInputsWeightUnits * feeRate / WITNESS_SCALE_FACTOR);
System.out.println("Actual target: " + actualTargetValue);
System.out.println("Cost of change: " + costOfChangeValue);
System.out.println("Selected must be less than: " + (actualTargetValue + costOfChangeValue));
System.out.println("Selected must be: " + actualTargetValue + " < x < " + (actualTargetValue + costOfChangeValue));
long currentAvailableValue = utxoPool.stream().mapToLong(OutputGroup::getEffectiveValue).sum();
if(currentAvailableValue < targetValue) {
@ -74,7 +63,6 @@ public class BnBUtxoSelector implements UtxoSelector {
}
if(backtrack) {
System.out.println("Backtracking");
// Walk backwards to find the last included UTXO that still needs to have its omission branch traversed
while(!currentSelection.isEmpty() && !currentSelection.getLast()) {
currentSelection.removeLast();
@ -132,12 +120,6 @@ public class BnBUtxoSelector implements UtxoSelector {
return outList;
}
private long getCostOfChange() {
WalletNode changeNode = wallet.getFreshNode(KeyPurpose.CHANGE);
TransactionOutput changeOutput = new TransactionOutput(new Transaction(), 1L, wallet.getOutputScript(changeNode));
return wallet.getFee(changeOutput, feeRate, longTermFeeRate);
}
private ArrayDeque<Boolean> resize(ArrayDeque<Boolean> deque, int size) {
Boolean[] arr = new Boolean[size];
Arrays.fill(arr, Boolean.FALSE);
@ -162,46 +144,4 @@ public class BnBUtxoSelector implements UtxoSelector {
long noChangeFeeRequiredAmt = noInputsFee + inputsFee;
System.out.println(joiner.toString() + " = " + currentValue + " (plus fee of " + noChangeFeeRequiredAmt + ")");
}
private class OutputGroup {
private final List<BlockTransactionHashIndex> utxos = new ArrayList<>();
private long effectiveValue = 0;
private long fee = 0;
private long longTermFee = 0;
public OutputGroup(BlockTransactionHashIndex utxo) {
add(utxo);
}
public void add(BlockTransactionHashIndex utxo) {
utxos.add(utxo);
effectiveValue += utxo.getValue() - (long)(inputWeightUnits * feeRate / WITNESS_SCALE_FACTOR);
fee += (long)(inputWeightUnits * feeRate / WITNESS_SCALE_FACTOR);
longTermFee += (long)(inputWeightUnits * longTermFeeRate / WITNESS_SCALE_FACTOR);
}
public void remove(BlockTransactionHashIndex utxo) {
if(utxos.remove(utxo)) {
effectiveValue -= (utxo.getValue() - (long)(inputWeightUnits * feeRate / WITNESS_SCALE_FACTOR));
fee -= (long)(inputWeightUnits * feeRate / WITNESS_SCALE_FACTOR);
longTermFee -= (long)(inputWeightUnits * longTermFeeRate / WITNESS_SCALE_FACTOR);
}
}
public List<BlockTransactionHashIndex> getUtxos() {
return utxos;
}
public long getEffectiveValue() {
return effectiveValue;
}
public long getFee() {
return fee;
}
public long getLongTermFee() {
return longTermFee;
}
}
}

View file

@ -0,0 +1,113 @@
package com.sparrowwallet.drongo.wallet;
import com.sparrowwallet.drongo.protocol.Transaction;
import java.util.*;
import java.util.stream.Collectors;
public class KnapsackUtxoSelector implements UtxoSelector {
private static final long MIN_CHANGE = Transaction.SATOSHIS_PER_BITCOIN / 100;
@Override
public Collection<BlockTransactionHashIndex> select(long targetValue, Collection<OutputGroup> candidates) {
List<OutputGroup> shuffled = new ArrayList<>(candidates);
Collections.shuffle(shuffled);
OutputGroup lowestLarger = null;
List<OutputGroup> applicableGroups = new ArrayList<>();
long totalLower = 0;
for(OutputGroup outputGroup : shuffled) {
if(outputGroup.getEffectiveValue() == targetValue) {
return new ArrayList<>(outputGroup.getUtxos());
} else if(outputGroup.getEffectiveValue() < targetValue + MIN_CHANGE) {
applicableGroups.add(outputGroup);
totalLower += outputGroup.getEffectiveValue();
} else if(lowestLarger == null || outputGroup.getEffectiveValue() < lowestLarger.getEffectiveValue()) {
lowestLarger = outputGroup;
}
}
if(totalLower == targetValue) {
return applicableGroups.stream().flatMap(outputGroup -> outputGroup.getUtxos().stream()).collect(Collectors.toList());
}
if(totalLower < targetValue) {
if(lowestLarger == null) {
return Collections.emptyList();
}
return lowestLarger.getUtxos();
}
//We now have a list of UTXOs that are all smaller than the target + MIN_CHANGE, but together sum to greater than targetValue
// Solve subset sum by stochastic approximation
applicableGroups.sort((a, b) -> (int)(b.getEffectiveValue() - a.getEffectiveValue()));
boolean[] bestSelection = new boolean[applicableGroups.size()];
long bestValue = findApproximateBestSubset(applicableGroups, totalLower, targetValue, bestSelection);
if(bestValue != targetValue && totalLower >= targetValue + MIN_CHANGE) {
bestValue = findApproximateBestSubset(applicableGroups, totalLower, targetValue + MIN_CHANGE, bestSelection);
}
// If we have a bigger coin and (either the stochastic approximation didn't find a good solution,
// or the next bigger coin is closer), return the bigger coin
if(lowestLarger != null && ((bestValue != targetValue && bestValue < targetValue + MIN_CHANGE) || lowestLarger.getEffectiveValue() <= bestValue)) {
return lowestLarger.getUtxos();
} else {
List<BlockTransactionHashIndex> utxos = new ArrayList<>();
for(int i = 0; i < applicableGroups.size(); i++) {
if(bestSelection[i]) {
utxos.addAll(applicableGroups.get(i).getUtxos());
}
}
return utxos;
}
}
private long findApproximateBestSubset(List<OutputGroup> groups, long totalLower, long targetValue, boolean[] bestSelection) {
int iterations = 1000;
boolean[] includedSelection;
Arrays.fill(bestSelection, true);
long bestValue = totalLower;
Random random = new Random();
for(int rep = 0; rep < iterations && bestValue != targetValue; rep++) {
includedSelection = new boolean[groups.size()];
Arrays.fill(includedSelection, false);
long total = 0;
boolean reachedTarget = false;
for(int pass = 0; pass < 2 && !reachedTarget; pass++) {
for(int i = 0; i < groups.size(); i++) {
//The solver here uses a randomized algorithm,
//the randomness serves no real security purpose but is just
//needed to prevent degenerate behavior and it is important
//that the rng is fast. We do not use a constant random sequence,
//because there may be some privacy improvement by making
//the selection random.
if(pass == 0 ? random.nextBoolean() : !includedSelection[i]) {
total += groups.get(i).getEffectiveValue();
includedSelection[i] = true;
if(total >= targetValue) {
reachedTarget = true;
if(total < bestValue) {
bestValue = total;
System.arraycopy(includedSelection, 0, bestSelection, 0, groups.size());
}
total -= groups.get(i).getEffectiveValue();
includedSelection[i] = false;
}
}
}
}
}
return bestValue;
}
}

View file

@ -0,0 +1,11 @@
package com.sparrowwallet.drongo.wallet;
import java.util.Collection;
import java.util.stream.Collectors;
public class MaxUtxoSelector implements UtxoSelector {
@Override
public Collection<BlockTransactionHashIndex> select(long targetValue, Collection<OutputGroup> candidates) {
return candidates.stream().flatMap(outputGroup -> outputGroup.getUtxos().stream()).collect(Collectors.toUnmodifiableList());
}
}

View file

@ -0,0 +1,67 @@
package com.sparrowwallet.drongo.wallet;
import java.util.ArrayList;
import java.util.List;
import static com.sparrowwallet.drongo.protocol.Transaction.WITNESS_SCALE_FACTOR;
public class OutputGroup {
private final List<BlockTransactionHashIndex> utxos = new ArrayList<>();
private final long inputWeightUnits;
private final double feeRate;
private final double longTermFeeRate;
private long value = 0;
private long effectiveValue = 0;
private long fee = 0;
private long longTermFee = 0;
public OutputGroup(long inputWeightUnits, double feeRate, double longTermFeeRate) {
this.inputWeightUnits = inputWeightUnits;
this.feeRate = feeRate;
this.longTermFeeRate = longTermFeeRate;
}
public OutputGroup(long inputWeightUnits, double feeRate, double longTermFeeRate, BlockTransactionHashIndex utxo) {
this.inputWeightUnits = inputWeightUnits;
this.feeRate = feeRate;
this.longTermFeeRate = longTermFeeRate;
add(utxo);
}
public void add(BlockTransactionHashIndex utxo) {
utxos.add(utxo);
value += utxo.getValue();
effectiveValue += utxo.getValue() - (long)(inputWeightUnits * feeRate / WITNESS_SCALE_FACTOR);
fee += (long)(inputWeightUnits * feeRate / WITNESS_SCALE_FACTOR);
longTermFee += (long)(inputWeightUnits * longTermFeeRate / WITNESS_SCALE_FACTOR);
}
public void remove(BlockTransactionHashIndex utxo) {
if(utxos.remove(utxo)) {
value -= utxo.getValue();
effectiveValue -= (utxo.getValue() - (long)(inputWeightUnits * feeRate / WITNESS_SCALE_FACTOR));
fee -= (long)(inputWeightUnits * feeRate / WITNESS_SCALE_FACTOR);
longTermFee -= (long)(inputWeightUnits * longTermFeeRate / WITNESS_SCALE_FACTOR);
}
}
public List<BlockTransactionHashIndex> getUtxos() {
return utxos;
}
public long getValue() {
return value;
}
public long getEffectiveValue() {
return effectiveValue;
}
public long getFee() {
return fee;
}
public long getLongTermFee() {
return longTermFee;
}
}

View file

@ -3,6 +3,7 @@ package com.sparrowwallet.drongo.wallet;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.stream.Collectors;
public class PresetUtxoSelector implements UtxoSelector {
private final Collection<BlockTransactionHashIndex> presetUtxos;
@ -12,9 +13,9 @@ public class PresetUtxoSelector implements UtxoSelector {
}
@Override
public Collection<BlockTransactionHashIndex> select(long targetValue, Collection<BlockTransactionHashIndex> candidates) {
public Collection<BlockTransactionHashIndex> select(long targetValue, Collection<OutputGroup> candidates) {
List<BlockTransactionHashIndex> utxos = new ArrayList<>(presetUtxos);
utxos.retainAll(candidates);
utxos.retainAll(candidates.stream().flatMap(outputGroup -> outputGroup.getUtxos().stream()).collect(Collectors.toList()));
return utxos;
}

View file

@ -12,10 +12,10 @@ public class PriorityUtxoSelector implements UtxoSelector {
}
@Override
public Collection<BlockTransactionHashIndex> select(long targetValue, Collection<BlockTransactionHashIndex> candidates) {
public Collection<BlockTransactionHashIndex> select(long targetValue, Collection<OutputGroup> candidates) {
List<BlockTransactionHashIndex> selected = new ArrayList<>();
List<BlockTransactionHashIndex> sorted = candidates.stream().filter(ref -> ref.getHeight() != 0).collect(Collectors.toList());
List<BlockTransactionHashIndex> sorted = candidates.stream().flatMap(outputGroup -> outputGroup.getUtxos().stream()).filter(ref -> ref.getHeight() != 0).collect(Collectors.toList());
sort(sorted);
//Testing only: remove

View file

@ -3,5 +3,5 @@ package com.sparrowwallet.drongo.wallet;
import java.util.Collection;
public interface UtxoSelector {
Collection<BlockTransactionHashIndex> select(long targetValue, Collection<BlockTransactionHashIndex> candidates);
Collection<BlockTransactionHashIndex> select(long targetValue, Collection<OutputGroup> candidates);
}

View file

@ -281,6 +281,21 @@ public class Wallet {
return (long)(feeRate * outputVbytes + longTermFeeRate * inputVbytes);
}
/**
* Determines the weight units for a transaction from this wallet that has one output and no inputs
*
* @param recipientAddress The address to create the output to send to
* @return The determined weight units
*/
public int getNoInputsWeightUnits(Address recipientAddress) {
Transaction transaction = new Transaction();
if(Arrays.asList(ScriptType.WITNESS_TYPES).contains(getScriptType())) {
transaction.setSegwitVersion(0);
}
transaction.addOutput(1L, recipientAddress);
return transaction.getWeightUnits();
}
/**
* Return the number of vBytes required for an input created by this wallet.
*
@ -326,11 +341,17 @@ public class Wallet {
return wu;
}
public long getCostOfChange(double feeRate, double longTermFeeRate) {
WalletNode changeNode = getFreshNode(KeyPurpose.CHANGE);
TransactionOutput changeOutput = new TransactionOutput(new Transaction(), 1L, getOutputScript(changeNode));
return getFee(changeOutput, feeRate, longTermFeeRate);
}
public WalletTransaction createWalletTransaction(List<UtxoSelector> utxoSelectors, Address recipientAddress, long recipientAmount, double feeRate, double longTermFeeRate, Long fee, boolean sendAll) throws InsufficientFundsException {
long valueRequiredAmt = recipientAmount;
while(true) {
Map<BlockTransactionHashIndex, WalletNode> selectedUtxos = selectInputs(utxoSelectors, valueRequiredAmt);
Map<BlockTransactionHashIndex, WalletNode> selectedUtxos = selectInputs(utxoSelectors, valueRequiredAmt, feeRate, longTermFeeRate);
long totalSelectedAmt = selectedUtxos.keySet().stream().mapToLong(BlockTransactionHashIndex::getValue).sum();
//Add inputs
@ -377,11 +398,11 @@ public class Wallet {
//Determine if a change output is required by checking if its value is greater than its dust threshold
long changeAmt = differenceAmt - noChangeFeeRequiredAmt;
WalletNode changeNode = getFreshNode(KeyPurpose.CHANGE);
TransactionOutput changeOutput = new TransactionOutput(transaction, changeAmt, getOutputScript(changeNode));
long costOfChangeAmt = getFee(changeOutput, feeRate, longTermFeeRate);
long costOfChangeAmt = getCostOfChange(feeRate, longTermFeeRate);
if(changeAmt > costOfChangeAmt) {
//Change output is required, determine new fee once change output has been added
WalletNode changeNode = getFreshNode(KeyPurpose.CHANGE);
TransactionOutput changeOutput = new TransactionOutput(transaction, changeAmt, getOutputScript(changeNode));
int changeVSize = noChangeVSize + changeOutput.getLength();
long changeFeeRequiredAmt = (fee == null ? (long)(feeRate * changeVSize) : fee);
@ -403,11 +424,12 @@ public class Wallet {
}
}
private Map<BlockTransactionHashIndex, WalletNode> selectInputs(List<UtxoSelector> utxoSelectors, Long targetValue) throws InsufficientFundsException {
private Map<BlockTransactionHashIndex, WalletNode> selectInputs(List<UtxoSelector> utxoSelectors, Long targetValue, double feeRate, double longTermFeeRate) throws InsufficientFundsException {
Map<BlockTransactionHashIndex, WalletNode> utxos = getWalletUtxos();
for(UtxoSelector utxoSelector : utxoSelectors) {
Collection<BlockTransactionHashIndex> selectedInputs = utxoSelector.select(targetValue, utxos.keySet());
List<OutputGroup> utxoPool = utxos.keySet().stream().map(utxo -> new OutputGroup(getInputWeightUnits(), feeRate, longTermFeeRate, utxo)).collect(Collectors.toList());
Collection<BlockTransactionHashIndex> selectedInputs = utxoSelector.select(targetValue, utxoPool);
long total = selectedInputs.stream().mapToLong(BlockTransactionHashIndex::getValue).sum();
if(total > targetValue) {
utxos.keySet().retainAll(selectedInputs);