一个能提供更友好的外健名称的sqlacodegen自定义generator

效果

一个自定义generator,来生成类似如下效果,将外健的名称改为目标表名_for_目标字段名,back_populates改为当前表明_for_目标字段名

例子1

将code_mstr表的类似这部分内容

1
2
3
4
5
6
appd_det: Mapped[List['AppdDet']] = relationship('AppdDet', foreign_keys='[AppdDet.appd_build_source]', back_populates='code_mstr')
appd_det_: Mapped[List['AppdDet']] = relationship('AppdDet', foreign_keys='[AppdDet.appd_conf_repo_type]', back_populates='code_mstr_')
appd_det1: Mapped[List['AppdDet']] = relationship('AppdDet', foreign_keys='[AppdDet.appd_deploy_target]', back_populates='code_mstr1')
appd_det2: Mapped[List['AppdDet']] = relationship('AppdDet', foreign_keys='[AppdDet.appd_env]', back_populates='code_mstr2')
appd_det3: Mapped[List['AppdDet']] = relationship('AppdDet', foreign_keys='[AppdDet.appd_image_repo_type]', back_populates='code_mstr3')
appd_det4: Mapped[List['AppdDet']] = relationship('AppdDet', foreign_keys='[AppdDet.appd_kind]', back_populates='code_mstr4')

替换为

1
2
3
4
5
6
appd_det_for_appd_build_source: Mapped[List['AppdDet']] = relationship('AppdDet', foreign_keys='[AppdDet.appd_build_source]', back_populates='code_mstr_for_appd_build_source')
appd_det_for_appd_conf_repo_type: Mapped[List['AppdDet']] = relationship('AppdDet', foreign_keys='[AppdDet.appd_conf_repo_type]', back_populates='code_mstr_for_appd_conf_repo_type')
appd_det_for_appd_deploy_target: Mapped[List['AppdDet']] = relationship('AppdDet', foreign_keys='[AppdDet.appd_deploy_target]', back_populates='code_mstr_for_appd_deploy_target')
appd_det_for_appd_env: Mapped[List['AppdDet']] = relationship('AppdDet', foreign_keys='[AppdDet.appd_env]', back_populates='code_mstr_for_appd_env')
appd_det_for_appd_image_repo_type: Mapped[List['AppdDet']] = relationship('AppdDet', foreign_keys='[AppdDet.appd_image_repo_type]', back_populates='code_mstr_for_appd_image_repo_type')
appd_det_for_appd_kind: Mapped[List['AppdDet']] = relationship('AppdDet', foreign_keys='[AppdDet.appd_kind]', back_populates='code_mstr_for_appd_kind')

例子2

将appd_det表类似这部分内容

1
2
3
4
5
6
code_mstr: Mapped[Optional['CodeMstr']] = relationship('CodeMstr', foreign_keys=[appd_build_source], back_populates='appd_det')
code_mstr_: Mapped[Optional['CodeMstr']] = relationship('CodeMstr', foreign_keys=[appd_conf_repo_type], back_populates='appd_det_')
code_mstr1: Mapped[Optional['CodeMstr']] = relationship('CodeMstr', foreign_keys=[appd_deploy_target], back_populates='appd_det1')
code_mstr2: Mapped['CodeMstr'] = relationship('CodeMstr', foreign_keys=[appd_env], back_populates='appd_det2')
code_mstr3: Mapped[Optional['CodeMstr']] = relationship('CodeMstr', foreign_keys=[appd_image_repo_type], back_populates='appd_det3')
code_mstr4: Mapped[Optional['CodeMstr']] = relationship('CodeMstr', foreign_keys=[appd_kind], back_populates='appd_det4')

替换为

1
2
3
4
5
6
code_mstr_for_appd_build_source: Mapped[Optional['CodeMstr']] = relationship('CodeMstr', foreign_keys=[appd_build_source], back_populates='appd_det_for_appd_build_source')
code_mstr_for_appd_conf_repo_type: Mapped[Optional['CodeMstr']] = relationship('CodeMstr', foreign_keys=[appd_conf_repo_type], back_populates='appd_det_for_appd_conf_repo_type')
code_mstr_for_appd_deploy_target: Mapped[Optional['CodeMstr']] = relationship('CodeMstr', foreign_keys=[appd_deploy_target], back_populates='appd_det_for_appd_deploy_target')
code_mstr_for_appd_env: Mapped['CodeMstr'] = relationship('CodeMstr', foreign_keys=[appd_env], back_populates='appd_det_for_appd_env')
code_mstr_for_appd_image_repo_type: Mapped[Optional['CodeMstr']] = relationship('CodeMstr', foreign_keys=[appd_image_repo_type], back_populates='appd_det_for_appd_image_repo_type')
code_mstr_for_appd_kind: Mapped[Optional['CodeMstr']] = relationship('CodeMstr', foreign_keys=[appd_kind], back_populates='appd_det_for_appd_kind')

