Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
74 changes: 51 additions & 23 deletions src/embit/bech32.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,10 @@ class Encoding:
BECH32M = 2


class Bech32DecodeError(Exception):
pass


def bech32_polymod(values):
"""Internal function that computes the Bech32 checksum."""
generator = [0x3B6A57B2, 0x26508E6D, 0x1EA119FA, 0x3D4233DD, 0x2A1462B3]
Expand Down Expand Up @@ -77,21 +81,28 @@ def bech32_encode(encoding, hrp, data):

def bech32_decode(bech):
"""Validate a Bech32/Bech32m string, and determine HRP and data."""
if (any(ord(x) < 33 or ord(x) > 126 for x in bech)) or (
bech.lower() != bech and bech.upper() != bech
):
return (None, None, None)
if any(ord(x) < 33 or ord(x) > 126 for x in bech):
raise Bech32DecodeError("Invalid character in input")
if bech.lower() != bech and bech.upper() != bech:
raise Bech32DecodeError("Mixed case strings not allowed")
bech = bech.lower()
pos = bech.rfind("1")
if pos < 1 or pos + 7 > len(bech) or len(bech) > 90:
return (None, None, None)
if not all(x in CHARSET for x in bech[pos + 1 :]):
return (None, None, None)
if pos < 1:
raise Bech32DecodeError("Separator '1' not found or misplaced")
if pos > 83:
raise Bech32DecodeError("HRP too long (max 83 characters)")
if pos + 7 > len(bech):
raise Bech32DecodeError("Data part too short")
if len(bech) > 118:
raise Bech32DecodeError("String too long for SP address")
hrp = bech[:pos]
data = [CHARSET.find(x) for x in bech[pos + 1 :]]
data_part = bech[pos + 1 :]
if not all(x in CHARSET for x in data_part):
raise Bech32DecodeError("Data part contains invalid characters")
data = [CHARSET.find(x) for x in data_part]
encoding = bech32_verify_checksum(hrp, data)
if encoding is None:
return (None, None, None)
raise Bech32DecodeError("Checksum verification failed")
return (encoding, hrp, data[:-6])


Expand All @@ -104,7 +115,7 @@ def convertbits(data, frombits, tobits, pad=True):
max_acc = (1 << (frombits + tobits - 1)) - 1
for value in data:
if value < 0 or (value >> frombits):
return None
raise Bech32DecodeError("Invalid input value for bit conversion")
acc = ((acc << frombits) | value) & max_acc
bits += frombits
while bits >= tobits:
Expand All @@ -114,33 +125,50 @@ def convertbits(data, frombits, tobits, pad=True):
if bits:
ret.append((acc << (tobits - bits)) & maxv)
elif bits >= frombits or ((acc << (tobits - bits)) & maxv):
return None
raise Bech32DecodeError("Invalid padding in bit conversion")
return ret


