|
1 | 1 | import ast |
2 | 2 | import re |
3 | | -from typing import Optional |
| 3 | +from typing import Optional, Union |
4 | 4 |
|
5 | 5 | from jsonschema_path.paths import SchemaPath |
6 | 6 |
|
@@ -139,14 +139,15 @@ def _build_method_call_args( |
139 | 139 |
|
140 | 140 |
|
141 | 141 | def http_method_to_func_body( |
142 | | - method_spec: SchemaPath, api_path: Optional[str] = None |
| 142 | + method_spec: SchemaPath, api_path: Optional[str] = None, use_async: bool = False |
143 | 143 | ) -> list[ast.stmt]: |
144 | 144 | """ |
145 | 145 | Generates the body of the Python function, including a docstring and request call. |
146 | 146 |
|
147 | 147 | Args: |
148 | 148 | method_spec: The SchemaPath for the operation |
149 | 149 | api_path: The original API path string (e.g., '/users/{user_id}') |
| 150 | + use_async: Whether to generate an async function |
150 | 151 |
|
151 | 152 | Returns: |
152 | 153 | A list of ast.stmt nodes representing the function body |
@@ -241,16 +242,32 @@ def http_method_to_func_body( |
241 | 242 | # Create the appropriate request method call |
242 | 243 | method_name = '_iter_request' if is_collection else '_request' |
243 | 244 |
|
244 | | - request_call = ast.Return( |
245 | | - value=ast.Call( |
246 | | - func=ast.Attribute( |
247 | | - value=ast.Name(id='self', ctx=ast.Load()), attr=method_name, ctx=ast.Load() |
248 | | - ), |
249 | | - args=[], |
250 | | - keywords=call_args, |
251 | | - ) |
| 245 | + # Create the request call |
| 246 | + request_call_expr = ast.Call( |
| 247 | + func=ast.Attribute( |
| 248 | + value=ast.Name(id='self', ctx=ast.Load()), attr=method_name, ctx=ast.Load() |
| 249 | + ), |
| 250 | + args=[], |
| 251 | + keywords=call_args, |
252 | 252 | ) |
253 | 253 |
|
| 254 | + if use_async and is_collection: |
| 255 | + # For async collection responses, use async for loop with yield |
| 256 | + if_async_for = ast.AsyncFor( |
| 257 | + target=ast.Name(id='item', ctx=ast.Store()), |
| 258 | + iter=request_call_expr, |
| 259 | + body=[ast.Expr(value=ast.Yield(value=ast.Name(id='item', ctx=ast.Load())))], |
| 260 | + orelse=[], |
| 261 | + ) |
| 262 | + return [docstring_node, if_async_for] |
| 263 | + elif use_async: |
| 264 | + # For async non-collection responses, use await |
| 265 | + request_call_expr = ast.Await(value=request_call_expr) |
| 266 | + request_call = ast.Return(value=request_call_expr) |
| 267 | + else: |
| 268 | + # For sync responses, return directly |
| 269 | + request_call = ast.Return(value=request_call_expr) |
| 270 | + |
254 | 271 | # Put it all together |
255 | 272 | return [docstring_node, request_call] |
256 | 273 |
|
@@ -345,28 +362,41 @@ def http_method_to_func_args(method_spec: SchemaPath) -> ast.arguments: |
345 | 362 |
|
346 | 363 |
|
347 | 364 | def http_method_to_func_def( |
348 | | - method_spec: SchemaPath, override_func_name: Optional[str] = None, api_path: Optional[str] = None |
349 | | -) -> ast.FunctionDef: |
| 365 | + method_spec: SchemaPath, override_func_name: Optional[str] = None, api_path: Optional[str] = None, use_async: bool = False |
| 366 | +) -> Union[ast.FunctionDef, ast.AsyncFunctionDef]: |
350 | 367 | """ |
351 | 368 | Converts an OpenAPI method spec to a Python function definition. |
352 | 369 |
|
353 | 370 | Args: |
354 | 371 | method_spec: The SchemaPath for the operation |
355 | 372 | override_func_name: An optional name to use for the function instead of the default |
356 | 373 | api_path: The original API path string (e.g., '/users/{user_id}') |
| 374 | + use_async: Whether to generate an async function |
357 | 375 |
|
358 | 376 | Returns: |
359 | | - An ast.FunctionDef node representing the Python method |
| 377 | + An ast.FunctionDef or ast.AsyncFunctionDef node representing the Python method |
360 | 378 | """ |
361 | 379 | func_name = override_func_name if override_func_name else http_method_to_func_name(method_spec) |
362 | 380 |
|
363 | 381 | # Generate function body with potentially modified path |
364 | | - func_body = http_method_to_func_body(method_spec, api_path=api_path) |
365 | | - |
366 | | - return ast.FunctionDef( |
367 | | - name=func_name, |
368 | | - args=http_method_to_func_args(method_spec), |
369 | | - body=func_body, |
370 | | - decorator_list=[], |
371 | | - returns=spec_piece_to_annotation(method_spec / 'responses'), |
372 | | - ) |
| 382 | + func_body = http_method_to_func_body(method_spec, api_path=api_path, use_async=use_async) |
| 383 | + |
| 384 | + func_args = http_method_to_func_args(method_spec) |
| 385 | + returns_annotation = spec_piece_to_annotation(method_spec / 'responses', use_async=use_async) |
| 386 | + |
| 387 | + if use_async: |
| 388 | + return ast.AsyncFunctionDef( |
| 389 | + name=func_name, |
| 390 | + args=func_args, |
| 391 | + body=func_body, |
| 392 | + decorator_list=[], |
| 393 | + returns=returns_annotation, |
| 394 | + ) |
| 395 | + else: |
| 396 | + return ast.FunctionDef( |
| 397 | + name=func_name, |
| 398 | + args=func_args, |
| 399 | + body=func_body, |
| 400 | + decorator_list=[], |
| 401 | + returns=returns_annotation, |
| 402 | + ) |
0 commit comments