实现

pyproject.toml

1
2
[project.entry-points."sqlacodegen.generators"]
custom_relationship = "sqlacodegen.custom_generator:CustomRelationshipGenerator"

src/sqlacodegen/custom_generator.py

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
from __future__ import annotations

from typing import ClassVar, Sequence, Any

from sqlacodegen.generators import DeclarativeGenerator
from sqlacodegen.models import ModelClass, Model, RelationshipAttribute, RelationshipType
from sqlacodegen.utils import get_column_names, get_common_fk_constraints, get_constraint_sort_key, \
qualified_table_name, render_callable, uses_default_name
from sqlalchemy import Connection, Engine, Index, MetaData
from sqlalchemy import SmallInteger, Integer, BigInteger
from sqlalchemy.dialects.mysql import TINYINT
from sqlalchemy.schema import PrimaryKeyConstraint, UniqueConstraint, DefaultClause, Column, Computed, Identity


class CustomRelationshipGenerator(DeclarativeGenerator):
"""自定义生成器,实现关系名称为 '目标表名_for_目标字段名' 的格式"""

valid_options: ClassVar[set[str]] = DeclarativeGenerator.valid_options

def __init__(
self,
metadata: MetaData,
bind: Connection | Engine,
options: Sequence[str],
*,
indentation: str = " ",
base_class_name: str = "Base",
):
super().__init__(
metadata, bind, options, indentation=indentation, base_class_name=base_class_name
)

def render_column(
self, column: Column[Any], show_name: bool, is_table: bool = False
) -> str:
"""重写列渲染方法,对于 数字类型且有默认值的字段,使用 server_default=text("1") 而不是 server_default=text("'1'")"""
args = []
kwargs: dict[str, Any] = {}
kwarg = []
is_sole_pk = column.primary_key and len(column.table.primary_key) == 1
dedicated_fks = [
c
for c in column.foreign_keys
if c.constraint
and len(c.constraint.columns) == 1
and uses_default_name(c.constraint)
]
is_unique = any(
isinstance(c, UniqueConstraint)
and set(c.columns) == {column}
and uses_default_name(c)
for c in column.table.constraints
)
is_unique = is_unique or any(
i.unique and set(i.columns) == {column} and uses_default_name(i)
for i in column.table.indexes
)
is_primary = (
any(
isinstance(c, PrimaryKeyConstraint)
and column.name in c.columns
and uses_default_name(c)
for c in column.table.constraints
)
or column.primary_key
)
has_index = any(
set(i.columns) == {column} and uses_default_name(i)
for i in column.table.indexes
)

if show_name:
args.append(repr(column.name))

# Render the column type if there are no foreign keys on it or any of them
# points back to itself
if not dedicated_fks or any(fk.column is column for fk in dedicated_fks):
args.append(self.render_column_type(column.type))

for fk in dedicated_fks:
args.append(self.render_constraint(fk))

if column.default:
args.append(repr(column.default))

if column.key != column.name:
kwargs["key"] = column.key
if is_primary:
kwargs["primary_key"] = True
if not column.nullable and not is_sole_pk and is_table:
kwargs["nullable"] = False

if is_unique:
column.unique = True
kwargs["unique"] = True
if has_index:
column.index = True
kwarg.append("index")
kwargs["index"] = True

# 处理 server_default
if isinstance(column.server_default, DefaultClause):
# 检查是否为整数类型(包括TINYINT、INT、BIGINT等)
is_integer_type = isinstance(column.type, (TINYINT, SmallInteger, Integer, BigInteger))
default_text = column.server_default.arg.text

# 如果是整数类型且默认值为带引号的数字,则去掉引号
if is_integer_type and default_text.startswith("'") and default_text.endswith("'"):
# 检查去掉引号后是否为数字字符(包括单个数字字符)
stripped_value = default_text.strip("'")
if stripped_value.isdigit() or (stripped_value.startswith('-') and stripped_value[1:].isdigit()):
# 去掉外部引号,保留内部数字
kwargs["server_default"] = render_callable("text", f'"{stripped_value}"')
else:
# 如果不是数字,按原样处理
kwargs["server_default"] = render_callable(
"text", repr(default_text)
)
else:
kwargs["server_default"] = render_callable(
"text", repr(default_text)
)
elif isinstance(column.server_default, Computed):
expression = str(column.server_default.sqltext)

