diff --git a/src/main/java/com/sparrowwallet/drongo/protocol/ScriptType.java b/src/main/java/com/sparrowwallet/drongo/protocol/ScriptType.java index 1d45f46..7512517 100644 --- a/src/main/java/com/sparrowwallet/drongo/protocol/ScriptType.java +++ b/src/main/java/com/sparrowwallet/drongo/protocol/ScriptType.java @@ -8,8 +8,6 @@ import com.sparrowwallet.drongo.crypto.ChildNumber; import com.sparrowwallet.drongo.crypto.ECKey; import com.sparrowwallet.drongo.policy.PolicyType; -import java.time.LocalDate; -import java.time.Month; import java.util.*; import java.util.stream.Collectors; @@ -1174,10 +1172,31 @@ public enum ScriptType { return Collections.unmodifiableList(copy); } + public static boolean derivationMatchesAnotherNetwork(String derivationPath) { + if(KeyDerivation.isValid(derivationPath)) { + List derivation = new ArrayList<>(KeyDerivation.parsePath(derivationPath)); + if(derivation.size() > 1) { + int networkIndex = derivation.get(1).num(); + return Network.get() == Network.MAINNET ? (networkIndex == 1) : (networkIndex == 0); + } + } + + return false; + } + public int getAccount(String derivationPath) { + return getAccount(derivationPath, false); + } + + public int getAccount(String derivationPath, boolean ignoreNetwork) { if(KeyDerivation.isValid(derivationPath)) { List derivation = new ArrayList<>(KeyDerivation.parsePath(derivationPath)); if(derivation.size() > 2) { + if(ignoreNetwork) { + ChildNumber networkChildNumber = new ChildNumber(Network.get() == Network.MAINNET ? 0 : 1, true); + derivation.set(1, networkChildNumber); + } + int account = derivation.get(2).num(); List defaultDerivation = getDefaultDerivation(account); if(defaultDerivation.equals(derivation)) { diff --git a/src/main/java/com/sparrowwallet/drongo/wallet/Wallet.java b/src/main/java/com/sparrowwallet/drongo/wallet/Wallet.java index 68e7968..bfcc768 100644 --- a/src/main/java/com/sparrowwallet/drongo/wallet/Wallet.java +++ b/src/main/java/com/sparrowwallet/drongo/wallet/Wallet.java @@ -26,6 +26,7 @@ public class Wallet extends Persistable implements Comparable { public static final int DEFAULT_LOOKAHEAD = 20; public static final int SEARCH_LOOKAHEAD = 4000; public static final String ALLOW_DERIVATIONS_MATCHING_OTHER_SCRIPT_TYPES_PROPERTY = "com.sparrowwallet.allowDerivationsMatchingOtherScriptTypes"; + public static final String ALLOW_DERIVATIONS_MATCHING_OTHER_NETWORKS_PROPERTY = "com.sparrowwallet.allowDerivationsMatchingOtherNetworks"; private String name; private String label; @@ -1773,6 +1774,10 @@ public class Wallet extends Persistable implements Comparable { if(derivationMatchesAnotherScriptType(keystore.getKeyDerivation().getDerivationPath())) { throw new InvalidWalletException("Keystore " + keystore.getLabel() + " derivation of " + keystore.getKeyDerivation().getDerivationPath() + " in " + scriptType.getName() + " wallet matches another default script type."); } + + if(derivationMatchesAnotherNetwork(keystore.getKeyDerivation().getDerivationPath())) { + throw new InvalidWalletException("Keystore " + keystore.getLabel() + " derivation of " + keystore.getKeyDerivation().getDerivationPath() + " in " + scriptType.getName() + " wallet matches another network."); + } } if(containsDuplicateExtendedKeys()) { @@ -1789,7 +1794,19 @@ public class Wallet extends Persistable implements Comparable { return false; } - return Arrays.stream(ScriptType.values()).anyMatch(scriptType -> !scriptType.equals(this.scriptType) && scriptType.getAccount(derivationPath) > -1); + return Arrays.stream(ScriptType.values()).anyMatch(scriptType -> !scriptType.equals(this.scriptType) && scriptType.getAccount(derivationPath, true) > -1); + } + + public boolean derivationMatchesAnotherNetwork(String derivationPath) { + if(Boolean.TRUE.toString().equals(System.getProperty(ALLOW_DERIVATIONS_MATCHING_OTHER_NETWORKS_PROPERTY))) { + return false; + } + + if(scriptType != null && scriptType.getAccount(derivationPath, true) > -1) { + return ScriptType.derivationMatchesAnotherNetwork(derivationPath); + } + + return false; } public boolean containsDuplicateKeystoreLabels() {