Skip to content

Commit 8303910

Browse files
committed
More strict typing for query fallback
1 parent 0db697a commit 8303910

3 files changed

Lines changed: 139 additions & 137 deletions

File tree

example/db/mydb.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
from contextvars import ContextVar
1717
from dataclasses import dataclass
1818
from enum import StrEnum
19+
from typing import Any
1920
from typing import Literal
2021
from typing import overload
2122

@@ -282,7 +283,7 @@ def query_stream(self, *, status: MydbTaskStatus) -> AbstractAsyncContextManager
282283
return self._server_cursor((status,))
283284

284285

285-
_QUERIES: dict[str, type[Query]] = {
286+
_QUERIES: dict[str, type[Query[Any]]] = {
286287
'\n INSERT INTO users (id, username, email)\n VALUES (@id, @username, @email)\n ': Query_3ee53b6909da8b4496346dda36c9f442,
287288
'\n INSERT INTO projects (id, name, owner_id, settings)\n VALUES (@id, @name, @owner_id, @settings)\n ': Query_67ac0768d48a654b1a305124c92372e8,
288289
'\n INSERT INTO tasks (id, project_id, title, priority, assignee_id, metadata, due_date)\n VALUES (@id, @project_id, @title, @priority, @assignee_id?, @metadata?, @due_date?)\n ': Query_bd4c62c78a942bfd1f087f87a19f2743,
@@ -317,10 +318,10 @@ def mydb_sql(sql: Literal['SELECT id FROM tasks WHERE project_id = @project_id A
317318
@overload
318319
def mydb_sql(sql: Literal['SELECT count(*) FROM tasks WHERE status = @status']) -> Query_29c838280e39383dd6b0760431eb3e60: ...
319320
@overload
320-
def mydb_sql(sql: str) -> Query: ...
321+
def mydb_sql(sql: str) -> Query[Any]: ...
321322

322323

323-
def mydb_sql(sql: str, row_type: str | None = None) -> Query:
324+
def mydb_sql(sql: str, row_type: str | None = None) -> Query[Any]:
324325
if sql in _QUERIES:
325326
return _QUERIES[sql]()
326327
msg = f"Unknown statement: {sql!r}"

src/iron_sql/codegen/generator.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -511,6 +511,7 @@ def render_module( # noqa: PLR0913, PLR0917
511511
from contextvars import ContextVar
512512
from dataclasses import dataclass
513513
from enum import StrEnum
514+
from typing import Any
514515
from typing import Literal
515516
from typing import overload
516517
@@ -573,17 +574,17 @@ class Query[T](runtime.Query[T]):
573574
{"\n\n\n".join(query_classes)}
574575
575576
576-
_QUERIES: dict[str, type[Query]] = {{
577+
_QUERIES: dict[str, type[Query[Any]]] = {{
577578
{(",\n ").join(query_dict_entries)}
578579
}}
579580
580581
581582
{"\n".join(query_overloads)}
582583
@overload
583-
def {sql_fn_name}(sql: str) -> Query: ...
584+
def {sql_fn_name}(sql: str) -> Query[Any]: ...
584585
585586
586-
def {sql_fn_name}(sql: str, row_type: str | None = None) -> Query:
587+
def {sql_fn_name}(sql: str, row_type: str | None = None) -> Query[Any]:
587588
if sql in _QUERIES:
588589
return _QUERIES[sql]()
589590
msg = f"Unknown statement: {{sql!r}}"

0 commit comments

Comments
 (0)