diff --git a/sqlmodel/main.py b/sqlmodel/main.py index 300031de8b..86f62dd5ff 100644 --- a/sqlmodel/main.py +++ b/sqlmodel/main.py @@ -48,8 +48,9 @@ from sqlalchemy.orm.attributes import set_attribute from sqlalchemy.orm.decl_api import DeclarativeMeta from sqlalchemy.orm.instrumentation import is_instrumented -from sqlalchemy.sql.schema import MetaData +from sqlalchemy.sql.schema import MetaData, SchemaEventTarget from sqlalchemy.sql.sqltypes import LargeBinary, Time, Uuid +from sqlalchemy.types import TypeEngine from typing_extensions import deprecated from ._compat import ( # type: ignore[attr-defined] @@ -88,6 +89,9 @@ | Mapping[int, Union["IncEx", bool]] | Mapping[str, Union["IncEx", bool]] ) +SaTypeOrInstance: TypeAlias = ( + TypeEngine[Any] | type[TypeEngine[Any]] | SchemaEventTarget +) OnDeleteType = Literal["CASCADE", "SET NULL", "RESTRICT"] @@ -209,7 +213,7 @@ class FieldInfoMetadata: ondelete: OnDeleteType | UndefinedType = Undefined unique: bool | UndefinedType = Undefined index: bool | UndefinedType = Undefined - sa_type: type[Any] | UndefinedType = Undefined + sa_type: SaTypeOrInstance | UndefinedType = Undefined sa_column: Column[Any] | UndefinedType = Undefined sa_column_args: Sequence[Any] | UndefinedType = Undefined sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined @@ -268,7 +272,7 @@ def Field( unique: bool | UndefinedType = Undefined, nullable: bool | UndefinedType = Undefined, index: bool | UndefinedType = Undefined, - sa_type: type[Any] | UndefinedType = Undefined, + sa_type: SaTypeOrInstance | UndefinedType = Undefined, sa_column_args: Sequence[Any] | UndefinedType = Undefined, sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined, schema_extra: dict[str, Any] | None = None, @@ -312,7 +316,7 @@ def Field( unique: bool | UndefinedType = Undefined, nullable: bool | UndefinedType = Undefined, index: bool | UndefinedType = Undefined, - sa_type: type[Any] | UndefinedType = Undefined, + sa_type: SaTypeOrInstance | UndefinedType = Undefined, sa_column_args: Sequence[Any] | UndefinedType = Undefined, sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined, schema_extra: dict[str, Any] | None = None, @@ -397,7 +401,7 @@ def Field( unique: bool | UndefinedType = Undefined, nullable: bool | UndefinedType = Undefined, index: bool | UndefinedType = Undefined, - sa_type: type[Any] | UndefinedType = Undefined, + sa_type: SaTypeOrInstance | UndefinedType = Undefined, sa_column: Column | UndefinedType = Undefined, # type: ignore sa_column_args: Sequence[Any] | UndefinedType = Undefined, sa_column_kwargs: Mapping[str, Any] | UndefinedType = Undefined,