-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathgenerate_drizzle.py
More file actions
218 lines (185 loc) · 9.62 KB
/
generate_drizzle.py
File metadata and controls
218 lines (185 loc) · 9.62 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
import sys
import argparse
import xml.etree.ElementTree as ET
import re
from collections import OrderedDict
def pg_type_to_drizzle_type(pg_type):
"""Convert PostgreSQL types to Drizzle ORM types."""
pg_type = pg_type.upper()
if pg_type == 'UUID':
return 'uuid'
elif pg_type == 'SERIAL':
return 'serial'
elif pg_type.startswith('VARCHAR') or pg_type == 'TEXT':
return 'text'
elif pg_type.startswith('INT') or pg_type == 'INTEGER':
return 'integer'
elif pg_type == 'BOOLEAN':
return 'boolean'
elif pg_type.startswith('TIMESTAMP'):
if 'WITH TIME ZONE' in pg_type:
return 'timestamp', {'withTimezone': True}
return 'timestamp'
elif pg_type.startswith('DATE'):
return 'date'
elif pg_type == 'JSON':
return 'json'
elif pg_type == 'JSONB':
return 'jsonb'
elif pg_type.startswith('DECIMAL') or pg_type.startswith('NUMERIC'):
pattern = r'(\w+)\((\d+),\s*(\d+)\)'
match = re.match(pattern, pg_type)
if match:
precision = match.group(2)
scale = match.group(3)
return 'decimal', {'precision': precision, 'scale': scale}
return 'decimal'
# Add more type mappings as needed
return 'text' # Default fallback
def generate_drizzle_schema(xml_file):
"""Generate Drizzle ORM schema definitions from the XML schema."""
try:
tree = ET.parse(xml_file)
except Exception as e:
sys.exit(f"Error parsing XML file: {e}")
root = tree.getroot()
# Updated imports now include `sql`
drizzle_imports = (
"import { pgTable, serial, uuid, text, integer, boolean, "
"timestamp, date, json, jsonb, decimal, primaryKey, foreignKey } "
"from 'drizzle-orm/pg-core';\n"
"import { sql } from 'drizzle-orm';\n\n"
)
table_definitions = []
exports = []
for command in root:
if command.tag == 'addTable':
table_name = command.attrib.get('name')
if not table_name:
sys.stderr.write("Warning: <addTable> without a name attribute.\n")
continue
# Determine if history logging is enabled for this table.
is_history = command.attrib.get('history', 'false').lower() in ['true', 'yes', '1']
columns = []
foreign_keys = []
primary_key_field = None # (columnName, drizzleType)
for child in command:
if child.tag == 'addColumn':
col_name = child.attrib.get('name')
col_type = child.attrib.get('type')
if not col_name or not col_type:
sys.stderr.write(f"Warning: <addColumn> missing name or type in table {table_name}.\n")
continue
# Convert PostgreSQL type to Drizzle type
drizzle_type_result = pg_type_to_drizzle_type(col_type)
if isinstance(drizzle_type_result, tuple):
drizzle_type, options = drizzle_type_result
options_str = ', '.join([
f"{k}: {str(v).lower() if isinstance(v, bool) else v}"
for k, v in options.items()
])
type_options = f"{{ {options_str} }}"
else:
drizzle_type = drizzle_type_result
type_options = ""
col_def = f" {col_name}: {drizzle_type}('{col_name}'"
if type_options:
col_def += f", {type_options}"
col_def += ")"
# Add nullability: if nullable is "false", add .notNull(), otherwise do nothing.
nullable_attr = child.attrib.get('nullable', 'true').lower()
if nullable_attr == 'false':
col_def += ".notNull()"
# Add default if provided
default_val = child.attrib.get('default')
if default_val is not None:
if default_val.startswith("sql`") and default_val.endswith("`"):
# Use the provided sql`...` syntax directly
col_def += f".default({default_val})"
elif default_val == "now()":
col_def += ".defaultNow()"
elif default_val == "uuid_generate_v4()":
col_def += ".default(sql`uuid_generate_v4()`)"
else:
# For text-like types, wrap the default value in quotes if not already quoted
if drizzle_type in ['text', 'varchar']:
if not (default_val.startswith("'") or default_val.startswith('"')):
default_val = f"'{default_val}'"
col_def += f".default({default_val})"
if child.attrib.get('primaryKey', 'false').lower() == 'true':
col_def += ".primaryKey()"
if primary_key_field is None:
# Store primary key field and its drizzle type.
primary_key_field = (col_name, drizzle_type)
col_def += ","
columns.append(col_def)
""" disable foreign key support for now, could add it
elif child.tag == 'addForeignKey':
fk_col = child.attrib.get('column')
ref_table = child.attrib.get('refTable')
ref_column = child.attrib.get('refColumn')
if not fk_col or not ref_table or not ref_column:
sys.stderr.write(f"Warning: <addForeignKey> missing required attributes in table {table_name}.\n")
continue
# Build foreign key relation
fk_def = f" {fk_col}Relation: foreignKey({{ columns: ['{fk_col}'], foreignColumns: ['{ref_column}'], table: '{ref_table}'"
# Define reference options
if child.attrib.get('onDelete'):
fk_def += f", onDelete: '{child.attrib.get('onDelete')}'"
if child.attrib.get('onUpdate'):
fk_def += f", onUpdate: '{child.attrib.get('onUpdate')}'"
fk_def += " }),"
foreign_keys.append(fk_def)
"""
# Create table definition
table_def = f"export const {table_name} = pgTable('{table_name}', {{\n"
table_def += "\n".join(columns)
if foreign_keys:
table_def += "\n" + "\n".join(foreign_keys)
table_def += "\n});\n"
table_definitions.append(table_def)
exports.append(table_name)
# Generate history table if history is enabled
if is_history:
hist_table_name = f"history_{table_name}"
hist_columns = []
# Add history table specific columns
hist_columns.append(f" historyid: serial('historyid').primaryKey(),")
# Add primarykey field if there is one
if primary_key_field:
pk_name, pk_type = primary_key_field
# Use integer if primary key type is serial
if pk_type == 'serial':
hist_columns.append(f" primarykey: integer('primarykey'),")
else:
hist_columns.append(f" primarykey: {pk_type}('primarykey'),")
# Add other standard history columns
hist_columns.append(f" changed_at: timestamp('changed_at', {{ withTimezone: true }}).defaultNow(),")
hist_columns.append(f" operation: text('operation'),")
hist_columns.append(f" historyjson: jsonb('historyjson'),")
# Create history table definition
hist_table_def = f"export const {hist_table_name} = pgTable('{hist_table_name}', {{\n"
hist_table_def += "\n".join(hist_columns)
hist_table_def += "\n});\n"
table_definitions.append(hist_table_def)
exports.append(hist_table_name)
# Create exports statement
# not needed, redudant: export_statement = f"export {{ {', '.join(exports)} }};\n"
# Combine all code
return drizzle_imports + "\n".join(table_definitions) + "\n"# + export_statement
def save_to_file(content, output_file):
"""Save generated Drizzle schema to a file."""
try:
with open(output_file, 'w') as f:
f.write(content)
print(f"Drizzle schema successfully written to {output_file}")
except Exception as e:
sys.exit(f"Error writing to output file: {e}")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description='Generate Drizzle ORM schema from XML definition.')
parser.add_argument('input_file', help='Path to the input XML schema file')
parser.add_argument('-o', '--output', required=True, help='Path to the output Drizzle schema file')
args = parser.parse_args()
drizzle_schema = generate_drizzle_schema(args.input_file)
# Save to specified output file
save_to_file(drizzle_schema, args.output)