Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
136 changes: 59 additions & 77 deletions indico/http/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,7 @@
from pathlib import Path
from typing import TYPE_CHECKING, cast

import aiohttp
import requests
import niquests

from indico.config import IndicoConfig
from indico.errors import (
Expand All @@ -25,6 +24,8 @@
from typing import Any, Dict, Iterator, List, Optional, Union
from urllib.request import Request

from niquests.cookies import RequestsCookieJar

from indico.client.request import HTTPRequest, ResponseType
from indico.typing import AnyDict

Expand All @@ -51,7 +52,7 @@ def __init__(self, config: "Optional[IndicoConfig]" = None):
self.config = config or IndicoConfig()
self.base_url = f"{self.config.protocol}://{self.config.host}"

self.request_session = requests.Session()
self.request_session = niquests.Session()
if isinstance(self.config.requests_params, dict):
for param in self.config.requests_params.keys():
setattr(self.request_session, param, self.config.requests_params[param])
Expand All @@ -73,11 +74,12 @@ def get(
return self._make_request("post", *args, params=params, **kwargs)

def get_short_lived_access_token(self) -> "AnyDict":
cookies = cast("RequestsCookieJar", self.request_session.cookies)
# If the cookie here is already due to _disable_cookie_domain set and we try to
# pop it later it will error out because we have two cookies with the same
# name. We just remove the old one here as we are about to refresh it.
if "auth_token" in self.request_session.cookies:
self.request_session.cookies.pop("auth_token")
if "auth_token" in cookies:
cookies.pop("auth_token")

r = self.post(
"/auth/users/refresh_token",
Expand All @@ -87,13 +89,12 @@ def get_short_lived_access_token(self) -> "AnyDict":

# Disable cookie domain in cases where the domain won't match due to using short name domains
if self.config._disable_cookie_domain:
value = self.request_session.cookies.get("auth_token", None)
value = cookies.get("auth_token", None) # type: ignore[no-untyped-call]
if not value:
raise IndicoAuthenticationFailed()
self.request_session.cookies.pop("auth_token")
cookies.pop("auth_token")
self.request_session.cookies.set_cookie(
# must ignore because untyped in typeshed
requests.cookies.create_cookie(name="auth_token", value=value) # type: ignore
niquests.cookies.create_cookie(name="auth_token", value=value) # type: ignore[no-untyped-call]
)

return cast("AnyDict", r)
Expand Down Expand Up @@ -188,12 +189,12 @@ def _make_request(
raise IndicoAuthenticationFailed()

if response.status_code == 503 and "Retry-After" in response.headers:
raise IndicoHibernationError(after=response.headers.get("Retry-After"))
raise IndicoHibernationError(after=int(response.headers["Retry-After"]))

if response.status_code >= 500:
raise IndicoRequestError(
code=response.status_code,
error=response.reason,
error=response.reason or "",
extras=repr(response.content),
)

Expand Down Expand Up @@ -221,19 +222,15 @@ def _make_request(

class AIOHTTPClient:
"""
Beta client with a minimal set of features that can execute
requests using the aiohttp library
Async client using niquests. Supports HTTP/1.1 and HTTP/2.
"""

def __init__(self, config: "Optional[IndicoConfig]" = None):
"""
Config options specific to aiohttp
unsafe - allows interacting with IP urls
"""
self.config = config or IndicoConfig()
self.base_url = f"{self.config.protocol}://{self.config.host}"

self.request_session = aiohttp.ClientSession()
self.request_session = niquests.AsyncSession()
self.request_session.verify = self.config.verify_ssl
if isinstance(self.config.requests_params, dict):
for param in self.config.requests_params.keys():
setattr(self.request_session, param, self.config.requests_params[param])
Expand Down Expand Up @@ -269,54 +266,42 @@ async def execute_request(
)

@contextmanager
def _handle_files(
self, req_kwargs: "AnyDict"
) -> "Iterator[List[aiohttp.FormData]]":
files = []
file_args = []
def _handle_files(self, req_kwargs: "AnyDict") -> "Iterator[List[Dict[str, Any]]]":
files: "List[Any]" = []
file_args: "List[Dict[str, Any]]" = []
dup_counts: "Dict[str, int]" = {}
for filepath in req_kwargs.pop("files", []) or []:
data = aiohttp.FormData()
path = Path(filepath)
fd = path.open("rb")
files.append(fd)
# follow the convention of adding (n) after a duplicate filename
_add_suffix = f".{path.suffix}" if path.suffix else ""
if path.stem in dup_counts:
data.add_field(
"file",
fd,
filename=path.stem + f"({dup_counts[path.stem]})" + _add_suffix,
)
name = path.stem + f"({dup_counts[path.stem]})" + _add_suffix
dup_counts[path.stem] += 1
else:
data.add_field("file", fd, filename=path.name)
name = path.name
dup_counts[path.stem] = 1
file_args.append(data)
file_args.append({"files": {"file": (name, fd)}})
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

weird we do this with individual nested dicts

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

requests/niquests is just funny like that


for filename, stream in (req_kwargs.pop("streams", {}) or {}).items():
# similar operation as above.
files.append(stream)
data = aiohttp.FormData()
if filename in dup_counts:
data.add_field(
"file",
stream,
filename=filename + f"({dup_counts[filename]})",
)
name = filename + f"({dup_counts[filename]})"
dup_counts[filename] += 1
else:
data.add_field("file", stream, filename=filename)
name = filename
dup_counts[filename] = 1
file_args.append(data)
file_args.append({"files": {"file": (name, stream)}})

try:
yield file_args
finally:
for f in files:
f.close()

@aioretry(aiohttp.ClientConnectionError, aiohttp.ServerDisconnectedError)
@aioretry(niquests.ConnectionError)
async def _make_request(
self,
method: str,
Expand All @@ -336,57 +321,54 @@ async def _make_request(
resps = await asyncio.gather(
*(
self._make_request(
method, path, headers, **request_kwargs, data=data
method, path, headers, **request_kwargs, **data
)
for data in file_args
)
)
return [resp for resp_set in resps for resp in resp_set]

async with getattr(self.request_session, method)(
response: niquests.Response = await getattr(self.request_session, method)(
f"{self.base_url}{path}",
headers=headers,
verify_ssl=self.config.verify_ssl,
**request_kwargs,
) as response:
# If auth expired refresh
if response.status == 401 and not _refreshed:
await self.get_short_lived_access_token()
return await self._make_request(
method, path, headers, _refreshed=True, **request_kwargs
)
elif response.status == 401 and _refreshed:
raise IndicoAuthenticationFailed()
)

if response.status == 503 and "Retry-After" in response.headers:
raise IndicoHibernationError(
after=response.headers.get("Retry-After")
)
status_code = response.status_code or 0

if response.status >= 500:
raise IndicoRequestError(
code=response.status,
error=response.reason,
extras=repr(response.content),
)
if status_code == 401 and not _refreshed:
await self.get_short_lived_access_token()
return await self._make_request(
method, path, headers, _refreshed=True, **request_kwargs
)
if status_code == 401 and _refreshed:
raise IndicoAuthenticationFailed()

content: "Any" = await aio_deserialize(
response, force_json=json, force_decompress=decompress
if status_code == 503 and "Retry-After" in response.headers:
raise IndicoHibernationError(after=int(response.headers["Retry-After"]))

if status_code >= 500:
raise IndicoRequestError(
code=status_code,
error=response.reason or "",
extras=repr(response.content),
)

if response.status >= 400:
if isinstance(content, dict):
error = (
f"{content.pop('error_type', 'Unknown Error')}, "
f"{content.pop('message', '')}"
)
extras = content
else:
error = content
extras = None
content: "Any" = await aio_deserialize(
response, force_json=json, force_decompress=decompress
)

raise IndicoRequestError(
error=error, code=response.status, extras=extras
if status_code >= 400:
if isinstance(content, dict):
error = (
f"{content.pop('error_type', 'Unknown Error')}, "
f"{content.pop('message', '')}"
)
extras = content
else:
error = content
extras = None

raise IndicoRequestError(error=error, code=status_code, extras=extras)

return content
return content
20 changes: 10 additions & 10 deletions indico/http/serialization.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,28 +16,29 @@
if TYPE_CHECKING: # pragma: no cover
from typing import Any, Callable, Dict, Mapping, Optional, Tuple

from aiohttp import ClientResponse
from requests import Response
from niquests import Response

logger = logging.getLogger(__name__)


def decompress(response: "Response") -> bytes:
response.raw.decode_content = True
value: bytes = io.BytesIO(response.raw.data).getvalue()
raw = response.raw
assert raw is not None
raw.decode_content = True
value: bytes = io.BytesIO(raw.data).getvalue()
return gzip.decompress(value)


def deserialize(
response: "Response", force_json: bool = False, force_decompress: bool = False
) -> "Any":
content_type, params = parse_header(response.headers["Content-Type"])
content_type, params = parse_header(str(response.headers["Content-Type"]))
content: bytes

if force_decompress or content_type in ["application/x-gzip", "application/gzip"]:
content = decompress(response)
else:
content = response.content
content = response.content or b""

charset = params.get("charset", "utf-8")

Expand All @@ -55,17 +56,16 @@ def deserialize(


async def aio_deserialize(
response: "ClientResponse", force_json: bool = False, force_decompress: bool = False
response: "Response", force_json: bool = False, force_decompress: bool = False
) -> "Any":
content_type, params = parse_header(response.headers["Content-Type"])
content: bytes = await response.read()
content_type, params = parse_header(str(response.headers["Content-Type"]))
content: bytes = response.content or b""

if force_decompress or content_type in ["application/x-gzip", "application/gzip"]:
content = gzip.decompress(io.BytesIO(content).getvalue())

charset: str = params.get("charset", "utf-8")

# For storage object for example where the content is json based on url ending
if force_json:
content_type = "application/json"

Expand Down
6 changes: 3 additions & 3 deletions indico/queries/model_import.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, cast

import requests
import niquests

from indico.client.request import GraphQLRequest, RequestChain
from indico.errors import IndicoInputError, IndicoRequestError
Expand Down Expand Up @@ -37,12 +37,12 @@ def process_response(self, response: "Payload") -> str:
file_content = file.read()

headers = {"Content-Type": "application/zip"}
export_response = requests.put(signed_url, data=file_content, headers=headers)
export_response = niquests.put(signed_url, data=file_content, headers=headers)

if export_response.status_code != 200:
raise IndicoRequestError(
f"Failed to upload static model export: {export_response.text}",
export_response.status_code,
export_response.status_code or 0,
)
return storage_uri

Expand Down
4 changes: 2 additions & 2 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ license = { text = "MIT License" }
authors = [{ name = "indico", email = "engineering@indico.io" }]
requires-python = ">=3.9"
classifiers = ["License :: OSI Approved :: MIT License", "Programming Language :: Python :: 3", "Programming Language :: Python :: 3.9", "Programming Language :: Python :: 3.10", "Programming Language :: Python :: 3.11", "Programming Language :: Python :: 3.12", "Programming Language :: Python :: 3.13"]
dependencies = ["requests>=2.22.0", "deprecation>=2.1.0", "jsons", "aiohttp[speedups]"]
dependencies = ["niquests>=3.0.0", "deprecation>=2.1.0", "jsons", "urllib3>=2.0.0"]
dynamic = ["version"]

[project.optional-dependencies]
Expand Down Expand Up @@ -56,4 +56,4 @@ module = "importlib_metadata"
ignore_missing_imports = true

[dependency-groups]
dev = ["pytest>=8.4.2", "pytest-mock>=3.15.1"]
dev = ["pytest>=8.4.2", "pytest-asyncio>=0.24.0", "pytest-mock>=3.15.1"]
6 changes: 3 additions & 3 deletions tests/unit/client/test_aioclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@

@pytest.fixture(scope="function")
def indico_test_config():
return IndicoConfig(protocol="mock", host="mock")
return IndicoConfig(protocol="mock", host="mock", api_token="mock")


@pytest.fixture(scope="function")
def indico_request(requests_mock, indico_test_config, monkeypatch):
def indico_request(indico_test_config, monkeypatch):
registered = {}

async def _mock_make_request(self, method, path, *args, **kwargs):
Expand Down Expand Up @@ -58,7 +58,7 @@ async def test_client_basic_http_request(indico_request, auth, indico_test_confi


async def test_client_creation_error_handling(indico_test_config):
client = AsyncIndicoClient()
client = AsyncIndicoClient(config=indico_test_config)
with pytest.raises(IndicoError):
await client.call(HTTPRequest(method=HTTPMethod.GET, path="/users/details"))

Expand Down
9 changes: 4 additions & 5 deletions tests/unit/client/test_aiohttp_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,8 @@ async def test_handle_files_correct_filename():
with client._handle_files(request_kwargs) as file_args:
assert len(file_args) == 1
for arg in file_args:
fields = arg._fields
for field in fields:
field_dict = field[0]
filename = field_dict.get("filename")
assert filename == "testfile.txt"
assert "files" in arg
assert "file" in arg["files"]
filename, _fd = arg["files"]["file"]
assert filename == "testfile.txt"
await client.request_session.close()
Loading