computed_kwargs = {}
if column.server_default.persisted is not None:
computed_kwargs["persisted"] = column.server_default.persisted

args.append(
render_callable("Computed", repr(expression), kwargs=computed_kwargs)
)
elif isinstance(column.server_default, Identity):
args.append(repr(column.server_default))
elif column.server_default:
kwargs["server_default"] = repr(column.server_default)

comment = getattr(column, "comment", None)
if comment:
kwargs["comment"] = repr(comment)

return self.render_column_callable(is_table, *args, **kwargs)

def generate_relationships(
self,
source: ModelClass,
models_by_table_name: dict[str, Model],
association_tables: list[Model],
) -> list[RelationshipAttribute]:
"""重写关系生成方法,确保正确设置双向关系的名称"""
relationships: list[RelationshipAttribute] = []
reverse_relationship: RelationshipAttribute | None

# 添加多对一(和一对多)关系
pk_column_names = {col.name for col in source.table.primary_key.columns}
for constraint in sorted(
source.table.foreign_key_constraints, key=get_constraint_sort_key
):
target = models_by_table_name[
qualified_table_name(constraint.elements[0].column.table)
]
if isinstance(target, ModelClass):
if "nojoined" not in self.options:
if set(get_column_names(constraint)) == pk_column_names:
parent = models_by_table_name[
qualified_table_name(constraint.elements[0].column.table)
]
if isinstance(parent, ModelClass):
source.parent_class = parent
parent.children.append(source)
continue

# 添加uselist=False到一对一关系
column_names = get_column_names(constraint)
if any(
isinstance(c, (PrimaryKeyConstraint, UniqueConstraint))
and {col.name for col in c.columns} == set(column_names)
for c in constraint.table.constraints
):
r_type = RelationshipType.ONE_TO_ONE
else:
r_type = RelationshipType.MANY_TO_ONE

# 创建关系
relationship = RelationshipAttribute(r_type, source, target, constraint)

# 为关系创建自定义名称
foreign_key_column = None
if column_names: # 只要有外键列就处理
# 将所有列名用下划线连接作为外键标识
foreign_key_column = "_".join(column_names)

if foreign_key_column:
relationship.name = f"{target.table.name}_for_{foreign_key_column}"

source.relationships.append(relationship)

# 对于自引用关系,需要设置remote_side
if source is target:
relationship.remote_side = [
source.get_column_attribute(col.name)
for col in constraint.referred_table.primary_key
]

# 如果两个表共享多个外键约束,SQLAlchemy需要显式的primaryjoin
common_fk_constraints = get_common_fk_constraints(
source.table, target.table
)
if len(common_fk_constraints) > 1:
relationship.foreign_keys = [
source.get_column_attribute(key)
for key in constraint.column_keys
]

# 在目标类中生成关系的反向
if "nobidi" not in self.options:
if r_type is RelationshipType.MANY_TO_ONE:
r_type = RelationshipType.ONE_TO_MANY

# 创建反向关系
reverse_relationship = RelationshipAttribute(
r_type,
target,
source,
constraint,
foreign_keys=relationship.foreign_keys,
backref=relationship,
)

# 为反向关系设置自定义名称
if foreign_key_column:
reverse_relationship.name = f"{source.table.name}_for_{foreign_key_column}"

relationship.backref = reverse_relationship
target.relationships.append(reverse_relationship)

# 对于自引用关系,需要设置remote_side
if source is target:
reverse_relationship.remote_side = [
source.get_column_attribute(colname)
for colname in constraint.column_keys
]

# 添加多对多关系
for association_table in association_tables:
fk_constraints = sorted(
association_table.table.foreign_key_constraints,
key=get_constraint_sort_key,
)
target = models_by_table_name[
qualified_table_name(fk_constraints[1].elements[0].column.table)
]
if isinstance(target, ModelClass):
# 获取中间表的外键列名
source_fk_columns = get_column_names(fk_constraints[0])
target_fk_columns = get_column_names(fk_constraints[1])

# 创建多对多关系
relationship = RelationshipAttribute(
RelationshipType.MANY_TO_MANY,
source,
target,
fk_constraints[1],
association_table,
)

# 为多对多关系设置自定义名称
# 将所有列名用下划线连接
source_fk_column_str = "_".join(source_fk_columns)
relationship.name = f"{target.table.name}_for_{source_fk_column_str}"

