diff --git a/CHANGELOG.md b/CHANGELOG.md index 4eb79df..d220672 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/) - [MINOR] Added `async_util.gather_batched` function to `async_util` - [MINOR] Added more exception classes - [MAJOR] Added `apiPathPattern` to `logging.api` +- [MAJOR] Added `chainId` to `EthClientInterface` +- [MINOR] Added `call` and `multicall` to `EthClientInterface` ### Changed diff --git a/core/util/chain_util.py b/core/util/chain_util.py index 3665cd5..404ff76 100644 --- a/core/util/chain_util.py +++ b/core/util/chain_util.py @@ -1,10 +1,12 @@ import typing import eth_utils +from eth_abi.exceptions import DecodingError from eth_typing import ABI from eth_typing import ABIFunction from eth_typing import HexStr from eth_utils import add_0x_prefix +from eth_utils.abi import get_abi_output_types from web3 import Web3 from web3._utils.contracts import encode_abi as web3_encode_abi from web3._utils.contracts import encode_transaction_data as web3_encode_transaction_data @@ -115,7 +117,7 @@ def encode_transaction_data_by_name( # type: ignore[explicit-any] ) -# NOTE(krishan711): not sure how difference this is from `encode_transaction_data` +# NOTE(krishan711): not sure how different this is from `encode_transaction_data` def encode_function_params(functionAbi: ABIFunction, arguments: list[typing.Any]) -> str: # type: ignore[explicit-any] return add_0x_prefix( web3_encode_abi( @@ -124,3 +126,17 @@ def encode_function_params(functionAbi: ABIFunction, arguments: list[typing.Any] arguments=arguments, ) ) + + +def decode_function_result(functionAbi: ABIFunction, resultData: bytes) -> list[typing.Any]: # type: ignore[explicit-any] + if resultData == b'' or resultData.hex() == '0x': + outputs = functionAbi.get('outputs', []) + if len(outputs) == 0: + return [] + raise BadRequestException(message=f'Empty response. Maybe the method does not exist on this contract.') + outputTypes = get_abi_output_types(abi_element=functionAbi) + try: + outputData = _w3.codec.decode(types=outputTypes, data=bytes.fromhex(resultData.hex())) + except DecodingError as exception: + raise BadRequestException(message=str(exception)) + return list(outputData) diff --git a/core/web3/eth_client.py b/core/web3/eth_client.py index 9d6fd75..d67398f 100644 --- a/core/web3/eth_client.py +++ b/core/web3/eth_client.py @@ -2,11 +2,8 @@ import typing from typing import Any -from eth_abi.exceptions import DecodingError -from eth_abi.exceptions import InsufficientDataBytes from eth_typing import ABI from eth_typing import ABIFunction -from eth_utils.abi import get_abi_output_types from pydantic import BaseModel from web3 import Web3 from web3._utils import method_formatters @@ -30,6 +27,8 @@ from core.util import chain_util from core.util import json_util from core.util.typing_util import JsonObject +from core.web3.multicall3 import CHAIN_ID_MULTICALL3_ADDRESS_MAP +from core.web3.multicall3 import MULTICALL3_ABI ListAny = list[Any] # type: ignore[explicit-any] DictStrAny = dict[str, Any] # type: ignore[explicit-any] @@ -41,6 +40,14 @@ class EncodedCall(BaseModel): data: str = '0x' +class ContractCall(BaseModel): + toAddress: str + functionName: str + contractAbi: ABI + arguments: DictStrAny | None = None + fromAddress: str | None = None + + class TransactionFailedException(KibaException): def __init__(self, transactionReceipt: TxReceipt) -> None: super().__init__(message='Transaction failed') @@ -53,8 +60,9 @@ def to_dict(self) -> JsonObject: class EthClientInterface: - def __init__(self, web3Connection: Web3, isTestnet: bool = False) -> None: + def __init__(self, web3Connection: Web3, chainId: int, isTestnet: bool = False) -> None: self.w3 = web3Connection + self.chainId = chainId self.isTestnet = isTestnet async def get_latest_block_number(self) -> int: @@ -75,7 +83,7 @@ async def get_transaction_receipt(self, transactionHash: str) -> TxReceipt: async def get_log_entries(self, topics: list[str] | None = None, startBlockNumber: int | None = None, endBlockNumber: int | None = None, address: str | None = None) -> list[LogReceipt]: raise NotImplementedError - async def call_function(self, toAddress: str, contractAbi: ABI, functionAbi: ABIFunction, fromAddress: str | None = None, arguments: DictStrAny | None = None, blockNumber: int | None = None) -> ListAny: + async def call(self, contractCall: ContractCall, blockNumber: int | None = None) -> ListAny: raise NotImplementedError async def fill_transaction_params( @@ -102,13 +110,34 @@ async def call_function_by_name( arguments: DictStrAny | None = None, blockNumber: int | None = None, ) -> ListAny: - functionAbi = chain_util.find_abi_by_name_args(contractAbi=contractAbi, functionName=functionName, arguments=arguments) - return await self.call_function( - toAddress=toAddress, - contractAbi=contractAbi, - functionAbi=functionAbi, - fromAddress=fromAddress, - arguments=arguments, + return await self.call( + contractCall=ContractCall( + toAddress=toAddress, + functionName=functionName, + contractAbi=contractAbi, + arguments=arguments, + fromAddress=fromAddress, + ), + blockNumber=blockNumber, + ) + + async def call_function( + self, + toAddress: str, + contractAbi: ABI, + functionAbi: ABIFunction, + fromAddress: str | None = None, + arguments: DictStrAny | None = None, + blockNumber: int | None = None, + ) -> ListAny: + return await self.call( + contractCall=ContractCall( + toAddress=toAddress, + functionName=functionAbi['name'], + contractAbi=contractAbi, + arguments=arguments, + fromAddress=fromAddress, + ), blockNumber=blockNumber, ) @@ -139,7 +168,7 @@ async def send_transaction( gas=gas, maxFeePerGas=maxFeePerGas, maxPriorityFeePerGas=maxPriorityFeePerGas, - chainId=chainId, + chainId=chainId or self.chainId, ) signedParams = self.w3.eth.account.sign_transaction(transaction_dict=params, private_key=privateKey) output = await self.send_raw_transaction(transactionData=signedParams.raw_transaction.hex()) @@ -171,7 +200,7 @@ async def send_transaction_by_name( maxFeePerGas=maxFeePerGas, maxPriorityFeePerGas=maxPriorityFeePerGas, arguments=arguments, - chainId=chainId, + chainId=chainId or self.chainId, ) async def wait_for_transaction_receipt(self, transactionHash: str, sleepSeconds: int = 2, maxWaitSeconds: int = 120, raiseOnFailure: bool = True) -> TxReceipt: @@ -191,8 +220,38 @@ async def wait_for_transaction_receipt(self, transactionHash: str, sleepSeconds: raise TransactionFailedException(transactionReceipt=transactionReceipt) return transactionReceipt + async def multicall(self, contractCalls: list[ContractCall], shouldUseMulticall3: bool = True) -> list[ListAny]: + multicall3Address = CHAIN_ID_MULTICALL3_ADDRESS_MAP.get(self.chainId) if shouldUseMulticall3 else None + if not multicall3Address or not shouldUseMulticall3: + results = await asyncio.gather(*[self.call(contractCall=contractCall) for contractCall in contractCalls]) + return results + multicallResponse = await self.call_function_by_name( + toAddress=multicall3Address, + contractAbi=MULTICALL3_ABI, + functionName='aggregate3', + arguments={ + 'calls': [ + {'target': contractCall.toAddress, 'allowFailure': False, 'callData': chain_util.encode_transaction_data_by_name(contractAbi=contractCall.contractAbi, functionName=contractCall.functionName, arguments=contractCall.arguments)} + for contractCall in contractCalls + ] + }, + ) + multicallResults = multicallResponse[0] + decodedResults: list[ListAny] = [] + for contractCall, (success, returnData) in zip(contractCalls, multicallResults, strict=False): + if not success: + decodedResults.append([None]) + continue + functionAbi = chain_util.find_abi_by_name_args(contractAbi=contractCall.contractAbi, functionName=contractCall.functionName, arguments=contractCall.arguments) + decodedValue = chain_util.decode_function_result(functionAbi=functionAbi, resultData=returnData) + decodedResults.append(decodedValue) + return decodedResults + class Web3EthClient(EthClientInterface): + def __init__(self, web3Connection: Web3, chainId: int, isTestnet: bool = False) -> None: + super().__init__(web3Connection=web3Connection, chainId=chainId, isTestnet=isTestnet) + async def get_latest_block_number(self) -> int: return self.w3.eth.block_number @@ -231,11 +290,10 @@ async def get_log_entries( class RestEthClient(EthClientInterface): # NOTE(krishan711): find docs at https://eth.wiki/json-rpc/API - def __init__(self, url: str, requester: Requester, isTestnet: bool = False, shouldBackoffRetryOnRateLimit: bool = True, retryLimit: int = 10) -> None: - super().__init__(web3Connection=Web3(), isTestnet=isTestnet) + def __init__(self, url: str, requester: Requester, chainId: int, isTestnet: bool = False, shouldBackoffRetryOnRateLimit: bool = True, retryLimit: int = 10) -> None: + super().__init__(web3Connection=Web3(), chainId=chainId, isTestnet=isTestnet) self.url = url self.requester = requester - self.isTestnet = isTestnet self.shouldBackoffRetryOnRateLimit = shouldBackoffRetryOnRateLimit self.retryLimit = retryLimit self.w3 = Web3() @@ -350,33 +408,20 @@ async def get_log_entries( response = await self._make_request(method='eth_getLogs', params=[params]) return typing.cast(list[LogReceipt], method_formatters.PYTHONIC_RESULT_FORMATTERS[RPC.eth_getLogs](response['result'])) - async def call_function( + async def call( self, - toAddress: str, - # TODO(krishan711): remove on major bump - contractAbi: ABI, # noqa: ARG002 - functionAbi: ABIFunction, - fromAddress: str | None = None, - arguments: DictStrAny | None = None, + contractCall: ContractCall, blockNumber: int | None = None, ) -> ListAny: - data = chain_util.encode_transaction_data(functionAbi=functionAbi, arguments=arguments) params = { - 'from': fromAddress or '0x0000000000000000000000000000000000000000', - 'to': toAddress, - 'data': data, + 'from': contractCall.fromAddress or '0x0000000000000000000000000000000000000000', + 'to': contractCall.toAddress, + 'data': chain_util.encode_transaction_data_by_name(contractAbi=contractCall.contractAbi, functionName=contractCall.functionName, arguments=contractCall.arguments), } + functionAbi = chain_util.find_abi_by_name_args(contractAbi=contractCall.contractAbi, functionName=contractCall.functionName, arguments=contractCall.arguments) response = await self._make_request(method='eth_call', params=[params, hex(blockNumber) if blockNumber is not None else 'latest']) - outputTypes = get_abi_output_types(abi_element=functionAbi) - try: - outputData = self.w3.codec.decode(types=outputTypes, data=HexBytes(response['result'])) - except InsufficientDataBytes as exception: - if response['result'] == '0x': - raise BadRequestException(message=f'Empty response: {exception!s}. Maybe the method does not exist on this contract.') - raise - except DecodingError as exception: - raise BadRequestException(message=str(exception)) - return list(outputData) + decodedResponse = chain_util.decode_function_result(functionAbi=functionAbi, resultData=HexBytes(response['result'])) + return decodedResponse async def get_max_priority_fee_per_gas(self) -> int: response = await self._make_request(method='eth_maxPriorityFeePerGas') @@ -402,7 +447,7 @@ async def fill_transaction_params( ) -> TxParams: if 'chainId' not in params: if chainId is None: - raise BadRequestException(message='chainId is required') + chainId = self.chainId params['chainId'] = hex(chainId) # type: ignore[typeddict-item] if 'nonce' not in params: if nonce is None: diff --git a/core/web3/multicall3.py b/core/web3/multicall3.py new file mode 100644 index 0000000..f824d22 --- /dev/null +++ b/core/web3/multicall3.py @@ -0,0 +1,39 @@ +# mypy: disable-error-code="typeddict-unknown-key, misc, list-item, typeddict-item" + +from eth_typing import ABI + +# Get from https://www.multicall3.com/deployments +CHAIN_ID_MULTICALL3_ADDRESS_MAP: dict[int, str] = { + 1: '0xcA11bde05977b3631167028862bE2a173976CA11', + 8453: '0xcA11bde05977b3631167028862bE2a173976CA11', + 84532: '0xcA11bde05977b3631167028862bE2a173976CA11', +} + +MULTICALL3_ABI: ABI = [ + { + 'inputs': [ + { + 'components': [ + {'name': 'target', 'type': 'address'}, + {'name': 'allowFailure', 'type': 'bool'}, + {'name': 'callData', 'type': 'bytes'}, + ], + 'name': 'calls', + 'type': 'tuple[]', + } + ], + 'name': 'aggregate3', + 'outputs': [ + { + 'components': [ + {'name': 'success', 'type': 'bool'}, + {'name': 'returnData', 'type': 'bytes'}, + ], + 'name': 'returnData', + 'type': 'tuple[]', + } + ], + 'stateMutability': 'payable', + 'type': 'function', + }, +] diff --git a/tests/web3/test_rest_eth_client.py b/tests/web3/test_rest_eth_client.py index 8500247..3925c8a 100644 --- a/tests/web3/test_rest_eth_client.py +++ b/tests/web3/test_rest_eth_client.py @@ -69,6 +69,7 @@ def client(self, mock_requester): return RestEthClient( url='https://test-rpc-url.com', requester=mock_requester, + chainId=1, isTestnet=False ) @@ -146,6 +147,7 @@ async def test_get_block_testnet_strips_extra_data(self, mock_requester): client = RestEthClient( url='https://test-rpc-url.com', requester=mock_requester, + chainId=1, isTestnet=True ) mock_requester.responses['eth_getBlockByNumber'] = { @@ -440,6 +442,7 @@ async def test_rate_limit_retry_success(self, mock_requester): client = RestEthClient( url='https://test-rpc-url.com', requester=mock_requester, + chainId=1, shouldBackoffRetryOnRateLimit=True, retryLimit=2 ) @@ -466,6 +469,7 @@ async def test_rate_limit_retry_exhausted(self, mock_requester): client = RestEthClient( url='https://test-rpc-url.com', requester=mock_requester, + chainId=1, shouldBackoffRetryOnRateLimit=True, retryLimit=1 ) @@ -480,6 +484,7 @@ async def test_rate_limit_retry_disabled(self, mock_requester): client = RestEthClient( url='https://test-rpc-url.com', requester=mock_requester, + chainId=1, shouldBackoffRetryOnRateLimit=False ) async def mock_post_json_rate_limit(*args, **kwargs): @@ -511,19 +516,6 @@ async def test_fill_transaction_params_all_provided(self, client, mock_requester assert result == params assert len(mock_requester.requests_made) == 0 - @pytest.mark.asyncio - async def test_fill_transaction_params_missing_chain_id_raises_error(self, client, mock_requester): - params = { - 'to': '0x1234567890123456789012345678901234567890', - 'from': '0x0987654321098765432109876543210987654321' - } - with pytest.raises(BadRequestException) as exc_info: - await client.fill_transaction_params( - params=params, - fromAddress='0x0987654321098765432109876543210987654321' - ) - assert 'chainId is required' in str(exc_info.value) - @pytest.mark.asyncio async def test_fill_transaction_params_fetches_missing_values(self, client, mock_requester): mock_requester.responses = { @@ -590,3 +582,255 @@ async def test_fill_transaction_params_fetches_missing_values(self, client, mock assert 'eth_estimateGas' in methods_called assert 'eth_maxPriorityFeePerGas' in methods_called assert 'eth_getBlockByNumber' in methods_called + + @pytest.mark.asyncio + async def test_call_function_success(self, client, mock_requester): + mock_requester.responses['eth_call'] = { + 'jsonrpc': '2.0', + 'result': '0x000000000000000000000000000000000000000000000000000000000000007b', + 'id': None + } + function_abi = { + 'inputs': [], + 'name': 'getValue', + 'outputs': [{'name': '', 'type': 'uint256'}], + 'stateMutability': 'view', + 'type': 'function' + } + result = await client.call_function( + toAddress='0x1234567890123456789012345678901234567890', + contractAbi=[function_abi], + functionAbi=function_abi, + fromAddress='0x0987654321098765432109876543210987654321' + ) + assert len(result) == 1 + assert result[0] == 123 + assert len(mock_requester.requests_made) == 1 + request = mock_requester.requests_made[0]['dataDict'] + assert request['method'] == 'eth_call' + assert request['params'][0]['to'] == '0x1234567890123456789012345678901234567890' + assert request['params'][0]['from'] == '0x0987654321098765432109876543210987654321' + assert request['params'][1] == 'latest' + + @pytest.mark.asyncio + async def test_call_function_with_block_number(self, client, mock_requester): + mock_requester.responses['eth_call'] = { + 'jsonrpc': '2.0', + 'result': '0x000000000000000000000000000000000000000000000000000000000000007b', + 'id': None + } + function_abi = { + 'inputs': [], + 'name': 'getValue', + 'outputs': [{'name': '', 'type': 'uint256'}], + 'stateMutability': 'view', + 'type': 'function' + } + result = await client.call_function( + toAddress='0x1234567890123456789012345678901234567890', + contractAbi=[function_abi], + functionAbi=function_abi, + blockNumber=0x100 + ) + assert len(result) == 1 + assert result[0] == 123 + request = mock_requester.requests_made[0]['dataDict'] + assert request['params'][1] == hex(0x100) + + @pytest.mark.asyncio + async def test_call_function_with_arguments(self, client, mock_requester): + mock_requester.responses['eth_call'] = { + 'jsonrpc': '2.0', + 'result': '0x0000000000000000000000000000000000000000000000000000000000000001', + 'id': None + } + function_abi = { + 'inputs': [{'name': 'amount', 'type': 'uint256'}], + 'name': 'hasEnoughBalance', + 'outputs': [{'name': '', 'type': 'bool'}], + 'stateMutability': 'view', + 'type': 'function' + } + result = await client.call_function( + toAddress='0x1234567890123456789012345678901234567890', + contractAbi=[function_abi], + functionAbi=function_abi, + arguments={'amount': 1000} + ) + assert result[0] is True + + @pytest.mark.asyncio + async def test_call_function_by_name_success(self, client, mock_requester): + mock_requester.responses['eth_call'] = { + 'jsonrpc': '2.0', + 'result': '0x000000000000000000000000000000000000000000000000000000000000007b', + 'id': None + } + contract_abi = [{ + 'inputs': [], + 'name': 'getValue', + 'outputs': [{'name': '', 'type': 'uint256'}], + 'stateMutability': 'view', + 'type': 'function' + }] + result = await client.call_function_by_name( + toAddress='0x1234567890123456789012345678901234567890', + contractAbi=contract_abi, + functionName='getValue' + ) + assert result[0] == 123 + + @pytest.mark.asyncio + async def test_call_success(self, client, mock_requester): + from core.web3.eth_client import ContractCall + mock_requester.responses['eth_call'] = { + 'jsonrpc': '2.0', + 'result': '0x000000000000000000000000000000000000000000000000000000000000007b', + 'id': None + } + contract_abi = [{ + 'inputs': [], + 'name': 'getValue', + 'outputs': [{'name': '', 'type': 'uint256'}], + 'stateMutability': 'view', + 'type': 'function' + }] + contract_call = ContractCall( + toAddress='0x1234567890123456789012345678901234567890', + functionName='getValue', + contractAbi=contract_abi, + fromAddress='0x0987654321098765432109876543210987654321' + ) + result = await client.call(contractCall=contract_call) + assert result[0] == 123 + + @pytest.mark.asyncio + async def test_call_with_block_number(self, client, mock_requester): + from core.web3.eth_client import ContractCall + mock_requester.responses['eth_call'] = { + 'jsonrpc': '2.0', + 'result': '0x000000000000000000000000000000000000000000000000000000000000007b', + 'id': None + } + contract_abi = [{ + 'inputs': [], + 'name': 'getValue', + 'outputs': [{'name': '', 'type': 'uint256'}], + 'stateMutability': 'view', + 'type': 'function' + }] + contract_call = ContractCall( + toAddress='0x1234567890123456789012345678901234567890', + functionName='getValue', + contractAbi=contract_abi + ) + result = await client.call(contractCall=contract_call, blockNumber=436) + assert result[0] == 123 + request = mock_requester.requests_made[0]['dataDict'] + assert request['params'][1] == hex(436) + + + @pytest.mark.asyncio + async def test_multicall_fallback_to_individual_calls(self, client, mock_requester): + from core.web3.eth_client import ContractCall + unsupported_client = RestEthClient( + url='https://test-rpc-url.com', + requester=mock_requester, + chainId=999, + isTestnet=False + ) + call_responses = [ + '0x000000000000000000000000000000000000000000000000000000000000007b', + '0x00000000000000000000000000000000000000000000000000000000000000c8' + ] + call_count = 0 + async def mock_post_json(url, dataDict, timeout=None): + nonlocal call_count + response_data = call_responses[call_count] + call_count += 1 + class MockResponse: + def json(self): + return { + 'jsonrpc': '2.0', + 'result': response_data, + 'id': None + } + return MockResponse() + mock_requester.post_json = mock_post_json + contract_abi = [{ + 'inputs': [], + 'name': 'getValue', + 'outputs': [{'name': '', 'type': 'uint256'}], + 'stateMutability': 'view', + 'type': 'function' + }] + contract_calls = [ + ContractCall( + toAddress='0x1234567890123456789012345678901234567890', + functionName='getValue', + contractAbi=contract_abi + ), + ContractCall( + toAddress='0x1234567890123456789012345678901234567890', + functionName='getValue', + contractAbi=contract_abi + ) + ] + result = await unsupported_client.multicall(contractCalls=contract_calls) + assert len(result) == 2 + assert result[0][0] == 123 + assert result[1][0] == 200 + assert call_count == 2 + + @pytest.mark.asyncio + async def test_multicall_disabled_explicitly(self, client, mock_requester): + from core.web3.eth_client import ContractCall + mock_requester.responses['eth_call'] = { + 'jsonrpc': '2.0', + 'result': '0x000000000000000000000000000000000000000000000000000000000000007b', + 'id': None + } + contract_abi = [{ + 'inputs': [], + 'name': 'getValue', + 'outputs': [{'name': '', 'type': 'uint256'}], + 'stateMutability': 'view', + 'type': 'function' + }] + contract_calls = [ + ContractCall( + toAddress='0x1234567890123456789012345678901234567890', + functionName='getValue', + contractAbi=contract_abi + ) + ] + result = await client.multicall(contractCalls=contract_calls, shouldUseMulticall3=False) + assert len(result) == 1 + assert result[0][0] == 123 + assert len(mock_requester.requests_made) == 1 + request = mock_requester.requests_made[0]['dataDict'] + assert request['params'][0]['to'] == '0x1234567890123456789012345678901234567890' + + @pytest.mark.asyncio + async def test_call_function_empty_response_error(self, client, mock_requester): + # Mock empty response (0x) + mock_requester.responses['eth_call'] = { + 'jsonrpc': '2.0', + 'result': '0x', + 'id': None + } + function_abi = { + 'inputs': [], + 'name': 'getValue', + 'outputs': [{'name': '', 'type': 'uint256'}], + 'stateMutability': 'view', + 'type': 'function' + } + with pytest.raises(BadRequestException) as exc_info: + await client.call_function( + toAddress='0x1234567890123456789012345678901234567890', + contractAbi=[function_abi], + functionAbi=function_abi + ) + assert 'Empty response' in str(exc_info.value) + assert 'Maybe the method does not exist on this contract' in str(exc_info.value)