Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

setup(
name="story_protocol_python_sdk",
version="0.3.17",
version="0.3.18",
packages=find_packages(where="src", exclude=["tests"]),
package_dir={"": "src"},
install_requires=["web3>=7.0.0", "pytest", "python-dotenv", "base58"],
Expand Down
2 changes: 1 addition & 1 deletion src/story_protocol_python_sdk/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
__version__ = "0.3.17"
__version__ = "0.3.18"

from .resources.Dispute import Dispute
from .resources.IPAccount import IPAccount
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def removeIp(self, groupIpId, ipIds):
return self.contract.functions.removeIp(groupIpId, ipIds).transact()

def build_removeIp_transaction(self, groupIpId, ipIds, tx_params):
return self.contract.functions.removeIp(
groupIpId, ipIds
).build_transaction(tx_params)
return self.contract.functions.removeIp(groupIpId, ipIds).build_transaction(
tx_params
)

def claimReward(self, groupId, token, ipIds):
return self.contract.functions.claimReward(groupId, token, ipIds).transact()
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,7 @@ def __init__(self, web3: Web3):
self.contract = self.web3.eth.contract(address=contract_address, abi=abi)

def getTotalTokensByLicensor(self, licensorIpId):
return self.contract.functions.getTotalTokensByLicensor(
licensorIpId
).call()
return self.contract.functions.getTotalTokensByLicensor(licensorIpId).call()

def ownerOf(self, tokenId):
return self.contract.functions.ownerOf(tokenId).call()
Original file line number Diff line number Diff line change
Expand Up @@ -55,3 +55,9 @@ def build_claimAllRevenue_transaction(
return self.contract.functions.claimAllRevenue(
ancestorIpId, claimer, childIpIds, royaltyPolicies, currencyTokens
).build_transaction(tx_params)

def multicall(self, data):
return self.contract.functions.multicall(data).transact()

def build_multicall_transaction(self, data, tx_params):
return self.contract.functions.multicall(data).build_transaction(tx_params)
6 changes: 4 additions & 2 deletions src/story_protocol_python_sdk/resources/IPAsset.py
Original file line number Diff line number Diff line change
Expand Up @@ -2144,7 +2144,7 @@ def is_registered(self, ip_id: str) -> bool:
"""
if not ip_id:
raise ValueError("is_registered: ip_id is required")

if not self.web3.is_address(ip_id):
raise ValueError(f"is_registered: invalid IP ID address format: {ip_id}")

Expand All @@ -2170,7 +2170,9 @@ def _parse_tx_ip_registered_event(self, tx_receipt: dict) -> list[RegisteredIP]:
)
registered_ips.append(
RegisteredIP(
ip_id=self.web3.to_checksum_address(event_result["args"]["ipId"]),
ip_id=self.web3.to_checksum_address(
event_result["args"]["ipId"]
),
token_id=event_result["args"]["tokenId"],
)
)
Expand Down
167 changes: 167 additions & 0 deletions src/story_protocol_python_sdk/resources/Royalty.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
IpRoyaltyVaultImplClient,
)
from story_protocol_python_sdk.abi.MockERC20.MockERC20_client import MockERC20Client
from story_protocol_python_sdk.abi.Multicall3.Multicall3_client import Multicall3Client
from story_protocol_python_sdk.abi.RoyaltyModule.RoyaltyModule_client import (
RoyaltyModuleClient,
)
Expand Down Expand Up @@ -56,6 +57,7 @@ def __init__(self, web3: Web3, account, chain_id: int):
self.mock_erc20_client = MockERC20Client(web3)
self.royalty_policy_lrp_client = RoyaltyPolicyLRPClient(web3)
self.wrapped_ip_client = WrappedIPClient(web3)
self.multicall3_client = Multicall3Client(web3)

def get_royalty_vault_address(self, ip_id: str) -> str:
"""
Expand Down Expand Up @@ -222,6 +224,171 @@ def claim_all_revenue(
except Exception as e:
raise ValueError(f"Failed to claim all revenue: {str(e)}")

def batch_claim_all_revenue(
self,
ancestor_ips: list[dict],
claim_options: dict | None = None,
options: dict | None = None,
tx_options: dict | None = None,
) -> dict:
"""
Batch claims all revenue from the child IPs of multiple ancestor IPs.
If multicall is disabled, it will call claim_all_revenue for each ancestor IP.
Then transfer all claimed tokens to the wallet if the wallet owns the IP or is the claimer.
If claimed token is WIP, it will also be converted back to native tokens.

Even if there are no child IPs, you must still populate `currency_tokens` in each ancestor IP
with the token addresses you wish to claim. This is required for the claim operation to know which
token balances to process.

:param ancestor_ips list[dict]: List of ancestor IP configurations, each containing:
:param ip_id str: The IP ID of the ancestor.
:param claimer str: The address of the claimer.
:param child_ip_ids list: List of child IP IDs.
:param royalty_policies list: List of royalty policy addresses.
:param currency_tokens list: List of currency token addresses.
:param claim_options dict: [Optional] Options for auto_transfer_all_claimed_tokens_from_ip and auto_unwrap_ip_tokens. Default values are True.
:param options dict: [Optional] Options for use_multicall_when_possible. Default is True.
:param tx_options dict: [Optional] Transaction options.
:return dict: Dictionary with transaction hashes, receipts, and claimed tokens.
:return tx_hashes list[str]: List of transaction hashes.
:return receipts list[dict]: List of transaction receipts.
:return claimed_tokens list[dict]: Aggregated list of claimed tokens.
"""
try:
tx_hashes = []
receipts = []
claimed_tokens = []

use_multicall = (
options.get("use_multicall_when_possible", True) if options else True
)

# If only 1 ancestor IP or multicall is disabled, call claim_all_revenue for each
if len(ancestor_ips) == 1 or not use_multicall:
for ancestor_ip in ancestor_ips:
result = self.claim_all_revenue(
ancestor_ip_id=ancestor_ip["ip_id"],
claimer=ancestor_ip["claimer"],
child_ip_ids=ancestor_ip["child_ip_ids"],
royalty_policies=ancestor_ip["royalty_policies"],
currency_tokens=ancestor_ip["currency_tokens"],
claim_options={
"auto_transfer_all_claimed_tokens_from_ip": False,
"auto_unwrap_ip_tokens": False,
},
tx_options=tx_options,
)
tx_hashes.extend(result["tx_hashes"])
receipts.append(result["receipt"])
if result.get("claimed_tokens"):
claimed_tokens.extend(result["claimed_tokens"])
else:
# Batch claimAllRevenue calls into a single multicall
encoded_txs = []
for ancestor_ip in ancestor_ips:
encoded_data = self.royalty_workflows_client.contract.functions.claimAllRevenue(
validate_address(ancestor_ip["ip_id"]),
validate_address(ancestor_ip["claimer"]),
validate_addresses(ancestor_ip["child_ip_ids"]),
validate_addresses(ancestor_ip["royalty_policies"]),
validate_addresses(ancestor_ip["currency_tokens"]),
)._encode_transaction_data()
encoded_txs.append(encoded_data)

response = build_and_send_transaction(
self.web3,
self.account,
self.royalty_workflows_client.build_multicall_transaction,
encoded_txs,
tx_options=tx_options,
)
tx_hashes.append(response["tx_hash"])
receipts.append(response["tx_receipt"])

# Parse claimed tokens from the receipt
claimed_token_logs = self._parse_tx_revenue_token_claimed_event(
response["tx_receipt"]
)
claimed_tokens.extend(claimed_token_logs)

# Aggregate claimed tokens by claimer and token address
aggregated_claimed_tokens = {}
for token in claimed_tokens:
key = f"{token['claimer']}_{token['token']}"
if key not in aggregated_claimed_tokens:
aggregated_claimed_tokens[key] = dict(token)
else:
aggregated_claimed_tokens[key]["amount"] += token["amount"]

aggregated_claimed_tokens = list(aggregated_claimed_tokens.values())

# Get unique claimers
claimers = list(set(ancestor_ip["claimer"] for ancestor_ip in ancestor_ips))

auto_transfer = (
claim_options.get("auto_transfer_all_claimed_tokens_from_ip", True)
if claim_options
else True
)
auto_unwrap = (
claim_options.get("auto_unwrap_ip_tokens", True)
if claim_options
else True
)

wip_claimable_amounts = 0

for claimer in claimers:
owns_claimer, is_claimer_ip, ip_account = self._get_claimer_info(
claimer
)

# If ownsClaimer is false, skip
if not owns_claimer:
continue

filter_claimed_tokens = [
token
for token in aggregated_claimed_tokens
if token["claimer"] == claimer
]

# Transfer claimed tokens from IP to wallet if wallet owns IP
if auto_transfer and is_claimer_ip and owns_claimer:
hashes = self._transfer_claimed_tokens_from_ip_to_wallet(
ip_account, filter_claimed_tokens
)
tx_hashes.extend(hashes)

# Sum up the amount of WIP tokens claimed
for token in filter_claimed_tokens:
if token["token"] == WIP_TOKEN_ADDRESS:
wip_claimable_amounts += token["amount"]

# Unwrap WIP tokens if needed
if wip_claimable_amounts > 0 and auto_unwrap:
hashes = self._unwrap_claimed_tokens_from_ip_to_wallet(
[
{
"token": WIP_TOKEN_ADDRESS,
"amount": wip_claimable_amounts,
"claimer": self.account.address,
}
]
)
tx_hashes.extend(hashes)

return {
"receipts": receipts,
"claimed_tokens": aggregated_claimed_tokens,
"tx_hashes": tx_hashes,
}

except Exception as e:
error_msg = str(e).replace("Failed to claim all revenue: ", "").strip()
raise ValueError(f"Failed to batch claim all revenue: {error_msg}")

def transfer_to_vault(
self,
ip_id: str,
Expand Down
13 changes: 6 additions & 7 deletions src/story_protocol_python_sdk/utils/transaction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,9 +44,7 @@ def _get_transaction_options(
# Gas: bump for replacement, or use tx_options
if bump_gas:
try:
opts["gasPrice"] = int(
web3.eth.gas_price * REPLACEMENT_GAS_BUMP_RATIO
)
opts["gasPrice"] = int(web3.eth.gas_price * REPLACEMENT_GAS_BUMP_RATIO)
except Exception:
opts["gasPrice"] = web3.to_wei(2, "gwei")
else:
Expand All @@ -55,16 +53,17 @@ def _get_transaction_options(
if "maxFeePerGas" in tx_options:
opts["maxFeePerGas"] = tx_options["maxFeePerGas"]

# Gas limit: use explicit gas if provided to avoid estimation
if "gas" in tx_options:
opts["gas"] = tx_options["gas"]

return opts


def _is_retryable_send_error(exc: Exception) -> bool:
"""True if we should retry send (same nonce, higher gas)."""
msg = str(exc).lower()
return (
"replacement transaction underpriced" in msg
or "nonce too low" in msg
)
return "replacement transaction underpriced" in msg or "nonce too low" in msg


def _send_one(
Expand Down
54 changes: 28 additions & 26 deletions tests/integration/test_integration_group.py
Original file line number Diff line number Diff line change
Expand Up @@ -376,9 +376,7 @@ def _normalize_address(web3, addr: str) -> str:
class TestAddIpsToGroupAndRemoveIpsFromGroup:
"""Integration tests for add_ips_to_group and remove_ips_from_group with strict on-chain verification."""

def test_add_ips_to_group(
self, story_client: StoryClient, nft_collection: Address
):
def test_add_ips_to_group(self, story_client: StoryClient, nft_collection: Address):
"""Test adding IPs to an existing group; verify chain state via AddedIpToGroup event and get_claimable_reward."""
result1 = GroupTestHelper.mint_and_register_ip_asset_with_pil_terms(
story_client, nft_collection
Expand All @@ -403,17 +401,19 @@ def test_add_ips_to_group(
assert isinstance(result["tx_hash"], str)
assert len(result["tx_hash"]) > 0
# Strict: verify on-chain AddedIpToGroup event
assert "tx_receipt" in result, "add_ips_to_group must return tx_receipt for verification"
assert (
"tx_receipt" in result
), "add_ips_to_group must return tx_receipt for verification"
added_events = story_client.Group.get_added_ip_to_group_events(
result["tx_receipt"]
)
assert len(added_events) == 1
assert _normalize_address(story_client.web3, added_events[0]["groupId"]) == _normalize_address(
story_client.web3, group_ip_id
)
assert set(_normalize_address(story_client.web3, a) for a in added_events[0]["ipIds"]) == {
_normalize_address(story_client.web3, ip_id2)
}
assert _normalize_address(
story_client.web3, added_events[0]["groupId"]
) == _normalize_address(story_client.web3, group_ip_id)
assert set(
_normalize_address(story_client.web3, a) for a in added_events[0]["ipIds"]
) == {_normalize_address(story_client.web3, ip_id2)}
# Verify new member is in group: get_claimable_reward for [ip_id1, ip_id2] should succeed
claimable = story_client.Group.get_claimable_reward(
group_ip_id=group_ip_id,
Expand Down Expand Up @@ -454,12 +454,12 @@ def test_add_ips_to_group_with_max_reward_share(
result["tx_receipt"]
)
assert len(added_events) == 1
assert _normalize_address(story_client.web3, added_events[0]["groupId"]) == _normalize_address(
story_client.web3, group_ip_id
)
assert set(_normalize_address(story_client.web3, a) for a in added_events[0]["ipIds"]) == {
_normalize_address(story_client.web3, ip_id2)
}
assert _normalize_address(
story_client.web3, added_events[0]["groupId"]
) == _normalize_address(story_client.web3, group_ip_id)
assert set(
_normalize_address(story_client.web3, a) for a in added_events[0]["ipIds"]
) == {_normalize_address(story_client.web3, ip_id2)}

def test_remove_ips_from_group(
self, story_client: StoryClient, nft_collection: Address
Expand Down Expand Up @@ -492,12 +492,12 @@ def test_remove_ips_from_group(
result["tx_receipt"]
)
assert len(removed_events) == 1
assert _normalize_address(story_client.web3, removed_events[0]["groupId"]) == _normalize_address(
story_client.web3, group_ip_id
)
assert set(_normalize_address(story_client.web3, a) for a in removed_events[0]["ipIds"]) == {
_normalize_address(story_client.web3, ip_id2)
}
assert _normalize_address(
story_client.web3, removed_events[0]["groupId"]
) == _normalize_address(story_client.web3, group_ip_id)
assert set(
_normalize_address(story_client.web3, a) for a in removed_events[0]["ipIds"]
) == {_normalize_address(story_client.web3, ip_id2)}
# After remove, only ip_id1 remains; get_claimable_reward for [ip_id1] must succeed
claimable = story_client.Group.get_claimable_reward(
group_ip_id=group_ip_id,
Expand Down Expand Up @@ -540,7 +540,9 @@ def test_add_then_remove_ips_from_group(
add_result["tx_receipt"]
)
assert len(added_events) == 1
assert set(_normalize_address(story_client.web3, a) for a in added_events[0]["ipIds"]) == {
assert set(
_normalize_address(story_client.web3, a) for a in added_events[0]["ipIds"]
) == {
_normalize_address(story_client.web3, ip_id2),
_normalize_address(story_client.web3, ip_id3),
}
Expand All @@ -556,9 +558,9 @@ def test_add_then_remove_ips_from_group(
remove_result["tx_receipt"]
)
assert len(removed_events) == 1
assert set(_normalize_address(story_client.web3, a) for a in removed_events[0]["ipIds"]) == {
_normalize_address(story_client.web3, ip_id2)
}
assert set(
_normalize_address(story_client.web3, a) for a in removed_events[0]["ipIds"]
) == {_normalize_address(story_client.web3, ip_id2)}

# Final state: only ip_id1 and ip_id3 are members
claimable = story_client.Group.get_claimable_reward(
Expand Down
Loading
Loading