Skip to content

Commit 90f832d

Browse files
committed
Add foreign keys to create table API
- Add fk_table and optional fk_column support to create-table columns. - Validate create-table requests with Pydantic while preserving existing errors. - Document the API and cover inferred primary-key and validation cases. Refs #2789 (comment)
1 parent 7822dc3 commit 90f832d

3 files changed

Lines changed: 321 additions & 106 deletions

File tree

datasette/views/table_create_alter.py

Lines changed: 181 additions & 106 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,15 @@
22
import re
33
from typing import Annotated, Any, Literal, Union
44

5-
from pydantic import BaseModel, ConfigDict, Field, ValidationError, model_validator
5+
from pydantic import (
6+
BaseModel,
7+
ConfigDict,
8+
Field,
9+
ValidationError,
10+
field_validator,
11+
model_validator,
12+
)
13+
from pydantic_core import PydanticCustomError
614
import sqlite_utils
715
from sqlite_utils.db import DEFAULT as SQLITE_UTILS_DEFAULT
816

@@ -25,6 +33,7 @@
2533
sqlite_type: column_type
2634
for column_type, sqlite_type in CREATE_TABLE_SQLITE_TYPES.items()
2735
}
36+
TABLE_NAME_RE = re.compile(r"^(?!sqlite_)[^\n]+$")
2837
ALTER_TABLE_COLUMN_TYPES = CREATE_TABLE_COLUMN_TYPES
2938
ALTER_TABLE_TYPE_FOR_SQLITE_TYPE = {
3039
SQLiteType.TEXT: "text",
@@ -98,6 +107,137 @@ class _StrictPydanticModel(BaseModel):
98107
model_config = ConfigDict(extra="forbid")
99108

100109

110+
class CreateTableColumn(BaseModel):
111+
model_config = ConfigDict(extra="forbid")
112+
113+
name: Any = None
114+
type: Any = "text"
115+
fk_table: str | None = None
116+
fk_column: str | None = None
117+
118+
@model_validator(mode="after")
119+
def validate_column(self):
120+
if not self.name or not isinstance(self.name, str):
121+
raise PydanticCustomError("create_table", "Column name is required")
122+
if not self.type:
123+
self.type = "text"
124+
elif self.type not in CREATE_TABLE_COLUMN_TYPES:
125+
raise PydanticCustomError(
126+
"create_table", "Unsupported column type: {type}", {"type": self.type}
127+
)
128+
if self.fk_column and not self.fk_table:
129+
raise PydanticCustomError(
130+
"create_table_with_location",
131+
"fk_column requires fk_table",
132+
)
133+
return self
134+
135+
136+
class CreateTableRequest(_StrictPydanticModel):
137+
table: Any = None
138+
rows: Any = None
139+
row: Any = None
140+
columns: list[CreateTableColumn] | None = None
141+
pk: Any = None
142+
pks: Any = None
143+
ignore: bool | None = None
144+
replace: bool | None = None
145+
alter: bool | None = None
146+
147+
@field_validator("columns", mode="before")
148+
@classmethod
149+
def validate_columns_list(cls, value):
150+
if value is None:
151+
return value
152+
if not isinstance(value, list):
153+
raise PydanticCustomError("create_table", "columns must be a list")
154+
if not all(isinstance(column, dict) for column in value):
155+
raise PydanticCustomError(
156+
"create_table", "columns must be a list of objects"
157+
)
158+
return value
159+
160+
@model_validator(mode="after")
161+
def validate_request(self):
162+
if not self.table:
163+
raise PydanticCustomError("create_table", "Table is required")
164+
if not isinstance(self.table, str) or not TABLE_NAME_RE.match(self.table):
165+
raise PydanticCustomError("create_table", "Invalid table name")
166+
if not self.columns and not self.rows and not self.row:
167+
raise PydanticCustomError(
168+
"create_table", "columns, rows or row is required"
169+
)
170+
if self.rows and self.row:
171+
raise PydanticCustomError(
172+
"create_table", "Cannot specify both rows and row"
173+
)
174+
if self.columns and (self.rows or self.row):
175+
raise PydanticCustomError(
176+
"create_table", "Cannot specify columns with rows or row"
177+
)
178+
if self.columns is not None:
179+
seen = set()
180+
duplicates = []
181+
for column in self.columns:
182+
if column.name in seen and column.name not in duplicates:
183+
duplicates.append(column.name)
184+
seen.add(column.name)
185+
if duplicates:
186+
raise PydanticCustomError(
187+
"create_table",
188+
"Duplicate column name: {names}",
189+
{"names": ", ".join(duplicates)},
190+
)
191+
if self.rows is not None:
192+
if not isinstance(self.rows, list):
193+
raise PydanticCustomError("create_table", "rows must be a list")
194+
if not all(isinstance(row, dict) for row in self.rows):
195+
raise PydanticCustomError(
196+
"create_table", "rows must be a list of objects"
197+
)
198+
if self.pk is not None and not isinstance(self.pk, str):
199+
raise PydanticCustomError("create_table", "pk must be a string")
200+
if self.pk and self.pks:
201+
raise PydanticCustomError("create_table", "Cannot specify both pk and pks")
202+
if self.pks is not None:
203+
if not isinstance(self.pks, list):
204+
raise PydanticCustomError("create_table", "pks must be a list")
205+
if not all(isinstance(pk, str) for pk in self.pks):
206+
raise PydanticCustomError(
207+
"create_table", "pks must be a list of strings"
208+
)
209+
if self.ignore and self.replace:
210+
raise PydanticCustomError(
211+
"create_table", "ignore and replace are mutually exclusive"
212+
)
213+
if {"ignore", "replace"} & self.model_fields_set:
214+
if not self.row and not self.rows:
215+
raise PydanticCustomError(
216+
"create_table", "ignore and replace require row or rows"
217+
)
218+
if not self.pk and not self.pks:
219+
raise PydanticCustomError(
220+
"create_table", "ignore and replace require pk or pks"
221+
)
222+
return self
223+
224+
@property
225+
def rows_list(self):
226+
return [self.row] if self.row else self.rows
227+
228+
@property
229+
def foreign_keys(self):
230+
if not self.columns:
231+
return None
232+
foreign_keys = []
233+
for column in self.columns:
234+
if column.fk_table and column.fk_column:
235+
foreign_keys.append((column.name, column.fk_table, column.fk_column))
236+
elif column.fk_table:
237+
foreign_keys.append((column.name, column.fk_table))
238+
return foreign_keys or None
239+
240+
101241
class _DefaultArgsMixin(_StrictPydanticModel):
102242
default: Any | None = None
103243
default_expr: DefaultExpr | None = None
@@ -209,6 +349,27 @@ def _pydantic_errors(validation_error):
209349
return errors
210350

211351

352+
def _create_table_pydantic_errors(validation_error):
353+
errors = validation_error.errors()
354+
invalid_keys = sorted(
355+
str(error["loc"][0])
356+
for error in errors
357+
if error["type"] == "extra_forbidden" and len(error["loc"]) == 1
358+
)
359+
if invalid_keys:
360+
return ["Invalid keys: {}".format(", ".join(invalid_keys))]
361+
362+
output = []
363+
for error in errors:
364+
message = error["msg"]
365+
if error["type"] == "create_table":
366+
output.append(message)
367+
continue
368+
location = ".".join(str(item) for item in error["loc"])
369+
output.append("{}: {}".format(location, message) if location else message)
370+
return output
371+
372+
212373
def _table_schema_from_conn(conn, table_name):
213374
row = conn.execute(
214375
"select sql from sqlite_master where type = 'table' and name = ?",
@@ -236,21 +397,6 @@ def _literal_default(db, value):
236397
class TableCreateView(BaseView):
237398
name = "table-create"
238399

239-
_valid_keys = {
240-
"table",
241-
"rows",
242-
"row",
243-
"columns",
244-
"pk",
245-
"pks",
246-
"ignore",
247-
"replace",
248-
"alter",
249-
}
250-
_supported_column_types = set(CREATE_TABLE_COLUMN_TYPES)
251-
# Any string that does not contain a newline or start with sqlite_
252-
_table_name_re = re.compile(r"^(?!sqlite_)[^\n]+$")
253-
254400
def __init__(self, datasette):
255401
self.ds = datasette
256402

@@ -274,26 +420,13 @@ async def post(self, request):
274420
if not isinstance(data, dict):
275421
return _error(["JSON must be an object"])
276422

277-
invalid_keys = set(data.keys()) - self._valid_keys
278-
if invalid_keys:
279-
return _error(["Invalid keys: {}".format(", ".join(invalid_keys))])
280-
281-
# ignore and replace are mutually exclusive
282-
if data.get("ignore") and data.get("replace"):
283-
return _error(["ignore and replace are mutually exclusive"])
284-
285-
# ignore and replace only allowed with row or rows
286-
if "ignore" in data or "replace" in data:
287-
if not data.get("row") and not data.get("rows"):
288-
return _error(["ignore and replace require row or rows"])
289-
290-
# ignore and replace require pk or pks
291-
if "ignore" in data or "replace" in data:
292-
if not data.get("pk") and not data.get("pks"):
293-
return _error(["ignore and replace require pk or pks"])
423+
try:
424+
create_request = CreateTableRequest.model_validate(data)
425+
except ValidationError as e:
426+
return _error(_create_table_pydantic_errors(e))
294427

295-
ignore = data.get("ignore")
296-
replace = data.get("replace")
428+
ignore = create_request.ignore
429+
replace = create_request.replace
297430

298431
if replace:
299432
# Must have update-row permission
@@ -304,24 +437,12 @@ async def post(self, request):
304437
):
305438
return _error(["Permission denied: need update-row"], 403)
306439

307-
table_name = data.get("table")
308-
if not table_name:
309-
return _error(["Table is required"])
310-
311-
if not self._table_name_re.match(table_name):
312-
return _error(["Invalid table name"])
440+
table_name = create_request.table
441+
table_exists = await db.table_exists(table_name)
442+
columns = create_request.columns
443+
rows = create_request.rows_list
313444

314-
table_exists = await db.table_exists(data["table"])
315-
columns = data.get("columns")
316-
rows = data.get("rows")
317-
row = data.get("row")
318-
if not columns and not rows and not row:
319-
return _error(["columns, rows or row is required"])
320-
321-
if rows and row:
322-
return _error(["Cannot specify both rows and row"])
323-
324-
if rows or row:
445+
if rows:
325446
# Must have insert-row permission
326447
if not await self.ds.allowed(
327448
action="insert-row",
@@ -331,13 +452,13 @@ async def post(self, request):
331452
return _error(["Permission denied: need insert-row"], 403)
332453

333454
alter = False
334-
if rows or row:
455+
if rows:
335456
if not table_exists:
336457
# if table is being created for the first time, alter=True
337458
alter = True
338459
else:
339460
# alter=True only if they request it AND they have permission
340-
if data.get("alter"):
461+
if create_request.alter:
341462
if not await self.ds.allowed(
342463
action="alter-table",
343464
resource=DatabaseResource(database=database_name),
@@ -346,64 +467,17 @@ async def post(self, request):
346467
return _error(["Permission denied: need alter-table"], 403)
347468
alter = True
348469

349-
if columns:
350-
if rows or row:
351-
return _error(["Cannot specify columns with rows or row"])
352-
if not isinstance(columns, list):
353-
return _error(["columns must be a list"])
354-
for column in columns:
355-
if not isinstance(column, dict):
356-
return _error(["columns must be a list of objects"])
357-
if not column.get("name") or not isinstance(column.get("name"), str):
358-
return _error(["Column name is required"])
359-
if not column.get("type"):
360-
column["type"] = "text"
361-
if column["type"] not in self._supported_column_types:
362-
return _error(
363-
["Unsupported column type: {}".format(column["type"])]
364-
)
365-
# No duplicate column names
366-
dupes = {c["name"] for c in columns if columns.count(c) > 1}
367-
if dupes:
368-
return _error(["Duplicate column name: {}".format(", ".join(dupes))])
369-
370-
if row:
371-
rows = [row]
372-
373-
if rows:
374-
if not isinstance(rows, list):
375-
return _error(["rows must be a list"])
376-
for row in rows:
377-
if not isinstance(row, dict):
378-
return _error(["rows must be a list of objects"])
379-
380-
pk = data.get("pk")
381-
pks = data.get("pks")
382-
383-
if pk and pks:
384-
return _error(["Cannot specify both pk and pks"])
385-
if pk:
386-
if not isinstance(pk, str):
387-
return _error(["pk must be a string"])
388-
if pks:
389-
if not isinstance(pks, list):
390-
return _error(["pks must be a list"])
391-
for pk in pks:
392-
if not isinstance(pk, str):
393-
return _error(["pks must be a list of strings"])
470+
pk = create_request.pk
471+
pks = create_request.pks
394472

395473
# If table exists already, read pks from that instead
396474
if table_exists:
397475
actual_pks = await db.primary_keys(table_name)
398476
# if pk passed and table already exists check it does not change
399477
bad_pks = False
400-
if len(actual_pks) == 1 and data.get("pk") and data["pk"] != actual_pks[0]:
478+
if len(actual_pks) == 1 and pk and pk != actual_pks[0]:
401479
bad_pks = True
402-
elif (
403-
len(actual_pks) > 1
404-
and data.get("pks")
405-
and set(data["pks"]) != set(actual_pks)
406-
):
480+
elif len(actual_pks) > 1 and pks and set(pks) != set(actual_pks):
407481
bad_pks = True
408482
if bad_pks:
409483
return _error(["pk cannot be changed for existing table"])
@@ -423,8 +497,9 @@ def create_table(conn):
423497
)
424498
else:
425499
table.create(
426-
{c["name"]: c["type"] for c in columns},
500+
{column.name: column.type for column in columns},
427501
pk=pks or pk,
502+
foreign_keys=create_request.foreign_keys,
428503
)
429504
return table.schema
430505

0 commit comments

Comments
 (0)