source.relationships.append(relationship)

# 在目标类中生成关系的反向
reverse_relationship = None
if "nobidi" not in self.options:
reverse_relationship = RelationshipAttribute(
RelationshipType.MANY_TO_MANY,
target,
source,
fk_constraints[0],
association_table,
relationship,
)

# 为反向多对多关系设置自定义名称
# 将所有列名用下划线连接
target_fk_column_str = "_".join(target_fk_columns)
reverse_relationship.name = f"{source.table.name}_for_{target_fk_column_str}"

relationship.backref = reverse_relationship
target.relationships.append(reverse_relationship)

# 为自引用多对多关系添加primary/secondary join
if source is target:
both_relationships = [relationship]
reverse_flags = [False, True]
if reverse_relationship:
both_relationships.append(reverse_relationship)

for rel, reverse in zip(both_relationships, reverse_flags):
if not rel.association_table or not rel.constraint:
continue

constraints = sorted(
rel.constraint.table.foreign_key_constraints,
key=get_constraint_sort_key,
reverse=reverse,
)
pri_pairs = zip(
get_column_names(constraints[0]), constraints[0].elements
)
sec_pairs = zip(
get_column_names(constraints[1]), constraints[1].elements
)
rel.primaryjoin = [
(
rel.source,
elem.column.name,
rel.association_table,
col,
)
for col, elem in pri_pairs
]
rel.secondaryjoin = [
(
rel.target,
elem.column.name,
rel.association_table,
col,
)
for col, elem in sec_pairs
]

return relationships

def generate_relationship_name(
self,
relationship: RelationshipAttribute,
global_names: set[str],
local_names: set[str],
) -> None:
"""重写关系名称生成方法,使用 '目标表名_for_目标字段名' 的格式"""

# 注意:我们在generate_relationships中已经设置了名称
# 该方法仍然保留以处理可能未被定制命名的关系

# 如果关系名称已经设置且格式正确,则不进行更改
if relationship.name and "_for_" in relationship.name:
# 确保名称不与已存在的名称冲突
relationship.name = self.find_free_name(
relationship.name, global_names, local_names
)
return

# 自引用关系的处理
if relationship.source is relationship.target and relationship.backref and relationship.backref.name:
# 对于自引用关系,我们保持原有的命名方式
preferred_name = relationship.backref.name + "_reverse"
else:
# 普通关系的处理
# 基本名称为目标表名
target_table_name = relationship.target.table.name

# 获取外键列名
foreign_key_column = None
if relationship.constraint:
# 当前关系的源表是否与约束表相同
is_source = relationship.source.table is relationship.constraint.table

if is_source or relationship.type not in (
RelationshipType.ONE_TO_ONE,
RelationshipType.ONE_TO_MANY,
):
# 获取外键列名
column_names = [c.name for c in relationship.constraint.columns]
if column_names: # 只要有外键列就处理
# 将所有列名用下划线连接作为外键标识
foreign_key_column = "_".join(column_names)

if foreign_key_column:
# 使用 "目标表名_for_外键列名" 的格式
preferred_name = f"{target_table_name}_for_{foreign_key_column}"
else:
# 如果无法确定外键列,则使用默认行为
preferred_name = target_table_name

# 进行数量适配(单复数形式)
if "use_inflect" in self.options:
if relationship.type in (
RelationshipType.ONE_TO_MANY,
RelationshipType.MANY_TO_MANY,
):
inflected_name = self.inflect_engine.plural_noun(preferred_name)
if inflected_name:
preferred_name = inflected_name
else:
inflected_name = self.inflect_engine.singular_noun(preferred_name)
if inflected_name:
preferred_name = inflected_name

# 保存生成的名称(确保名称不重复)
relationship.name = self.find_free_name(
preferred_name, global_names, local_names
)

def render_index(self, index: Index) -> str:
"""重写索引渲染方法,将索引名称设置为None"""
extra_args = [repr(col.name) for col in index.columns]
kwargs = {}
if index.unique:
kwargs["unique"] = True

# 使用None作为索引名称,而不是原始名称
return render_callable("Index", "None", *extra_args, kwargs=kwargs)

执行命令

1
uv run sqlacodegen --generator=custom_relationship mysql+pymysql://root:123654qwe@127.0.0.1:3306/cmdb_design

一个能提供更友好的外健名称的sqlacodegen自定义generator
https://itxiaopang.github.io/p/9e351bb3a0e342d6b8c90991875e5034/
作者
挨踢小胖
发布于
2025年3月7日
许可协议