Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,8 @@ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/)
- [MINOR] Added `async_util.gather_batched` function to `async_util`
- [MINOR] Added more exception classes
- [MAJOR] Added `apiPathPattern` to `logging.api`
- [MAJOR] Added `chainId` to `EthClientInterface`
- [MINOR] Added `call` and `multicall` to `EthClientInterface`

### Changed

Expand Down
18 changes: 17 additions & 1 deletion core/util/chain_util.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
import typing

import eth_utils
from eth_abi.exceptions import DecodingError
from eth_typing import ABI
from eth_typing import ABIFunction
from eth_typing import HexStr
from eth_utils import add_0x_prefix
from eth_utils.abi import get_abi_output_types
from web3 import Web3
from web3._utils.contracts import encode_abi as web3_encode_abi
from web3._utils.contracts import encode_transaction_data as web3_encode_transaction_data
Expand Down Expand Up @@ -115,7 +117,7 @@ def encode_transaction_data_by_name( # type: ignore[explicit-any]
)


# NOTE(krishan711): not sure how difference this is from `encode_transaction_data`
# NOTE(krishan711): not sure how different this is from `encode_transaction_data`
def encode_function_params(functionAbi: ABIFunction, arguments: list[typing.Any]) -> str: # type: ignore[explicit-any]
return add_0x_prefix(
web3_encode_abi(
Expand All @@ -124,3 +126,17 @@ def encode_function_params(functionAbi: ABIFunction, arguments: list[typing.Any]
arguments=arguments,
)
)


def decode_function_result(functionAbi: ABIFunction, resultData: bytes) -> list[typing.Any]: # type: ignore[explicit-any]
if resultData == b'' or resultData.hex() == '0x':
outputs = functionAbi.get('outputs', [])
if len(outputs) == 0:
return []
raise BadRequestException(message=f'Empty response. Maybe the method does not exist on this contract.')
outputTypes = get_abi_output_types(abi_element=functionAbi)
try:
outputData = _w3.codec.decode(types=outputTypes, data=bytes.fromhex(resultData.hex()))
except DecodingError as exception:
raise BadRequestException(message=str(exception))
return list(outputData)
123 changes: 84 additions & 39 deletions core/web3/eth_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,11 +2,8 @@
import typing
from typing import Any

from eth_abi.exceptions import DecodingError
from eth_abi.exceptions import InsufficientDataBytes
from eth_typing import ABI
from eth_typing import ABIFunction
from eth_utils.abi import get_abi_output_types
from pydantic import BaseModel
from web3 import Web3
from web3._utils import method_formatters
Expand All @@ -30,6 +27,8 @@
from core.util import chain_util
from core.util import json_util
from core.util.typing_util import JsonObject
from core.web3.multicall3 import CHAIN_ID_MULTICALL3_ADDRESS_MAP
from core.web3.multicall3 import MULTICALL3_ABI

ListAny = list[Any] # type: ignore[explicit-any]
DictStrAny = dict[str, Any] # type: ignore[explicit-any]
Expand All @@ -41,6 +40,14 @@ class EncodedCall(BaseModel):
data: str = '0x'


class ContractCall(BaseModel):
toAddress: str
functionName: str
contractAbi: ABI
arguments: DictStrAny | None = None
fromAddress: str | None = None


class TransactionFailedException(KibaException):
def __init__(self, transactionReceipt: TxReceipt) -> None:
super().__init__(message='Transaction failed')
Expand All @@ -53,8 +60,9 @@ def to_dict(self) -> JsonObject:


class EthClientInterface:
def __init__(self, web3Connection: Web3, isTestnet: bool = False) -> None:
def __init__(self, web3Connection: Web3, chainId: int, isTestnet: bool = False) -> None:
self.w3 = web3Connection
self.chainId = chainId
self.isTestnet = isTestnet

async def get_latest_block_number(self) -> int:
Expand All @@ -75,7 +83,7 @@ async def get_transaction_receipt(self, transactionHash: str) -> TxReceipt:
async def get_log_entries(self, topics: list[str] | None = None, startBlockNumber: int | None = None, endBlockNumber: int | None = None, address: str | None = None) -> list[LogReceipt]:
raise NotImplementedError

async def call_function(self, toAddress: str, contractAbi: ABI, functionAbi: ABIFunction, fromAddress: str | None = None, arguments: DictStrAny | None = None, blockNumber: int | None = None) -> ListAny:
async def call(self, contractCall: ContractCall, blockNumber: int | None = None) -> ListAny:
raise NotImplementedError

async def fill_transaction_params(
Expand All @@ -102,13 +110,34 @@ async def call_function_by_name(
arguments: DictStrAny | None = None,
blockNumber: int | None = None,
) -> ListAny:
functionAbi = chain_util.find_abi_by_name_args(contractAbi=contractAbi, functionName=functionName, arguments=arguments)
return await self.call_function(
toAddress=toAddress,
contractAbi=contractAbi,
functionAbi=functionAbi,
fromAddress=fromAddress,
arguments=arguments,
return await self.call(
contractCall=ContractCall(
toAddress=toAddress,
functionName=functionName,
contractAbi=contractAbi,
arguments=arguments,
fromAddress=fromAddress,
),
blockNumber=blockNumber,
)

async def call_function(
self,
toAddress: str,
contractAbi: ABI,
functionAbi: ABIFunction,
fromAddress: str | None = None,
arguments: DictStrAny | None = None,
blockNumber: int | None = None,
) -> ListAny:
return await self.call(
contractCall=ContractCall(
toAddress=toAddress,
functionName=functionAbi['name'],
contractAbi=contractAbi,
arguments=arguments,
fromAddress=fromAddress,
),
blockNumber=blockNumber,
)

Expand Down Expand Up @@ -139,7 +168,7 @@ async def send_transaction(
gas=gas,
maxFeePerGas=maxFeePerGas,
maxPriorityFeePerGas=maxPriorityFeePerGas,
chainId=chainId,
chainId=chainId or self.chainId,
)
signedParams = self.w3.eth.account.sign_transaction(transaction_dict=params, private_key=privateKey)
output = await self.send_raw_transaction(transactionData=signedParams.raw_transaction.hex())
Expand Down Expand Up @@ -171,7 +200,7 @@ async def send_transaction_by_name(
maxFeePerGas=maxFeePerGas,
maxPriorityFeePerGas=maxPriorityFeePerGas,
arguments=arguments,
chainId=chainId,
chainId=chainId or self.chainId,
)

async def wait_for_transaction_receipt(self, transactionHash: str, sleepSeconds: int = 2, maxWaitSeconds: int = 120, raiseOnFailure: bool = True) -> TxReceipt:
Expand All @@ -191,8 +220,38 @@ async def wait_for_transaction_receipt(self, transactionHash: str, sleepSeconds:
raise TransactionFailedException(transactionReceipt=transactionReceipt)
return transactionReceipt

async def multicall(self, contractCalls: list[ContractCall], shouldUseMulticall3: bool = True) -> list[ListAny]:
multicall3Address = CHAIN_ID_MULTICALL3_ADDRESS_MAP.get(self.chainId) if shouldUseMulticall3 else None
if not multicall3Address or not shouldUseMulticall3:
results = await asyncio.gather(*[self.call(contractCall=contractCall) for contractCall in contractCalls])
return results
multicallResponse = await self.call_function_by_name(
toAddress=multicall3Address,
contractAbi=MULTICALL3_ABI,
functionName='aggregate3',
arguments={
'calls': [
{'target': contractCall.toAddress, 'allowFailure': False, 'callData': chain_util.encode_transaction_data_by_name(contractAbi=contractCall.contractAbi, functionName=contractCall.functionName, arguments=contractCall.arguments)}
for contractCall in contractCalls
]
},
)
multicallResults = multicallResponse[0]
decodedResults: list[ListAny] = []
for contractCall, (success, returnData) in zip(contractCalls, multicallResults, strict=False):
if not success:
decodedResults.append([None])
continue
functionAbi = chain_util.find_abi_by_name_args(contractAbi=contractCall.contractAbi, functionName=contractCall.functionName, arguments=contractCall.arguments)
decodedValue = chain_util.decode_function_result(functionAbi=functionAbi, resultData=returnData)
decodedResults.append(decodedValue)
return decodedResults


class Web3EthClient(EthClientInterface):
def __init__(self, web3Connection: Web3, chainId: int, isTestnet: bool = False) -> None:
super().__init__(web3Connection=web3Connection, chainId=chainId, isTestnet=isTestnet)

async def get_latest_block_number(self) -> int:
return self.w3.eth.block_number

Expand Down Expand Up @@ -231,11 +290,10 @@ async def get_log_entries(

class RestEthClient(EthClientInterface):
# NOTE(krishan711): find docs at https://eth.wiki/json-rpc/API
def __init__(self, url: str, requester: Requester, isTestnet: bool = False, shouldBackoffRetryOnRateLimit: bool = True, retryLimit: int = 10) -> None:
super().__init__(web3Connection=Web3(), isTestnet=isTestnet)
def __init__(self, url: str, requester: Requester, chainId: int, isTestnet: bool = False, shouldBackoffRetryOnRateLimit: bool = True, retryLimit: int = 10) -> None:
super().__init__(web3Connection=Web3(), chainId=chainId, isTestnet=isTestnet)
self.url = url
self.requester = requester
self.isTestnet = isTestnet
self.shouldBackoffRetryOnRateLimit = shouldBackoffRetryOnRateLimit
self.retryLimit = retryLimit
self.w3 = Web3()
Expand Down Expand Up @@ -350,33 +408,20 @@ async def get_log_entries(
response = await self._make_request(method='eth_getLogs', params=[params])
return typing.cast(list[LogReceipt], method_formatters.PYTHONIC_RESULT_FORMATTERS[RPC.eth_getLogs](response['result']))

async def call_function(
async def call(
self,
toAddress: str,
# TODO(krishan711): remove on major bump
contractAbi: ABI, # noqa: ARG002
functionAbi: ABIFunction,
fromAddress: str | None = None,
arguments: DictStrAny | None = None,
contractCall: ContractCall,
blockNumber: int | None = None,
) -> ListAny:
data = chain_util.encode_transaction_data(functionAbi=functionAbi, arguments=arguments)
params = {
'from': fromAddress or '0x0000000000000000000000000000000000000000',
'to': toAddress,
'data': data,
'from': contractCall.fromAddress or '0x0000000000000000000000000000000000000000',
'to': contractCall.toAddress,
'data': chain_util.encode_transaction_data_by_name(contractAbi=contractCall.contractAbi, functionName=contractCall.functionName, arguments=contractCall.arguments),
}
functionAbi = chain_util.find_abi_by_name_args(contractAbi=contractCall.contractAbi, functionName=contractCall.functionName, arguments=contractCall.arguments)
response = await self._make_request(method='eth_call', params=[params, hex(blockNumber) if blockNumber is not None else 'latest'])
outputTypes = get_abi_output_types(abi_element=functionAbi)
try:
outputData = self.w3.codec.decode(types=outputTypes, data=HexBytes(response['result']))
except InsufficientDataBytes as exception:
if response['result'] == '0x':
raise BadRequestException(message=f'Empty response: {exception!s}. Maybe the method does not exist on this contract.')
raise
except DecodingError as exception:
raise BadRequestException(message=str(exception))
return list(outputData)
decodedResponse = chain_util.decode_function_result(functionAbi=functionAbi, resultData=HexBytes(response['result']))
return decodedResponse

async def get_max_priority_fee_per_gas(self) -> int:
response = await self._make_request(method='eth_maxPriorityFeePerGas')
Expand All @@ -402,7 +447,7 @@ async def fill_transaction_params(
) -> TxParams:
if 'chainId' not in params:
if chainId is None:
raise BadRequestException(message='chainId is required')
chainId = self.chainId
params['chainId'] = hex(chainId) # type: ignore[typeddict-item]
if 'nonce' not in params:
if nonce is None:
Expand Down
39 changes: 39 additions & 0 deletions core/web3/multicall3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
# mypy: disable-error-code="typeddict-unknown-key, misc, list-item, typeddict-item"

from eth_typing import ABI

# Get from https://www.multicall3.com/deployments
CHAIN_ID_MULTICALL3_ADDRESS_MAP: dict[int, str] = {
1: '0xcA11bde05977b3631167028862bE2a173976CA11',
8453: '0xcA11bde05977b3631167028862bE2a173976CA11',
84532: '0xcA11bde05977b3631167028862bE2a173976CA11',
}

MULTICALL3_ABI: ABI = [
{
'inputs': [
{
'components': [
{'name': 'target', 'type': 'address'},
{'name': 'allowFailure', 'type': 'bool'},
{'name': 'callData', 'type': 'bytes'},
],
'name': 'calls',
'type': 'tuple[]',
}
],
'name': 'aggregate3',
'outputs': [
{
'components': [
{'name': 'success', 'type': 'bool'},
{'name': 'returnData', 'type': 'bytes'},
],
'name': 'returnData',
'type': 'tuple[]',
}
],
'stateMutability': 'payable',
'type': 'function',
},
]
Loading
Loading