diff --git a/.github/workflows/test-suite.yml b/.github/workflows/test-suite.yml index 113e4df..140cc00 100644 --- a/.github/workflows/test-suite.yml +++ b/.github/workflows/test-suite.yml @@ -14,7 +14,7 @@ jobs: strategy: matrix: - python-version: ["3.6", "3.7", "3.8", "3.9"] + python-version: ["3.7", "3.8", "3.9", "3.10"] services: mysql: diff --git a/docs/declaring_models.md b/docs/declaring_models.md index cc1b22f..1099ef4 100644 --- a/docs/declaring_models.md +++ b/docs/declaring_models.md @@ -52,13 +52,23 @@ All fields are required unless one of the following is set: * `allow_blank` - A boolean. Determine if empty strings are allowed. Sets the default to `""`. * `default` - A value or a callable (function). +Special keyword arguments for `DateTime` and `Date` fields: + +* `auto_now` - Automatically set the field to now every time the object is saved. Useful for “last-modified” timestamps. +* `auto_now_add` - Automatically set the field to now when the object is first created. Useful for creation of timestamps. + +Default=`datetime.date.today()` for `DateField` and `datetime.datetime.now()` for `DateTimeField`. + +!!! note + Setting `auto_now` or `auto_now_add` to True will cause the field to be read_only. + The following column types are supported. See `TypeSystem` for [type-specific validation keyword arguments][typesystem-fields]. * `orm.BigInteger()` * `orm.Boolean()` -* `orm.Date()` -* `orm.DateTime()` +* `orm.Date(auto_now, auto_now_add)` +* `orm.DateTime(auto_now, auto_now_add)` * `orm.Decimal()` * `orm.Email(max_length)` * `orm.Enum()` diff --git a/docs/making_queries.md b/docs/making_queries.md index 8f259a8..ea544e3 100644 --- a/docs/making_queries.md +++ b/docs/making_queries.md @@ -19,7 +19,7 @@ class Note(orm.Model): ``` ORM supports two types of queryset methods. -Some queryset methods return another queryset and can be chianed together like `.filter()` and `order_by`: +Some queryset methods return another queryset and can be chained together like `.filter()` and `order_by`: ```python Note.objects.filter(completed=True).order_by("id") @@ -141,6 +141,20 @@ await Note.objects.create(text="Call Mum.", completed=True) await Note.objects.create(text="Send invoices.", completed=True) ``` +### .bulk_create() + +You need to pass a list of dictionaries of required fields to create multiple objects: + +```python +await Product.objects.bulk_create( + [ + {"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED}, + {"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT}, + + ] +) +``` + ### .delete() You can `delete` instances by calling `.delete()` on a queryset: @@ -224,7 +238,7 @@ await note.update(completed=True) ### .get_or_create() To get an existing instance matching the query, or create a new one. -This will retuurn a tuple of `instance` and `created`. +This will return a tuple of `instance` and `created`. ```python note, created = await Note.objects.get_or_create( @@ -242,7 +256,7 @@ if it doesn't exist, it will use `defaults` argument to create the new instance. ### .update_or_create() To update an existing instance matching the query, or create a new one. -This will retuurn a tuple of `instance` and `created`. +This will return a tuple of `instance` and `created`. ```python note, created = await Note.objects.update_or_create( @@ -252,7 +266,7 @@ note, created = await Note.objects.update_or_create( This will query a `Note` with `text` as `"Going to car wash"`, if an instance is found, it will use the `defaults` argument to update the instance. -If it matches no records, it will use the comibnation of arguments to create the new instance. +If it matches no records, it will use the combination of arguments to create the new instance. !!! note Since `update_or_create()` is doing a [get()](#get), it can raise `MultipleMatches` exception. diff --git a/docs/relationships.md b/docs/relationships.md index fa86bf4..4749c88 100644 --- a/docs/relationships.md +++ b/docs/relationships.md @@ -72,7 +72,7 @@ track = await Track.objects.select_related("album").get(title="The Bird") assert track.album.name == "Malibu" ``` -To fetch an instance, filtering across a foregin key relationship: +To fetch an instance, filtering across a foreign key relationship: ```python tracks = Track.objects.filter(album__name="Fantasies") @@ -84,7 +84,7 @@ assert len(tracks) == 2 ### ForeignKey constraints -`ForeigknKey` supports specfiying a constraint through `on_delete` argument. +`ForeigknKey` supports specifying a constraint through `on_delete` argument. This will result in a SQL `ON DELETE` query being generated when the referenced object is removed. diff --git a/orm/fields.py b/orm/fields.py index b12ea9c..9015a76 100644 --- a/orm/fields.py +++ b/orm/fields.py @@ -1,4 +1,5 @@ import typing +from datetime import date, datetime import sqlalchemy import typesystem @@ -100,16 +101,31 @@ def get_column_type(self): return sqlalchemy.Boolean() -class DateTime(ModelField): +class AutoNowMixin(ModelField): + def __init__(self, auto_now=False, auto_now_add=False, **kwargs): + self.auto_now = auto_now + self.auto_now_add = auto_now_add + if auto_now_add and auto_now: + raise ValueError("auto_now and auto_now_add cannot be both True") + if auto_now_add or auto_now: + kwargs["read_only"] = True + super().__init__(**kwargs) + + +class DateTime(AutoNowMixin): def get_validator(self, **kwargs) -> typesystem.Field: + if self.auto_now_add or self.auto_now: + kwargs["default"] = datetime.now return typesystem.DateTime(**kwargs) def get_column_type(self): return sqlalchemy.DateTime() -class Date(ModelField): +class Date(AutoNowMixin): def get_validator(self, **kwargs) -> typesystem.Field: + if self.auto_now_add or self.auto_now: + kwargs["default"] = date.today return typesystem.Date(**kwargs) def get_column_type(self): @@ -154,7 +170,7 @@ def target(self): return self._target def get_validator(self, **kwargs) -> typesystem.Field: - return self.ForeignKeyValidator() + return self.ForeignKeyValidator(**kwargs) def get_column(self, name: str) -> sqlalchemy.Column: target = self.target diff --git a/orm/models.py b/orm/models.py index 49ae2ec..b402814 100644 --- a/orm/models.py +++ b/orm/models.py @@ -6,7 +6,7 @@ from sqlalchemy.ext.asyncio import create_async_engine from orm.exceptions import MultipleMatches, NoMatch -from orm.fields import String, Text +from orm.fields import Date, DateTime, String, Text FILTER_OPERATORS = { "exact": "__eq__", @@ -21,19 +21,29 @@ } +def _update_auto_now_fields(values, fields): + for key, value in fields.items(): + if isinstance(value, (DateTime, Date)) and value.auto_now: + values[key] = value.validator.get_default_value() + return values + + class ModelRegistry: def __init__(self, database: databases.Database) -> None: self.database = database self.models = {} - self.metadata = sqlalchemy.MetaData() + self._metadata = sqlalchemy.MetaData() + + @property + def metadata(self): + for model_cls in self.models.values(): + model_cls.build_table() + return self._metadata async def create_all(self): url = self._get_database_url() engine = create_async_engine(url) - for model_cls in self.models.values(): - model_cls.build_table() - async with self.database: async with engine.begin() as conn: await conn.run_sync(self.metadata.create_all) @@ -44,9 +54,6 @@ async def drop_all(self): url = self._get_database_url() engine = create_async_engine(url) - for model_cls in self.models.values(): - model_cls.build_table() - async with self.database: async with engine.begin() as conn: await conn.run_sync(self.metadata.drop_all) @@ -394,17 +401,19 @@ async def first(self, **kwargs): if rows: return rows[0] - async def create(self, **kwargs): + def _validate_kwargs(self, **kwargs): fields = self.model_cls.fields validator = typesystem.Schema( fields={key: value.validator for key, value in fields.items()} ) kwargs = validator.validate(kwargs) - for key, value in fields.items(): if value.validator.read_only and value.validator.has_default(): kwargs[key] = value.validator.get_default_value() + return kwargs + async def create(self, **kwargs): + kwargs = self._validate_kwargs(**kwargs) instance = self.model_cls(**kwargs) expr = self.table.insert().values(**kwargs) @@ -415,6 +424,12 @@ async def create(self, **kwargs): return instance + async def bulk_create(self, objs: typing.List[typing.Dict]) -> None: + new_objs = [self._validate_kwargs(**obj) for obj in objs] + + expr = self.table.insert().values(new_objs) + await self.database.execute(expr) + async def delete(self) -> None: expr = self.table.delete() for filter_clause in self.filter_clauses: @@ -429,8 +444,9 @@ async def update(self, **kwargs) -> None: if key in kwargs } validator = typesystem.Schema(fields=fields) - kwargs = validator.validate(kwargs) - + kwargs = _update_auto_now_fields( + validator.validate(kwargs), self.model_cls.fields + ) expr = self.table.update().values(**kwargs) for filter_clause in self.filter_clauses: @@ -498,7 +514,7 @@ def __str__(self): @classmethod def build_table(cls): tablename = cls.tablename - metadata = cls.registry.metadata + metadata = cls.registry._metadata columns = [] for name, field in cls.fields.items(): columns.append(field.get_column(name)) @@ -513,11 +529,9 @@ async def update(self, **kwargs): key: field.validator for key, field in self.fields.items() if key in kwargs } validator = typesystem.Schema(fields=fields) - kwargs = validator.validate(kwargs) - + kwargs = _update_auto_now_fields(validator.validate(kwargs), self.fields) pk_column = getattr(self.table.c, self.pkname) expr = self.table.update().values(**kwargs).where(pk_column == self.pk) - await self.database.execute(expr) # Update the model instance. diff --git a/setup.py b/setup.py index 49557ea..b20503f 100644 --- a/setup.py +++ b/setup.py @@ -50,7 +50,6 @@ def get_packages(package): author_email="tom@tomchristie.com", packages=get_packages(PACKAGE), package_data={PACKAGE: ["py.typed"]}, - data_files=[("", ["LICENSE.md"])], install_requires=["databases~=0.5", "typesystem==0.3.1"], extras_require={ "postgresql": ["asyncpg"], @@ -66,7 +65,10 @@ def get_packages(package): "Operating System :: OS Independent", "Topic :: Internet :: WWW/HTTP", "Programming Language :: Python :: 3", - "Programming Language :: Python :: 3.6", "Programming Language :: Python :: 3.7", + "Programming Language :: Python :: 3.8", + "Programming Language :: Python :: 3.9", + "Programming Language :: Python :: 3.10", + "Programming Language :: Python :: 3 :: Only", ], ) diff --git a/tests/test_columns.py b/tests/test_columns.py index c075b08..278aecd 100644 --- a/tests/test_columns.py +++ b/tests/test_columns.py @@ -33,6 +33,10 @@ class Product(orm.Model): "created": orm.DateTime(default=datetime.datetime.now), "created_day": orm.Date(default=datetime.date.today), "created_time": orm.Time(default=time), + "created_date": orm.Date(auto_now_add=True), + "created_datetime": orm.DateTime(auto_now_add=True), + "updated_datetime": orm.DateTime(auto_now=True), + "updated_date": orm.Date(auto_now=True), "data": orm.JSON(default={}), "description": orm.Text(allow_blank=True), "huge_number": orm.BigInteger(default=0), @@ -69,10 +73,13 @@ async def rollback_transactions(): async def test_model_crud(): product = await Product.objects.create() - product = await Product.objects.get(pk=product.pk) assert product.created.year == datetime.datetime.now().year assert product.created_day == datetime.date.today() + assert product.created_date == datetime.date.today() + assert product.created_datetime.date() == datetime.datetime.now().date() + assert product.updated_date == datetime.date.today() + assert product.updated_datetime.date() == datetime.datetime.now().date() assert product.data == {} assert product.description == "" assert product.huge_number == 0 @@ -96,6 +103,8 @@ async def test_model_crud(): assert product.price == decimal.Decimal("999.99") assert product.uuid == uuid.UUID("01175cde-c18f-4a13-a492-21bd9e1cb01b") + last_updated_datetime = product.updated_datetime + last_updated_date = product.updated_date user = await User.objects.create() assert isinstance(user.pk, uuid.UUID) @@ -114,3 +123,39 @@ async def test_model_crud(): user = await User.objects.get() assert isinstance(user.ipaddress, (ipaddress.IPv4Address, ipaddress.IPv6Address)) assert user.url == "https://encode.io" + # Test auto_now update + await product.update( + data={"foo": 1234}, + ) + assert product.updated_datetime != last_updated_datetime + assert product.updated_date == last_updated_date + + +async def test_both_auto_now_and_auto_now_add_raise_error(): + with pytest.raises(ValueError): + + class Product(orm.Model): + registry = models + fields = { + "id": orm.Integer(primary_key=True), + "created_datetime": orm.DateTime(auto_now_add=True, auto_now=True), + } + + await Product.objects.create() + + +async def test_bulk_create(): + await Product.objects.bulk_create( + [ + {"data": {"foo": 123}, "value": 123.456, "status": StatusEnum.RELEASED}, + {"data": {"foo": 456}, "value": 456.789, "status": StatusEnum.DRAFT}, + ] + ) + products = await Product.objects.all() + assert len(products) == 2 + assert products[0].data == {"foo": 123} + assert products[0].value == 123.456 + assert products[0].status == StatusEnum.RELEASED + assert products[1].data == {"foo": 456} + assert products[1].value == 456.789 + assert products[1].status == StatusEnum.DRAFT diff --git a/tests/test_foreignkey.py b/tests/test_foreignkey.py index 8017198..1ab6175 100644 --- a/tests/test_foreignkey.py +++ b/tests/test_foreignkey.py @@ -269,3 +269,12 @@ async def test_one_to_one_crud(): with pytest.raises(exceptions): await Person.objects.create(email="contact@encode.io", profile=profile) + + +async def test_nullable_foreign_key(): + await Member.objects.create(email="dev@encode.io") + + member = await Member.objects.get() + + assert member.email == "dev@encode.io" + assert member.team.pk is None