def decode(hrp, addr):
"""Decode a segwit address."""
encoding, hrpgot, data = bech32_decode(addr)
if hrpgot != hrp:
return (None, None)
raise Bech32DecodeError(f"HRP mismatch: expected {hrp}, got {hrpgot}")
decoded = convertbits(data[1:], 5, 8, False)
if decoded is None or len(decoded) < 2 or len(decoded) > 40:
return (None, None)
if len(decoded) < 2 or len(decoded) > 66:
raise Bech32DecodeError(f"Invalid witness program length")
if data[0] > 16:
return (None, None)
if data[0] == 0 and len(decoded) != 20 and len(decoded) != 32:
return (None, None)
if (data[0] == 0 and encoding != Encoding.BECH32) or (
data[0] != 0 and encoding != Encoding.BECH32M
raise Bech32DecodeError("Invalid witness version")
if (
hrp not in ["sp", "tsp"]
and data[0] == 0
and len(decoded) != 20
and len(decoded) != 32
):
raise Bech32DecodeError("Invalid witness program length for version 0")
if hrp not in ["sp", "tsp"] and (
(data[0] == 0 and encoding != Encoding.BECH32)
or (data[0] != 0 and encoding != Encoding.BECH32M)
):
return (None, None)
raise Bech32DecodeError("Invalid encoding for witness version")
return (data[0], decoded)


def encode(hrp, witver, witprog):
"""Encode a segwit address."""
if witver < 0 or witver > 16:
raise Bech32DecodeError("Invalid witness version")
if len(witprog) < 2 or len(witprog) > 40:
raise Bech32DecodeError("Invalid witness program length")
if witver == 0 and len(witprog) != 20 and len(witprog) != 32:
raise Bech32DecodeError("Invalid witness program length for version 0")

encoding = Encoding.BECH32 if witver == 0 else Encoding.BECH32M
ret = bech32_encode(encoding, hrp, [witver] + convertbits(witprog, 8, 5))
if decode(hrp, ret) == (None, None):
return None

try:
decode(hrp, ret)
except Bech32DecodeError:
raise Bech32DecodeError("Failed to encode valid segwit address")

return ret
4 changes: 2 additions & 2 deletions src/embit/bip39.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ def mnemonic_to_bytes(mnemonic: str, ignore_checksum: bool = False, wordlist=WOR
# this function is copied from Jimmy Song's HDPrivateKey.from_mnemonic() method

words = mnemonic.strip().split()
if len(words) % 3 != 0 or len(words) < 12:
if len(words) % 3 != 0 or not 12 <= len(words) <= 24:
raise ValueError("Invalid recovery phrase")

binary_seed = bytearray()
Expand Down Expand Up @@ -97,7 +97,7 @@ def _extract_index(bits, b, n):


def mnemonic_from_bytes(entropy, wordlist=WORDLIST):
if len(entropy) % 4 != 0:
if len(entropy) % 4 != 0 or not 16 <= len(entropy) <= 32:
raise ValueError("Byte array should be multiple of 4 long (16, 20, ..., 32)")
total_bits = len(entropy) * 8
checksum_bits = total_bits // 32
Expand Down
3 changes: 3 additions & 0 deletions src/embit/descriptor/descriptor.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ def script_len(self):

@property
def num_branches(self):
if self.miniscript is not None:
return max({k.num_branches for k in self.miniscript.keys})

return max([k.num_branches for k in self.keys])

def branch(self, branch_index=None):
Expand Down
50 changes: 43 additions & 7 deletions src/embit/descriptor/miniscript.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,11 +54,25 @@ def type(self):

@classmethod
def read_from(cls, s, taproot=False):
op, char = read_until(s, b"(")
def wrapped(m_script):
for w in reversed(wrappers):
if w not in WRAPPER_NAMES:
raise MiniscriptError("Unknown wrapper")
WrapperCls = WRAPPERS[WRAPPER_NAMES.index(w)]
m_script = WrapperCls(m_script, taproot=taproot)
return m_script

op, char = read_until(s, b"(,)")
if char in (b",", b")"):
s.seek(-1, 1)
op = op.decode()
wrappers = ""
if ":" in op:
wrappers, op = op.split(":")
# handle boolean literals: 0 or 1
if op in ("0", "1"):
miniscript = JustOne() if op == "1" else JustZero()
return wrapped(miniscript)
if char != b"(":
raise MiniscriptError("Missing operator")
if op not in OPERATOR_NAMES:
Expand All @@ -67,12 +81,7 @@ def read_from(cls, s, taproot=False):
MiniscriptCls = OPERATORS[OPERATOR_NAMES.index(op)]
args = MiniscriptCls.read_arguments(s, taproot=taproot)
miniscript = MiniscriptCls(*args, taproot=taproot)
for w in reversed(wrappers):
if w not in WRAPPER_NAMES:
raise MiniscriptError("Unknown wrapper")
WrapperCls = WRAPPERS[WRAPPER_NAMES.index(w)]
miniscript = WrapperCls(miniscript, taproot=taproot)
return miniscript
return wrapped(miniscript)

@classmethod
def read_arguments(cls, s, taproot=False):
Expand Down Expand Up @@ -119,6 +128,28 @@ def len_args(self):
########### Known fragments (miniscript operators) ##############


class JustZero(Miniscript):
TYPE = "B"
PROPS = "zud"

def inner_compile(self):
return Number(0).compile()

def __str__(self):
return "0"


class JustOne(Miniscript):
TYPE = "B"
PROPS = "zu"

def inner_compile(self):
return Number(1).compile()

def __str__(self):
return "1"


class OneArg(Miniscript):
NARGS = 1

Expand Down Expand Up @@ -870,6 +901,11 @@ def inner_compile(self):

def __len__(self):
return len(self.arg) + 1

def verify(self):
super().verify()
if self.arg.type != "V":
raise MiniscriptError("t: X must be of type V")

@property
def properties(self):
Expand Down
40 changes: 30 additions & 10 deletions src/embit/psbt.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def __init__(self, unknown: dict = {}, vin=None, compress=CompressMode.KEEP_ALL)
self.taproot_bip32_derivations = OrderedDict()
self.taproot_internal_key = None
self.taproot_merkle_root = None
self.taproot_key_sig = None
self.taproot_sigs = OrderedDict()
self.taproot_scripts = OrderedDict()

Expand Down Expand Up @@ -187,6 +188,7 @@ def update(self, other):
self.taproot_bip32_derivations.update(other.taproot_bip32_derivations)
self.taproot_internal_key = other.taproot_internal_key
self.taproot_merkle_root = other.taproot_merkle_root or self.taproot_merkle_root
self.taproot_key_sig = other.taproot_key_sig or self.taproot_key_sig
self.taproot_sigs.update(other.taproot_sigs)
self.taproot_scripts.update(other.taproot_scripts)
self.final_scriptsig = other.final_scriptsig or self.final_scriptsig
Expand Down Expand Up @@ -350,7 +352,15 @@ def read_value(self, stream, k):
elif k == b"\x10":
self.sequence = int.from_bytes(v, "little")

# TODO: 0x13 - tap key signature
# PSBT_IN_TAP_KEY_SIG
elif k[0] == 0x13:
# read the taproot key sig
if len(k) != 1:
raise PSBTError("Invalid taproot key signature key")
if self.taproot_key_sig is not None:
raise PSBTError("Duplicated taproot key signature")
self.taproot_key_sig = v

# PSBT_IN_TAP_SCRIPT_SIG
elif k[0] == 0x14:
if len(k) != 65:
Expand Down Expand Up @@ -434,6 +444,11 @@ def write_to(self, stream, skip_separator=False, version=None, **kwargs) -> int:
r += ser_string(stream, b"\x10")
r += ser_string(stream, self.sequence.to_bytes(4, "little"))

# PSBT_IN_TAP_KEY_SIG
if self.taproot_key_sig is not None:
r += ser_string(stream, b"\x13")
r += ser_string(stream, self.taproot_key_sig)

# PSBT_IN_TAP_SCRIPT_SIG
for pub, leaf in self.taproot_sigs:
r += ser_string(stream, b"\x14" + pub.xonly() + leaf)
Expand Down Expand Up @@ -881,11 +896,11 @@ def sign_input_with_tapkey(
sighash=sighash,
)
sig = pk.schnorr_sign(h)
wit = sig.serialize()
sigdata = sig.serialize()
if sighash != SIGHASH.DEFAULT:
wit += bytes([sighash])
# TODO: maybe better to put into internal key sig field
inp.final_scriptwitness = Witness([wit])
sigdata += bytes([sighash])
inp.taproot_key_sig = sigdata
inp.final_scriptwitness = Witness([sigdata])
# no need to sign anything else
return 1
counter = 0
Expand Down Expand Up @@ -977,22 +992,25 @@ def sign_with(self, root, sighash=SIGHASH.DEFAULT) -> int:
continue

# get all possible derivations with matching fingerprint
bip32_derivations = set()
bip32_derivations = OrderedDict() # OrderedDict to keep order
if fingerprint:
# if taproot derivations are present add them
for pub in inp.taproot_bip32_derivations:
(_leafs, derivation) = inp.taproot_bip32_derivations[pub]
if derivation.fingerprint == fingerprint:
bip32_derivations.add((pub, derivation))
# Add only if not already present
if (pub, derivation) not in bip32_derivations:
bip32_derivations[(pub, derivation)] = True

# segwit and legacy derivations
for pub in inp.bip32_derivations:
derivation = inp.bip32_derivations[pub]
if derivation.fingerprint == fingerprint:
bip32_derivations.add((pub, derivation))
if (pub, derivation) not in bip32_derivations:
bip32_derivations[(pub, derivation)] = True

# get derived keys for signing
derived_keypairs = set() # (prv, pub)
derived_keypairs = OrderedDict() # (prv, pub)
for pub, derivation in bip32_derivations:
der = derivation.derivation
# descriptor key has origin derivation that we take into account
Expand All @@ -1008,7 +1026,9 @@ def sign_with(self, root, sighash=SIGHASH.DEFAULT) -> int:

if hdkey.xonly() != pub.xonly():
raise PSBTError("Derivation path doesn't look right")
derived_keypairs.add((hdkey.key, pub))
# Insert into derived_keypairs if not present
if (hdkey.key, pub) not in derived_keypairs:
derived_keypairs[(hdkey.key, pub)] = True

# sign with taproot key
if inp.is_taproot:
Expand Down
16 changes: 11 additions & 5 deletions src/embit/psbtview.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
Makes sense to run gc.collect() after processing of each scope to free memory.
"""
# TODO: refactor, a lot of code is duplicated here from transaction.py
from collections import OrderedDict
import hashlib
from . import compact
from . import ec
Expand Down Expand Up @@ -742,22 +743,25 @@ def sign_input(
return 0

# get all possible derivations with matching fingerprint
bip32_derivations = set()
bip32_derivations = OrderedDict()
if fingerprint:
# if taproot derivations are present add them
for pub in inp.taproot_bip32_derivations:
(_leafs, derivation) = inp.taproot_bip32_derivations[pub]
if derivation.fingerprint == fingerprint:
bip32_derivations.add((pub, derivation))
# Add only if not already present
if (pub, derivation) not in bip32_derivations:
bip32_derivations[(pub, derivation)] = True

# segwit and legacy derivations
for pub in inp.bip32_derivations:
derivation = inp.bip32_derivations[pub]
if derivation.fingerprint == fingerprint:
bip32_derivations.add((pub, derivation))
if (pub, derivation) not in bip32_derivations:
bip32_derivations[(pub, derivation)] = True

# get derived keys for signing
derived_keypairs = set() # (prv, pub)
derived_keypairs = OrderedDict() # (prv, pub)
for pub, derivation in bip32_derivations:
der = derivation.derivation
# descriptor key has origin derivation that we take into account
Expand All @@ -773,7 +777,9 @@ def sign_input(

if hdkey.xonly() != pub.xonly():
raise PSBTError("Derivation path doesn't look right")
derived_keypairs.add((hdkey.key, pub))
# Insert into derived_keypairs if not present
if (hdkey.key, pub) not in derived_keypairs:
derived_keypairs[(hdkey.key, pub)] = True

counter = 0
# sign with taproot key
Expand Down
Loading