-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathwrapper.py
More file actions
292 lines (253 loc) · 10.1 KB
/
wrapper.py
File metadata and controls
292 lines (253 loc) · 10.1 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
import ast
import inspect
from pathlib import Path
from typing import Any, Optional, List, Dict, Union
import textwrap
TAG = "expose_as"
class InterfaceFunction:
"""
This class encapsulate a ast.FunctionDef providing also
the necessary parameters to generate the wrapped API call
"""
generated_error_model = False
preamble_generated = False
error_model_name = 'ErrorOutput'
mod_parameter = 'use_filesystem'
mod_msg = """Please not that the server filesystem is never used when calling these APIs
(parameter use_filesystem is always set to False).
"""
mod_code = f"{mod_parameter} = False # No server filesystem access via API"
def __init__(self, node: ast.FunctionDef, method: str, file_to_read:str):
self.node = node
self.name = node.name
self.args = node.args.args
self.method = method
self.file_to_read = file_to_read
def get_docstring(self, arg_names):
docstr = ast.get_docstring(self.node) or ""
if InterfaceFunction.mod_parameter in arg_names:
# Add warning that this parameter is always set to False
ret = f"""{docstr}
{InterfaceFunction.mod_msg}
"""
else:
ret = docstr
return repr(ret)
@staticmethod
def resolve_type(t):
if t is None:
return "Any"
t_str = ast.unparse(t)
# Check whether the type needs to be further specified
if t_str == "dict":
return "Dict[str, Any]"
if t_str == "list":
return "List[Any]"
# Redundant code but explicit
if t_str.startswith("Dict["):
return t_str # e.g., Dict[str, Dict[str, Any]]
if t_str.startswith("List["):
return t_str
return t_str
def _get_args(self):
"""
Helper function to get args, their types and their defaults.
"""
nr_arg = len(self.args)
nr_defaults = len(self.node.args.defaults)
# if nr_defaults > 1:
# breakpoint()
args = []
for i in range(nr_arg):
a = self.args[i]
idx = i - (nr_arg - nr_defaults)
# defaults are stored in an array starting with the first argument that has a default
default = ast.unparse(self.node.args.defaults[idx]) if idx >= 0 else None
args.append((a.arg, self.resolve_type(a.annotation), default))
# if nr_defaults > 1:
# breakpoint()
return args
def generate_IO_models(self):
"""
Create input model for POST and output models for GET and POST
Also create a generic output model for errors
"""
code = []
if not InterfaceFunction.generated_error_model:
code.append(f"class {InterfaceFunction.error_model_name}(BaseModel):")
code.append(f" error: str")
InterfaceFunction.generated_error_model = True
if self.method == "post":
args = self._get_args()
if len(args) > 0:
code.append(f"class {self.name.title()}Input(BaseModel):")
# breakpoint()
for arg_name, arg_type, arg_default in args:
# breakpoint()
default = arg_default if arg_default else ""
code.append(f" {arg_name}: {arg_type}{' = ' if default else ''}{default}")
code.append("\n")
code.append(f"class {self.name.title()}Output(BaseModel):")
# breakpoint()
if self.method == "post":
if self.resolve_type(self.node.returns) != 'None':
# breakpoint()
raise Exception(f'{self.name}: only no return allowed for original functions exposed as post')
code.append(f" result: {self.name.title()}Input\n")
else:
if self.resolve_type(self.node.returns) == 'None':
# breakpoint()
raise Exception(f'{self.name}: must return something if exposed as get')
returns = self.resolve_type(self.node.returns)
code.append(f" result: {returns}\n")
return "\n".join(code)
def generate_API(self) -> str:
"""
This function generates the code for a single API call
"""
# These variables are common to get and post methods
call = (f"@app.{self.method.lower()}('/{self.name}', response_model={self.name.title()}Output|{InterfaceFunction.error_model_name})")
arg_names = [arg.arg for arg in self.args if arg.arg != 'self']
if self.method == 'post':
# args_unpack = ", ".join(f"{arg}=payload.{arg}" for arg in arg_names)
# Deep copy since we need to pass data that is going to be modified in place and returned
args_copy = "\n ".join(f"{arg} = copy.deepcopy(payload.{arg})" for arg in arg_names if arg != InterfaceFunction.mod_parameter)
# args_copy = "\n ".join(f"{arg} = payload.{arg}" for arg in arg_names if arg != InterfaceFunction.mod_parameter)
return_keys = ", ".join(f'\"{arg}\": {arg} if type({arg}) != dict or is_diff({arg},payload.{arg}) else {{}}' for arg in arg_names)
fn_args = f'payload: {self.name.title()}Input' if len(arg_names)>0 else ''
# Code to be generated
code = f"""
{call}
def {self.name}_endpoint({fn_args}):
''{self.get_docstring(arg_names)}''
{args_copy}
{InterfaceFunction.mod_code if InterfaceFunction.mod_parameter in arg_names else ''}
try:
{self.file_to_read}.{self.name}({', '.join(arg_names)})
except Exception as e:
return {{"error": str(e)}}
return {{"result": {{ {return_keys} }} }}
"""
# For get methods
else:
params = ", ".join([
f"{arg}: {typ}{' = ' if default else ''}{default if default else ''}"
for arg, typ, default in self._get_args()
if arg != "self"
])
code = f"""
{call}
def {self.name}_endpoint({params}):
''{self.get_docstring(arg_names)}''
result = None
try:
result = {self.file_to_read}.{self.name}({', '.join(arg_names)})
except Exception as e:
return {{"error": str(e)}}
return {{"result": result}}
"""
return code
def extract_interface_functions(file_to_read):
"""
Parses the file and extracts function tagged for the interface
"""
source_code = file_to_read.read_text()
tree = ast.parse(source_code)
lines = source_code.splitlines()
functions = []
for node in ast.walk(tree):
if isinstance(node, ast.FunctionDef):
comment_line = lines[node.lineno - 2].strip()
if comment_line.startswith(f"# {TAG}"):
parts = comment_line.split()
method = parts[2] if len(parts) > 2 else "post"
functions.append(InterfaceFunction(node, method, file_to_read.stem))
return functions
def build_fastapi_code(functions_in_file):
if len(functions_in_file) == 0:
return ""
if not InterfaceFunction.preamble_generated:
code = [
"from fastapi import FastAPI",
"from pydantic import BaseModel, Field",
"from typing import Any, Optional, List, Dict, Union, Tuple",
"import copy",
]
else:
code = []
code.append(f"import {functions_in_file[0].file_to_read}")
if not InterfaceFunction.preamble_generated:
diff_function = f"""
def is_diff(dict1:dict, dict2:dict)-> bool:
for key in dict1:
if key not in dict2:
return True
if type(dict1[key]) == dict and type(dict2[key]) == dict:
if is_diff(dict1[key],dict2[key]):
return True
elif dict1[key] != dict2[key]:
return True
return False
"""
code.append(diff_function)
code.append("app = FastAPI()")
code.append("")
InterfaceFunction.preamble_generated = True
for func in functions_in_file:
code.append(func.generate_IO_models())
wrapper = func.generate_API()
code.append(textwrap.dedent(wrapper))
# breakpoint()
return "\n".join(code)
def main(inputs:list[str], output:str):
# breakpoint()
if 'wrapped' not in output:
# require wrapped in the file name to write to avoid accidentally
# overwriting other files
raise Exception(f"Refuse to possibly overwrite file {output}")
wrapped_file = Path(output)
fastapi_code = []
for input in inputs:
file_to_read = Path(input)
if not file_to_read.exists():
raise Exception(f"Cannot read file {input}")
functions = extract_interface_functions(file_to_read)
if len(functions) == 0:
print(f"Warning: file {file_to_read} does not contain any function to extract")
else:
fastapi_code.append(build_fastapi_code(functions))
if len(fastapi_code) > 0:
wrapped_file.write_text("\n".join(fastapi_code))
print(f"✅ FastAPI wrapper generated in {output}")
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(
description=__doc__,
formatter_class=argparse.RawDescriptionHelpFormatter,
)
parser.add_argument(
'-i', '--input',
dest='inputs',
action='store',
required=True,
nargs="+",
help='specifies the name of the files containing the functions',
)
parser.add_argument(
'-o', '--output',
dest='output',
action='store',
required=True,
help='specifies the name of the file to write',
)
args, unknown = parser.parse_known_args()
if len(unknown) > 0:
print(f'Unknown options {unknown}')
parser.print_help()
exit(-1)
main(args.inputs, args.output)
# if_lib
# generate_random_challenge, read_HMAC, read_keypair, get_id_person, get_location_id, \
# get_unit_id, get_resource_spec_id, get_resource, get_process, create_event, make_transfer, reduce_resource, set_user_location
# if_dpp
# trace_query, check_traces, er_before, get_dpp