diff --git a/.gitignore b/.gitignore
index cbfef3fe..b2e76af0 100644
--- a/.gitignore
+++ b/.gitignore
@@ -7,6 +7,7 @@ dist
build
_build
.cache
+venv
# Installer logs
pip-log.txt
@@ -18,7 +19,11 @@ nosetests.xml
.DS_Store
.idea/*
+.python-version
+.pytest_cache
+/setup.py
+/requirements.txt
/test.py
/test_*.py
app.py
@@ -34,3 +39,5 @@ profile.html
benchmark.py
results.json
*.so
+pyproject.lock
+poetry.lock
diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml
new file mode 100644
index 00000000..7d92bbd3
--- /dev/null
+++ b/.pre-commit-config.yaml
@@ -0,0 +1,6 @@
+repos:
+- repo: https://github.com/ambv/black
+ rev: stable
+ hooks:
+ - id: black
+ python_version: python3.7
diff --git a/.travis.yml b/.travis.yml
index b761169a..c4572bc3 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -1,32 +1,58 @@
language: python
-# Once mysql 5.6 is available on the container infra, this can be removed.
-sudo: required
-dist: trusty
+stages:
+ - linting
+ - test
-python:
- - "2.7"
- - "3.5"
+cache:
+ pip: true
+ directories:
+ - $HOME/.cache/pypoetry
-env:
- - MYSQL_PACKAGE=pymysql
- - MYSQL_PACKAGE=mysqlclient
+services:
+ - mysql
addons:
- postgresql: '9.4'
-
-before_install:
- # Manually install mysql 5.6 since the default is v5.5.
- - sudo apt-get update -qq
- - sudo apt-get install -qq mysql-server-5.6 mysql-client-5.6 mysql-client-core-5.6
+ postgresql: '9.6'
install:
- - pip install -r tests-requirements.txt -U
- - if [[ $MYSQL_PACKAGE == 'pymysql' ]]; then pip install pymysql; fi
- - if [[ $MYSQL_PACKAGE == 'mysqlclient' ]]; then pip install mysqlclient; fi
+ - curl -fsS -o get-poetry.py https://raw.githubusercontent.com/sdispater/poetry/master/get-poetry.py
+ - python get-poetry.py --preview -y
+ - source $HOME/.poetry/env
+ - if [[ $MYSQL_PACKAGE == 'pymysql' ]]; then poetry install --extras mysql-python --extras pgsql; fi
+ - if [[ $MYSQL_PACKAGE == 'mysqlclient' ]]; then poetry install --extras mysql --extras pgsql; fi
-script: py.test tests/
+script: pytest tests/
before_script:
- psql -c 'create database orator_test;' -U postgres
- - mysql -u root -e 'create database orator_test;'
+ - mysql -e 'create database orator_test;'
+
+jobs:
+ include:
+ - python: "2.7"
+ env: MYSQL_PACKAGE=pymysql
+ - python: "2.7"
+ env: MYSQL_PACKAGE=mysqlclient
+ - python: "3.5"
+ env: MYSQL_PACKAGE=pymysql
+ - python: "3.5"
+ env: MYSQL_PACKAGE=mysqlclient
+ - python: "3.6"
+ env: MYSQL_PACKAGE=pymysql
+ - python: "3.6"
+ env: MYSQL_PACKAGE=mysqlclient
+ - python: "3.7"
+ env: MYSQL_PACKAGE=pymysql
+ dist: xenial
+ - python: "3.7"
+ env: MYSQL_PACKAGE=mysqlclient
+ dist: xenial
+
+ - stage: linting
+ python: "3.6"
+ install:
+ - pip install pre-commit
+ - pre-commit install-hooks
+ script:
+ - pre-commit run --all-files
diff --git a/CHANGELOG.md b/CHANGELOG.md
index af811624..220ee067 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,5 +1,80 @@
# Change Log
+## [0.9.9] - 2019-07-15
+
+### Fixed
+
+- Fixed missing relationships when eager loading multiple nested relationships.
+- Fixed a possible `AttributeError` when starting a transaction.
+- Fixed an infinite recursion when using `where_exists()` on a soft-deletable model.
+- Fixed some cases where a reconnection would not occur for PostgreSQL.
+
+
+## [0.9.8] - 2018-10-10
+
+### Fixed
+
+- Fixed the `morphed_by_many()` decorator.
+- Fixed decoding errors for MySQL.
+- Fixed connection errors check.
+- Fixed the `touches()` method.
+- Fixed `has_many` not showing `DISTINCT`.
+- Fixed `save_many()` for Python 3.
+- Fixed an error when listing columns for recent MySQL versions.
+
+
+## [0.9.7] - 2017-05-17
+
+### Fixed
+
+- Fixed `orator` command no longer working
+
+
+## [0.9.6] - 2017-05-16
+
+### Added
+
+- Added support for `DATE` types in models.
+- Added support for fractional seconds for the `TIMESTAMP` type in MySQL 5.6.4+.
+- Added support for fractional seconds for the `TIMESTAMP` and `TIME` types in PostreSQL.
+
+### Changed
+
+- Improved implementation of the `chunk` method.
+
+### Fixed
+
+- Fixed timezone offset errors when inserting datetime aware objects into PostgreSQL.
+- Fixed a bug occurring when using `__touches__` with an optional relationship.
+- Fixed collections serialization when using the query builder
+
+
+## [0.9.5] - 2017-02-11
+
+### Changed
+
+- `make:migration` now shows the name of the created migration file. (Thanks to [denislins](https://github.com/denislins))
+
+### Fixed
+
+- Fixed transactions not working for PostgreSQL and SQLite.
+
+
+## [0.9.4] - 2017-01-12
+
+### Fixed
+
+- Fixes `BelongsTo.associate()` for non saved models.
+- Fixes reconnection for PostgreSQL.
+- Fixes dependencies (changed `fake-factory` to `Faker`) (thanks to [acristoffers](https://github.com/acristoffers))
+
+
+## [0.9.3] - 2016-11-10
+
+### Fixed
+
+- Fixes `compile_table_exists()` method in PostgreSQL schema grammar that could break migrations.
+
## [0.9.2] - 2016-10-17
@@ -361,7 +436,14 @@
Initial release
-
+[Unreleased]: https://github.com/sdispater/orator/compare/0.9.9...0.9
+[0.9.9]: https://github.com/sdispater/orator/releases/0.9.9
+[0.9.8]: https://github.com/sdispater/orator/releases/0.9.8
+[0.9.7]: https://github.com/sdispater/orator/releases/0.9.7
+[0.9.6]: https://github.com/sdispater/orator/releases/0.9.6
+[0.9.5]: https://github.com/sdispater/orator/releases/0.9.5
+[0.9.4]: https://github.com/sdispater/orator/releases/0.9.4
+[0.9.3]: https://github.com/sdispater/orator/releases/0.9.3
[0.9.2]: https://github.com/sdispater/orator/releases/0.9.2
[0.9.1]: https://github.com/sdispater/orator/releases/0.9.1
[0.9.0]: https://github.com/sdispater/orator/releases/0.9.0
diff --git a/MANIFEST.in b/MANIFEST.in
deleted file mode 100644
index c602245e..00000000
--- a/MANIFEST.in
+++ /dev/null
@@ -1,6 +0,0 @@
-include README.rst LICENSE requirements.txt test-requirements.txt
-recursive-exclude tests *
-recursive-exclude benchmark *
-recursive-exclude seeders *
-recursive-exclude seeds *
-recursive-exclude migrations *
diff --git a/Makefile b/Makefile
index 9fd505fd..26c890c4 100644
--- a/Makefile
+++ b/Makefile
@@ -47,7 +47,7 @@ setup-mysql: drop-mysql
TO 'orator'@'localhost';"
drop-mysql:
- @type -p psql > /dev/null || { echo 'Install and setup PostgreSQL'; exit 1; }
+ @type -p mysql > /dev/null || { echo 'Install and setup MySQL'; exit 1; }
@-mysql -u root -e 'DROP DATABASE orator_test;' > /dev/null 2>&1
@-mysql -u root -e "DROP USER 'orator'@'localhost';" > /dev/null 2>&1
@@ -62,4 +62,6 @@ test:
# run tests against all supported python versions
tox:
+ @poet make:setup
@tox
+ @rm -f setup.py
diff --git a/README.rst b/README.rst
index a055a5dd..afffd0e9 100644
--- a/README.rst
+++ b/README.rst
@@ -361,7 +361,7 @@ Note that entire collections of models can also be converted to dictionaries:
.. code-block:: python
- return User.all().serailize()
+ return User.all().serialize()
Converting a model to JSON
diff --git a/orator/__init__.py b/orator/__init__.py
index dba4d8b3..0316577d 100644
--- a/orator/__init__.py
+++ b/orator/__init__.py
@@ -1,5 +1,7 @@
# -*- coding: utf-8 -*-
+__version__ = "0.9.9"
+
from .orm import Model, SoftDeletes, Collection, accessor, mutator, scope
from .database_manager import DatabaseManager
from .query.expression import QueryExpression
diff --git a/orator/commands/__init__.py b/orator/commands/__init__.py
index 633f8661..40a96afc 100644
--- a/orator/commands/__init__.py
+++ b/orator/commands/__init__.py
@@ -1,2 +1 @@
# -*- coding: utf-8 -*-
-
diff --git a/orator/commands/application.py b/orator/commands/application.py
index c9313ffd..9f1435f1 100644
--- a/orator/commands/application.py
+++ b/orator/commands/application.py
@@ -1,15 +1,19 @@
# -*- coding: utf-8 -*-
from cleo import Application
-from ..version import VERSION
+from .. import __version__
-application = Application('Orator', VERSION, complete=True)
+application = Application("Orator", __version__, complete=True)
# Migrations
from .migrations import (
- InstallCommand, MigrateCommand,
- MigrateMakeCommand, RollbackCommand,
- StatusCommand, ResetCommand, RefreshCommand
+ InstallCommand,
+ MigrateCommand,
+ MigrateMakeCommand,
+ RollbackCommand,
+ StatusCommand,
+ ResetCommand,
+ RefreshCommand,
)
application.add(InstallCommand())
diff --git a/orator/commands/command.py b/orator/commands/command.py
index 19351d56..2a114400 100644
--- a/orator/commands/command.py
+++ b/orator/commands/command.py
@@ -23,18 +23,18 @@ def configure(self):
if self.needs_config and not self.resolver:
# Checking if a default config file is present
if not self._check_config():
- self.add_option('config', 'c',
- InputOption.VALUE_REQUIRED,
- 'The config file path')
+ self.add_option(
+ "config", "c", InputOption.VALUE_REQUIRED, "The config file path"
+ )
def execute(self, i, o):
"""
Executes the command.
"""
- self.set_style('question', fg='blue')
+ self.set_style("question", fg="blue")
if self.needs_config and not self.resolver:
- self._handle_config(self.option('config'))
+ self._handle_config(self.option("config"))
return self.handle()
@@ -50,28 +50,24 @@ def call_silent(self, name, options=None):
return super(Command, self).call_silent(name, options)
- def confirm(self, question, default=False, true_answer_regex='(?i)^y'):
- """
- Confirm a question with the user.
+ def confirm_to_proceed(self, message=None):
+ if message is None:
+ message = "Do you really wish to run this command?: "
- :param question: The question to ask
- :type question: str
+ if self.option("force"):
+ return True
- :param default: The default value
- :type default: bool
+ confirmed = self.confirm(message)
- :param true_answer_regex: A regex to match the "yes" answer
- :type true_answer_regex: str
+ if not confirmed:
+ self.comment("Command Cancelled!")
- :rtype: bool
- """
- if not self.input.is_interactive():
- return True
+ return False
- return super(Command, self).confirm(question, default=False, true_answer_regex='(?i)^y')
+ return True
def _get_migration_path(self):
- return os.path.join(os.getcwd(), 'migrations')
+ return os.path.join(os.getcwd(), "migrations")
def _check_config(self):
"""
@@ -81,7 +77,7 @@ def _check_config(self):
"""
current_path = os.path.relpath(os.getcwd())
- accepted_files = ['orator.yml', 'orator.py']
+ accepted_files = ["orator.yml", "orator.py"]
for accepted_file in accepted_files:
config_file = os.path.join(current_path, accepted_file)
if os.path.exists(config_file):
@@ -101,7 +97,9 @@ def _handle_config(self, config_file):
"""
config = self._get_config(config_file)
- self.resolver = DatabaseManager(config.get('databases', config.get('DATABASES', {})))
+ self.resolver = DatabaseManager(
+ config.get("databases", config.get("DATABASES", {}))
+ )
return True
@@ -111,22 +109,22 @@ def _get_config(self, path=None):
:rtype: dict
"""
- if not path and not self.option('config'):
- raise Exception('The --config|-c option is missing.')
+ if not path and not self.option("config"):
+ raise Exception("The --config|-c option is missing.")
if not path:
- path = self.option('config')
+ path = self.option("config")
filename, ext = os.path.splitext(path)
- if ext in ['.yml', '.yaml']:
+ if ext in [".yml", ".yaml"]:
with open(path) as fd:
config = yaml.load(fd)
- elif ext in ['.py']:
+ elif ext in [".py"]:
config = {}
with open(path) as fh:
exec(fh.read(), {}, config)
else:
- raise RuntimeError('Config file [%s] is not supported.' % path)
+ raise RuntimeError("Config file [%s] is not supported." % path)
return config
diff --git a/orator/commands/migrations/base_command.py b/orator/commands/migrations/base_command.py
index 85188c84..ba7ec414 100644
--- a/orator/commands/migrations/base_command.py
+++ b/orator/commands/migrations/base_command.py
@@ -6,6 +6,5 @@
class BaseCommand(Command):
-
def _get_migration_path(self):
- return os.path.join(os.getcwd(), 'migrations')
+ return os.path.join(os.getcwd(), "migrations")
diff --git a/orator/commands/migrations/install_command.py b/orator/commands/migrations/install_command.py
index 161a5485..b8738698 100644
--- a/orator/commands/migrations/install_command.py
+++ b/orator/commands/migrations/install_command.py
@@ -16,10 +16,10 @@ def handle(self):
"""
Executes the command
"""
- database = self.option('database')
- repository = DatabaseMigrationRepository(self.resolver, 'migrations')
+ database = self.option("database")
+ repository = DatabaseMigrationRepository(self.resolver, "migrations")
repository.set_source(database)
repository.create_repository()
- self.info('Migration table created successfully')
+ self.info("Migration table created successfully")
diff --git a/orator/commands/migrations/make_command.py b/orator/commands/migrations/make_command.py
index d0a46e38..e5c929c0 100644
--- a/orator/commands/migrations/make_command.py
+++ b/orator/commands/migrations/make_command.py
@@ -24,20 +24,20 @@ def handle(self):
"""
creator = MigrationCreator()
- name = self.argument('name')
- table = self.option('table')
- create = bool(self.option('create'))
+ name = self.argument("name")
+ table = self.option("table")
+ create = bool(self.option("create"))
if not table and create is not False:
table = create
- path = self.option('path')
+ path = self.option("path")
if path is None:
path = self._get_migration_path()
- self._write_migration(creator, name, table, create, path)
+ migration_name = self._write_migration(creator, name, table, create, path)
- self.info('Migration created successfully')
+ self.line("Created migration: {}".format(migration_name))
def _write_migration(self, creator, name, table, create, path):
"""
diff --git a/orator/commands/migrations/migrate_command.py b/orator/commands/migrations/migrate_command.py
index 9faef082..24dec942 100644
--- a/orator/commands/migrations/migrate_command.py
+++ b/orator/commands/migrations/migrate_command.py
@@ -15,26 +15,25 @@ class MigrateCommand(BaseCommand):
{--seed-path= : The path of seeds files to be executed.
Defaults to ./seeders.}
{--P|pretend : Dump the SQL queries that would be run.}
+ {--f|force : Force the operation to run.}
"""
def handle(self):
- confirm = self.confirm(
- 'Are you sure you want to proceed with the migration? ',
- False
- )
- if not confirm:
+ if not self.confirm_to_proceed(
+ "Are you sure you want to proceed with the migration? "
+ ):
return
- database = self.option('database')
- repository = DatabaseMigrationRepository(self.resolver, 'migrations')
+ database = self.option("database")
+ repository = DatabaseMigrationRepository(self.resolver, "migrations")
migrator = Migrator(repository, self.resolver)
self._prepare_database(migrator, database)
- pretend = self.option('pretend')
+ pretend = self.option("pretend")
- path = self.option('path')
+ path = self.option("path")
if path is None:
path = self._get_migration_path()
@@ -46,21 +45,19 @@ def handle(self):
# If the "seed" option has been given, we will rerun the database seed task
# to repopulate the database.
- if self.option('seed'):
- options = [
- ('-n', True)
- ]
+ if self.option("seed"):
+ options = [("--force", self.option("force"))]
if database:
- options.append(('--database', database))
+ options.append(("--database", database))
- if self.get_definition().has_option('config'):
- options.append(('--config', self.option('config')))
+ if self.get_definition().has_option("config"):
+ options.append(("--config", self.option("config")))
- if self.option('seed-path'):
- options.append(('--path', self.option('seed-path')))
+ if self.option("seed-path"):
+ options.append(("--path", self.option("seed-path")))
- self.call('db:seed', options)
+ self.call("db:seed", options)
def _prepare_database(self, migrator, database):
migrator.set_connection(database)
@@ -69,9 +66,9 @@ def _prepare_database(self, migrator, database):
options = []
if database:
- options.append(('--database', database))
+ options.append(("--database", database))
- if self.get_definition().has_option('config'):
- options.append(('--config', self.option('config')))
+ if self.get_definition().has_option("config"):
+ options.append(("--config", self.option("config")))
- self.call('migrate:install', options)
+ self.call("migrate:install", options)
diff --git a/orator/commands/migrations/refresh_command.py b/orator/commands/migrations/refresh_command.py
index d9c44028..7dfded18 100644
--- a/orator/commands/migrations/refresh_command.py
+++ b/orator/commands/migrations/refresh_command.py
@@ -12,59 +12,53 @@ class RefreshCommand(BaseCommand):
{--p|path= : The path of migrations files to be executed.}
{--s|seed : Indicates if the seed task should be re-run.}
{--seed-path= : The path of seeds files to be executed.
- Defaults to ./seeders.}
+ Defaults to ./seeds.}
{--seeder=database_seeder : The name of the root seeder.}
+ {--f|force : Force the operation to run.}
"""
def handle(self):
"""
Executes the command.
"""
- confirm = self.confirm(
- 'Are you sure you want to refresh the database? ',
- False
- )
- if not confirm:
+ if not self.confirm_to_proceed(
+ "Are you sure you want to refresh the database?: "
+ ):
return
- database = self.option('database')
+ database = self.option("database")
- options = [
- ('-n', True)
- ]
+ options = [("--force", True)]
- if self.option('path'):
- options.append(('--path', self.option('path')))
+ if self.option("path"):
+ options.append(("--path", self.option("path")))
if database:
- options.append(('--database', database))
+ options.append(("--database", database))
- if self.get_definition().has_option('config'):
- options.append(('--config', self.option('config')))
+ if self.get_definition().has_option("config"):
+ options.append(("--config", self.option("config")))
- self.call('migrate:reset', options)
+ self.call("migrate:reset", options)
- self.call('migrate', options)
+ self.call("migrate", options)
if self._needs_seeding():
self._run_seeder(database)
def _needs_seeding(self):
- return self.option('seed') or self.option('seeder')
+ return self.option("seed")
def _run_seeder(self, database):
- options = [
- ('--seeder', self.option('seeder')),
- ('-n', True)
- ]
+ options = [("--seeder", self.option("seeder")), ("--force", True)]
if database:
- options.append(('--database', database))
+ options.append(("--database", database))
- if self.get_definition().has_option('config'):
- options.append(('--config', self.option('config')))
+ if self.get_definition().has_option("config"):
+ options.append(("--config", self.option("config")))
- if self.option('seed-path'):
- options.append(('--path', self.option('seed-path')))
+ if self.option("seed-path"):
+ options.append(("--path", self.option("seed-path")))
- self.call('db:seed', options)
+ self.call("db:seed", options)
diff --git a/orator/commands/migrations/reset_command.py b/orator/commands/migrations/reset_command.py
index 92fc5bc5..55978eec 100644
--- a/orator/commands/migrations/reset_command.py
+++ b/orator/commands/migrations/reset_command.py
@@ -12,29 +12,28 @@ class ResetCommand(BaseCommand):
{--d|database= : The database connection to use.}
{--p|path= : The path of migrations files to be executed.}
{--P|pretend : Dump the SQL queries that would be run.}
+ {--f|force : Force the operation to run.}
"""
def handle(self):
"""
Executes the command.
"""
- confirm = self.confirm(
- 'Are you sure you want to reset all of the migrations? ',
- False
- )
- if not confirm:
+ if not self.confirm_to_proceed(
+ "Are you sure you want to reset all of the migrations?: "
+ ):
return
- database = self.option('database')
- repository = DatabaseMigrationRepository(self.resolver, 'migrations')
+ database = self.option("database")
+ repository = DatabaseMigrationRepository(self.resolver, "migrations")
migrator = Migrator(repository, self.resolver)
self._prepare_database(migrator, database)
- pretend = bool(self.option('pretend'))
+ pretend = bool(self.option("pretend"))
- path = self.option('path')
+ path = self.option("path")
if path is None:
path = self._get_migration_path()
diff --git a/orator/commands/migrations/rollback_command.py b/orator/commands/migrations/rollback_command.py
index 47607872..2b789d65 100644
--- a/orator/commands/migrations/rollback_command.py
+++ b/orator/commands/migrations/rollback_command.py
@@ -12,29 +12,28 @@ class RollbackCommand(BaseCommand):
{--d|database= : The database connection to use.}
{--p|path= : The path of migrations files to be executed.}
{--P|pretend : Dump the SQL queries that would be run.}
+ {--f|force : Force the operation to run.}
"""
def handle(self):
"""
Executes the command.
"""
- confirm = self.confirm(
- 'Are you sure you want to rollback the last migration? ',
- True
- )
- if not confirm:
+ if not self.confirm_to_proceed(
+ "Are you sure you want to rollback the last migration?: "
+ ):
return
- database = self.option('database')
- repository = DatabaseMigrationRepository(self.resolver, 'migrations')
+ database = self.option("database")
+ repository = DatabaseMigrationRepository(self.resolver, "migrations")
migrator = Migrator(repository, self.resolver)
self._prepare_database(migrator, database)
- pretend = self.option('pretend')
+ pretend = self.option("pretend")
- path = self.option('path')
+ path = self.option("path")
if path is None:
path = self._get_migration_path()
diff --git a/orator/commands/migrations/status_command.py b/orator/commands/migrations/status_command.py
index bd74f881..7560c480 100644
--- a/orator/commands/migrations/status_command.py
+++ b/orator/commands/migrations/status_command.py
@@ -17,20 +17,20 @@ def handle(self):
"""
Executes the command.
"""
- database = self.option('database')
+ database = self.option("database")
self.resolver.set_default_connection(database)
- repository = DatabaseMigrationRepository(self.resolver, 'migrations')
+ repository = DatabaseMigrationRepository(self.resolver, "migrations")
migrator = Migrator(repository, self.resolver)
if not migrator.repository_exists():
- return self.error('No migrations found')
+ return self.error("No migrations found")
self._prepare_database(migrator, database)
- path = self.option('path')
+ path = self.option("path")
if path is None:
path = self._get_migration_path()
@@ -40,18 +40,15 @@ def handle(self):
migrations = []
for migration in migrator._get_migration_files(path):
if migration in ran:
- migrations.append(['%s>' % migration, 'Yes>'])
+ migrations.append(["%s>" % migration, "Yes>"])
else:
- migrations.append(['%s>' % migration, 'No>'])
+ migrations.append(["%s>" % migration, "No>"])
if migrations:
- table = self.table(
- ['Migration', 'Ran?'],
- migrations
- )
+ table = self.table(["Migration", "Ran?"], migrations)
table.render()
else:
- return self.error('No migrations found')
+ return self.error("No migrations found")
for note in migrator.get_notes():
self.line(note)
diff --git a/orator/commands/models/make_command.py b/orator/commands/models/make_command.py
index 1280dc31..22be9f5b 100644
--- a/orator/commands/models/make_command.py
+++ b/orator/commands/models/make_command.py
@@ -18,39 +18,39 @@ class ModelMakeCommand(Command):
"""
def handle(self):
- name = self.argument('name')
+ name = self.argument("name")
singular = inflection.singularize(inflection.tableize(name))
directory = self._get_path()
- filepath = self._get_path(singular + '.py')
+ filepath = self._get_path(singular + ".py")
if os.path.exists(filepath):
- raise RuntimeError('The model file already exists.')
+ raise RuntimeError("The model file already exists.")
mkdir_p(directory)
- parent = os.path.join(directory, '__init__.py')
+ parent = os.path.join(directory, "__init__.py")
if not os.path.exists(parent):
- with open(parent, 'w'):
+ with open(parent, "w"):
pass
stub = self._get_stub()
stub = self._populate_stub(name, stub)
- with open(filepath, 'w') as f:
+ with open(filepath, "w") as f:
f.write(stub)
- self.info('Model %s> successfully created.' % name)
+ self.info("Model %s> successfully created." % name)
- if self.option('migration'):
+ if self.option("migration"):
table = inflection.tableize(name)
self.call(
- 'make:migration',
+ "make:migration",
[
- ('name', 'create_%s_table' % table),
- ('--table', table),
- ('--create', True)
- ]
+ ("name", "create_%s_table" % table),
+ ("--table", table),
+ ("--create", True),
+ ],
)
def _get_stub(self):
@@ -73,15 +73,15 @@ def _populate_stub(self, name, stub):
:rtype: str
"""
- stub = stub.replace('DummyClass', name)
+ stub = stub.replace("DummyClass", name)
return stub
def _get_path(self, name=None):
- if self.option('path'):
- directory = self.option('path')
+ if self.option("path"):
+ directory = self.option("path")
else:
- directory = os.path.join(os.getcwd(), 'models')
+ directory = os.path.join(os.getcwd(), "models")
if name:
return os.path.join(directory, name)
diff --git a/orator/commands/seeds/base_command.py b/orator/commands/seeds/base_command.py
index f14646b4..4169466b 100644
--- a/orator/commands/seeds/base_command.py
+++ b/orator/commands/seeds/base_command.py
@@ -5,6 +5,5 @@
class BaseCommand(Command):
-
def _get_seeders_path(self):
- return os.path.join(os.getcwd(), 'seeds')
+ return os.path.join(os.getcwd(), "seeds")
diff --git a/orator/commands/seeds/make_command.py b/orator/commands/seeds/make_command.py
index a3cd6d82..a6d5cc1f 100644
--- a/orator/commands/seeds/make_command.py
+++ b/orator/commands/seeds/make_command.py
@@ -24,9 +24,9 @@ def handle(self):
Executes the command.
"""
# Making root seeder
- self._make('database_seeder', True)
+ self._make("database_seeder", True)
- self._make(self.argument('name'))
+ self._make(self.argument("name"))
def _make(self, name, root=False):
name = self._parse_name(name)
@@ -34,24 +34,24 @@ def _make(self, name, root=False):
path = self._get_path(name)
if os.path.exists(path):
if not root:
- self.error('%s already exists' % name)
+ self.error("%s already exists" % name)
return False
self._make_directory(os.path.dirname(path))
- with open(path, 'w') as fh:
+ with open(path, "w") as fh:
fh.write(self._build_class(name))
if root:
- with open(os.path.join(os.path.dirname(path), '__init__.py'), 'w'):
+ with open(os.path.join(os.path.dirname(path), "__init__.py"), "w"):
pass
- self.info('%s> created successfully.' % name)
+ self.info("%s> created successfully." % name)
def _parse_name(self, name):
- if name.endswith('.py'):
- name = name.replace('.py', '', -1)
+ if name.endswith(".py"):
+ name = name.replace(".py", "", -1)
return name
@@ -64,11 +64,11 @@ def _get_path(self, name):
:rtype: str
"""
- path = self.option('path')
+ path = self.option("path")
if path is None:
path = self._get_seeders_path()
- return os.path.join(path, '%s.py' % name)
+ return os.path.join(path, "%s.py" % name)
def _make_directory(self, path):
try:
@@ -83,7 +83,7 @@ def _build_class(self, name):
stub = self._get_stub()
klass = self._get_class_name(name)
- stub = stub.replace('DummyClass', klass)
+ stub = stub.replace("DummyClass", klass)
return stub
diff --git a/orator/commands/seeds/seed_command.py b/orator/commands/seeds/seed_command.py
index 607c4630..9e0ce6c1 100644
--- a/orator/commands/seeds/seed_command.py
+++ b/orator/commands/seeds/seed_command.py
@@ -18,34 +18,33 @@ class SeedCommand(BaseCommand):
{--p|path= : The path to seeders files.
Defaults to ./seeds.}
{--seeder=database_seeder : The name of the root seeder.}
+ {--f|force : Force the operation to run.}
"""
def handle(self):
"""
Executes the command.
"""
- confirm = self.confirm(
- 'Are you sure you want to seed the database? ',
- False
- )
- if not confirm:
+ if not self.confirm_to_proceed(
+ "Are you sure you want to seed the database?: "
+ ):
return
- self.resolver.set_default_connection(self.option('database'))
+ self.resolver.set_default_connection(self.option("database"))
self._get_seeder().run()
- self.info('Database seeded!')
+ self.info("Database seeded!")
def _get_seeder(self):
- name = self._parse_name(self.option('seeder'))
+ name = self._parse_name(self.option("seeder"))
seeder_file = self._get_path(name)
# Loading parent module
- load_module('seeds', self._get_path('__init__'))
+ load_module("seeds", self._get_path("__init__"))
# Loading module
- mod = load_module('seeds.%s' % name, seeder_file)
+ mod = load_module("seeds.%s" % name, seeder_file)
klass = getattr(mod, inflection.camelize(name))
@@ -56,8 +55,8 @@ def _get_seeder(self):
return instance
def _parse_name(self, name):
- if name.endswith('.py'):
- name = name.replace('.py', '', -1)
+ if name.endswith(".py"):
+ name = name.replace(".py", "", -1)
return name
@@ -70,8 +69,8 @@ def _get_path(self, name):
:rtype: str
"""
- path = self.option('path')
+ path = self.option("path")
if path is None:
path = self._get_seeders_path()
- return os.path.join(path, '%s.py' % name)
+ return os.path.join(path, "%s.py" % name)
diff --git a/orator/connections/connection.py b/orator/connections/connection.py
index 14cace24..469b5989 100644
--- a/orator/connections/connection.py
+++ b/orator/connections/connection.py
@@ -14,14 +14,15 @@
from ..exceptions.query import QueryException
-query_logger = logging.getLogger('orator.connection.queries')
-connection_logger = logging.getLogger('orator.connection')
+query_logger = logging.getLogger("orator.connection.queries")
+connection_logger = logging.getLogger("orator.connection")
def run(wrapped):
"""
Special decorator encapsulating query method.
"""
+
@wraps(wrapped)
def _run(self, query, bindings=None, *args, **kwargs):
self._reconnect_if_missing_connection()
@@ -46,8 +47,15 @@ class Connection(ConnectionInterface):
name = None
- def __init__(self, connection, database='', table_prefix='', config=None,
- builder_class=QueryBuilder, builder_default_kwargs=None):
+ def __init__(
+ self,
+ connection,
+ database="",
+ table_prefix="",
+ config=None,
+ builder_class=QueryBuilder,
+ builder_default_kwargs=None,
+ ):
"""
:param connection: A dbapi connection instance
:type connection: Connector
@@ -69,7 +77,7 @@ def __init__(self, connection, database='', table_prefix='', config=None,
self._database = database
if table_prefix is None:
- table_prefix = ''
+ table_prefix = ""
self._table_prefix = table_prefix
@@ -91,13 +99,13 @@ def __init__(self, connection, database='', table_prefix='', config=None,
self._builder_default_kwargs = builder_default_kwargs
- self._logging_queries = config.get('log_queries', False)
+ self._logging_queries = config.get("log_queries", False)
self._logged_queries = []
# Setting the marker based on config
self._marker = None
- if self._config.get('use_qmark'):
- self._marker = '?'
+ if self._config.get("use_qmark"):
+ self._marker = "?"
self._query_grammar = self.get_default_query_grammar()
@@ -163,7 +171,9 @@ def query(self):
:rtype: QueryBuilder
"""
query = self._builder_class(
- self, self._query_grammar, self._post_processor,
+ self,
+ self._query_grammar,
+ self._post_processor,
**self._builder_default_kwargs
)
@@ -200,6 +210,34 @@ def select(self, query, bindings=None, use_read_connection=True):
return cursor.fetchall()
+ def select_many(
+ self, size, query, bindings=None, use_read_connection=True, abort=False
+ ):
+ if self.pretending():
+ yield []
+ else:
+ bindings = self.prepare_bindings(bindings)
+ cursor = self._get_cursor_for_select(use_read_connection)
+
+ try:
+ cursor.execute(query, bindings)
+ except Exception as e:
+ if self._caused_by_lost_connection(e) and not abort:
+ self.reconnect()
+
+ for results in self.select_many(
+ size, query, bindings, use_read_connection, True
+ ):
+ yield results
+ else:
+ raise
+ else:
+ results = cursor.fetchmany(size)
+ while results:
+ yield results
+
+ results = cursor.fetchmany(size)
+
def _get_cursor_for_select(self, use_read_connection=True):
if use_read_connection:
self._cursor = self.get_read_connection().cursor()
@@ -308,7 +346,9 @@ def pretend(self):
self._pretending = False
- def _try_again_if_caused_by_lost_connection(self, e, query, bindings, callback, *args, **kwargs):
+ def _try_again_if_caused_by_lost_connection(
+ self, e, query, bindings, callback, *args, **kwargs
+ ):
if self._caused_by_lost_connection(e):
self.reconnect()
@@ -317,18 +357,28 @@ def _try_again_if_caused_by_lost_connection(self, e, query, bindings, callback,
raise QueryException(query, bindings, e)
def _caused_by_lost_connection(self, e):
- message = str(e)
-
- for s in ['server has gone away',
- 'no connection to the server',
- 'Lost Connection']:
+ message = str(e).lower()
+
+ for s in [
+ "server has gone away",
+ "no connection to the server",
+ "lost connection",
+ "is dead or not enabled",
+ "error while sending",
+ "decryption failed or bad record mac",
+ "server closed the connection unexpectedly",
+ "ssl connection has been closed unexpectedly",
+ "error writing data to the connection",
+ "connection timed out",
+ "resource deadlock avoided",
+ ]:
if s in message:
return True
return False
def disconnect(self):
- connection_logger.debug('%s is disconnecting' % self.__class__.__name__)
+ connection_logger.debug("%s is disconnecting" % self.__class__.__name__)
if self._connection:
self._connection.close()
@@ -337,14 +387,14 @@ def disconnect(self):
self.set_connection(None).set_read_connection(None)
- connection_logger.debug('%s disconnected' % self.__class__.__name__)
+ connection_logger.debug("%s disconnected" % self.__class__.__name__)
def reconnect(self):
- connection_logger.debug('%s is reconnecting' % self.__class__.__name__)
+ connection_logger.debug("%s is reconnecting" % self.__class__.__name__)
if self._reconnector is not None and callable(self._reconnector):
return self._reconnector(self)
- raise Exception('Lost connection and no reconnector available')
+ raise Exception("Lost connection and no reconnector available")
def _reconnect_if_missing_connection(self):
if self.get_connection() is None or self.get_read_connection() is None:
@@ -360,17 +410,14 @@ def log_query(self, query, bindings, time_=None):
query = self._get_cursor_query(query, bindings)
if query:
- log = 'Executed %s' % (query,)
+ log = "Executed %s" % (query,)
if time_:
- log += ' in %sms' % time_
+ log += " in %sms" % time_
- query_logger.debug(log,
- extra={
- 'query': query,
- 'bindings': bindings,
- 'elapsed_time': time_
- })
+ query_logger.debug(
+ log, extra={"query": query, "bindings": bindings, "elapsed_time": time_}
+ )
def _get_elapsed_time(self, start):
return round((time.time() - start) * 1000, 2)
@@ -398,8 +445,9 @@ def get_read_connection(self):
def set_connection(self, connection):
if self._transactions >= 1:
- raise RuntimeError("Can't swap dbapi connection"
- "while within transaction.")
+ raise RuntimeError(
+ "Can't swap dbapi connection" "while within transaction."
+ )
self._connection = connection
@@ -416,7 +464,7 @@ def set_reconnector(self, reconnector):
return self
def get_name(self):
- return self._config.get('name')
+ return self._config.get("name")
def get_config(self, option):
return self._config.get(option)
@@ -502,6 +550,7 @@ def set_builder_class(self, klass, default_kwargs=None):
return self
def __enter__(self):
+ self._reconnect_if_missing_connection()
self.begin_transaction()
return self
diff --git a/orator/connections/connection_interface.py b/orator/connections/connection_interface.py
index bbce1d86..b0f7d30d 100644
--- a/orator/connections/connection_interface.py
+++ b/orator/connections/connection_interface.py
@@ -2,7 +2,6 @@
class ConnectionInterface(object):
-
def table(self, table):
"""
Begin a fluent query against a database table
@@ -170,5 +169,3 @@ def transaction_level(self):
def pretend(self):
raise NotImplementedError()
-
-
diff --git a/orator/connections/connection_resolver_interface.py b/orator/connections/connection_resolver_interface.py
index d9329b90..d28987f4 100644
--- a/orator/connections/connection_resolver_interface.py
+++ b/orator/connections/connection_resolver_interface.py
@@ -2,7 +2,6 @@
class ConnectionResolverInterface(object):
-
def connection(self, name=None):
raise NotImplementedError()
diff --git a/orator/connections/mysql_connection.py b/orator/connections/mysql_connection.py
index 3fe0beb7..57558a7e 100644
--- a/orator/connections/mysql_connection.py
+++ b/orator/connections/mysql_connection.py
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
+from ..utils import decode
from ..utils import PY2
from .connection import Connection
from ..query.grammars.mysql_grammar import MySQLQueryGrammar
@@ -11,7 +12,7 @@
class MySQLConnection(Connection):
- name = 'mysql'
+ name = "mysql"
def get_default_query_grammar(self):
return MySQLQueryGrammar(marker=self._marker)
@@ -37,7 +38,16 @@ def get_schema_manager(self):
return MySQLSchemaManager(self)
def begin_transaction(self):
- self._connection.autocommit(False)
+ self._reconnect_if_missing_connection()
+
+ try:
+ self._connection.autocommit(False)
+ except Exception as e:
+ if self._caused_by_lost_connection(e):
+ self.reconnect()
+ self._connection.autocommit(False)
+ else:
+ raise
super(MySQLConnection, self).begin_transaction()
@@ -58,10 +68,10 @@ def rollback(self):
self._transactions -= 1
def _get_cursor_query(self, query, bindings):
- if not hasattr(self._cursor, '_last_executed') or self._pretending:
+ if not hasattr(self._cursor, "_last_executed") or self._pretending:
return super(MySQLConnection, self)._get_cursor_query(query, bindings)
if PY2:
- return self._cursor._last_executed.decode()
+ return decode(self._cursor._last_executed)
return self._cursor._last_executed
diff --git a/orator/connections/postgres_connection.py b/orator/connections/postgres_connection.py
index b5e44ce2..b9617978 100644
--- a/orator/connections/postgres_connection.py
+++ b/orator/connections/postgres_connection.py
@@ -11,7 +11,7 @@
class PostgresConnection(Connection):
- name = 'pgsql'
+ name = "pgsql"
def get_default_query_grammar(self):
return PostgresQueryGrammar(marker=self._marker)
@@ -64,7 +64,7 @@ def _get_cursor_query(self, query, bindings):
return self._cursor.mogrify(query, bindings).decode()
- if not hasattr(self._cursor, 'query'):
+ if not hasattr(self._cursor, "query"):
return super(PostgresConnection, self)._get_cursor_query(query, bindings)
if PY2:
diff --git a/orator/connections/sqlite_connection.py b/orator/connections/sqlite_connection.py
index 060c6efc..4f7b4e62 100644
--- a/orator/connections/sqlite_connection.py
+++ b/orator/connections/sqlite_connection.py
@@ -10,7 +10,7 @@
class SQLiteConnection(Connection):
- name = 'sqlite'
+ name = "sqlite"
def get_default_query_grammar(self):
return self.with_table_prefix(SQLiteQueryGrammar())
@@ -25,7 +25,7 @@ def get_schema_manager(self):
return SQLiteSchemaManager(self)
def begin_transaction(self):
- self._connection.isolation_level = 'DEFERRED'
+ self._connection.isolation_level = "DEFERRED"
super(SQLiteConnection, self).begin_transaction()
diff --git a/orator/connectors/connection_factory.py b/orator/connectors/connection_factory.py
index c0fe3be8..49ee29e9 100644
--- a/orator/connectors/connection_factory.py
+++ b/orator/connectors/connection_factory.py
@@ -6,31 +6,27 @@
from .mysql_connector import MySQLConnector
from .postgres_connector import PostgresConnector
from .sqlite_connector import SQLiteConnector
-from ..connections import (
- MySQLConnection,
- PostgresConnection,
- SQLiteConnection
-)
+from ..connections import MySQLConnection, PostgresConnection, SQLiteConnection
class ConnectionFactory(object):
CONNECTORS = {
- 'sqlite': SQLiteConnector,
- 'mysql': MySQLConnector,
- 'postgres': PostgresConnector,
- 'pgsql': PostgresConnector
+ "sqlite": SQLiteConnector,
+ "mysql": MySQLConnector,
+ "postgres": PostgresConnector,
+ "pgsql": PostgresConnector,
}
CONNECTIONS = {
- 'sqlite': SQLiteConnection,
- 'mysql': MySQLConnection,
- 'postgres': PostgresConnection,
- 'pgsql': PostgresConnection
+ "sqlite": SQLiteConnection,
+ "mysql": MySQLConnection,
+ "postgres": PostgresConnection,
+ "pgsql": PostgresConnection,
}
def make(self, config, name=None):
- if 'read' in config:
+ if "read" in config:
return self._create_read_write_connection(config)
return self._create_single_connection(config)
@@ -39,11 +35,7 @@ def _create_single_connection(self, config):
conn = self.create_connector(config).connect(config)
return self._create_connection(
- config['driver'],
- conn,
- config['database'],
- config.get('prefix', ''),
- config
+ config["driver"], conn, config["database"], config.get("prefix", ""), config
)
def _create_read_write_connection(self, config):
@@ -59,12 +51,12 @@ def _create_read_connection(self, config):
return self.create_connector(read_config).connect(read_config)
def _get_read_config(self, config):
- read_config = self._get_read_write_config(config, 'read')
+ read_config = self._get_read_write_config(config, "read")
return self._merge_read_write_config(config, read_config)
def _get_write_config(self, config):
- write_config = self._get_read_write_config(config, 'write')
+ write_config = self._get_read_write_config(config, "write")
return self._merge_read_write_config(config, write_config)
@@ -78,16 +70,16 @@ def _merge_read_write_config(self, config, merge):
config = config.copy()
config.update(merge)
- del config['read']
- del config['write']
+ del config["read"]
+ del config["write"]
return config
def create_connector(self, config):
- if 'driver' not in config:
- raise ArgumentError('A driver must be specified')
+ if "driver" not in config:
+ raise ArgumentError("A driver must be specified")
- driver = config['driver']
+ driver = config["driver"]
if driver not in self.CONNECTORS:
raise UnsupportedDriver(driver)
@@ -102,7 +94,7 @@ def register_connector(cls, name, connector):
def register_connection(cls, name, connection):
cls.CONNECTIONS[name] = connection
- def _create_connection(self, driver, connection, database, prefix='', config=None):
+ def _create_connection(self, driver, connection, database, prefix="", config=None):
if config is None:
config = {}
diff --git a/orator/connectors/connector.py b/orator/connectors/connector.py
index babd3ba1..60916adb 100644
--- a/orator/connectors/connector.py
+++ b/orator/connectors/connector.py
@@ -6,9 +6,7 @@
class Connector(object):
- RESERVED_KEYWORDS = [
- 'log_queries', 'driver', 'prefix', 'name'
- ]
+ RESERVED_KEYWORDS = ["log_queries", "driver", "prefix", "name"]
SUPPORTED_PACKAGES = []
@@ -47,16 +45,16 @@ def get_params(self):
return self._params
def get_database(self):
- return self._params.get('database')
+ return self._params.get("database")
def get_host(self):
- return self._params.get('host')
+ return self._params.get("host")
def get_user(self):
- return self._params.get('user')
+ return self._params.get("user")
def get_password(self):
- return self._params.get('password')
+ return self._params.get("password")
def get_database_platform(self):
if self._platform is None:
diff --git a/orator/connectors/mysql_connector.py b/orator/connectors/mysql_connector.py
index 580158c6..858493d7 100644
--- a/orator/connectors/mysql_connector.py
+++ b/orator/connectors/mysql_connector.py
@@ -1,29 +1,32 @@
# -*- coding: utf-8 -*-
import re
-from pendulum import Pendulum
+from pendulum import Pendulum, Date
try:
import MySQLdb as mysql
# Fix for understanding Pendulum object
import MySQLdb.converters
+
MySQLdb.converters.conversions[Pendulum] = MySQLdb.converters.DateTime2literal
+ MySQLdb.converters.conversions[Date] = MySQLdb.converters.Thing2Literal
from MySQLdb.cursors import DictCursor as cursor_class
- keys_fix = {
- 'password': 'passwd',
- 'database': 'db'
- }
+
+ keys_fix = {"password": "passwd", "database": "db"}
except ImportError as e:
try:
import pymysql as mysql
# Fix for understanding Pendulum object
import pymysql.converters
+
pymysql.converters.conversions[Pendulum] = pymysql.converters.escape_datetime
+ pymysql.converters.conversions[Date] = pymysql.converters.escape_date
from pymysql.cursors import DictCursor as cursor_class
+
keys_fix = {}
except ImportError as e:
mysql = None
@@ -32,19 +35,21 @@
from ..dbal.platforms import MySQLPlatform, MySQL57Platform
from .connector import Connector
from ..utils.qmarker import qmark, denullify
+from ..utils.helpers import serialize
class Record(dict):
-
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(item)
+ def serialize(self):
+ return serialize(self)
-class BaseDictCursor(cursor_class):
+class BaseDictCursor(cursor_class):
def _fetch_row(self, size=1):
# Overridden for mysqclient
if not self._result:
@@ -52,14 +57,13 @@ def _fetch_row(self, size=1):
rows = self._result.fetch_row(size, self._fetch_type)
return tuple(Record(r) for r in rows)
-
+
def _conv_row(self, row):
# Overridden for pymysql
return Record(super(BaseDictCursor, self)._conv_row(row))
class DictCursor(BaseDictCursor):
-
def execute(self, query, args=None):
query = qmark(query)
@@ -68,20 +72,22 @@ def execute(self, query, args=None):
def executemany(self, query, args):
query = qmark(query)
- return super(DictCursor, self).executemany(
- query, denullify(args)
- )
+ return super(DictCursor, self).executemany(query, denullify(args))
class MySQLConnector(Connector):
RESERVED_KEYWORDS = [
- 'log_queries', 'driver', 'prefix',
- 'engine', 'collation',
- 'name', 'use_qmark'
+ "log_queries",
+ "driver",
+ "prefix",
+ "engine",
+ "collation",
+ "name",
+ "use_qmark",
]
- SUPPORTED_PACKAGES = ['PyMySQL', 'mysqlclient']
+ SUPPORTED_PACKAGES = ["PyMySQL", "mysqlclient"]
def _do_connect(self, config):
config = dict(config.items())
@@ -89,19 +95,16 @@ def _do_connect(self, config):
config[value] = config[key]
del config[key]
- config['autocommit'] = True
- config['cursorclass'] = self.get_cursor_class(config)
+ config["autocommit"] = True
+ config["cursorclass"] = self.get_cursor_class(config)
return self.get_api().connect(**self.get_config(config))
def get_default_config(self):
- return {
- 'charset': 'utf8',
- 'use_unicode': True
- }
+ return {"charset": "utf8", "use_unicode": True}
def get_cursor_class(self, config):
- if config.get('use_qmark'):
+ if config.get("use_qmark"):
return DictCursor
return BaseDictCursor
@@ -112,25 +115,27 @@ def get_api(self):
def get_server_version(self):
version = self._connection.get_server_info()
- version_parts = re.match('^(?P\d+)(?:\.(?P\d+)(?:\.(?P\d+))?)?', version)
+ version_parts = re.match(
+ "^(?P\d+)(?:\.(?P\d+)(?:\.(?P\d+))?)?", version
+ )
- major = int(version_parts.group('major'))
- minor = version_parts.group('minor') or 0
- patch = version_parts.group('patch') or 0
+ major = int(version_parts.group("major"))
+ minor = version_parts.group("minor") or 0
+ patch = version_parts.group("patch") or 0
minor, patch = int(minor), int(patch)
- server_version = (major, minor, patch, '')
+ server_version = (major, minor, patch, "")
- if 'mariadb' in version.lower():
- server_version = (major, minor, patch, 'mariadb')
+ if "mariadb" in version.lower():
+ server_version = (major, minor, patch, "mariadb")
return server_version
def _create_database_platform_for_version(self, version):
major, minor, _, extra = version
- if extra == 'mariadb':
+ if extra == "mariadb":
return self.get_dbal_platform()
if (major, minor) >= (5, 7):
diff --git a/orator/connectors/postgres_connector.py b/orator/connectors/postgres_connector.py
index 4f635ebb..fd28f737 100644
--- a/orator/connectors/postgres_connector.py
+++ b/orator/connectors/postgres_connector.py
@@ -1,7 +1,5 @@
# -*- coding: utf-8 -*-
-import re
-
try:
import psycopg2
import psycopg2.extras
@@ -20,34 +18,31 @@
from ..dbal.platforms import PostgresPlatform
from .connector import Connector
from ..utils.qmarker import qmark, denullify
+from ..utils.helpers import serialize
class BaseDictConnection(connection_class):
-
def cursor(self, *args, **kwargs):
- kwargs.setdefault('cursor_factory', BaseDictCursor)
+ kwargs.setdefault("cursor_factory", BaseDictCursor)
return super(BaseDictConnection, self).cursor(*args, **kwargs)
class DictConnection(BaseDictConnection):
-
def cursor(self, *args, **kwargs):
- kwargs.setdefault('cursor_factory', DictCursor)
+ kwargs.setdefault("cursor_factory", DictCursor)
return super(DictConnection, self).cursor(*args, **kwargs)
class BaseDictCursor(cursor_class):
-
def __init__(self, *args, **kwargs):
- kwargs['row_factory'] = DictRow
+ kwargs["row_factory"] = DictRow
super(cursor_class, self).__init__(*args, **kwargs)
self._prefetch = 1
class DictCursor(BaseDictCursor):
-
def execute(self, query, vars=None):
query = qmark(query)
@@ -56,27 +51,36 @@ def execute(self, query, vars=None):
def executemany(self, query, args_seq):
query = qmark(query)
- return super(DictCursor, self).executemany(
- query, denullify(args_seq))
+ return super(DictCursor, self).executemany(query, denullify(args_seq))
class DictRow(row_class):
-
def __getattr__(self, item):
try:
return self[item]
except KeyError:
raise AttributeError(item)
+ def serialize(self):
+ serialized = {}
+ for column, index in self._index.items():
+ serialized[column] = list.__getitem__(self, index)
+
+ return serialize(serialized)
+
class PostgresConnector(Connector):
RESERVED_KEYWORDS = [
- 'log_queries', 'driver', 'prefix', 'name',
- 'register_unicode', 'use_qmark'
+ "log_queries",
+ "driver",
+ "prefix",
+ "name",
+ "register_unicode",
+ "use_qmark",
]
- SUPPORTED_PACKAGES = ['psycopg2']
+ SUPPORTED_PACKAGES = ["psycopg2"]
def _do_connect(self, config):
connection = self.get_api().connect(
@@ -84,7 +88,7 @@ def _do_connect(self, config):
**self.get_config(config)
)
- if config.get('use_unicode', True):
+ if config.get("use_unicode", True):
extensions.register_type(extensions.UNICODE, connection)
extensions.register_type(extensions.UNICODEARRAY, connection)
@@ -93,7 +97,7 @@ def _do_connect(self, config):
return connection
def get_connection_class(self, config):
- if config.get('use_qmark'):
+ if config.get("use_qmark"):
return DictConnection
return BaseDictConnection
@@ -101,6 +105,14 @@ def get_connection_class(self, config):
def get_api(self):
return psycopg2
+ @property
+ def autocommit(self):
+ return self._connection.autocommit
+
+ @autocommit.setter
+ def autocommit(self, value):
+ self._connection.autocommit = value
+
def get_dbal_platform(self):
return PostgresPlatform()
@@ -113,4 +125,4 @@ def get_server_version(self):
minor = int_version // 100 % 100
fix = int_version % 10
- return major, minor, fix, ''
+ return major, minor, fix, ""
diff --git a/orator/connectors/sqlite_connector.py b/orator/connectors/sqlite_connector.py
index 9d18fb7d..2b165f55 100644
--- a/orator/connectors/sqlite_connector.py
+++ b/orator/connectors/sqlite_connector.py
@@ -1,22 +1,23 @@
# -*- coding: utf-8 -*-
-from pendulum import Pendulum
+from pendulum import Pendulum, Date
try:
import sqlite3
from sqlite3 import register_adapter
- register_adapter(Pendulum, lambda val: val.isoformat(' '))
+ register_adapter(Pendulum, lambda val: val.isoformat(" "))
+ register_adapter(Date, lambda val: val.isoformat())
except ImportError:
sqlite3 = None
from ..dbal.platforms import SQLitePlatform
+from ..utils.helpers import serialize
from .connector import Connector
-class DictCursor(object):
-
+class DictCursor(dict):
def __init__(self, cursor, row):
self.dict = {}
self.cursor = cursor
@@ -24,30 +25,27 @@ def __init__(self, cursor, row):
for idx, col in enumerate(cursor.description):
self.dict[col[0]] = row[idx]
+ super(DictCursor, self).__init__(self.dict)
+
def __getattr__(self, item):
try:
return self[item]
except KeyError:
return getattr(self.cursor, item)
- def __getitem__(self, item):
- return self.dict[item]
-
- def keys(self):
- return self.dict.keys()
-
- def values(self):
- return self.dict.values()
-
- def items(self):
- return self.dict.items()
+ def serialize(self):
+ return serialize(self)
class SQLiteConnector(Connector):
RESERVED_KEYWORDS = [
- 'log_queries', 'driver', 'prefix', 'name',
- 'foreign_keys', 'use_qmark'
+ "log_queries",
+ "driver",
+ "prefix",
+ "name",
+ "foreign_keys",
+ "use_qmark",
]
def _do_connect(self, config):
@@ -56,7 +54,7 @@ def _do_connect(self, config):
connection.row_factory = DictCursor
# We activate foreign keys support by default
- if config.get('foreign_keys', True):
+ if config.get("foreign_keys", True):
connection.execute("PRAGMA foreign_keys = ON")
return connection
@@ -64,6 +62,14 @@ def _do_connect(self, config):
def get_api(self):
return sqlite3
+ @property
+ def isolation_level(self):
+ return self._connection.isolation_level
+
+ @isolation_level.setter
+ def isolation_level(self, value):
+ self._connection.isolation_level = value
+
def get_dbal_platform(self):
return SQLitePlatform()
@@ -71,9 +77,9 @@ def is_version_aware(self):
return False
def get_server_version(self):
- sql = 'select sqlite_version() AS sqlite_version'
+ sql = "select sqlite_version() AS sqlite_version"
rows = self._connection.execute(sql).fetchall()
- version = rows[0]['sqlite_version']
+ version = rows[0]["sqlite_version"]
- return tuple(version.split('.')[:3] + [''])
+ return tuple(version.split(".")[:3] + [""])
diff --git a/orator/database_manager.py b/orator/database_manager.py
index 2c0b8bd7..e304d755 100644
--- a/orator/database_manager.py
+++ b/orator/database_manager.py
@@ -6,11 +6,10 @@
from .connectors.connection_factory import ConnectionFactory
from .exceptions import ArgumentError
-logger = logging.getLogger('orator.database_manager')
+logger = logging.getLogger("orator.database_manager")
class BaseDatabaseManager(ConnectionResolverInterface):
-
def __init__(self, config, factory=ConnectionFactory()):
"""
:param config: The connections configuration
@@ -39,7 +38,7 @@ def connection(self, name=None):
name, type = self._parse_connection_name(name)
if name not in self._connections:
- logger.debug('Initiating connection %s' % name)
+ logger.debug("Initiating connection %s" % name)
connection = self._make_connection(name)
self._set_connection_for_type(connection, type)
@@ -61,8 +60,8 @@ def _parse_connection_name(self, name):
if name is None:
name = self.get_default_connection()
- if name.endswith(('::read', '::write')):
- return name.split('::', 1)
+ if name.endswith(("::read", "::write")):
+ return name.split("::", 1)
return name, None
@@ -87,7 +86,7 @@ def disconnect(self, name=None):
if name is None:
name = self.get_default_connection()
- logger.debug('Disconnecting %s' % name)
+ logger.debug("Disconnecting %s" % name)
if name in self._connections:
self._connections[name].disconnect()
@@ -96,7 +95,7 @@ def reconnect(self, name=None):
if name is None:
name = self.get_default_connection()
- logger.debug('Reconnecting %s' % name)
+ logger.debug("Reconnecting %s" % name)
self.disconnect(name)
@@ -106,25 +105,27 @@ def reconnect(self, name=None):
return self._refresh_api_connections(name)
def _refresh_api_connections(self, name):
- logger.debug('Refreshing api connections for %s' % name)
+ logger.debug("Refreshing api connections for %s" % name)
fresh = self._make_connection(name)
- return self._connections[name]\
- .set_connection(fresh.get_connection())\
+ return (
+ self._connections[name]
+ .set_connection(fresh.get_connection())
.set_read_connection(fresh.get_read_connection())
+ )
def _make_connection(self, name):
- logger.debug('Making connection for %s' % name)
+ logger.debug("Making connection for %s" % name)
config = self._get_config(name)
- if 'name' not in config:
- config['name'] = name
+ if "name" not in config:
+ config["name"] = name
if name in self._extensions:
return self._extensions[name](config, name)
- driver = config['driver']
+ driver = config["driver"]
if driver in self._extensions:
return self._extensions[driver](config, name)
@@ -132,7 +133,7 @@ def _make_connection(self, name):
return self._factory.make(config, name)
def _prepare(self, connection):
- logger.debug('Preparing connection %s' % connection.get_name())
+ logger.debug("Preparing connection %s" % connection.get_name())
def reconnector(connection_):
self.reconnect(connection_.get_name())
@@ -142,9 +143,9 @@ def reconnector(connection_):
return connection
def _set_connection_for_type(self, connection, type):
- if type == 'read':
+ if type == "read":
connection.set_connection(connection.get_read_api())
- elif type == 'write':
+ elif type == "write":
connection.set_read_connection(connection.get_api())
return connection
@@ -157,7 +158,7 @@ def _get_config(self, name):
config = connections.get(name)
if not config:
- raise ArgumentError('Database [%s] not configured' % name)
+ raise ArgumentError("Database [%s] not configured" % name)
return config
@@ -165,11 +166,11 @@ def get_default_connection(self):
if len(self._config) == 1:
return list(self._config.keys())[0]
- return self._config['default']
+ return self._config["default"]
def set_default_connection(self, name):
if name is not None:
- self._config['default'] = name
+ self._config["default"] = name
def extend(self, name, resolver):
self._extensions[name] = resolver
diff --git a/orator/dbal/__init__.py b/orator/dbal/__init__.py
index 633f8661..40a96afc 100644
--- a/orator/dbal/__init__.py
+++ b/orator/dbal/__init__.py
@@ -1,2 +1 @@
# -*- coding: utf-8 -*-
-
diff --git a/orator/dbal/abstract_asset.py b/orator/dbal/abstract_asset.py
index 2fe1f3a0..06287e1f 100644
--- a/orator/dbal/abstract_asset.py
+++ b/orator/dbal/abstract_asset.py
@@ -24,8 +24,8 @@ def _set_name(self, name):
self._quoted = True
name = self._trim_quotes(name)
- if '.' in name:
- parts = name.split('.', 1)
+ if "." in name:
+ parts = name.split(".", 1)
self._namespace = parts[0]
name = parts[1]
@@ -47,7 +47,7 @@ def get_shortest_name(self, default_namespace):
def get_full_qualified_name(self, default_namespace):
name = self.get_name()
if not self._namespace:
- name = default_namespace + '.' + name
+ name = default_namespace + "." + name
return name.lower()
@@ -55,33 +55,34 @@ def is_quoted(self):
return self._quoted
def _is_identifier_quoted(self, identifier):
- return len(identifier) > 0\
- and (identifier[0] == '`' or identifier[0] == '"' or identifier[0] == '[')
+ return len(identifier) > 0 and (
+ identifier[0] == "`" or identifier[0] == '"' or identifier[0] == "["
+ )
def _trim_quotes(self, identifier):
- return re.sub('[`"\[\]]', '', identifier)
+ return re.sub('[`"\[\]]', "", identifier)
def get_name(self):
if self._namespace:
- return self._namespace + '.' + self._name
+ return self._namespace + "." + self._name
return self._name
def get_quoted_name(self, platform):
keywords = platform.get_reserved_keywords_list()
- parts = self.get_name().split('.')
+ parts = self.get_name().split(".")
for k, v in enumerate(parts):
if self._quoted or keywords.is_keyword(v):
parts[k] = platform.quote_identifier(v)
- return '.'.join(parts)
+ return ".".join(parts)
- def _generate_identifier_name(self, columns, prefix='', max_size=30):
+ def _generate_identifier_name(self, columns, prefix="", max_size=30):
"""
Generates an identifier from a list of column names obeying a certain string length.
"""
- hash = ''
+ hash = ""
for column in columns:
- hash += '%x' % binascii.crc32(encode(str(column)))
+ hash += "%x" % binascii.crc32(encode(str(column)))
- return (prefix + '_' + hash)[:max_size]
+ return (prefix + "_" + hash)[:max_size]
diff --git a/orator/dbal/column.py b/orator/dbal/column.py
index ce970af6..59ce0b3b 100644
--- a/orator/dbal/column.py
+++ b/orator/dbal/column.py
@@ -5,7 +5,6 @@
class Column(AbstractAsset):
-
def __init__(self, name, type, options=None):
self._set_name(name)
self._type = type
@@ -25,7 +24,7 @@ def __init__(self, name, type, options=None):
def set_options(self, options):
for key, value in options.items():
- method = 'set_%s' % key
+ method = "set_%s" % key
if hasattr(self, method):
getattr(self, method)(value)
@@ -59,7 +58,11 @@ def set_length(self, length):
return self
def set_precision(self, precision):
- if precision is None or isinstance(precision, basestring) and not precision.isdigit():
+ if (
+ precision is None
+ or isinstance(precision, basestring)
+ and not precision.isdigit()
+ ):
precision = 10
self._precision = int(precision)
@@ -123,21 +126,19 @@ def get_default(self):
def to_dict(self):
d = {
- 'name': self._name,
- 'type': self._type,
- 'default': self._default,
- 'notnull': self._notnull,
- 'length': self._length,
- 'precision': self._precision,
- 'scale': self._scale,
- 'fixed': self._fixed,
- 'unsigned': self._unsigned,
- 'autoincrement': self._autoincrement,
- 'extra': self._extra
+ "name": self._name,
+ "type": self._type,
+ "default": self._default,
+ "notnull": self._notnull,
+ "length": self._length,
+ "precision": self._precision,
+ "scale": self._scale,
+ "fixed": self._fixed,
+ "unsigned": self._unsigned,
+ "autoincrement": self._autoincrement,
+ "extra": self._extra,
}
d.update(self._platform_options)
return d
-
-
diff --git a/orator/dbal/column_diff.py b/orator/dbal/column_diff.py
index a3a3a796..4ad0da67 100644
--- a/orator/dbal/column_diff.py
+++ b/orator/dbal/column_diff.py
@@ -4,8 +4,9 @@
class ColumnDiff(object):
-
- def __init__(self, old_column_name, column, changed_properties=None, from_column=None):
+ def __init__(
+ self, old_column_name, column, changed_properties=None, from_column=None
+ ):
self.old_column_name = old_column_name
self.column = column
self.changed_properties = changed_properties
diff --git a/orator/dbal/comparator.py b/orator/dbal/comparator.py
index 838dbf5a..4c889528 100644
--- a/orator/dbal/comparator.py
+++ b/orator/dbal/comparator.py
@@ -42,12 +42,16 @@ def diff_table(self, table1, table2):
continue
# See if column has changed properties in table2
- changed_properties = self.diff_column(column, table2.get_column(column_name))
+ changed_properties = self.diff_column(
+ column, table2.get_column(column_name)
+ )
if changed_properties:
- column_diff = ColumnDiff(column.get_name(),
- table2.get_column(column_name),
- changed_properties)
+ column_diff = ColumnDiff(
+ column.get_name(),
+ table2.get_column(column_name),
+ changed_properties,
+ )
column_diff.from_column = column
table_differences.changed_columns[column.get_name()] = column_diff
changes += 1
@@ -59,7 +63,9 @@ def diff_table(self, table1, table2):
# See if all the fields in table1 exist in table2
for index_name, index in table2_indexes.items():
- if (index.is_primary() and not table1.has_primary_key()) or table1.has_index(index_name):
+ if (
+ index.is_primary() and not table1.has_primary_key()
+ ) or table1.has_index(index_name):
continue
table_differences.added_indexes[index_name] = index
@@ -67,8 +73,9 @@ def diff_table(self, table1, table2):
# See if there are any removed fields in table2
for index_name, index in table1_indexes.items():
- if (index.is_primary() and not table2.has_primary_key())\
- or (not index.is_primary() and not table2.has_index(index_name)):
+ if (index.is_primary() and not table2.has_primary_key()) or (
+ not index.is_primary() and not table2.has_index(index_name)
+ ):
table_differences.removed_indexes[index_name] = index
changes += 1
continue
@@ -127,7 +134,9 @@ def detect_column_renamings(self, table_differences):
if added_column.get_name() not in rename_candidates:
rename_candidates[added_column.get_name()] = []
- rename_candidates[added_column.get_name()].append((removed_column, added_column, added_column_name))
+ rename_candidates[added_column.get_name()].append(
+ (removed_column, added_column, added_column_name)
+ )
for candidate_columns in rename_candidates.values():
if len(candidate_columns) == 1:
@@ -136,7 +145,9 @@ def detect_column_renamings(self, table_differences):
added_column_name = added_column.get_name().lower()
if removed_column_name not in table_differences.renamed_columns:
- table_differences.renamed_columns[removed_column_name] = added_column
+ table_differences.renamed_columns[
+ removed_column_name
+ ] = added_column
del table_differences.added_columns[added_column_name]
del table_differences.removed_columns[removed_column_name]
@@ -159,7 +170,9 @@ def detect_index_renamings(self, table_differences):
if added_index.get_name() not in rename_candidates:
rename_candidates[added_index.get_name()] = []
- rename_candidates[added_index.get_name()].append((removed_index, added_index, added_index_name))
+ rename_candidates[added_index.get_name()].append(
+ (removed_index, added_index, added_index_name)
+ )
for candidate_indexes in rename_candidates.values():
# If the current rename candidate contains exactly one semantically equal index,
@@ -185,19 +198,30 @@ def diff_foreign_key(self, key1, key2):
:rtype: bool
"""
- key1_unquoted_local_columns = [c.lower() for c in key1.get_unquoted_local_columns()]
- key2_unquoted_local_columns = [c.lower() for c in key2.get_unquoted_local_columns()]
+ key1_unquoted_local_columns = [
+ c.lower() for c in key1.get_unquoted_local_columns()
+ ]
+ key2_unquoted_local_columns = [
+ c.lower() for c in key2.get_unquoted_local_columns()
+ ]
if key1_unquoted_local_columns != key2_unquoted_local_columns:
return True
- key1_unquoted_foreign_columns = [c.lower() for c in key1.get_unquoted_foreign_columns()]
- key2_unquoted_foreign_columns = [c.lower() for c in key2.get_unquoted_foreign_columns()]
+ key1_unquoted_foreign_columns = [
+ c.lower() for c in key1.get_unquoted_foreign_columns()
+ ]
+ key2_unquoted_foreign_columns = [
+ c.lower() for c in key2.get_unquoted_foreign_columns()
+ ]
if key1_unquoted_foreign_columns != key2_unquoted_foreign_columns:
return True
- if key1.get_unqualified_foreign_table_name() != key2.get_unqualified_foreign_table_name():
+ if (
+ key1.get_unqualified_foreign_table_name()
+ != key2.get_unqualified_foreign_table_name()
+ ):
return True
if key1.on_update() != key2.on_update():
@@ -222,34 +246,39 @@ def diff_column(self, column1, column2):
changed_properties = []
- for prop in ['type', 'notnull', 'unsigned', 'autoincrement']:
+ for prop in ["type", "notnull", "unsigned", "autoincrement"]:
if properties1[prop] != properties2[prop]:
changed_properties.append(prop)
- if properties1['default'] != properties2['default']\
- or (properties1['default'] is None and properties2['default'] is not None)\
- or (properties2['default'] is None and properties1['default'] is not None):
- changed_properties.append('default')
-
- if properties1['type'] == 'string' and properties1['type'] != 'guid'\
- or properties1['type'] in ['binary', 'blob']:
- length1 = properties1['length'] or 255
- length2 = properties2['length'] or 255
+ if (
+ properties1["default"] != properties2["default"]
+ or (properties1["default"] is None and properties2["default"] is not None)
+ or (properties2["default"] is None and properties1["default"] is not None)
+ ):
+ changed_properties.append("default")
+
+ if (
+ properties1["type"] == "string"
+ and properties1["type"] != "guid"
+ or properties1["type"] in ["binary", "blob"]
+ ):
+ length1 = properties1["length"] or 255
+ length2 = properties2["length"] or 255
if length1 != length2:
- changed_properties.append('length')
+ changed_properties.append("length")
- if properties1['fixed'] != properties2['fixed']:
- changed_properties.append('fixed')
- elif properties1['type'] in ['decimal', 'float', 'double precision']:
- precision1 = properties1['precision'] or 10
- precision2 = properties2['precision'] or 10
+ if properties1["fixed"] != properties2["fixed"]:
+ changed_properties.append("fixed")
+ elif properties1["type"] in ["decimal", "float", "double precision"]:
+ precision1 = properties1["precision"] or 10
+ precision2 = properties2["precision"] or 10
if precision1 != precision2:
- changed_properties.append('precision')
+ changed_properties.append("precision")
- if properties1['scale'] != properties2['scale']:
- changed_properties.append('scale')
+ if properties1["scale"] != properties2["scale"]:
+ changed_properties.append("scale")
return list(set(changed_properties))
diff --git a/orator/dbal/exceptions/__init__.py b/orator/dbal/exceptions/__init__.py
index 191a411d..6a0d8270 100644
--- a/orator/dbal/exceptions/__init__.py
+++ b/orator/dbal/exceptions/__init__.py
@@ -7,7 +7,6 @@ class DBALException(Exception):
class InvalidPlatformSpecified(DBALException):
-
def __init__(self, index_name, table_name):
message = 'Invalid "platform" option specified, need to give an instance of dbal.platforms.Platform'
@@ -20,7 +19,6 @@ class SchemaException(DBALException):
class IndexDoesNotExist(SchemaException):
-
def __init__(self, index_name, table_name):
message = 'Index "%s" does not exist on table "%s".' % (index_name, table_name)
@@ -28,15 +26,16 @@ def __init__(self, index_name, table_name):
class IndexAlreadyExists(SchemaException):
-
def __init__(self, index_name, table_name):
- message = 'An index with name "%s" already exists on table "%s".' % (index_name, table_name)
+ message = 'An index with name "%s" already exists on table "%s".' % (
+ index_name,
+ table_name,
+ )
super(IndexAlreadyExists, self).__init__(message)
class IndexNameInvalid(SchemaException):
-
def __init__(self, index_name):
message = 'Invalid index name "%s" given, has to be [a-zA-Z0-9_]' % index_name
@@ -44,7 +43,6 @@ def __init__(self, index_name):
class ColumnDoesNotExist(SchemaException):
-
def __init__(self, column, table_name):
message = 'Column "%s" does not exist on table "%s".' % (column, table_name)
@@ -52,16 +50,20 @@ def __init__(self, column, table_name):
class ColumnAlreadyExists(SchemaException):
-
def __init__(self, column, table_name):
- message = 'An column with name "%s" already exists on table "%s".' % (column, table_name)
+ message = 'An column with name "%s" already exists on table "%s".' % (
+ column,
+ table_name,
+ )
super(ColumnAlreadyExists, self).__init__(message)
class ForeignKeyDoesNotExist(SchemaException):
-
def __init__(self, constraint, table_name):
- message = 'Foreign key "%s" does not exist on table "%s".' % (constraint, table_name)
+ message = 'Foreign key "%s" does not exist on table "%s".' % (
+ constraint,
+ table_name,
+ )
super(ForeignKeyDoesNotExist, self).__init__(message)
diff --git a/orator/dbal/foreign_key_constraint.py b/orator/dbal/foreign_key_constraint.py
index 192c99bb..52987bbd 100644
--- a/orator/dbal/foreign_key_constraint.py
+++ b/orator/dbal/foreign_key_constraint.py
@@ -10,9 +10,14 @@ class ForeignKeyConstraint(AbstractAsset):
An abstraction class for a foreign key constraint.
"""
- def __init__(self, local_column_names,
- foreign_table_name, foreign_column_names,
- name=None, options=None):
+ def __init__(
+ self,
+ local_column_names,
+ foreign_table_name,
+ foreign_column_names,
+ name=None,
+ options=None,
+ ):
"""
Constructor.
@@ -152,7 +157,7 @@ def get_unqualified_foreign_table_name(self):
:rtype: str
"""
- parts = self.get_foreign_table_name().split('.')
+ parts = self.get_foreign_table_name().split(".")
return parts[-1].lower()
@@ -227,7 +232,7 @@ def on_update(self):
:rtype: str or None
"""
- return self._on_event('on_update')
+ return self._on_event("on_update")
def on_delete(self):
"""
@@ -236,7 +241,7 @@ def on_delete(self):
:rtype: str or None
"""
- return self._on_event('on_delete')
+ return self._on_event("on_delete")
def _on_event(self, event):
"""
@@ -251,7 +256,7 @@ def _on_event(self, event):
if self.has_option(event):
on_event = self.get_option(event).upper()
- if on_event not in ['NO ACTION', 'RESTRICT']:
+ if on_event not in ["NO ACTION", "RESTRICT"]:
return on_event
return False
diff --git a/orator/dbal/identifier.py b/orator/dbal/identifier.py
index 3fc779a9..0ffbfcda 100644
--- a/orator/dbal/identifier.py
+++ b/orator/dbal/identifier.py
@@ -4,6 +4,5 @@
class Identifier(AbstractAsset):
-
def __init__(self, identifier):
self._set_name(identifier)
diff --git a/orator/dbal/index.py b/orator/dbal/index.py
index b3616d91..2dbb6690 100644
--- a/orator/dbal/index.py
+++ b/orator/dbal/index.py
@@ -10,7 +10,9 @@ class Index(AbstractAsset):
An abstraction class for an index.
"""
- def __init__(self, name, columns, is_unique=False, is_primary=False, flags=None, options=None):
+ def __init__(
+ self, name, columns, is_unique=False, is_primary=False, flags=None, options=None
+ ):
"""
Constructor.
@@ -124,7 +126,9 @@ def spans_columns(self, column_names):
for i in range(number_of_columns):
column = self._trim_quotes(columns[i].lower())
- if i >= len(column_names) or column != self._trim_quotes(column_names[i].lower()):
+ if i >= len(column_names) or column != self._trim_quotes(
+ column_names[i].lower()
+ ):
same_columns = False
return same_columns
@@ -176,11 +180,14 @@ def same_partial_index(self, other):
:rtype: bool
"""
- if (self.has_option('where') and other.has_option('where')
- and self.get_option('where') == other.get_option('where')):
+ if (
+ self.has_option("where")
+ and other.has_option("where")
+ and self.get_option("where") == other.get_option("where")
+ ):
return True
- if not self.has_option('where') and not other.has_option('where'):
+ if not self.has_option("where") and not other.has_option("where"):
return True
return False
@@ -201,7 +208,11 @@ def overrules(self, other):
return False
same_columns = self.spans_columns(other.get_columns())
- if same_columns and (self.is_primary() or self.is_unique()) and self.same_partial_index(other):
+ if (
+ same_columns
+ and (self.is_primary() or self.is_unique())
+ and self.same_partial_index(other)
+ ):
return True
return False
diff --git a/orator/dbal/mysql_schema_manager.py b/orator/dbal/mysql_schema_manager.py
index d9486df9..d72a5bd6 100644
--- a/orator/dbal/mysql_schema_manager.py
+++ b/orator/dbal/mysql_schema_manager.py
@@ -9,25 +9,24 @@
class MySQLSchemaManager(SchemaManager):
-
def _get_portable_table_column_definition(self, table_column):
- db_type = table_column['type'].lower()
- type_match = re.match('(.+)\((.*)\).*', db_type)
+ db_type = table_column["type"].lower()
+ type_match = re.match("(.+)\((.*)\).*", db_type)
if type_match:
db_type = type_match.group(1)
- if 'length' in table_column:
- length = table_column['length']
+ if "length" in table_column:
+ length = table_column["length"]
else:
- if type_match and type_match.group(2) and ',' not in type_match.group(2):
+ if type_match and type_match.group(2) and "," not in type_match.group(2):
length = int(type_match.group(2))
else:
length = 0
fixed = None
- if 'name' not in table_column:
- table_column['name'] = ''
+ if "name" not in table_column:
+ table_column["name"] = ""
precision = None
scale = None
@@ -35,55 +34,55 @@ def _get_portable_table_column_definition(self, table_column):
type = self._platform.get_type_mapping(db_type)
- if db_type in ['char', 'binary']:
+ if db_type in ["char", "binary"]:
fixed = True
- elif db_type in ['float', 'double', 'real', 'decimal', 'numeric']:
- match = re.match('([A-Za-z]+\(([0-9]+),([0-9]+)\))', table_column['type'])
+ elif db_type in ["float", "double", "real", "decimal", "numeric"]:
+ match = re.match("([A-Za-z]+\(([0-9]+),([0-9]+)\))", table_column["type"])
if match:
precision = match.group(1)
scale = match.group(2)
length = None
- elif db_type == 'tinytext':
+ elif db_type == "tinytext":
length = MySQLPlatform.LENGTH_LIMIT_TINYTEXT
- elif db_type == 'text':
+ elif db_type == "text":
length = MySQLPlatform.LENGTH_LIMIT_TEXT
- elif db_type == 'mediumtext':
+ elif db_type == "mediumtext":
length = MySQLPlatform.LENGTH_LIMIT_MEDIUMTEXT
- elif db_type == 'tinyblob':
+ elif db_type == "tinyblob":
length = MySQLPlatform.LENGTH_LIMIT_TINYBLOB
- elif db_type == 'blob':
+ elif db_type == "blob":
length = MySQLPlatform.LENGTH_LIMIT_BLOB
- elif db_type == 'mediumblob':
+ elif db_type == "mediumblob":
length = MySQLPlatform.LENGTH_LIMIT_MEDIUMBLOB
- elif db_type in ['tinyint', 'smallint', 'mediumint', 'int', 'bigint', 'year']:
+ elif db_type in ["tinyint", "smallint", "mediumint", "int", "bigint", "year"]:
length = None
- elif db_type == 'enum':
+ elif db_type == "enum":
length = None
- extra['definition'] = '({})'.format(type_match.group(2))
+ extra["definition"] = "({})".format(type_match.group(2))
if length is None or length == 0:
length = None
options = {
- 'length': length,
- 'unsigned': table_column['type'].find('unsigned') != -1,
- 'fixed': fixed,
- 'notnull': table_column['null'] != 'YES',
- 'default': table_column.get('default'),
- 'precision': None,
- 'scale': None,
- 'autoincrement': table_column['extra'].find('auto_increment') != -1,
- 'extra': extra,
+ "length": length,
+ "unsigned": table_column["type"].find("unsigned") != -1,
+ "fixed": fixed,
+ "notnull": table_column["null"] != "YES",
+ "default": table_column.get("default"),
+ "precision": None,
+ "scale": None,
+ "autoincrement": table_column["extra"].find("auto_increment") != -1,
+ "extra": extra,
}
if scale is not None and precision is not None:
- options['scale'] = scale
- options['precision'] = precision
+ options["scale"] = scale
+ options["precision"] = precision
- column = Column(table_column['field'], type, options)
+ column = Column(table_column["field"], type, options)
- if 'collation' in table_column:
- column.set_platform_option('collation', table_column['collation'])
+ if "collation" in table_column:
+ column.set_platform_option("collation", table_column["collation"])
return column
@@ -91,55 +90,61 @@ def _get_portable_table_indexes_list(self, table_indexes, table_name):
new = []
for v in table_indexes:
v = dict((k.lower(), value) for k, value in v.items())
- if v['key_name'] == 'PRIMARY':
- v['primary'] = True
+ if v["key_name"] == "PRIMARY":
+ v["primary"] = True
else:
- v['primary'] = False
+ v["primary"] = False
- if 'FULLTEXT' in v['index_type']:
- v['flags'] = {'FULLTEXT': True}
+ if "FULLTEXT" in v["index_type"]:
+ v["flags"] = {"FULLTEXT": True}
else:
- v['flags'] = {'SPATIAL': True}
-
+ v["flags"] = {"SPATIAL": True}
+
new.append(v)
-
- return super(MySQLSchemaManager, self)._get_portable_table_indexes_list(new, table_name)
+
+ return super(MySQLSchemaManager, self)._get_portable_table_indexes_list(
+ new, table_name
+ )
def _get_portable_table_foreign_keys_list(self, table_foreign_keys):
foreign_keys = OrderedDict()
for value in table_foreign_keys:
value = dict((k.lower(), v) for k, v in value.items())
- name = value.get('constraint_name', '')
+ name = value.get("constraint_name", "")
if name not in foreign_keys:
- if 'delete_rule' not in value or value['delete_rule'] == 'RESTRICT':
- value['delete_rule'] = ''
+ if "delete_rule" not in value or value["delete_rule"] == "RESTRICT":
+ value["delete_rule"] = ""
- if 'update_rule' not in value or value['update_rule'] == 'RESTRICT':
- value['update_rule'] = ''
+ if "update_rule" not in value or value["update_rule"] == "RESTRICT":
+ value["update_rule"] = ""
foreign_keys[name] = {
- 'name': name,
- 'local': [],
- 'foreign': [],
- 'foreign_table': value['referenced_table_name'],
- 'on_delete': value['delete_rule'],
- 'on_update': value['update_rule']
+ "name": name,
+ "local": [],
+ "foreign": [],
+ "foreign_table": value["referenced_table_name"],
+ "on_delete": value["delete_rule"],
+ "on_update": value["update_rule"],
}
- foreign_keys[name]['local'].append(value['column_name'])
- foreign_keys[name]['foreign'].append(value['referenced_column_name'])
+ foreign_keys[name]["local"].append(value["column_name"])
+ foreign_keys[name]["foreign"].append(value["referenced_column_name"])
result = []
for constraint in foreign_keys.values():
- result.append(ForeignKeyConstraint(
- constraint['local'], constraint['foreign_table'],
- constraint['foreign'], constraint['name'],
- {
- 'on_delete': constraint['on_delete'],
- 'on_update': constraint['on_update']
- }
- ))
+ result.append(
+ ForeignKeyConstraint(
+ constraint["local"],
+ constraint["foreign_table"],
+ constraint["foreign"],
+ constraint["name"],
+ {
+ "on_delete": constraint["on_delete"],
+ "on_update": constraint["on_update"],
+ },
+ )
+ )
return result
diff --git a/orator/dbal/platforms/keywords/__init__.py b/orator/dbal/platforms/keywords/__init__.py
index 633f8661..40a96afc 100644
--- a/orator/dbal/platforms/keywords/__init__.py
+++ b/orator/dbal/platforms/keywords/__init__.py
@@ -1,2 +1 @@
# -*- coding: utf-8 -*-
-
diff --git a/orator/dbal/platforms/keywords/mysql_keywords.py b/orator/dbal/platforms/keywords/mysql_keywords.py
index 390f86c5..1814afa8 100644
--- a/orator/dbal/platforms/keywords/mysql_keywords.py
+++ b/orator/dbal/platforms/keywords/mysql_keywords.py
@@ -6,232 +6,232 @@
class MySQLKeywords(KeywordList):
KEYWORDS = [
- 'ADD',
- 'ALL',
- 'ALTER',
- 'ANALYZE',
- 'AND',
- 'AS',
- 'ASC',
- 'ASENSITIVE',
- 'BEFORE',
- 'BETWEEN',
- 'BIGINT',
- 'BINARY',
- 'BLOB',
- 'BOTH',
- 'BY',
- 'CALL',
- 'CASCADE',
- 'CASE',
- 'CHANGE',
- 'CHAR',
- 'CHARACTER',
- 'CHECK',
- 'COLLATE',
- 'COLUMN',
- 'CONDITION',
- 'CONNECTION',
- 'CONSTRAINT',
- 'CONTINUE',
- 'CONVERT',
- 'CREATE',
- 'CROSS',
- 'CURRENT_DATE',
- 'CURRENT_TIME',
- 'CURRENT_TIMESTAMP',
- 'CURRENT_USER',
- 'CURSOR',
- 'DATABASE',
- 'DATABASES',
- 'DAY_HOUR',
- 'DAY_MICROSECOND',
- 'DAY_MINUTE',
- 'DAY_SECOND',
- 'DEC',
- 'DECIMAL',
- 'DECLARE',
- 'DEFAULT',
- 'DELAYED',
- 'DELETE',
- 'DESC',
- 'DESCRIBE',
- 'DETERMINISTIC',
- 'DISTINCT',
- 'DISTINCTROW',
- 'DIV',
- 'DOUBLE',
- 'DROP',
- 'DUAL',
- 'EACH',
- 'ELSE',
- 'ELSEIF',
- 'ENCLOSED',
- 'ESCAPED',
- 'EXISTS',
- 'EXIT',
- 'EXPLAIN',
- 'FALSE',
- 'FETCH',
- 'FLOAT',
- 'FLOAT4',
- 'FLOAT8',
- 'FOR',
- 'FORCE',
- 'FOREIGN',
- 'FROM',
- 'FULLTEXT',
- 'GOTO',
- 'GRANT',
- 'GROUP',
- 'HAVING',
- 'HIGH_PRIORITY',
- 'HOUR_MICROSECOND',
- 'HOUR_MINUTE',
- 'HOUR_SECOND',
- 'IF',
- 'IGNORE',
- 'IN',
- 'INDEX',
- 'INFILE',
- 'INNER',
- 'INOUT',
- 'INSENSITIVE',
- 'INSERT',
- 'INT',
- 'INT1',
- 'INT2',
- 'INT3',
- 'INT4',
- 'INT8',
- 'INTEGER',
- 'INTERVAL',
- 'INTO',
- 'IS',
- 'ITERATE',
- 'JOIN',
- 'KEY',
- 'KEYS',
- 'KILL',
- 'LABEL',
- 'LEADING',
- 'LEAVE',
- 'LEFT',
- 'LIKE',
- 'LIMIT',
- 'LINES',
- 'LOAD',
- 'LOCALTIME',
- 'LOCALTIMESTAMP',
- 'LOCK',
- 'LONG',
- 'LONGBLOB',
- 'LONGTEXT',
- 'LOOP',
- 'LOW_PRIORITY',
- 'MATCH',
- 'MEDIUMBLOB',
- 'MEDIUMINT',
- 'MEDIUMTEXT',
- 'MIDDLEINT',
- 'MINUTE_MICROSECOND',
- 'MINUTE_SECOND',
- 'MOD',
- 'MODIFIES',
- 'NATURAL',
- 'NOT',
- 'NO_WRITE_TO_BINLOG',
- 'NULL',
- 'NUMERIC',
- 'ON',
- 'OPTIMIZE',
- 'OPTION',
- 'OPTIONALLY',
- 'OR',
- 'ORDER',
- 'OUT',
- 'OUTER',
- 'OUTFILE',
- 'PRECISION',
- 'PRIMARY',
- 'PROCEDURE',
- 'PURGE',
- 'RAID0',
- 'RANGE',
- 'READ',
- 'READS',
- 'REAL',
- 'REFERENCES',
- 'REGEXP',
- 'RELEASE',
- 'RENAME',
- 'REPEAT',
- 'REPLACE',
- 'REQUIRE',
- 'RESTRICT',
- 'RETURN',
- 'REVOKE',
- 'RIGHT',
- 'RLIKE',
- 'SCHEMA',
- 'SCHEMAS',
- 'SECOND_MICROSECOND',
- 'SELECT',
- 'SENSITIVE',
- 'SEPARATOR',
- 'SET',
- 'SHOW',
- 'SMALLINT',
- 'SONAME',
- 'SPATIAL',
- 'SPECIFIC',
- 'SQL',
- 'SQLEXCEPTION',
- 'SQLSTATE',
- 'SQLWARNING',
- 'SQL_BIG_RESULT',
- 'SQL_CALC_FOUND_ROWS',
- 'SQL_SMALL_RESULT',
- 'SSL',
- 'STARTING',
- 'STRAIGHT_JOIN',
- 'TABLE',
- 'TERMINATED',
- 'THEN',
- 'TINYBLOB',
- 'TINYINT',
- 'TINYTEXT',
- 'TO',
- 'TRAILING',
- 'TRIGGER',
- 'TRUE',
- 'UNDO',
- 'UNION',
- 'UNIQUE',
- 'UNLOCK',
- 'UNSIGNED',
- 'UPDATE',
- 'USAGE',
- 'USE',
- 'USING',
- 'UTC_DATE',
- 'UTC_TIME',
- 'UTC_TIMESTAMP',
- 'VALUES',
- 'VARBINARY',
- 'VARCHAR',
- 'VARCHARACTER',
- 'VARYING',
- 'WHEN',
- 'WHERE',
- 'WHILE',
- 'WITH',
- 'WRITE',
- 'X509',
- 'XOR',
- 'YEAR_MONTH',
- 'ZEROFILL'
+ "ADD",
+ "ALL",
+ "ALTER",
+ "ANALYZE",
+ "AND",
+ "AS",
+ "ASC",
+ "ASENSITIVE",
+ "BEFORE",
+ "BETWEEN",
+ "BIGINT",
+ "BINARY",
+ "BLOB",
+ "BOTH",
+ "BY",
+ "CALL",
+ "CASCADE",
+ "CASE",
+ "CHANGE",
+ "CHAR",
+ "CHARACTER",
+ "CHECK",
+ "COLLATE",
+ "COLUMN",
+ "CONDITION",
+ "CONNECTION",
+ "CONSTRAINT",
+ "CONTINUE",
+ "CONVERT",
+ "CREATE",
+ "CROSS",
+ "CURRENT_DATE",
+ "CURRENT_TIME",
+ "CURRENT_TIMESTAMP",
+ "CURRENT_USER",
+ "CURSOR",
+ "DATABASE",
+ "DATABASES",
+ "DAY_HOUR",
+ "DAY_MICROSECOND",
+ "DAY_MINUTE",
+ "DAY_SECOND",
+ "DEC",
+ "DECIMAL",
+ "DECLARE",
+ "DEFAULT",
+ "DELAYED",
+ "DELETE",
+ "DESC",
+ "DESCRIBE",
+ "DETERMINISTIC",
+ "DISTINCT",
+ "DISTINCTROW",
+ "DIV",
+ "DOUBLE",
+ "DROP",
+ "DUAL",
+ "EACH",
+ "ELSE",
+ "ELSEIF",
+ "ENCLOSED",
+ "ESCAPED",
+ "EXISTS",
+ "EXIT",
+ "EXPLAIN",
+ "FALSE",
+ "FETCH",
+ "FLOAT",
+ "FLOAT4",
+ "FLOAT8",
+ "FOR",
+ "FORCE",
+ "FOREIGN",
+ "FROM",
+ "FULLTEXT",
+ "GOTO",
+ "GRANT",
+ "GROUP",
+ "HAVING",
+ "HIGH_PRIORITY",
+ "HOUR_MICROSECOND",
+ "HOUR_MINUTE",
+ "HOUR_SECOND",
+ "IF",
+ "IGNORE",
+ "IN",
+ "INDEX",
+ "INFILE",
+ "INNER",
+ "INOUT",
+ "INSENSITIVE",
+ "INSERT",
+ "INT",
+ "INT1",
+ "INT2",
+ "INT3",
+ "INT4",
+ "INT8",
+ "INTEGER",
+ "INTERVAL",
+ "INTO",
+ "IS",
+ "ITERATE",
+ "JOIN",
+ "KEY",
+ "KEYS",
+ "KILL",
+ "LABEL",
+ "LEADING",
+ "LEAVE",
+ "LEFT",
+ "LIKE",
+ "LIMIT",
+ "LINES",
+ "LOAD",
+ "LOCALTIME",
+ "LOCALTIMESTAMP",
+ "LOCK",
+ "LONG",
+ "LONGBLOB",
+ "LONGTEXT",
+ "LOOP",
+ "LOW_PRIORITY",
+ "MATCH",
+ "MEDIUMBLOB",
+ "MEDIUMINT",
+ "MEDIUMTEXT",
+ "MIDDLEINT",
+ "MINUTE_MICROSECOND",
+ "MINUTE_SECOND",
+ "MOD",
+ "MODIFIES",
+ "NATURAL",
+ "NOT",
+ "NO_WRITE_TO_BINLOG",
+ "NULL",
+ "NUMERIC",
+ "ON",
+ "OPTIMIZE",
+ "OPTION",
+ "OPTIONALLY",
+ "OR",
+ "ORDER",
+ "OUT",
+ "OUTER",
+ "OUTFILE",
+ "PRECISION",
+ "PRIMARY",
+ "PROCEDURE",
+ "PURGE",
+ "RAID0",
+ "RANGE",
+ "READ",
+ "READS",
+ "REAL",
+ "REFERENCES",
+ "REGEXP",
+ "RELEASE",
+ "RENAME",
+ "REPEAT",
+ "REPLACE",
+ "REQUIRE",
+ "RESTRICT",
+ "RETURN",
+ "REVOKE",
+ "RIGHT",
+ "RLIKE",
+ "SCHEMA",
+ "SCHEMAS",
+ "SECOND_MICROSECOND",
+ "SELECT",
+ "SENSITIVE",
+ "SEPARATOR",
+ "SET",
+ "SHOW",
+ "SMALLINT",
+ "SONAME",
+ "SPATIAL",
+ "SPECIFIC",
+ "SQL",
+ "SQLEXCEPTION",
+ "SQLSTATE",
+ "SQLWARNING",
+ "SQL_BIG_RESULT",
+ "SQL_CALC_FOUND_ROWS",
+ "SQL_SMALL_RESULT",
+ "SSL",
+ "STARTING",
+ "STRAIGHT_JOIN",
+ "TABLE",
+ "TERMINATED",
+ "THEN",
+ "TINYBLOB",
+ "TINYINT",
+ "TINYTEXT",
+ "TO",
+ "TRAILING",
+ "TRIGGER",
+ "TRUE",
+ "UNDO",
+ "UNION",
+ "UNIQUE",
+ "UNLOCK",
+ "UNSIGNED",
+ "UPDATE",
+ "USAGE",
+ "USE",
+ "USING",
+ "UTC_DATE",
+ "UTC_TIME",
+ "UTC_TIMESTAMP",
+ "VALUES",
+ "VARBINARY",
+ "VARCHAR",
+ "VARCHARACTER",
+ "VARYING",
+ "WHEN",
+ "WHERE",
+ "WHILE",
+ "WITH",
+ "WRITE",
+ "X509",
+ "XOR",
+ "YEAR_MONTH",
+ "ZEROFILL",
]
def get_name(self):
- return 'MySQL'
+ return "MySQL"
diff --git a/orator/dbal/platforms/keywords/postgresql_keywords.py b/orator/dbal/platforms/keywords/postgresql_keywords.py
index afd7b1fa..3bf1d6b0 100644
--- a/orator/dbal/platforms/keywords/postgresql_keywords.py
+++ b/orator/dbal/platforms/keywords/postgresql_keywords.py
@@ -6,94 +6,94 @@
class PostgreSQLKeywords(KeywordList):
KEYWORDS = [
- 'ALL',
- 'ANALYSE',
- 'ANALYZE',
- 'AND',
- 'ANY',
- 'AS',
- 'ASC',
- 'AUTHORIZATION',
- 'BETWEEN',
- 'BINARY',
- 'BOTH',
- 'CASE',
- 'CAST',
- 'CHECK',
- 'COLLATE',
- 'COLUMN',
- 'CONSTRAINT',
- 'CREATE',
- 'CURRENT_DATE',
- 'CURRENT_TIME',
- 'CURRENT_TIMESTAMP',
- 'CURRENT_USER',
- 'DEFAULT',
- 'DEFERRABLE',
- 'DESC',
- 'DISTINCT',
- 'DO',
- 'ELSE',
- 'END',
- 'EXCEPT',
- 'FALSE',
- 'FOR',
- 'FOREIGN',
- 'FREEZE',
- 'FROM',
- 'FULL',
- 'GRANT',
- 'GROUP',
- 'HAVING',
- 'ILIKE',
- 'IN',
- 'INITIALLY',
- 'INNER',
- 'INTERSECT',
- 'INTO',
- 'IS',
- 'ISNULL',
- 'JOIN',
- 'LEADING',
- 'LEFT',
- 'LIKE',
- 'LIMIT',
- 'LOCALTIME',
- 'LOCALTIMESTAMP',
- 'NATURAL',
- 'NEW',
- 'NOT',
- 'NOTNULL',
- 'NULL',
- 'OFF',
- 'OFFSET',
- 'OLD',
- 'ON',
- 'ONLY',
- 'OR',
- 'ORDER',
- 'OUTER',
- 'OVERLAPS',
- 'PLACING',
- 'PRIMARY',
- 'REFERENCES',
- 'SELECT',
- 'SESSION_USER',
- 'SIMILAR',
- 'SOME',
- 'TABLE',
- 'THEN',
- 'TO',
- 'TRAILING',
- 'TRUE',
- 'UNION',
- 'UNIQUE',
- 'USER',
- 'USING',
- 'VERBOSE',
- 'WHEN',
- 'WHERE'
+ "ALL",
+ "ANALYSE",
+ "ANALYZE",
+ "AND",
+ "ANY",
+ "AS",
+ "ASC",
+ "AUTHORIZATION",
+ "BETWEEN",
+ "BINARY",
+ "BOTH",
+ "CASE",
+ "CAST",
+ "CHECK",
+ "COLLATE",
+ "COLUMN",
+ "CONSTRAINT",
+ "CREATE",
+ "CURRENT_DATE",
+ "CURRENT_TIME",
+ "CURRENT_TIMESTAMP",
+ "CURRENT_USER",
+ "DEFAULT",
+ "DEFERRABLE",
+ "DESC",
+ "DISTINCT",
+ "DO",
+ "ELSE",
+ "END",
+ "EXCEPT",
+ "FALSE",
+ "FOR",
+ "FOREIGN",
+ "FREEZE",
+ "FROM",
+ "FULL",
+ "GRANT",
+ "GROUP",
+ "HAVING",
+ "ILIKE",
+ "IN",
+ "INITIALLY",
+ "INNER",
+ "INTERSECT",
+ "INTO",
+ "IS",
+ "ISNULL",
+ "JOIN",
+ "LEADING",
+ "LEFT",
+ "LIKE",
+ "LIMIT",
+ "LOCALTIME",
+ "LOCALTIMESTAMP",
+ "NATURAL",
+ "NEW",
+ "NOT",
+ "NOTNULL",
+ "NULL",
+ "OFF",
+ "OFFSET",
+ "OLD",
+ "ON",
+ "ONLY",
+ "OR",
+ "ORDER",
+ "OUTER",
+ "OVERLAPS",
+ "PLACING",
+ "PRIMARY",
+ "REFERENCES",
+ "SELECT",
+ "SESSION_USER",
+ "SIMILAR",
+ "SOME",
+ "TABLE",
+ "THEN",
+ "TO",
+ "TRAILING",
+ "TRUE",
+ "UNION",
+ "UNIQUE",
+ "USER",
+ "USING",
+ "VERBOSE",
+ "WHEN",
+ "WHERE",
]
def get_name(self):
- return 'PostgreSQL'
+ return "PostgreSQL"
diff --git a/orator/dbal/platforms/keywords/sqlite_keywords.py b/orator/dbal/platforms/keywords/sqlite_keywords.py
index bb62f587..390896c0 100644
--- a/orator/dbal/platforms/keywords/sqlite_keywords.py
+++ b/orator/dbal/platforms/keywords/sqlite_keywords.py
@@ -6,128 +6,128 @@
class SQLiteKeywords(KeywordList):
KEYWORDS = [
- 'ABORT',
- 'ACTION',
- 'ADD',
- 'AFTER',
- 'ALL',
- 'ALTER',
- 'ANALYZE',
- 'AND',
- 'AS',
- 'ASC',
- 'ATTACH',
- 'AUTOINCREMENT',
- 'BEFORE',
- 'BEGIN',
- 'BETWEEN',
- 'BY',
- 'CASCADE',
- 'CASE',
- 'CAST',
- 'CHECK',
- 'COLLATE',
- 'COLUMN',
- 'COMMIT',
- 'CONFLICT',
- 'CONSTRAINT',
- 'CREATE',
- 'CROSS',
- 'CURRENT_DATE',
- 'CURRENT_TIME',
- 'CURRENT_TIMESTAMP',
- 'DATABASE',
- 'DEFAULT',
- 'DEFERRABLE',
- 'DEFERRED',
- 'DELETE',
- 'DESC',
- 'DETACH',
- 'DISTINCT',
- 'DROP',
- 'EACH',
- 'ELSE',
- 'END',
- 'ESCAPE',
- 'EXCEPT',
- 'EXCLUSIVE',
- 'EXISTS',
- 'EXPLAIN',
- 'FAIL',
- 'FOR',
- 'FOREIGN',
- 'FROM',
- 'FULL',
- 'GLOB',
- 'GROUP',
- 'HAVING',
- 'IF',
- 'IGNORE',
- 'IMMEDIATE',
- 'IN',
- 'INDEX',
- 'INDEXED',
- 'INITIALLY',
- 'INNER',
- 'INSERT',
- 'INSTEAD',
- 'INTERSECT',
- 'INTO',
- 'IS',
- 'ISNULL',
- 'JOIN',
- 'KEY',
- 'LEFT',
- 'LIKE',
- 'LIMIT',
- 'MATCH',
- 'NATURAL',
- 'NO',
- 'NOT',
- 'NOTNULL',
- 'NULL',
- 'OF',
- 'OFFSET',
- 'ON',
- 'OR',
- 'ORDER',
- 'OUTER',
- 'PLAN',
- 'PRAGMA',
- 'PRIMARY',
- 'QUERY',
- 'RAISE',
- 'REFERENCES',
- 'REGEXP',
- 'REINDEX',
- 'RELEASE',
- 'RENAME',
- 'REPLACE',
- 'RESTRICT',
- 'RIGHT',
- 'ROLLBACK',
- 'ROW',
- 'SAVEPOINT',
- 'SELECT',
- 'SET',
- 'TABLE',
- 'TEMP',
- 'TEMPORARY',
- 'THEN',
- 'TO',
- 'TRANSACTION',
- 'TRIGGER',
- 'UNION',
- 'UNIQUE',
- 'UPDATE',
- 'USING',
- 'VACUUM',
- 'VALUES',
- 'VIEW',
- 'VIRTUAL',
- 'WHEN',
- 'WHERE'
+ "ABORT",
+ "ACTION",
+ "ADD",
+ "AFTER",
+ "ALL",
+ "ALTER",
+ "ANALYZE",
+ "AND",
+ "AS",
+ "ASC",
+ "ATTACH",
+ "AUTOINCREMENT",
+ "BEFORE",
+ "BEGIN",
+ "BETWEEN",
+ "BY",
+ "CASCADE",
+ "CASE",
+ "CAST",
+ "CHECK",
+ "COLLATE",
+ "COLUMN",
+ "COMMIT",
+ "CONFLICT",
+ "CONSTRAINT",
+ "CREATE",
+ "CROSS",
+ "CURRENT_DATE",
+ "CURRENT_TIME",
+ "CURRENT_TIMESTAMP",
+ "DATABASE",
+ "DEFAULT",
+ "DEFERRABLE",
+ "DEFERRED",
+ "DELETE",
+ "DESC",
+ "DETACH",
+ "DISTINCT",
+ "DROP",
+ "EACH",
+ "ELSE",
+ "END",
+ "ESCAPE",
+ "EXCEPT",
+ "EXCLUSIVE",
+ "EXISTS",
+ "EXPLAIN",
+ "FAIL",
+ "FOR",
+ "FOREIGN",
+ "FROM",
+ "FULL",
+ "GLOB",
+ "GROUP",
+ "HAVING",
+ "IF",
+ "IGNORE",
+ "IMMEDIATE",
+ "IN",
+ "INDEX",
+ "INDEXED",
+ "INITIALLY",
+ "INNER",
+ "INSERT",
+ "INSTEAD",
+ "INTERSECT",
+ "INTO",
+ "IS",
+ "ISNULL",
+ "JOIN",
+ "KEY",
+ "LEFT",
+ "LIKE",
+ "LIMIT",
+ "MATCH",
+ "NATURAL",
+ "NO",
+ "NOT",
+ "NOTNULL",
+ "NULL",
+ "OF",
+ "OFFSET",
+ "ON",
+ "OR",
+ "ORDER",
+ "OUTER",
+ "PLAN",
+ "PRAGMA",
+ "PRIMARY",
+ "QUERY",
+ "RAISE",
+ "REFERENCES",
+ "REGEXP",
+ "REINDEX",
+ "RELEASE",
+ "RENAME",
+ "REPLACE",
+ "RESTRICT",
+ "RIGHT",
+ "ROLLBACK",
+ "ROW",
+ "SAVEPOINT",
+ "SELECT",
+ "SET",
+ "TABLE",
+ "TEMP",
+ "TEMPORARY",
+ "THEN",
+ "TO",
+ "TRANSACTION",
+ "TRIGGER",
+ "UNION",
+ "UNIQUE",
+ "UPDATE",
+ "USING",
+ "VACUUM",
+ "VALUES",
+ "VIEW",
+ "VIRTUAL",
+ "WHEN",
+ "WHERE",
]
def get_name(self):
- return 'SQLite'
+ return "SQLite"
diff --git a/orator/dbal/platforms/mysql57_platform.py b/orator/dbal/platforms/mysql57_platform.py
index 3e0d8a10..f10638fd 100644
--- a/orator/dbal/platforms/mysql57_platform.py
+++ b/orator/dbal/platforms/mysql57_platform.py
@@ -6,45 +6,45 @@
class MySQL57Platform(MySQLPlatform):
INTERNAL_TYPE_MAPPING = {
- 'tinyint': 'boolean',
- 'smallint': 'smallint',
- 'mediumint': 'integer',
- 'int': 'integer',
- 'integer': 'integer',
- 'bigint': 'bigint',
- 'int8': 'bigint',
- 'bool': 'boolean',
- 'boolean': 'boolean',
- 'tinytext': 'text',
- 'mediumtext': 'text',
- 'longtext': 'text',
- 'text': 'text',
- 'varchar': 'string',
- 'string': 'string',
- 'char': 'string',
- 'date': 'date',
- 'datetime': 'datetime',
- 'timestamp': 'datetime',
- 'time': 'time',
- 'float': 'float',
- 'double': 'float',
- 'real': 'float',
- 'decimal': 'decimal',
- 'numeric': 'decimal',
- 'year': 'date',
- 'longblob': 'blob',
- 'blob': 'blob',
- 'mediumblob': 'blob',
- 'tinyblob': 'blob',
- 'binary': 'binary',
- 'varbinary': 'binary',
- 'set': 'simple_array',
- 'enum': 'enum',
- 'json': 'json',
+ "tinyint": "boolean",
+ "smallint": "smallint",
+ "mediumint": "integer",
+ "int": "integer",
+ "integer": "integer",
+ "bigint": "bigint",
+ "int8": "bigint",
+ "bool": "boolean",
+ "boolean": "boolean",
+ "tinytext": "text",
+ "mediumtext": "text",
+ "longtext": "text",
+ "text": "text",
+ "varchar": "string",
+ "string": "string",
+ "char": "string",
+ "date": "date",
+ "datetime": "datetime",
+ "timestamp": "datetime",
+ "time": "time",
+ "float": "float",
+ "double": "float",
+ "real": "float",
+ "decimal": "decimal",
+ "numeric": "decimal",
+ "year": "date",
+ "longblob": "blob",
+ "blob": "blob",
+ "mediumblob": "blob",
+ "tinyblob": "blob",
+ "binary": "binary",
+ "varbinary": "binary",
+ "set": "simple_array",
+ "enum": "enum",
+ "json": "json",
}
def get_json_type_declaration_sql(self, column):
- return 'JSON'
+ return "JSON"
def has_native_json_type(self):
return True
diff --git a/orator/dbal/platforms/mysql_platform.py b/orator/dbal/platforms/mysql_platform.py
index 8de03921..75a763ff 100644
--- a/orator/dbal/platforms/mysql_platform.py
+++ b/orator/dbal/platforms/mysql_platform.py
@@ -16,53 +16,55 @@ class MySQLPlatform(Platform):
LENGTH_LIMIT_MEDIUMBLOB = 16777215
INTERNAL_TYPE_MAPPING = {
- 'tinyint': 'boolean',
- 'smallint': 'smallint',
- 'mediumint': 'integer',
- 'int': 'integer',
- 'integer': 'integer',
- 'bigint': 'bigint',
- 'int8': 'bigint',
- 'bool': 'boolean',
- 'boolean': 'boolean',
- 'tinytext': 'text',
- 'mediumtext': 'text',
- 'longtext': 'text',
- 'text': 'text',
- 'varchar': 'string',
- 'string': 'string',
- 'char': 'string',
- 'date': 'date',
- 'datetime': 'datetime',
- 'timestamp': 'datetime',
- 'time': 'time',
- 'float': 'float',
- 'double': 'float',
- 'real': 'float',
- 'decimal': 'decimal',
- 'numeric': 'decimal',
- 'year': 'date',
- 'longblob': 'blob',
- 'blob': 'blob',
- 'mediumblob': 'blob',
- 'tinyblob': 'blob',
- 'binary': 'binary',
- 'varbinary': 'binary',
- 'set': 'simple_array',
- 'enum': 'enum',
+ "tinyint": "boolean",
+ "smallint": "smallint",
+ "mediumint": "integer",
+ "int": "integer",
+ "integer": "integer",
+ "bigint": "bigint",
+ "int8": "bigint",
+ "bool": "boolean",
+ "boolean": "boolean",
+ "tinytext": "text",
+ "mediumtext": "text",
+ "longtext": "text",
+ "text": "text",
+ "varchar": "string",
+ "string": "string",
+ "char": "string",
+ "date": "date",
+ "datetime": "datetime",
+ "timestamp": "datetime",
+ "time": "time",
+ "float": "float",
+ "double": "float",
+ "real": "float",
+ "decimal": "decimal",
+ "numeric": "decimal",
+ "year": "date",
+ "longblob": "blob",
+ "blob": "blob",
+ "mediumblob": "blob",
+ "tinyblob": "blob",
+ "binary": "binary",
+ "varbinary": "binary",
+ "set": "simple_array",
+ "enum": "enum",
}
def get_list_table_columns_sql(self, table, database=None):
if database:
database = "'%s'" % database
else:
- database = 'DATABASE()'
+ database = "DATABASE()"
- return 'SELECT COLUMN_NAME AS field, COLUMN_TYPE AS type, IS_NULLABLE AS `null`, ' \
- 'COLUMN_KEY AS `key`, COLUMN_DEFAULT AS `default`, EXTRA AS extra, COLUMN_COMMENT AS comment, ' \
- 'CHARACTER_SET_NAME AS character_set, COLLATION_NAME AS collation ' \
- 'FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = %s AND TABLE_NAME = \'%s\''\
- % (database, table)
+ return (
+ "SELECT COLUMN_NAME AS field, COLUMN_TYPE AS type, IS_NULLABLE AS `null`, "
+ "COLUMN_KEY AS `key`, COLUMN_DEFAULT AS `default`, EXTRA AS extra, COLUMN_COMMENT AS comment, "
+ "CHARACTER_SET_NAME AS character_set, COLLATION_NAME AS collation "
+ "FROM information_schema.COLUMNS WHERE TABLE_SCHEMA = %s AND TABLE_NAME = '%s'"
+ % (database, table)
+ )
def get_list_table_indexes_sql(self, table, current_database=None):
sql = """
@@ -74,21 +76,25 @@ def get_list_table_indexes_sql(self, table, current_database=None):
"""
if current_database:
- sql += ' AND TABLE_SCHEMA = \'%s\'' % current_database
+ sql += " AND TABLE_SCHEMA = '%s'" % current_database
return sql % table
def get_list_table_foreign_keys_sql(self, table, database=None):
- sql = ("SELECT DISTINCT k.`CONSTRAINT_NAME`, k.`COLUMN_NAME`, k.`REFERENCED_TABLE_NAME`, "
- "k.`REFERENCED_COLUMN_NAME` /*!50116 , c.update_rule, c.delete_rule */ "
- "FROM information_schema.key_column_usage k /*!50116 "
- "INNER JOIN information_schema.referential_constraints c ON "
- " c.constraint_name = k.constraint_name AND "
- " c.table_name = '%s' */ WHERE k.table_name = '%s'" % (table, table))
+ sql = (
+ "SELECT DISTINCT k.`CONSTRAINT_NAME`, k.`COLUMN_NAME`, k.`REFERENCED_TABLE_NAME`, "
+ "k.`REFERENCED_COLUMN_NAME` /*!50116 , c.update_rule, c.delete_rule */ "
+ "FROM information_schema.key_column_usage k /*!50116 "
+ "INNER JOIN information_schema.referential_constraints c ON "
+ " c.constraint_name = k.constraint_name AND "
+ " c.table_name = '%s' */ WHERE k.table_name = '%s'" % (table, table)
+ )
if database:
- sql += " AND k.table_schema = '%s' /*!50116 AND c.constraint_schema = '%s' */"\
- % (database, database)
+ sql += (
+ " AND k.table_schema = '%s' /*!50116 AND c.constraint_schema = '%s' */"
+ % (database, database)
+ )
sql += " AND k.`REFERENCED_COLUMN_NAME` IS NOT NULL"
@@ -107,7 +113,9 @@ def get_alter_table_sql(self, diff):
query_parts = []
if diff.new_name is not False:
- query_parts.append('RENAME TO %s' % diff.get_new_name().get_quoted_name(self))
+ query_parts.append(
+ "RENAME TO %s" % diff.get_new_name().get_quoted_name(self)
+ )
# Added columns?
@@ -118,30 +126,43 @@ def get_alter_table_sql(self, diff):
column_dict = column.to_dict()
# Don't propagate default value changes for unsupported column types.
- if column_diff.has_changed('default') \
- and len(column_diff.changed_properties) == 1 \
- and (column_dict['type'] == 'text' or column_dict['type'] == 'blob'):
+ if (
+ column_diff.has_changed("default")
+ and len(column_diff.changed_properties) == 1
+ and (column_dict["type"] == "text" or column_dict["type"] == "blob")
+ ):
continue
- query_parts.append('CHANGE %s %s'
- % (column_diff.get_old_column_name().get_quoted_name(self),
- self.get_column_declaration_sql(column.get_quoted_name(self), column_dict)))
+ query_parts.append(
+ "CHANGE %s %s"
+ % (
+ column_diff.get_old_column_name().get_quoted_name(self),
+ self.get_column_declaration_sql(
+ column.get_quoted_name(self), column_dict
+ ),
+ )
+ )
for old_column_name, column in diff.renamed_columns.items():
column_dict = column.to_dict()
old_column_name = Identifier(old_column_name)
- query_parts.append('CHANGE %s %s'
- % (self.quote(old_column_name.get_quoted_name(self)),
- self.get_column_declaration_sql(
- self.quote(column.get_quoted_name(self)),
- column_dict)))
+ query_parts.append(
+ "CHANGE %s %s"
+ % (
+ self.quote(old_column_name.get_quoted_name(self)),
+ self.get_column_declaration_sql(
+ self.quote(column.get_quoted_name(self)), column_dict
+ ),
+ )
+ )
sql = []
if len(query_parts) > 0:
- sql.append('ALTER TABLE %s %s'
- % (diff.get_name(self).get_quoted_name(self),
- ', '.join(query_parts)))
+ sql.append(
+ "ALTER TABLE %s %s"
+ % (diff.get_name(self).get_quoted_name(self), ", ".join(query_parts))
+ )
return sql
@@ -156,85 +177,85 @@ def convert_booleans(self, item):
return item
def get_boolean_type_declaration_sql(self, column):
- return 'TINYINT(1)'
+ return "TINYINT(1)"
def get_integer_type_declaration_sql(self, column):
- return 'INT ' + self._get_common_integer_type_declaration_sql(column)
+ return "INT " + self._get_common_integer_type_declaration_sql(column)
def get_bigint_type_declaration_sql(self, column):
- return 'BIGINT ' + self._get_common_integer_type_declaration_sql(column)
+ return "BIGINT " + self._get_common_integer_type_declaration_sql(column)
def get_smallint_type_declaration_sql(self, column):
- return 'SMALLINT ' + self._get_common_integer_type_declaration_sql(column)
+ return "SMALLINT " + self._get_common_integer_type_declaration_sql(column)
def get_guid_type_declaration_sql(self, column):
- return 'UUID'
+ return "UUID"
def get_datetime_type_declaration_sql(self, column):
- if 'version' in column and column['version'] == True:
- return 'TIMESTAMP'
+ if "version" in column and column["version"] == True:
+ return "TIMESTAMP"
- return 'DATETIME'
+ return "DATETIME"
def get_date_type_declaration_sql(self, column):
- return 'DATE'
+ return "DATE"
def get_time_type_declaration_sql(self, column):
- return 'TIME'
+ return "TIME"
def get_varchar_type_declaration_sql_snippet(self, length, fixed):
if fixed:
- return 'CHAR(%s)' % length if length else 'CHAR(255)'
+ return "CHAR(%s)" % length if length else "CHAR(255)"
else:
- return 'VARCHAR(%s)' % length if length else 'VARCHAR(255)'
+ return "VARCHAR(%s)" % length if length else "VARCHAR(255)"
def get_binary_type_declaration_sql_snippet(self, length, fixed):
if fixed:
- return 'BINARY(%s)' % (length or 255)
+ return "BINARY(%s)" % (length or 255)
else:
- return 'VARBINARY(%s)' % (length or 255)
+ return "VARBINARY(%s)" % (length or 255)
def get_text_type_declaration_sql(self, column):
- length = column.get('length')
+ length = column.get("length")
if length:
if length <= self.LENGTH_LIMIT_TINYTEXT:
- return 'TINYTEXT'
+ return "TINYTEXT"
if length <= self.LENGTH_LIMIT_TEXT:
- return 'TEXT'
+ return "TEXT"
if length <= self.LENGTH_LIMIT_MEDIUMTEXT:
- return 'MEDIUMTEXT'
+ return "MEDIUMTEXT"
- return 'LONGTEXT'
+ return "LONGTEXT"
def get_blob_type_declaration_sql(self, column):
- length = column.get('length')
+ length = column.get("length")
if length:
if length <= self.LENGTH_LIMIT_TINYBLOB:
- return 'TINYBLOB'
+ return "TINYBLOB"
if length <= self.LENGTH_LIMIT_BLOB:
- return 'BLOB'
+ return "BLOB"
if length <= self.LENGTH_LIMIT_MEDIUMBLOB:
- return 'MEDIUMBLOB'
+ return "MEDIUMBLOB"
- return 'LONGBLOB'
+ return "LONGBLOB"
def get_clob_type_declaration_sql(self, column):
- length = column.get('length')
+ length = column.get("length")
if length:
if length <= self.LENGTH_LIMIT_TINYTEXT:
- return 'TINYTEXT'
+ return "TINYTEXT"
if length <= self.LENGTH_LIMIT_TEXT:
- return 'TEXT'
+ return "TEXT"
if length <= self.LENGTH_LIMIT_MEDIUMTEXT:
- return 'MEDIUMTEXT'
+ return "MEDIUMTEXT"
- return 'LONGTEXT'
+ return "LONGTEXT"
def get_decimal_type_declaration_sql(self, column):
decl = super(MySQLPlatform, self).get_decimal_type_declaration_sql(column)
@@ -242,23 +263,23 @@ def get_decimal_type_declaration_sql(self, column):
return decl + self.get_unsigned_declaration(column)
def get_unsigned_declaration(self, column):
- if column.get('unsigned'):
- return ' UNSIGNED'
+ if column.get("unsigned"):
+ return " UNSIGNED"
- return ''
+ return ""
def _get_common_integer_type_declaration_sql(self, column):
- autoinc = ''
- if column.get('autoincrement'):
- autoinc = ' AUTO_INCREMENT'
+ autoinc = ""
+ if column.get("autoincrement"):
+ autoinc = " AUTO_INCREMENT"
return self.get_unsigned_declaration(column) + autoinc
def get_float_type_declaration_sql(self, column):
- return 'DOUBLE PRECISION' + self.get_unsigned_declaration(column)
+ return "DOUBLE PRECISION" + self.get_unsigned_declaration(column)
def get_enum_type_declaration_sql(self, column):
- return 'ENUM{}'.format(column['extra']['definition'])
+ return "ENUM{}".format(column["extra"]["definition"])
def supports_foreign_key_constraints(self):
return True
@@ -267,10 +288,10 @@ def supports_column_collation(self):
return False
def quote(self, name):
- return '`%s`' % name.replace('`', '``')
+ return "`%s`" % name.replace("`", "``")
def _get_reserved_keywords_class(self):
return MySQLKeywords
def get_identifier_quote_character(self):
- return '`'
+ return "`"
diff --git a/orator/dbal/platforms/platform.py b/orator/dbal/platforms/platform.py
index 438c13a4..9814a69d 100644
--- a/orator/dbal/platforms/platform.py
+++ b/orator/dbal/platforms/platform.py
@@ -22,30 +22,39 @@ def __init__(self, version=None):
self._version = None
def get_default_value_declaration_sql(self, field):
- default = ''
-
- if not field.get('notnull'):
- default = ' DEFAULT NULL'
-
- if 'default' in field and field['default'] is not None:
- default = ' DEFAULT \'%s\'' % field['default']
-
- if 'type' in field:
- type = field['type']
-
- if type in ['integer', 'bigint', 'smallint']:
- default = ' DEFAULT %s' % field['default']
- elif type in ['datetime', 'datetimetz'] \
- and field['default'] in [self.get_current_timestamp_sql(), 'NOW', 'now']:
- default = ' DEFAULT %s' % self.get_current_timestamp_sql()
- elif type in ['time'] \
- and field['default'] in [self.get_current_time_sql(), 'NOW', 'now']:
- default = ' DEFAULT %s' % self.get_current_time_sql()
- elif type in ['date'] \
- and field['default'] in [self.get_current_date_sql(), 'NOW', 'now']:
- default = ' DEFAULT %s' % self.get_current_date_sql()
- elif type in ['boolean']:
- default = ' DEFAULT \'%s\'' % self.convert_booleans(field['default'])
+ default = ""
+
+ if not field.get("notnull"):
+ default = " DEFAULT NULL"
+
+ if "default" in field and field["default"] is not None:
+ default = " DEFAULT '%s'" % field["default"]
+
+ if "type" in field:
+ type = field["type"]
+
+ if type in ["integer", "bigint", "smallint"]:
+ default = " DEFAULT %s" % field["default"]
+ elif type in ["datetime", "datetimetz"] and field["default"] in [
+ self.get_current_timestamp_sql(),
+ "NOW",
+ "now",
+ ]:
+ default = " DEFAULT %s" % self.get_current_timestamp_sql()
+ elif type in ["time"] and field["default"] in [
+ self.get_current_time_sql(),
+ "NOW",
+ "now",
+ ]:
+ default = " DEFAULT %s" % self.get_current_time_sql()
+ elif type in ["date"] and field["default"] in [
+ self.get_current_date_sql(),
+ "NOW",
+ "now",
+ ]:
+ default = " DEFAULT %s" % self.get_current_date_sql()
+ elif type in ["boolean"]:
+ default = " DEFAULT '%s'" % self.convert_booleans(field["default"])
return default
@@ -73,15 +82,15 @@ def get_check_declaration_sql(self, definition):
constraints = []
for field, def_ in definition.items():
if isinstance(def_, basestring):
- constraints.append('CHECK (%s)' % def_)
+ constraints.append("CHECK (%s)" % def_)
else:
- if 'min' in def_:
- constraints.append('CHECK (%s >= %s)' % (field, def_['min']))
+ if "min" in def_:
+ constraints.append("CHECK (%s >= %s)" % (field, def_["min"]))
- if 'max' in def_:
- constraints.append('CHECK (%s <= %s)' % (field, def_['max']))
+ if "max" in def_:
+ constraints.append("CHECK (%s <= %s)" % (field, def_["max"]))
- return ', '.join(constraints)
+ return ", ".join(constraints)
def get_unique_constraint_declaration_sql(self, name, index):
"""
@@ -103,10 +112,11 @@ def get_unique_constraint_declaration_sql(self, name, index):
if not columns:
raise DBALException('Incomplete definition. "columns" required.')
- return 'CONSTRAINT %s UNIQUE (%s)%s'\
- % (name.get_quoted_name(self),
- self.get_index_field_declaration_list_sql(columns),
- self.get_partial_index_sql(index))
+ return "CONSTRAINT %s UNIQUE (%s)%s" % (
+ name.get_quoted_name(self),
+ self.get_index_field_declaration_list_sql(columns),
+ self.get_partial_index_sql(index),
+ )
def get_index_declaration_sql(self, name, index):
"""
@@ -128,11 +138,12 @@ def get_index_declaration_sql(self, name, index):
if not columns:
raise DBALException('Incomplete definition. "columns" required.')
- return '%sINDEX %s (%s)%s'\
- % (self.get_create_index_sql_flags(index),
- name.get_quoted_name(self),
- self.get_index_field_declaration_list_sql(columns),
- self.get_partial_index_sql(index))
+ return "%sINDEX %s (%s)%s" % (
+ self.get_create_index_sql_flags(index),
+ name.get_quoted_name(self),
+ self.get_index_field_declaration_list_sql(columns),
+ self.get_partial_index_sql(index),
+ )
def get_foreign_key_declaration_sql(self, foreign_key):
"""
@@ -159,12 +170,18 @@ def get_advanced_foreign_key_options_sql(self, foreign_key):
:rtype: str
"""
- query = ''
- if self.supports_foreign_key_on_update() and foreign_key.has_option('on_update'):
- query += ' ON UPDATE %s' % self.get_foreign_key_referential_action_sql(foreign_key.get_option('on_update'))
+ query = ""
+ if self.supports_foreign_key_on_update() and foreign_key.has_option(
+ "on_update"
+ ):
+ query += " ON UPDATE %s" % self.get_foreign_key_referential_action_sql(
+ foreign_key.get_option("on_update")
+ )
- if foreign_key.has_option('on_delete'):
- query += ' ON DELETE %s' % self.get_foreign_key_referential_action_sql(foreign_key.get_option('on_delete'))
+ if foreign_key.has_option("on_delete"):
+ query += " ON DELETE %s" % self.get_foreign_key_referential_action_sql(
+ foreign_key.get_option("on_delete")
+ )
return query
@@ -178,8 +195,14 @@ def get_foreign_key_referential_action_sql(self, action):
:rtype: str
"""
action = action.upper()
- if action not in ['CASCADE', 'SET NULL', 'NO ACTION', 'RESTRICT', 'SET DEFAULT']:
- raise DBALException('Invalid foreign key action: %s' % action)
+ if action not in [
+ "CASCADE",
+ "SET NULL",
+ "NO ACTION",
+ "RESTRICT",
+ "SET DEFAULT",
+ ]:
+ raise DBALException("Invalid foreign key action: %s" % action)
return action
@@ -193,11 +216,11 @@ def get_foreign_key_base_declaration_sql(self, foreign_key):
:rtype: str
"""
- sql = ''
+ sql = ""
if foreign_key.get_name():
- sql += 'CONSTRAINT %s ' % foreign_key.get_quoted_name(self)
+ sql += "CONSTRAINT %s " % foreign_key.get_quoted_name(self)
- sql += 'FOREIGN KEY ('
+ sql += "FOREIGN KEY ("
if not foreign_key.get_local_columns():
raise DBALException('Incomplete definition. "local" required.')
@@ -208,26 +231,27 @@ def get_foreign_key_base_declaration_sql(self, foreign_key):
if not foreign_key.get_foreign_table_name():
raise DBALException('Incomplete definition. "foreign_table" required.')
- sql += '%s) REFERENCES %s (%s)'\
- % (', '.join(foreign_key.get_quoted_local_columns(self)),
- foreign_key.get_quoted_foreign_table_name(self),
- ', '.join(foreign_key.get_quoted_foreign_columns(self)))
+ sql += "%s) REFERENCES %s (%s)" % (
+ ", ".join(foreign_key.get_quoted_local_columns(self)),
+ foreign_key.get_quoted_foreign_table_name(self),
+ ", ".join(foreign_key.get_quoted_foreign_columns(self)),
+ )
return sql
def get_current_date_sql(self):
- return 'CURRENT_DATE'
+ return "CURRENT_DATE"
def get_current_time_sql(self):
- return 'CURRENT_TIME'
+ return "CURRENT_TIME"
def get_current_timestamp_sql(self):
- return 'CURRENT_TIMESTAMP'
+ return "CURRENT_TIMESTAMP"
def get_sql_type_declaration(self, column):
- internal_type = column['type']
+ internal_type = column["type"]
- return getattr(self, 'get_%s_type_declaration_sql' % internal_type)(column)
+ return getattr(self, "get_%s_type_declaration_sql" % internal_type)(column)
def get_column_declaration_list_sql(self, fields):
"""
@@ -238,95 +262,97 @@ def get_column_declaration_list_sql(self, fields):
for name, field in fields.items():
query_fields.append(self.get_column_declaration_sql(name, field))
- return ', '.join(query_fields)
+ return ", ".join(query_fields)
def get_column_declaration_sql(self, name, field):
- if 'column_definition' in field:
+ if "column_definition" in field:
column_def = self.get_custom_type_declaration_sql(field)
else:
default = self.get_default_value_declaration_sql(field)
- charset = field.get('charset', '')
+ charset = field.get("charset", "")
if charset:
- charset = ' ' + self.get_column_charset_declaration_sql(charset)
+ charset = " " + self.get_column_charset_declaration_sql(charset)
- collation = field.get('collation', '')
+ collation = field.get("collation", "")
if charset:
- charset = ' ' + self.get_column_collation_declaration_sql(charset)
+ charset = " " + self.get_column_collation_declaration_sql(charset)
- notnull = field.get('notnull', '')
+ notnull = field.get("notnull", "")
if notnull:
- notnull = ' NOT NULL'
+ notnull = " NOT NULL"
else:
- notnull = ''
+ notnull = ""
- unique = field.get('unique', '')
+ unique = field.get("unique", "")
if unique:
- unique = ' ' + self.get_unique_field_declaration_sql()
+ unique = " " + self.get_unique_field_declaration_sql()
else:
- unique = ''
+ unique = ""
- check = field.get('check', '')
+ check = field.get("check", "")
type_decl = self.get_sql_type_declaration(field)
- column_def = type_decl + charset + default + notnull + unique + check + collation
+ column_def = (
+ type_decl + charset + default + notnull + unique + check + collation
+ )
- return name + ' ' + column_def
+ return name + " " + column_def
def get_custom_type_declaration_sql(self, column_def):
- return column_def['column_definition']
+ return column_def["column_definition"]
def get_column_charset_declaration_sql(self, charset):
- return ''
+ return ""
def get_column_collation_declaration_sql(self, collation):
if self.supports_column_collation():
- return 'COLLATE %s' % collation
+ return "COLLATE %s" % collation
- return ''
+ return ""
def supports_column_collation(self):
return False
def get_unique_field_declaration_sql(self):
- return 'UNIQUE'
+ return "UNIQUE"
def get_string_type_declaration_sql(self, column):
- if 'length' not in column:
- column['length'] = self.get_varchar_default_length()
+ if "length" not in column:
+ column["length"] = self.get_varchar_default_length()
- fixed = column.get('fixed', False)
+ fixed = column.get("fixed", False)
- if column['length'] > self.get_varchar_max_length():
+ if column["length"] > self.get_varchar_max_length():
return self.get_clob_type_declaration_sql(column)
- return self.get_varchar_type_declaration_sql_snippet(column['length'], fixed)
+ return self.get_varchar_type_declaration_sql_snippet(column["length"], fixed)
def get_binary_type_declaration_sql(self, column):
- if 'length' not in column:
- column['length'] = self.get_binary_default_length()
+ if "length" not in column:
+ column["length"] = self.get_binary_default_length()
- fixed = column.get('fixed', False)
+ fixed = column.get("fixed", False)
- if column['length'] > self.get_binary_max_length():
+ if column["length"] > self.get_binary_max_length():
return self.get_blob_type_declaration_sql(column)
- return self.get_binary_type_declaration_sql_snippet(column['length'], fixed)
+ return self.get_binary_type_declaration_sql_snippet(column["length"], fixed)
def get_varchar_type_declaration_sql_snippet(self, length, fixed):
- raise NotImplementedError('VARCHARS not supported by Platform')
+ raise NotImplementedError("VARCHARS not supported by Platform")
def get_binary_type_declaration_sql_snippet(self, length, fixed):
- raise NotImplementedError('BINARY/VARBINARY not supported by Platform')
+ raise NotImplementedError("BINARY/VARBINARY not supported by Platform")
def get_decimal_type_declaration_sql(self, column):
- if 'precision' not in column or not column['precision']:
- column['precision'] = 10
+ if "precision" not in column or not column["precision"]:
+ column["precision"] = 10
- if 'scale' not in column or not column['scale']:
- column['precision'] = 0
+ if "scale" not in column or not column["scale"]:
+ column["precision"] = 0
- return 'NUMERIC(%s, %s)' % (column['precision'], column['scale'])
+ return "NUMERIC(%s, %s)" % (column["precision"], column["scale"])
def get_json_type_declaration_sql(self, column):
return self.get_clob_type_declaration_sql(column)
@@ -387,7 +413,7 @@ def get_index_field_declaration_list_sql(self, fields):
for field in fields:
ret.append(field)
- return ', '.join(ret)
+ return ", ".join(ret)
def get_create_index_sql(self, index, table):
"""
@@ -413,9 +439,15 @@ def get_create_index_sql(self, index, table):
if index.is_primary():
return self.get_create_primary_key_sql(index, table)
- query = 'CREATE %sINDEX %s ON %s' % (self.get_create_index_sql_flags(index), name, table)
- query += ' (%s)%s' % (self.get_index_field_declaration_list_sql(columns),
- self.get_partial_index_sql(index))
+ query = "CREATE %sINDEX %s ON %s" % (
+ self.get_create_index_sql_flags(index),
+ name,
+ table,
+ )
+ query += " (%s)%s" % (
+ self.get_index_field_declaration_list_sql(columns),
+ self.get_partial_index_sql(index),
+ )
return query
@@ -428,10 +460,10 @@ def get_partial_index_sql(self, index):
:rtype: str
"""
- if self.supports_partial_indexes() and index.has_option('where'):
- return ' WHERE %s' % index.get_option('where')
+ if self.supports_partial_indexes() and index.has_option("where"):
+ return " WHERE %s" % index.get_option("where")
- return ''
+ return ""
def get_create_index_sql_flags(self, index):
"""
@@ -443,9 +475,9 @@ def get_create_index_sql_flags(self, index):
:rtype: str
"""
if index.is_unique():
- return 'UNIQUE '
+ return "UNIQUE "
- return ''
+ return ""
def get_create_primary_key_sql(self, index, table):
"""
@@ -459,9 +491,10 @@ def get_create_primary_key_sql(self, index, table):
:rtype: str
"""
- return 'ALTER TABLE %s ADD PRIMARY KEY (%s)'\
- % (table,
- self.get_index_field_declaration_list_sql(index.get_quoted_columns(self)))
+ return "ALTER TABLE %s ADD PRIMARY KEY (%s)" % (
+ table,
+ self.get_index_field_declaration_list_sql(index.get_quoted_columns(self)),
+ )
def get_create_foreign_key_sql(self, foreign_key, table):
"""
@@ -472,7 +505,10 @@ def get_create_foreign_key_sql(self, foreign_key, table):
if isinstance(table, Table):
table = table.get_quoted_name(self)
- query = 'ALTER TABLE %s ADD %s' % (table, self.get_foreign_key_declaration_sql(foreign_key))
+ query = "ALTER TABLE %s ADD %s" % (
+ table,
+ self.get_foreign_key_declaration_sql(foreign_key),
+ )
return query
@@ -488,7 +524,7 @@ def get_drop_table_sql(self, table):
if isinstance(table, Table):
table = table.get_quoted_name(self)
- return 'DROP TABLE %s' % table
+ return "DROP TABLE %s" % table
def get_drop_index_sql(self, index, table=None):
"""
@@ -505,7 +541,7 @@ def get_drop_index_sql(self, index, table=None):
if isinstance(index, Index):
index = index.get_quoted_name(self)
- return 'DROP INDEX %s' % index
+ return "DROP INDEX %s" % index
def get_create_table_sql(self, table, create_flags=CREATE_INDEXES):
"""
@@ -523,42 +559,42 @@ def get_create_table_sql(self, table, create_flags=CREATE_INDEXES):
table_name = table.get_quoted_name(self)
options = dict((k, v) for k, v in table.get_options().items())
- options['unique_constraints'] = OrderedDict()
- options['indexes'] = OrderedDict()
- options['primary'] = []
+ options["unique_constraints"] = OrderedDict()
+ options["indexes"] = OrderedDict()
+ options["primary"] = []
if create_flags & self.CREATE_INDEXES > 0:
for index in table.get_indexes().values():
if index.is_primary():
- options['primary'] = index.get_quoted_columns(self)
- options['primary_index'] = index
+ options["primary"] = index.get_quoted_columns(self)
+ options["primary_index"] = index
else:
- options['indexes'][index.get_quoted_name(self)] = index
+ options["indexes"][index.get_quoted_name(self)] = index
columns = OrderedDict()
for column in table.get_columns().values():
column_data = column.to_dict()
- column_data['name'] = column.get_quoted_name(self)
- if column.has_platform_option('version'):
- column_data['version'] = column.get_platform_option('version')
+ column_data["name"] = column.get_quoted_name(self)
+ if column.has_platform_option("version"):
+ column_data["version"] = column.get_platform_option("version")
else:
- column_data['version'] = False
+ column_data["version"] = False
# column_data['comment'] = self.get_column_comment(column)
- if column_data['type'] == 'string' and column_data['length'] is None:
- column_data['length'] = 255
+ if column_data["type"] == "string" and column_data["length"] is None:
+ column_data["length"] = 255
- if column.get_name() in options['primary']:
- column_data['primary'] = True
+ if column.get_name() in options["primary"]:
+ column_data["primary"] = True
- columns[column_data['name']] = column_data
+ columns[column_data["name"]] = column_data
if create_flags & self.CREATE_FOREIGNKEYS > 0:
- options['foreign_keys'] = []
+ options["foreign_keys"] = []
for fk in table.get_foreign_keys().values():
- options['foreign_keys'].append(fk)
+ options["foreign_keys"].append(fk)
sql = self._get_create_table_sql(table_name, columns, options)
@@ -585,34 +621,37 @@ def _get_create_table_sql(self, table_name, columns, options=None):
column_list_sql = self.get_column_declaration_list_sql(columns)
- if options.get('unique_constraints'):
- for name, definition in options['unique_constraints'].items():
- column_list_sql += ', %s' % self.get_unique_constraint_declaration_sql(name, definition)
+ if options.get("unique_constraints"):
+ for name, definition in options["unique_constraints"].items():
+ column_list_sql += ", %s" % self.get_unique_constraint_declaration_sql(
+ name, definition
+ )
- if options.get('primary'):
- column_list_sql += ', PRIMARY KEY(%s)' % ', '.join(options['primary'])
+ if options.get("primary"):
+ column_list_sql += ", PRIMARY KEY(%s)" % ", ".join(options["primary"])
- if options.get('indexes'):
- for index, definition in options['indexes']:
- column_list_sql += ', %s' % self.get_index_declaration_sql(index, definition)
+ if options.get("indexes"):
+ for index, definition in options["indexes"]:
+ column_list_sql += ", %s" % self.get_index_declaration_sql(
+ index, definition
+ )
- query = 'CREATE TABLE %s (%s' % (table_name, column_list_sql)
+ query = "CREATE TABLE %s (%s" % (table_name, column_list_sql)
check = self.get_check_declaration_sql(columns)
if check:
- query += ', %s' % check
+ query += ", %s" % check
- query += ')'
+ query += ")"
sql = [query]
- if options.get('foreign_keys'):
- for definition in options['foreign_keys']:
+ if options.get("foreign_keys"):
+ for definition in options["foreign_keys"]:
sql.append(self.get_create_foreign_key_sql(definition, table_name))
return sql
-
def quote_identifier(self, string):
"""
Quotes a string so that it can be safely used as a table or column name,
@@ -625,10 +664,10 @@ def quote_identifier(self, string):
:return: The quoted identifier string.
:rtype: str
"""
- if '.' in string:
- parts = list(map(self.quote_single_identifier, string.split('.')))
+ if "." in string:
+ parts = list(map(self.quote_single_identifier, string.split(".")))
- return '.'.join(parts)
+ return ".".join(parts)
return self.quote_single_identifier(string)
@@ -644,7 +683,7 @@ def quote_single_identifier(self, string):
"""
c = self.get_identifier_quote_character()
- return '%s%s%s' % (c, string.replace(c, c+c), c)
+ return "%s%s%s" % (c, string.replace(c, c + c), c)
def get_identifier_quote_character(self):
return '"'
diff --git a/orator/dbal/platforms/postgres_platform.py b/orator/dbal/platforms/postgres_platform.py
index 0901e27d..af4c6d60 100644
--- a/orator/dbal/platforms/postgres_platform.py
+++ b/orator/dbal/platforms/postgres_platform.py
@@ -10,46 +10,46 @@
class PostgresPlatform(Platform):
INTERNAL_TYPE_MAPPING = {
- 'smallint': 'smallint',
- 'int2': 'smallint',
- 'serial': 'integer',
- 'serial4': 'integer',
- 'int': 'integer',
- 'int4': 'integer',
- 'integer': 'integer',
- 'bigserial': 'bigint',
- 'serial8': 'bigint',
- 'bigint': 'bigint',
- 'int8': 'bigint',
- 'bool': 'boolean',
- 'boolean': 'boolean',
- 'text': 'text',
- 'tsvector': 'text',
- 'varchar': 'string',
- 'interval': 'string',
- '_varchar': 'string',
- 'char': 'string',
- 'bpchar': 'string',
- 'inet': 'string',
- 'date': 'date',
- 'datetime': 'datetime',
- 'timestamp': 'datetime',
- 'timestamptz': 'datetimez',
- 'time': 'time',
- 'timetz': 'time',
- 'float': 'float',
- 'float4': 'float',
- 'float8': 'float',
- 'double': 'float',
- 'double precision': 'float',
- 'real': 'float',
- 'decimal': 'decimal',
- 'money': 'decimal',
- 'numeric': 'decimal',
- 'year': 'date',
- 'uuid': 'guid',
- 'bytea': 'blob',
- 'json': 'json'
+ "smallint": "smallint",
+ "int2": "smallint",
+ "serial": "integer",
+ "serial4": "integer",
+ "int": "integer",
+ "int4": "integer",
+ "integer": "integer",
+ "bigserial": "bigint",
+ "serial8": "bigint",
+ "bigint": "bigint",
+ "int8": "bigint",
+ "bool": "boolean",
+ "boolean": "boolean",
+ "text": "text",
+ "tsvector": "text",
+ "varchar": "string",
+ "interval": "string",
+ "_varchar": "string",
+ "char": "string",
+ "bpchar": "string",
+ "inet": "string",
+ "date": "date",
+ "datetime": "datetime",
+ "timestamp": "datetime",
+ "timestamptz": "datetimez",
+ "time": "time",
+ "timetz": "time",
+ "float": "float",
+ "float4": "float",
+ "float8": "float",
+ "double": "float",
+ "double precision": "float",
+ "real": "float",
+ "decimal": "decimal",
+ "money": "decimal",
+ "numeric": "decimal",
+ "year": "date",
+ "uuid": "guid",
+ "bytea": "blob",
+ "json": "json",
}
def get_list_table_columns_sql(self, table):
@@ -82,7 +82,9 @@ def get_list_table_columns_sql(self, table):
AND a.attrelid = c.oid
AND a.atttypid = t.oid
AND n.oid = c.relnamespace
- ORDER BY a.attnum""" % self.get_table_where_clause(table)
+ ORDER BY a.attnum""" % self.get_table_where_clause(
+ table
+ )
return sql
@@ -99,57 +101,79 @@ def get_list_table_indexes_sql(self, table):
AND sc.oid=si.indrelid AND sc.relnamespace = sn.oid
) AND pg_index.indexrelid = oid"""
- sql = sql % self.get_table_where_clause(table, 'sc', 'sn')
+ sql = sql % self.get_table_where_clause(table, "sc", "sn")
return sql
def get_list_table_foreign_keys_sql(self, table):
- return 'SELECT quote_ident(r.conname) as conname, ' \
- 'pg_catalog.pg_get_constraintdef(r.oid, true) AS condef ' \
- 'FROM pg_catalog.pg_constraint r ' \
- 'WHERE r.conrelid = ' \
- '(' \
- 'SELECT c.oid ' \
- 'FROM pg_catalog.pg_class c, pg_catalog.pg_namespace n ' \
- 'WHERE ' + self.get_table_where_clause(table) + ' AND n.oid = c.relnamespace' \
- ')' \
- ' AND r.contype = \'f\''
-
- def get_table_where_clause(self, table, class_alias='c', namespace_alias='n'):
- where_clause = namespace_alias + '.nspname NOT IN (\'pg_catalog\', \'information_schema\', \'pg_toast\') AND '
- if table.find('.') >= 0:
- split = table.split('.')
+ return (
+ "SELECT quote_ident(r.conname) as conname, "
+ "pg_catalog.pg_get_constraintdef(r.oid, true) AS condef "
+ "FROM pg_catalog.pg_constraint r "
+ "WHERE r.conrelid = "
+ "("
+ "SELECT c.oid "
+ "FROM pg_catalog.pg_class c, pg_catalog.pg_namespace n "
+ "WHERE "
+ + self.get_table_where_clause(table)
+ + " AND n.oid = c.relnamespace"
+ ")"
+ " AND r.contype = 'f'"
+ )
+
+ def get_table_where_clause(self, table, class_alias="c", namespace_alias="n"):
+ where_clause = (
+ namespace_alias
+ + ".nspname NOT IN ('pg_catalog', 'information_schema', 'pg_toast') AND "
+ )
+ if table.find(".") >= 0:
+ split = table.split(".")
schema, table = split[0], split[1]
schema = "'%s'" % schema
else:
- schema = 'ANY(string_to_array((select replace(replace(setting, \'"$user"\', user), \' \', \'\')' \
- ' from pg_catalog.pg_settings where name = \'search_path\'),\',\'))'
-
- where_clause += '%s.relname = \'%s\' AND %s.nspname = %s' % (class_alias, table, namespace_alias, schema)
+ schema = (
+ "ANY(string_to_array((select replace(replace(setting, '\"$user\"', user), ' ', '')"
+ " from pg_catalog.pg_settings where name = 'search_path'),','))"
+ )
+
+ where_clause += "%s.relname = '%s' AND %s.nspname = %s" % (
+ class_alias,
+ table,
+ namespace_alias,
+ schema,
+ )
return where_clause
def get_advanced_foreign_key_options_sql(self, foreign_key):
- query = ''
+ query = ""
- if foreign_key.has_option('match'):
- query += ' MATCH %s' % foreign_key.get_option('match')
+ if foreign_key.has_option("match"):
+ query += " MATCH %s" % foreign_key.get_option("match")
- query += super(PostgresPlatform, self).get_advanced_foreign_key_options_sql(foreign_key)
+ query += super(PostgresPlatform, self).get_advanced_foreign_key_options_sql(
+ foreign_key
+ )
- deferrable = foreign_key.has_option('deferrable') and foreign_key.get_option('deferrable') is not False
+ deferrable = (
+ foreign_key.has_option("deferrable")
+ and foreign_key.get_option("deferrable") is not False
+ )
if deferrable:
- query += ' DEFERRABLE'
+ query += " DEFERRABLE"
else:
- query += ' NOT DEFERRABLE'
+ query += " NOT DEFERRABLE"
- query += ' INITIALLY'
+ query += " INITIALLY"
- deferred = foreign_key.has_option('deferred') and foreign_key.get_option('deferred') is not False
+ deferred = (
+ foreign_key.has_option("deferred")
+ and foreign_key.get_option("deferred") is not False
+ )
if deferred:
- query += ' DEFERRED'
+ query += " DEFERRED"
else:
- query += ' IMMEDIATE'
+ query += " IMMEDIATE"
return query
@@ -171,58 +195,118 @@ def get_alter_table_sql(self, diff):
old_column_name = column_diff.get_old_column_name().get_quoted_name(self)
column = column_diff.column
- if any([column_diff.has_changed('type'),
- column_diff.has_changed('precision'),
- column_diff.has_changed('scale'),
- column_diff.has_changed('fixed')]):
- query = 'ALTER ' + old_column_name + ' TYPE ' + self.get_sql_type_declaration(column.to_dict())
- sql.append('ALTER TABLE ' + diff.get_name(self).get_quoted_name(self) + ' ' + query)
-
- if column_diff.has_changed('default') or column_diff.has_changed('type'):
+ if any(
+ [
+ column_diff.has_changed("type"),
+ column_diff.has_changed("precision"),
+ column_diff.has_changed("scale"),
+ column_diff.has_changed("fixed"),
+ ]
+ ):
+ query = (
+ "ALTER "
+ + old_column_name
+ + " TYPE "
+ + self.get_sql_type_declaration(column.to_dict())
+ )
+ sql.append(
+ "ALTER TABLE "
+ + diff.get_name(self).get_quoted_name(self)
+ + " "
+ + query
+ )
+
+ if column_diff.has_changed("default") or column_diff.has_changed("type"):
if column.get_default() is None:
- default_clause = ' DROP DEFAULT'
+ default_clause = " DROP DEFAULT"
else:
- default_clause = ' SET' + self.get_default_value_declaration_sql(column.to_dict())
-
- query = 'ALTER ' + old_column_name + default_clause
- sql.append('ALTER TABLE ' + diff.get_name(self).get_quoted_name(self) + ' ' + query)
-
- if column_diff.has_changed('notnull'):
- op = 'DROP'
+ default_clause = " SET" + self.get_default_value_declaration_sql(
+ column.to_dict()
+ )
+
+ query = "ALTER " + old_column_name + default_clause
+ sql.append(
+ "ALTER TABLE "
+ + diff.get_name(self).get_quoted_name(self)
+ + " "
+ + query
+ )
+
+ if column_diff.has_changed("notnull"):
+ op = "DROP"
if column.get_notnull():
- op = 'SET'
+ op = "SET"
- query = 'ALTER ' + old_column_name + ' ' + op + ' NOT NULL'
- sql.append('ALTER TABLE ' + diff.get_name(self).get_quoted_name(self) + ' ' + query)
+ query = "ALTER " + old_column_name + " " + op + " NOT NULL"
+ sql.append(
+ "ALTER TABLE "
+ + diff.get_name(self).get_quoted_name(self)
+ + " "
+ + query
+ )
- if column_diff.has_changed('autoincrement'):
+ if column_diff.has_changed("autoincrement"):
if column.get_autoincrement():
- seq_name = self.get_identity_sequence_name(diff.name, old_column_name)
-
- sql.append('CREATE SEQUENCE ' + seq_name)
- sql.append('SELECT setval(\'' + seq_name + '\', '
- '(SELECT MAX(' + old_column_name + ') FROM ' + diff.name + '))')
- query = 'ALTER ' + old_column_name + ' SET DEFAULT nextval(\'' + seq_name + '\')'
- sql.append('ALTER TABLE ' + diff.get_name(self).get_quoted_name(self) + ' ' + query)
+ seq_name = self.get_identity_sequence_name(
+ diff.name, old_column_name
+ )
+
+ sql.append("CREATE SEQUENCE " + seq_name)
+ sql.append(
+ "SELECT setval('" + seq_name + "', "
+ "(SELECT MAX(" + old_column_name + ") FROM " + diff.name + "))"
+ )
+ query = (
+ "ALTER "
+ + old_column_name
+ + " SET DEFAULT nextval('"
+ + seq_name
+ + "')"
+ )
+ sql.append(
+ "ALTER TABLE "
+ + diff.get_name(self).get_quoted_name(self)
+ + " "
+ + query
+ )
else:
- query = 'ALTER ' + old_column_name + ' DROP DEFAULT'
- sql.append('ALTER TABLE ' + diff.get_name(self).get_quoted_name(self) + ' ' + query)
-
- if column_diff.has_changed('length'):
- query = 'ALTER ' + old_column_name + ' TYPE ' + self.get_sql_type_declaration(column.to_dict())
- sql.append('ALTER TABLE ' + diff.get_name(self).get_quoted_name(self) + ' ' + query)
+ query = "ALTER " + old_column_name + " DROP DEFAULT"
+ sql.append(
+ "ALTER TABLE "
+ + diff.get_name(self).get_quoted_name(self)
+ + " "
+ + query
+ )
+
+ if column_diff.has_changed("length"):
+ query = (
+ "ALTER "
+ + old_column_name
+ + " TYPE "
+ + self.get_sql_type_declaration(column.to_dict())
+ )
+ sql.append(
+ "ALTER TABLE "
+ + diff.get_name(self).get_quoted_name(self)
+ + " "
+ + query
+ )
for old_column_name, column in diff.renamed_columns.items():
- sql.append('ALTER TABLE ' + diff.get_name(self).get_quoted_name(self) + ' '
- 'RENAME COLUMN ' + Identifier(old_column_name).get_quoted_name(self) +
- ' TO ' + column.get_quoted_name(self))
+ sql.append(
+ "ALTER TABLE " + diff.get_name(self).get_quoted_name(self) + " "
+ "RENAME COLUMN "
+ + Identifier(old_column_name).get_quoted_name(self)
+ + " TO "
+ + column.get_quoted_name(self)
+ )
return sql
def is_unchanged_binary_column(self, column_diff):
column_type = column_diff.column.get_type()
- if column_type not in ['blob', 'binary']:
+ if column_type not in ["blob", "binary"]:
return False
if isinstance(column_diff.from_column, Column):
@@ -233,15 +317,33 @@ def is_unchanged_binary_column(self, column_diff):
if from_column:
from_column_type = self.INTERNAL_TYPE_MAPPING[from_column.get_type()]
- if from_column_type in ['blob', 'binary']:
+ if from_column_type in ["blob", "binary"]:
return False
- return len([x for x in column_diff.changed_properties if x not in ['type', 'length', 'fixed']]) == 0
-
- if column_diff.has_changed('type'):
+ return (
+ len(
+ [
+ x
+ for x in column_diff.changed_properties
+ if x not in ["type", "length", "fixed"]
+ ]
+ )
+ == 0
+ )
+
+ if column_diff.has_changed("type"):
return False
- return len([x for x in column_diff.changed_properties if x not in ['length', 'fixed']]) == 0
+ return (
+ len(
+ [
+ x
+ for x in column_diff.changed_properties
+ if x not in ["length", "fixed"]
+ ]
+ )
+ == 0
+ )
def convert_booleans(self, item):
if isinstance(item, list):
@@ -254,73 +356,73 @@ def convert_booleans(self, item):
return item
def get_boolean_type_declaration_sql(self, column):
- return 'BOOLEAN'
+ return "BOOLEAN"
def get_integer_type_declaration_sql(self, column):
- if column.get('autoincrement'):
- return 'SERIAL'
+ if column.get("autoincrement"):
+ return "SERIAL"
- return 'INT'
+ return "INT"
def get_bigint_type_declaration_sql(self, column):
- if column.get('autoincrement'):
- return 'BIGSERIAL'
+ if column.get("autoincrement"):
+ return "BIGSERIAL"
- return 'BIGINT'
+ return "BIGINT"
def get_smallint_type_declaration_sql(self, column):
- return 'SMALLINT'
+ return "SMALLINT"
def get_guid_type_declaration_sql(self, column):
- return 'UUID'
+ return "UUID"
def get_datetime_type_declaration_sql(self, column):
- return 'TIMESTAMP(0) WITHOUT TIME ZONE'
+ return "TIMESTAMP(0) WITHOUT TIME ZONE"
def get_datetimetz_type_declaration_sql(self, column):
- return 'TIMESTAMP(0) WITH TIME ZONE'
+ return "TIMESTAMP(0) WITH TIME ZONE"
def get_date_type_declaration_sql(self, column):
- return 'DATE'
+ return "DATE"
def get_time_type_declaration_sql(self, column):
- return 'TIME(0) WITHOUT TIME ZONE'
+ return "TIME(0) WITHOUT TIME ZONE"
def get_string_type_declaration_sql(self, column):
- length = column.get('length', '255')
- fixed = column.get('fixed')
+ length = column.get("length", "255")
+ fixed = column.get("fixed")
if fixed:
- return 'CHAR(%s)' % length
+ return "CHAR(%s)" % length
else:
- return 'VARCHAR(%s)' % length
+ return "VARCHAR(%s)" % length
def get_binary_type_declaration_sql(self, column):
- return 'BYTEA'
+ return "BYTEA"
def get_blob_type_declaration_sql(self, column):
- return 'BYTEA'
+ return "BYTEA"
def get_clob_type_declaration_sql(self, column):
- return 'TEXT'
+ return "TEXT"
def get_text_type_declaration_sql(self, column):
- return 'TEXT'
+ return "TEXT"
def get_json_type_declaration_sql(self, column):
- return 'JSON'
+ return "JSON"
def get_decimal_type_declaration_sql(self, column):
- if 'precision' not in column or not column['precision']:
- column['precision'] = 10
+ if "precision" not in column or not column["precision"]:
+ column["precision"] = 10
- if 'scale' not in column or not column['scale']:
- column['precision'] = 0
+ if "scale" not in column or not column["scale"]:
+ column["precision"] = 0
- return 'DECIMAL(%s, %s)' % (column['precision'], column['scale'])
+ return "DECIMAL(%s, %s)" % (column["precision"], column["scale"])
def get_float_type_declaration_sql(self, column):
- return 'DOUBLE PRECISION'
+ return "DOUBLE PRECISION"
def supports_foreign_key_constraints(self):
return True
diff --git a/orator/dbal/platforms/sqlite_platform.py b/orator/dbal/platforms/sqlite_platform.py
index 05540d50..b5698859 100644
--- a/orator/dbal/platforms/sqlite_platform.py
+++ b/orator/dbal/platforms/sqlite_platform.py
@@ -14,54 +14,54 @@
class SQLitePlatform(Platform):
INTERNAL_TYPE_MAPPING = {
- 'boolean': 'boolean',
- 'tinyint': 'boolean',
- 'smallint': 'smallint',
- 'mediumint': 'integer',
- 'int': 'integer',
- 'integer': 'integer',
- 'serial': 'integer',
- 'bigint': 'bigint',
- 'bigserial': 'bigint',
- 'clob': 'text',
- 'tinytext': 'text',
- 'mediumtext': 'text',
- 'longtext': 'text',
- 'text': 'text',
- 'varchar': 'string',
- 'longvarchar': 'string',
- 'varchar2': 'string',
- 'nvarchar': 'string',
- 'image': 'string',
- 'ntext': 'string',
- 'char': 'string',
- 'date': 'date',
- 'datetime': 'datetime',
- 'timestamp': 'datetime',
- 'time': 'time',
- 'float': 'float',
- 'double': 'float',
- 'double precision': 'float',
- 'real': 'float',
- 'decimal': 'decimal',
- 'numeric': 'decimal',
- 'blob': 'blob',
+ "boolean": "boolean",
+ "tinyint": "boolean",
+ "smallint": "smallint",
+ "mediumint": "integer",
+ "int": "integer",
+ "integer": "integer",
+ "serial": "integer",
+ "bigint": "bigint",
+ "bigserial": "bigint",
+ "clob": "text",
+ "tinytext": "text",
+ "mediumtext": "text",
+ "longtext": "text",
+ "text": "text",
+ "varchar": "string",
+ "longvarchar": "string",
+ "varchar2": "string",
+ "nvarchar": "string",
+ "image": "string",
+ "ntext": "string",
+ "char": "string",
+ "date": "date",
+ "datetime": "datetime",
+ "timestamp": "datetime",
+ "time": "time",
+ "float": "float",
+ "double": "float",
+ "double precision": "float",
+ "real": "float",
+ "decimal": "decimal",
+ "numeric": "decimal",
+ "blob": "blob",
}
def get_list_table_columns_sql(self, table):
- table = table.replace('.', '__')
+ table = table.replace(".", "__")
- return 'PRAGMA table_info(\'%s\')' % table
+ return "PRAGMA table_info('%s')" % table
def get_list_table_indexes_sql(self, table):
- table = table.replace('.', '__')
+ table = table.replace(".", "__")
- return 'PRAGMA index_list(\'%s\')' % table
+ return "PRAGMA index_list('%s')" % table
def get_list_table_foreign_keys_sql(self, table):
- table = table.replace('.', '__')
+ table = table.replace(".", "__")
- return 'PRAGMA foreign_key_list(\'%s\')' % table
+ return "PRAGMA foreign_key_list('%s')" % table
def get_pre_alter_table_index_foreign_key_sql(self, diff):
"""
@@ -71,8 +71,10 @@ def get_pre_alter_table_index_foreign_key_sql(self, diff):
:rtype: list
"""
if not isinstance(diff.from_table, Table):
- raise DBALException('Sqlite platform requires for alter table the table'
- 'diff with reference to original table schema')
+ raise DBALException(
+ "Sqlite platform requires for alter table the table"
+ "diff with reference to original table schema"
+ )
sql = []
for index in diff.from_table.get_indexes().values():
@@ -89,8 +91,10 @@ def get_post_alter_table_index_foreign_key_sql(self, diff):
:rtype: list
"""
if not isinstance(diff.from_table, Table):
- raise DBALException('Sqlite platform requires for alter table the table'
- 'diff with reference to original table schema')
+ raise DBALException(
+ "Sqlite platform requires for alter table the table"
+ "diff with reference to original table schema"
+ )
sql = []
@@ -103,7 +107,9 @@ def get_post_alter_table_index_foreign_key_sql(self, diff):
if index.is_primary():
continue
- sql.append(self.get_create_index_sql(index, table_name.get_quoted_name(self)))
+ sql.append(
+ self.get_create_index_sql(index, table_name.get_quoted_name(self))
+ )
return sql
@@ -114,34 +120,36 @@ def get_create_table_sql(self, table, create_flags=None):
return super(SQLitePlatform, self).get_create_table_sql(table, create_flags)
def _get_create_table_sql(self, table_name, columns, options=None):
- table_name = table_name.replace('.', '__')
+ table_name = table_name.replace(".", "__")
query_fields = self.get_column_declaration_list_sql(columns)
- if options.get('unique_constraints'):
- for name, definition in options['unique_constraints'].items():
- query_fields += ', %s' % self.get_unique_constraint_declaration_sql(name, definition)
+ if options.get("unique_constraints"):
+ for name, definition in options["unique_constraints"].items():
+ query_fields += ", %s" % self.get_unique_constraint_declaration_sql(
+ name, definition
+ )
- if options.get('primary'):
- key_columns = options['primary']
- query_fields += ', PRIMARY KEY(%s)' % ', '.join(key_columns)
+ if options.get("primary"):
+ key_columns = options["primary"]
+ query_fields += ", PRIMARY KEY(%s)" % ", ".join(key_columns)
- if options.get('foreign_keys'):
- for foreign_key in options['foreign_keys']:
- query_fields += ', %s' % self.get_foreign_key_declaration_sql(foreign_key)
+ if options.get("foreign_keys"):
+ for foreign_key in options["foreign_keys"]:
+ query_fields += ", %s" % self.get_foreign_key_declaration_sql(
+ foreign_key
+ )
- query = [
- 'CREATE TABLE %s (%s)' % (table_name, query_fields)
- ]
+ query = ["CREATE TABLE %s (%s)" % (table_name, query_fields)]
- if options.get('alter'):
+ if options.get("alter"):
return query
- if options.get('indexes'):
- for index_def in options['indexes'].values():
+ if options.get("indexes"):
+ for index_def in options["indexes"].values():
query.append(self.get_create_index_sql(index_def, table_name))
- if options.get('unique'):
- for index_def in options['unique'].values():
+ if options.get("unique"):
+ for index_def in options["unique"].values():
query.append(self.get_create_index_sql(index_def, table_name))
return query
@@ -150,29 +158,37 @@ def get_foreign_key_declaration_sql(self, foreign_key):
return super(SQLitePlatform, self).get_foreign_key_declaration_sql(
ForeignKeyConstraint(
foreign_key.get_quoted_local_columns(self),
- foreign_key.get_quoted_foreign_table_name(self).replace('.', '__'),
+ foreign_key.get_quoted_foreign_table_name(self).replace(".", "__"),
foreign_key.get_quoted_foreign_columns(self),
foreign_key.get_name(),
- foreign_key.get_options()
+ foreign_key.get_options(),
)
)
def get_advanced_foreign_key_options_sql(self, foreign_key):
- query = super(SQLitePlatform, self).get_advanced_foreign_key_options_sql(foreign_key)
+ query = super(SQLitePlatform, self).get_advanced_foreign_key_options_sql(
+ foreign_key
+ )
- deferrable = foreign_key.has_option('deferrable') and foreign_key.get_option('deferrable') is not False
+ deferrable = (
+ foreign_key.has_option("deferrable")
+ and foreign_key.get_option("deferrable") is not False
+ )
if deferrable:
- query += ' DEFERRABLE'
+ query += " DEFERRABLE"
else:
- query += ' NOT DEFERRABLE'
+ query += " NOT DEFERRABLE"
- query += ' INITIALLY'
+ query += " INITIALLY"
- deferred = foreign_key.has_option('deferred') and foreign_key.get_option('deferred') is not False
+ deferred = (
+ foreign_key.has_option("deferred")
+ and foreign_key.get_option("deferred") is not False
+ )
if deferred:
- query += ' DEFERRED'
+ query += " DEFERRED"
else:
- query += ' IMMEDIATE'
+ query += " IMMEDIATE"
return query
@@ -192,8 +208,8 @@ def get_alter_table_sql(self, diff):
from_table = diff.from_table
if not isinstance(from_table, Table):
raise DBALException(
- 'SQLite platform requires for the alter table the table diff '
- 'referencing the original table'
+ "SQLite platform requires for the alter table the table diff "
+ "referencing the original table"
)
table = from_table.clone()
@@ -231,33 +247,46 @@ def get_alter_table_sql(self, diff):
columns[column_diff.column.get_name().lower()] = column_diff.column
if old_column_name in new_column_names:
- new_column_names[old_column_name] = column_diff.column.get_quoted_name(self)
+ new_column_names[old_column_name] = column_diff.column.get_quoted_name(
+ self
+ )
for column_name, column in diff.added_columns.items():
columns[column_name.lower()] = column
table_sql = []
- data_table = Table('__temp__' + table.get_name())
- new_table = Table(table.get_quoted_name(self), columns,
- self._get_primary_index_in_altered_table(diff),
- self._get_foreign_keys_in_altered_table(diff),
- table.get_options())
- new_table.add_option('alter', True)
+ data_table = Table("__temp__" + table.get_name())
+ new_table = Table(
+ table.get_quoted_name(self),
+ columns,
+ self._get_primary_index_in_altered_table(diff),
+ self._get_foreign_keys_in_altered_table(diff),
+ table.get_options(),
+ )
+ new_table.add_option("alter", True)
sql = self.get_pre_alter_table_index_foreign_key_sql(diff)
- sql.append('CREATE TEMPORARY TABLE %s AS SELECT %s FROM %s'
- % (data_table.get_quoted_name(self),
- ', '.join(old_column_names.values()),
- table.get_quoted_name(self)))
+ sql.append(
+ "CREATE TEMPORARY TABLE %s AS SELECT %s FROM %s"
+ % (
+ data_table.get_quoted_name(self),
+ ", ".join(old_column_names.values()),
+ table.get_quoted_name(self),
+ )
+ )
sql.append(self.get_drop_table_sql(from_table))
sql += self.get_create_table_sql(new_table)
- sql.append('INSERT INTO %s (%s) SELECT %s FROM %s'
- % (new_table.get_quoted_name(self),
- ', '.join(new_column_names.values()),
- ', '.join(old_column_names.values()),
- data_table.get_name()))
+ sql.append(
+ "INSERT INTO %s (%s) SELECT %s FROM %s"
+ % (
+ new_table.get_quoted_name(self),
+ ", ".join(new_column_names.values()),
+ ", ".join(old_column_names.values()),
+ data_table.get_name(),
+ )
+ )
sql.append(self.get_drop_table_sql(data_table))
sql += self.get_post_alter_table_index_foreign_key_sql(diff)
@@ -266,26 +295,40 @@ def get_alter_table_sql(self, diff):
def _get_simple_alter_table_sql(self, diff):
for old_column_name, column_diff in diff.changed_columns.items():
- if not isinstance(column_diff.from_column, Column)\
- or not isinstance(column_diff.column, Column)\
- or not column_diff.column.get_autoincrement()\
- or column_diff.column.get_type().lower() != 'integer':
+ if (
+ not isinstance(column_diff.from_column, Column)
+ or not isinstance(column_diff.column, Column)
+ or not column_diff.column.get_autoincrement()
+ or column_diff.column.get_type().lower() != "integer"
+ ):
continue
- if not column_diff.has_changed('type') and not column_diff.has_changed('unsigned'):
+ if not column_diff.has_changed("type") and not column_diff.has_changed(
+ "unsigned"
+ ):
del diff.changed_columns[old_column_name]
continue
from_column_type = column_diff.column.get_type()
- if from_column_type == 'smallint' or from_column_type == 'bigint':
+ if from_column_type == "smallint" or from_column_type == "bigint":
del diff.changed_columns[old_column_name]
- if any([not diff.renamed_columns, not diff.added_foreign_keys, not diff.added_indexes,
- not diff.changed_columns, not diff.changed_foreign_keys, not diff.changed_indexes,
- not diff.removed_columns, not diff.removed_foreign_keys, not diff.removed_indexes,
- not diff.renamed_indexes]):
+ if any(
+ [
+ not diff.renamed_columns,
+ not diff.added_foreign_keys,
+ not diff.added_indexes,
+ not diff.changed_columns,
+ not diff.changed_foreign_keys,
+ not diff.changed_indexes,
+ not diff.removed_columns,
+ not diff.removed_foreign_keys,
+ not diff.removed_indexes,
+ not diff.renamed_indexes,
+ ]
+ ):
return False
table = Table(diff.name)
@@ -295,34 +338,45 @@ def _get_simple_alter_table_sql(self, diff):
column_sql = []
for column in diff.added_columns.values():
- field = {
- 'unique': None,
- 'autoincrement': None,
- 'default': None
- }
+ field = {"unique": None, "autoincrement": None, "default": None}
field.update(column.to_dict())
- type_ = field['type']
- if 'column_definition' in field or field['autoincrement'] or field['unique']:
+ type_ = field["type"]
+ if (
+ "column_definition" in field
+ or field["autoincrement"]
+ or field["unique"]
+ ):
return False
- elif type_ == 'datetime' and field['default'] == self.get_current_timestamp_sql():
+ elif (
+ type_ == "datetime"
+ and field["default"] == self.get_current_timestamp_sql()
+ ):
return False
- elif type_ == 'date' and field['default'] == self.get_current_date_sql():
+ elif type_ == "date" and field["default"] == self.get_current_date_sql():
return False
- elif type_ == 'time' and field['default'] == self.get_current_time_sql():
+ elif type_ == "time" and field["default"] == self.get_current_time_sql():
return False
- field['name'] = column.get_quoted_name(self)
- if field['type'].lower() == 'string' and field['length'] is None:
- field['length'] = 255
+ field["name"] = column.get_quoted_name(self)
+ if field["type"].lower() == "string" and field["length"] is None:
+ field["length"] = 255
- sql.append('ALTER TABLE ' + table.get_quoted_name(self) +
- ' ADD COLUMN ' + self.get_column_declaration_sql(field['name'], field))
+ sql.append(
+ "ALTER TABLE "
+ + table.get_quoted_name(self)
+ + " ADD COLUMN "
+ + self.get_column_declaration_sql(field["name"], field)
+ )
if diff.new_name is not False:
new_table = Identifier(diff.new_name)
- sql.append('ALTER TABLE ' + table.get_quoted_name(self) +
- ' RENAME TO ' + new_table.get_quoted_name(self))
+ sql.append(
+ "ALTER TABLE "
+ + table.get_quoted_name(self)
+ + " RENAME TO "
+ + new_table.get_quoted_name(self)
+ )
return sql
@@ -354,9 +408,13 @@ def _get_indexes_in_altered_table(self, diff):
changed = True
if changed:
- indexes[key] = Index(index.get_name(), index_columns,
- index.is_unique(), index.is_primary(),
- index.get_flags())
+ indexes[key] = Index(
+ index.get_name(),
+ index_columns,
+ index.is_unique(),
+ index.is_primary(),
+ index.get_flags(),
+ )
for index in diff.removed_indexes.values():
index_name = index.get_name().lower()
@@ -438,7 +496,7 @@ def _get_foreign_keys_in_altered_table(self, diff):
constraint.get_foreign_table_name(),
constraint.get_foreign_columns(),
constraint.get_name(),
- constraint.get_options()
+ constraint.get_options(),
)
for constraint in diff.removed_foreign_keys:
@@ -475,72 +533,72 @@ def supports_foreign_key_constraints(self):
return True
def get_boolean_type_declaration_sql(self, column):
- return 'BOOLEAN'
+ return "BOOLEAN"
def get_integer_type_declaration_sql(self, column):
- return 'INTEGER' + self._get_common_integer_type_declaration_sql(column)
+ return "INTEGER" + self._get_common_integer_type_declaration_sql(column)
def get_bigint_type_declaration_sql(self, column):
# SQLite autoincrement is implicit for INTEGER PKs, but not for BIGINT fields.
- if not column.get('autoincrement', False):
+ if not column.get("autoincrement", False):
return self.get_integer_type_declaration_sql(column)
- return 'BIGINT' + self._get_common_integer_type_declaration_sql(column)
+ return "BIGINT" + self._get_common_integer_type_declaration_sql(column)
def get_tinyint_type_declaration_sql(self, column):
# SQLite autoincrement is implicit for INTEGER PKs, but not for TINYINT fields.
- if not column.get('autoincrement', False):
+ if not column.get("autoincrement", False):
return self.get_integer_type_declaration_sql(column)
- return 'TINYINT' + self._get_common_integer_type_declaration_sql(column)
+ return "TINYINT" + self._get_common_integer_type_declaration_sql(column)
def get_smallint_type_declaration_sql(self, column):
# SQLite autoincrement is implicit for INTEGER PKs, but not for SMALLINT fields.
- if not column.get('autoincrement', False):
+ if not column.get("autoincrement", False):
return self.get_integer_type_declaration_sql(column)
- return 'SMALLINT' + self._get_common_integer_type_declaration_sql(column)
+ return "SMALLINT" + self._get_common_integer_type_declaration_sql(column)
def get_mediumint_type_declaration_sql(self, column):
# SQLite autoincrement is implicit for INTEGER PKs, but not for MEDIUMINT fields.
- if not column.get('autoincrement', False):
+ if not column.get("autoincrement", False):
return self.get_integer_type_declaration_sql(column)
- return 'MEDIUMINT' + self._get_common_integer_type_declaration_sql(column)
+ return "MEDIUMINT" + self._get_common_integer_type_declaration_sql(column)
def get_datetime_type_declaration_sql(self, column):
- return 'DATETIME'
+ return "DATETIME"
def get_date_type_declaration_sql(self, column):
- return 'DATE'
+ return "DATE"
def get_time_type_declaration_sql(self, column):
- return 'TIME'
+ return "TIME"
def _get_common_integer_type_declaration_sql(self, column):
# sqlite autoincrement is implicit for integer PKs, but not when the field is unsigned
- if not column.get('autoincrement', False):
- return ''
+ if not column.get("autoincrement", False):
+ return ""
- if not column.get('unsigned', False):
- return ' UNSIGNED'
+ if not column.get("unsigned", False):
+ return " UNSIGNED"
- return ''
+ return ""
def get_varchar_type_declaration_sql_snippet(self, length, fixed):
if fixed:
- return 'CHAR(%s)' % length if length else 'CHAR(255)'
+ return "CHAR(%s)" % length if length else "CHAR(255)"
else:
- return 'VARCHAR(%s)' % length if length else 'TEXT'
+ return "VARCHAR(%s)" % length if length else "TEXT"
def get_blob_type_declaration_sql(self, column):
- return 'BLOB'
+ return "BLOB"
def get_clob_type_declaration_sql(self, column):
- return 'CLOB'
+ return "CLOB"
def get_column_options(self):
- return ['pk']
+ return ["pk"]
def _get_reserved_keywords_class(self):
return SQLiteKeywords
diff --git a/orator/dbal/postgres_schema_manager.py b/orator/dbal/postgres_schema_manager.py
index 1390bcc0..e2724a3d 100644
--- a/orator/dbal/postgres_schema_manager.py
+++ b/orator/dbal/postgres_schema_manager.py
@@ -7,98 +7,110 @@
class PostgresSchemaManager(SchemaManager):
-
def _get_portable_table_column_definition(self, table_column):
- if table_column['type'].lower() == 'varchar' or table_column['type'] == 'bpchar':
- length = re.sub('.*\(([0-9]*)\).*', '\\1', table_column['complete_type'])
- table_column['length'] = length
+ if (
+ table_column["type"].lower() == "varchar"
+ or table_column["type"] == "bpchar"
+ ):
+ length = re.sub(".*\(([0-9]*)\).*", "\\1", table_column["complete_type"])
+ table_column["length"] = length
autoincrement = False
- match = re.match("^nextval\('?(.*)'?(::.*)?\)$", str(table_column['default']))
+ match = re.match("^nextval\('?(.*)'?(::.*)?\)$", str(table_column["default"]))
if match:
- table_column['sequence'] = match.group(1)
- table_column['default'] = None
+ table_column["sequence"] = match.group(1)
+ table_column["default"] = None
autoincrement = True
- match = re.match("^'?([^']*)'?::.*$", str(table_column['default']))
+ match = re.match("^'?([^']*)'?::.*$", str(table_column["default"]))
if match:
- table_column['default'] = match.group(1)
+ table_column["default"] = match.group(1)
- if str(table_column['default']).find('NULL') == 0:
- table_column['default'] = None
+ if str(table_column["default"]).find("NULL") == 0:
+ table_column["default"] = None
- if 'length' in table_column:
- length = table_column['length']
+ if "length" in table_column:
+ length = table_column["length"]
else:
length = None
- if length == '-1' and 'atttypmod' in table_column:
- length = table_column['atttypmod'] - 4
+ if length == "-1" and "atttypmod" in table_column:
+ length = table_column["atttypmod"] - 4
if length is None or not length.isdigit() or int(length) <= 0:
length = None
fixed = None
- if 'name' not in table_column:
- table_column['name'] = ''
+ if "name" not in table_column:
+ table_column["name"] = ""
precision = None
scale = None
- db_type = table_column['type'].lower()
+ db_type = table_column["type"].lower()
type = self._platform.get_type_mapping(db_type)
- if db_type in ['smallint', 'int2']:
+ if db_type in ["smallint", "int2"]:
length = None
- elif db_type in ['int', 'int4', 'integer']:
+ elif db_type in ["int", "int4", "integer"]:
length = None
- elif db_type in ['int8', 'bigint']:
+ elif db_type in ["int8", "bigint"]:
length = None
- elif db_type in ['bool', 'boolean']:
- if table_column['default'] == 'true':
- table_column['default'] = True
+ elif db_type in ["bool", "boolean"]:
+ if table_column["default"] == "true":
+ table_column["default"] = True
- if table_column['default'] == 'false':
- table_column['default'] = False
+ if table_column["default"] == "false":
+ table_column["default"] = False
length = None
- elif db_type == 'text':
+ elif db_type == "text":
fixed = False
- elif db_type in ['varchar', 'interval', '_varchar']:
+ elif db_type in ["varchar", "interval", "_varchar"]:
fixed = False
- elif db_type in ['char', 'bpchar']:
+ elif db_type in ["char", "bpchar"]:
fixed = True
- elif db_type in ['float', 'float4', 'float8',
- 'double', 'double precision',
- 'real', 'decimal', 'money', 'numeric']:
- match = re.match('([A-Za-z]+\(([0-9]+),([0-9]+)\))', table_column['complete_type'])
+ elif db_type in [
+ "float",
+ "float4",
+ "float8",
+ "double",
+ "double precision",
+ "real",
+ "decimal",
+ "money",
+ "numeric",
+ ]:
+ match = re.match(
+ "([A-Za-z]+\(([0-9]+),([0-9]+)\))", table_column["complete_type"]
+ )
if match:
precision = match.group(1)
scale = match.group(2)
length = None
- elif db_type == 'year':
+ elif db_type == "year":
length = None
- if table_column['default']:
- match = re.match("('?([^']+)'?::)", str(table_column['default']))
+ if table_column["default"]:
+ match = re.match("('?([^']+)'?::)", str(table_column["default"]))
if match:
- table_column['default'] = match.group(1)
+ table_column["default"] = match.group(1)
options = {
- 'length': length,
- 'notnull': table_column['isnotnull'],
- 'default': table_column['default'],
- 'primary': table_column['pri'] == 't',
- 'precision': precision,
- 'scale': scale,
- 'fixed': fixed,
- 'unsigned': False,
- 'autoincrement': autoincrement
+ "length": length,
+ "notnull": table_column["isnotnull"],
+ "default": table_column["default"],
+ "primary": table_column["pri"] == "t",
+ "precision": precision,
+ "scale": scale,
+ "fixed": fixed,
+ "unsigned": False,
+ "autoincrement": autoincrement,
}
- column = Column(table_column['field'], type, options)
+ column = Column(table_column["field"], type, options)
return column
@@ -106,48 +118,64 @@ def _get_portable_table_indexes_list(self, table_indexes, table_name):
buffer = []
for row in table_indexes:
- col_numbers = row['indkey'].split(' ')
- col_numbers_sql = 'IN (%s)' % ', '.join(col_numbers)
- column_name_sql = 'SELECT attnum, attname FROM pg_attribute ' \
- 'WHERE attrelid=%s AND attnum %s ORDER BY attnum ASC;'\
- % (row['indrelid'], col_numbers_sql)
+ col_numbers = row["indkey"].split(" ")
+ col_numbers_sql = "IN (%s)" % ", ".join(col_numbers)
+ column_name_sql = (
+ "SELECT attnum, attname FROM pg_attribute "
+ "WHERE attrelid=%s AND attnum %s ORDER BY attnum ASC;"
+ % (row["indrelid"], col_numbers_sql)
+ )
index_columns = self._connection.select(column_name_sql)
# required for getting the order of the columns right.
for col_num in col_numbers:
for col_row in index_columns:
- if int(col_num) == col_row['attnum']:
- buffer.append({
- 'key_name': row['relname'],
- 'column_name': col_row['attname'].strip(),
- 'non_unique': not row['indisunique'],
- 'primary': row['indisprimary'],
- 'where': row['where']
- })
-
- return super(PostgresSchemaManager, self)._get_portable_table_indexes_list(buffer, table_name)
+ if int(col_num) == col_row["attnum"]:
+ buffer.append(
+ {
+ "key_name": row["relname"],
+ "column_name": col_row["attname"].strip(),
+ "non_unique": not row["indisunique"],
+ "primary": row["indisprimary"],
+ "where": row["where"],
+ }
+ )
+
+ return super(PostgresSchemaManager, self)._get_portable_table_indexes_list(
+ buffer, table_name
+ )
def _get_portable_table_foreign_key_definition(self, table_foreign_key):
- on_update = ''
- on_delete = ''
+ on_update = ""
+ on_delete = ""
- match = re.match('ON UPDATE ([a-zA-Z0-9]+( (NULL|ACTION|DEFAULT))?)', table_foreign_key['condef'])
+ match = re.match(
+ "ON UPDATE ([a-zA-Z0-9]+( (NULL|ACTION|DEFAULT))?)",
+ table_foreign_key["condef"],
+ )
if match:
on_update = match.group(1)
- match = re.match('ON DELETE ([a-zA-Z0-9]+( (NULL|ACTION|DEFAULT))?)', table_foreign_key['condef'])
+ match = re.match(
+ "ON DELETE ([a-zA-Z0-9]+( (NULL|ACTION|DEFAULT))?)",
+ table_foreign_key["condef"],
+ )
if match:
on_delete = match.group(1)
- values = re.match('FOREIGN KEY \((.+)\) REFERENCES (.+)\((.+)\)', table_foreign_key['condef'])
+ values = re.match(
+ "FOREIGN KEY \((.+)\) REFERENCES (.+)\((.+)\)", table_foreign_key["condef"]
+ )
if values:
- local_columns = [c.strip() for c in values.group(1).split(',')]
- foreign_columns = [c.strip() for c in values.group(3).split(',')]
+ local_columns = [c.strip() for c in values.group(1).split(",")]
+ foreign_columns = [c.strip() for c in values.group(3).split(",")]
foreign_table = values.group(2)
return ForeignKeyConstraint(
- local_columns, foreign_table, foreign_columns,
- table_foreign_key['conname'],
- {'on_update': on_update, 'on_delete': on_delete}
+ local_columns,
+ foreign_table,
+ foreign_columns,
+ table_foreign_key["conname"],
+ {"on_update": on_update, "on_delete": on_delete},
)
diff --git a/orator/dbal/schema_manager.py b/orator/dbal/schema_manager.py
index 83a464d3..df9d913b 100644
--- a/orator/dbal/schema_manager.py
+++ b/orator/dbal/schema_manager.py
@@ -7,7 +7,6 @@
class SchemaManager(object):
-
def __init__(self, connection, platform=None):
"""
:param connection: The connection to use
@@ -78,35 +77,38 @@ def _get_portable_table_indexes_list(self, table_indexes, table_name):
result = OrderedDict()
for table_index in table_indexes:
- index_name = table_index['key_name']
- key_name = table_index['key_name']
- if table_index['primary']:
- key_name = 'primary'
+ index_name = table_index["key_name"]
+ key_name = table_index["key_name"]
+ if table_index["primary"]:
+ key_name = "primary"
key_name = key_name.lower()
if key_name not in result:
options = {}
- if 'where' in table_index:
- options['where'] = table_index['where']
+ if "where" in table_index:
+ options["where"] = table_index["where"]
result[key_name] = {
- 'name': index_name,
- 'columns': [table_index['column_name']],
- 'unique': not table_index['non_unique'],
- 'primary': table_index['primary'],
- 'flags': table_index.get('flags') or None,
- 'options': options
+ "name": index_name,
+ "columns": [table_index["column_name"]],
+ "unique": not table_index["non_unique"],
+ "primary": table_index["primary"],
+ "flags": table_index.get("flags") or None,
+ "options": options,
}
else:
- result[key_name]['columns'].append(table_index['column_name'])
+ result[key_name]["columns"].append(table_index["column_name"])
indexes = OrderedDict()
for index_key, data in result.items():
index = Index(
- data['name'], data['columns'],
- data['unique'], data['primary'],
- data['flags'], data['options']
+ data["name"],
+ data["columns"],
+ data["unique"],
+ data["primary"],
+ data["flags"],
+ data["options"],
)
indexes[index_key] = index
diff --git a/orator/dbal/sqlite_schema_manager.py b/orator/dbal/sqlite_schema_manager.py
index 0db7eb1f..844dc105 100644
--- a/orator/dbal/sqlite_schema_manager.py
+++ b/orator/dbal/sqlite_schema_manager.py
@@ -8,67 +8,68 @@
class SQLiteSchemaManager(SchemaManager):
-
def _get_portable_table_column_definition(self, table_column):
- parts = table_column['type'].split('(')
- table_column['type'] = parts[0]
+ parts = table_column["type"].split("(")
+ table_column["type"] = parts[0]
if len(parts) > 1:
- length = parts[1].strip(')')
- table_column['length'] = length
+ length = parts[1].strip(")")
+ table_column["length"] = length
- db_type = table_column['type'].lower()
- length = table_column.get('length', None)
+ db_type = table_column["type"].lower()
+ length = table_column.get("length", None)
unsigned = False
- if ' unsigned' in db_type:
- db_type = db_type.replace(' unsigned', '')
+ if " unsigned" in db_type:
+ db_type = db_type.replace(" unsigned", "")
unsigned = True
fixed = False
type = self._platform.get_type_mapping(db_type)
- default = table_column['dflt_value']
- if default == 'NULL':
+ default = table_column["dflt_value"]
+ if default == "NULL":
default = None
if default is not None:
# SQLite returns strings wrapped in single quotes, so we need to strip them
- default = re.sub("^'(.*)'$", '\\1', default)
+ default = re.sub("^'(.*)'$", "\\1", default)
- notnull = bool(table_column['notnull'])
+ notnull = bool(table_column["notnull"])
- if 'name' not in table_column:
- table_column['name'] = ''
+ if "name" not in table_column:
+ table_column["name"] = ""
precision = None
scale = None
- if db_type in ['char']:
+ if db_type in ["char"]:
fixed = True
- elif db_type in ['varchar']:
+ elif db_type in ["varchar"]:
length = length or 255
- elif db_type in ['float', 'double', 'real', 'decimal', 'numeric']:
- if 'length' in table_column:
- if ',' not in table_column['length']:
- table_column['length'] += ',0'
+ elif db_type in ["float", "double", "real", "decimal", "numeric"]:
+ if "length" in table_column:
+ if "," not in table_column["length"]:
+ table_column["length"] += ",0"
- precision, scale = tuple(map(lambda x: x.strip(), table_column['length'].split(',')))
+ precision, scale = tuple(
+ map(lambda x: x.strip(), table_column["length"].split(","))
+ )
length = None
options = {
- 'length': length,
- 'unsigned': bool(unsigned),
- 'fixed': fixed,
- 'notnull': notnull,
- 'default': default,
- 'precision': precision,
- 'scale': scale,
- 'autoincrement': False
+ "length": length,
+ "unsigned": bool(unsigned),
+ "fixed": fixed,
+ "notnull": notnull,
+ "default": default,
+ "precision": precision,
+ "scale": scale,
+ "autoincrement": False,
}
- column = Column(table_column['name'], type, options)
- column.set_platform_option('pk', table_column['pk'])
+ column = Column(table_column["name"], type, options)
+ column.set_platform_option("pk", table_column["pk"])
return column
@@ -76,78 +77,84 @@ def _get_portable_table_indexes_list(self, table_indexes, table_name):
index_buffer = []
# Fetch primary
- info = self._connection.select('PRAGMA TABLE_INFO (%s)' % table_name)
+ info = self._connection.select("PRAGMA TABLE_INFO (%s)" % table_name)
for row in info:
- if row['pk'] != 0:
- index_buffer.append({
- 'key_name': 'primary',
- 'primary': True,
- 'non_unique': False,
- 'column_name': row['name']
- })
+ if row["pk"] != 0:
+ index_buffer.append(
+ {
+ "key_name": "primary",
+ "primary": True,
+ "non_unique": False,
+ "column_name": row["name"],
+ }
+ )
# Fetch regular indexes
for index in table_indexes:
# Ignore indexes with reserved names, e.g. autoindexes
- if index['name'].find('sqlite_') == -1:
- key_name = index['name']
+ if index["name"].find("sqlite_") == -1:
+ key_name = index["name"]
idx = {
- 'key_name': key_name,
- 'primary': False,
- 'non_unique': not bool(index['unique'])
+ "key_name": key_name,
+ "primary": False,
+ "non_unique": not bool(index["unique"]),
}
- info = self._connection.select('PRAGMA INDEX_INFO (\'%s\')' % key_name)
+ info = self._connection.select("PRAGMA INDEX_INFO ('%s')" % key_name)
for row in info:
- idx['column_name'] = row['name']
+ idx["column_name"] = row["name"]
index_buffer.append(idx)
- return super(SQLiteSchemaManager, self)._get_portable_table_indexes_list(index_buffer, table_name)
+ return super(SQLiteSchemaManager, self)._get_portable_table_indexes_list(
+ index_buffer, table_name
+ )
def _get_portable_table_foreign_keys_list(self, table_foreign_keys):
foreign_keys = OrderedDict()
for value in table_foreign_keys:
value = dict((k.lower(), v) for k, v in value.items())
- name = value.get('constraint_name', None)
+ name = value.get("constraint_name", None)
if name is None:
- name = '%s_%s_%s' % (value['from'], value['table'], value['to'])
+ name = "%s_%s_%s" % (value["from"], value["table"], value["to"])
if name not in foreign_keys:
- if 'on_delete' not in value or value['on_delete'] == 'RESTRICT':
- value['on_delete'] = None
+ if "on_delete" not in value or value["on_delete"] == "RESTRICT":
+ value["on_delete"] = None
- if 'on_update' not in value or value['on_update'] == 'RESTRICT':
- value['on_update'] = None
+ if "on_update" not in value or value["on_update"] == "RESTRICT":
+ value["on_update"] = None
foreign_keys[name] = {
- 'name': name,
- 'local': [],
- 'foreign': [],
- 'foreign_table': value['table'],
- 'on_delete': value['on_delete'],
- 'on_update': value['on_update'],
- 'deferrable': value.get('deferrable', False),
- 'deferred': value.get('deferred', False)
+ "name": name,
+ "local": [],
+ "foreign": [],
+ "foreign_table": value["table"],
+ "on_delete": value["on_delete"],
+ "on_update": value["on_update"],
+ "deferrable": value.get("deferrable", False),
+ "deferred": value.get("deferred", False),
}
- foreign_keys[name]['local'].append(value['from'])
- foreign_keys[name]['foreign'].append(value['to'])
+ foreign_keys[name]["local"].append(value["from"])
+ foreign_keys[name]["foreign"].append(value["to"])
result = []
for constraint in foreign_keys.values():
result.append(
ForeignKeyConstraint(
- constraint['local'], constraint['foreign_table'],
- constraint['foreign'], constraint['name'],
+ constraint["local"],
+ constraint["foreign_table"],
+ constraint["foreign"],
+ constraint["name"],
{
- 'on_delete': constraint['on_delete'],
- 'on_update': constraint['on_update'],
- 'deferrable': constraint['deferrable'],
- 'deferred': constraint['deferred']
- }
+ "on_delete": constraint["on_delete"],
+ "on_update": constraint["on_update"],
+ "deferrable": constraint["deferrable"],
+ "deferred": constraint["deferred"],
+ },
)
)
diff --git a/orator/dbal/table.py b/orator/dbal/table.py
index 887c56a0..9eaefe1f 100644
--- a/orator/dbal/table.py
+++ b/orator/dbal/table.py
@@ -8,15 +8,19 @@
from .foreign_key_constraint import ForeignKeyConstraint
from .exceptions import (
DBALException,
- IndexDoesNotExist, IndexAlreadyExists, IndexNameInvalid,
- ColumnDoesNotExist, ColumnAlreadyExists,
- ForeignKeyDoesNotExist
+ IndexDoesNotExist,
+ IndexAlreadyExists,
+ IndexNameInvalid,
+ ColumnDoesNotExist,
+ ColumnAlreadyExists,
+ ForeignKeyDoesNotExist,
)
class Table(AbstractAsset):
-
- def __init__(self, table_name, columns=None, indexes=None, fk_constraints=None, options=None):
+ def __init__(
+ self, table_name, columns=None, indexes=None, fk_constraints=None, options=None
+ ):
self._set_name(table_name)
self._primary_key_name = False
self._columns = OrderedDict()
@@ -37,7 +41,11 @@ def __init__(self, table_name, columns=None, indexes=None, fk_constraints=None,
for index in indexes:
self._add_index(index)
- fk_constraints = fk_constraints.values() if isinstance(fk_constraints, dict) else fk_constraints
+ fk_constraints = (
+ fk_constraints.values()
+ if isinstance(fk_constraints, dict)
+ else fk_constraints
+ )
for constraint in fk_constraints:
self._add_foreign_key_constraint(constraint)
@@ -53,7 +61,9 @@ def set_primary_key(self, columns, index_name=False):
:rtype: Table
"""
- self._add_index(self._create_index(columns, index_name or 'primary', True, True))
+ self._add_index(
+ self._create_index(columns, index_name or "primary", True, True)
+ )
for column_name in columns:
column = self.get_column(column_name)
@@ -64,10 +74,12 @@ def set_primary_key(self, columns, index_name=False):
def add_index(self, columns, name=None, flags=None, options=None):
if not name:
name = self._generate_identifier_name(
- [self.get_name()] + columns, 'idx', self._get_max_identifier_length()
+ [self.get_name()] + columns, "idx", self._get_max_identifier_length()
)
- return self._add_index(self._create_index(columns, name, False, False, flags, options))
+ return self._add_index(
+ self._create_index(columns, name, False, False, flags, options)
+ )
def drop_primary_key(self):
"""
@@ -92,10 +104,12 @@ def drop_index(self, name):
def add_unique_index(self, columns, name=None, options=None):
if not name:
name = self._generate_identifier_name(
- [self.get_name()] + columns, 'uniq', self._get_max_identifier_length()
+ [self.get_name()] + columns, "uniq", self._get_max_identifier_length()
)
- return self._add_index(self._create_index(columns, name, True, False, None, options))
+ return self._add_index(
+ self._create_index(columns, name, True, False, None, options)
+ )
def rename_index(self, old_name, new_name=None):
"""
@@ -149,7 +163,9 @@ def columns_are_indexed(self, columns):
return False
- def _create_index(self, columns, name, is_unique, is_primary, flags=None, options=None):
+ def _create_index(
+ self, columns, name, is_unique, is_primary, flags=None, options=None
+ ):
"""
Creates an Index instance.
@@ -173,7 +189,7 @@ def _create_index(self, columns, name, is_unique, is_primary, flags=None, option
:rtype: Index
"""
- if re.match('[^a-zA-Z0-9_]+', self._normalize_identifier(name)):
+ if re.match("[^a-zA-Z0-9_]+", self._normalize_identifier(name)):
raise IndexNameInvalid(name)
for column in columns:
@@ -237,8 +253,14 @@ def drop_column(self, name):
return self
- def add_foreign_key_constraint(self, foreign_table, local_columns,
- foreign_columns, options=None, constraint_name=None):
+ def add_foreign_key_constraint(
+ self,
+ foreign_table,
+ local_columns,
+ foreign_columns,
+ options=None,
+ constraint_name=None,
+ ):
"""
Adds a foreign key constraint.
@@ -259,16 +281,18 @@ def add_foreign_key_constraint(self, foreign_table, local_columns,
"""
if not constraint_name:
constraint_name = self._generate_identifier_name(
- [self.get_name()] + local_columns, 'fk', self._get_max_identifier_length()
+ [self.get_name()] + local_columns,
+ "fk",
+ self._get_max_identifier_length(),
)
return self.add_named_foreign_key_constraint(
- constraint_name, foreign_table,
- local_columns, foreign_columns, options
+ constraint_name, foreign_table, local_columns, foreign_columns, options
)
- def add_named_foreign_key_constraint(self, name, foreign_table,
- local_columns, foreign_columns, options):
+ def add_named_foreign_key_constraint(
+ self, name, foreign_table, local_columns, foreign_columns, options
+ ):
"""
Adds a foreign key constraint with a given name.
@@ -334,8 +358,10 @@ def _add_index(self, index):
replaced_implicit_indexes.append(name)
already_exists = (
- index_name in self._indexes and index_name not in replaced_implicit_indexes
- or self._primary_key_name is not False and index.is_primary()
+ index_name in self._indexes
+ and index_name not in replaced_implicit_indexes
+ or self._primary_key_name is not False
+ and index.is_primary()
)
if already_exists:
raise IndexAlreadyExists(index_name, self._name)
@@ -366,7 +392,9 @@ def _add_foreign_key_constraint(self, constraint):
name = constraint.get_name()
else:
name = self._generate_identifier_name(
- [self.get_name()] + constraint.get_local_columns(), 'fk', self._get_max_identifier_length()
+ [self.get_name()] + constraint.get_local_columns(),
+ "fk",
+ self._get_max_identifier_length(),
)
name = self._normalize_identifier(name)
@@ -380,16 +408,20 @@ def _add_foreign_key_constraint(self, constraint):
# This creates computation overhead in this case, however no duplicate indexes
# are ever added (based on columns).
index_name = self._generate_identifier_name(
- [self.get_name()] + constraint.get_columns(), 'idx', self._get_max_identifier_length()
+ [self.get_name()] + constraint.get_columns(),
+ "idx",
+ self._get_max_identifier_length(),
+ )
+ index_candidate = self._create_index(
+ constraint.get_columns(), index_name, False, False
)
- index_candidate = self._create_index(constraint.get_columns(), index_name, False, False)
for existing_index in self._indexes.values():
if index_candidate.is_fullfilled_by(existing_index):
return
- #self._add_index(index_candidate)
- #self._implicit_indexes[self._normalize_identifier(index_name)] = index_candidate
+ # self._add_index(index_candidate)
+ # self._implicit_indexes[self._normalize_identifier(index_name)] = index_candidate
return self
@@ -550,19 +582,27 @@ def clone(self):
table._primary_key_name = self._primary_key_name
for k, column in self._columns.items():
- table._columns[k] = Column(column.get_name(), column.get_type(), column.to_dict())
+ table._columns[k] = Column(
+ column.get_name(), column.get_type(), column.to_dict()
+ )
for k, index in self._indexes.items():
table._indexes[k] = Index(
- index.get_name(), index.get_columns(),
- index.is_unique(), index.is_primary(),
- index.get_flags(), index.get_options()
+ index.get_name(),
+ index.get_columns(),
+ index.is_unique(),
+ index.is_primary(),
+ index.get_flags(),
+ index.get_options(),
)
for k, fk in self._fk_constraints.items():
table._fk_constraints[k] = ForeignKeyConstraint(
- fk.get_local_columns(), fk.get_foreign_table_name(),
- fk.get_foreign_columns(), fk.get_name(), fk.get_options()
+ fk.get_local_columns(),
+ fk.get_foreign_table_name(),
+ fk.get_foreign_columns(),
+ fk.get_name(),
+ fk.get_options(),
)
table._fk_constraints[k].set_local_table(table)
diff --git a/orator/dbal/table_diff.py b/orator/dbal/table_diff.py
index c06a7647..dc012fd5 100644
--- a/orator/dbal/table_diff.py
+++ b/orator/dbal/table_diff.py
@@ -6,10 +6,17 @@
class TableDiff(object):
-
- def __init__(self, table_name, added_columns=None,
- changed_columns=None, removed_columns=None, added_indexes=None,
- changed_indexes=None, removed_indexes=None, from_table=None):
+ def __init__(
+ self,
+ table_name,
+ added_columns=None,
+ changed_columns=None,
+ removed_columns=None,
+ added_indexes=None,
+ changed_indexes=None,
+ removed_indexes=None,
+ from_table=None,
+ ):
self.name = table_name
self.new_name = False
self.added_columns = added_columns or OrderedDict()
diff --git a/orator/dbal/types/__init__.py b/orator/dbal/types/__init__.py
index 633f8661..40a96afc 100644
--- a/orator/dbal/types/__init__.py
+++ b/orator/dbal/types/__init__.py
@@ -1,2 +1 @@
# -*- coding: utf-8 -*-
-
diff --git a/orator/events/__init__.py b/orator/events/__init__.py
index fe7f5601..c4ad62e5 100644
--- a/orator/events/__init__.py
+++ b/orator/events/__init__.py
@@ -9,7 +9,7 @@ class Event(object):
@classmethod
def fire(cls, name, *args, **kwargs):
- name = 'orator.%s' % name
+ name = "orator.%s" % name
signal = cls.events.signal(name)
for response in signal.send(*args, **kwargs):
@@ -18,14 +18,14 @@ def fire(cls, name, *args, **kwargs):
@classmethod
def listen(cls, name, callback, *args, **kwargs):
- name = 'orator.%s' % name
+ name = "orator.%s" % name
signal = cls.events.signal(name)
signal.connect(callback, weak=False, *args, **kwargs)
@classmethod
def forget(cls, name, *args, **kwargs):
- name = 'orator.%s' % name
+ name = "orator.%s" % name
signal = cls.events.signal(name)
for receiver in signal.receivers:
diff --git a/orator/exceptions/connection.py b/orator/exceptions/connection.py
index c8581c5d..d7ff7869 100644
--- a/orator/exceptions/connection.py
+++ b/orator/exceptions/connection.py
@@ -2,7 +2,6 @@
class TransactionError(ConnectionError):
-
def __init__(self, previous, message=None):
self.previous = previous
- self.message = 'Transaction Error: '
+ self.message = "Transaction Error: "
diff --git a/orator/exceptions/connectors.py b/orator/exceptions/connectors.py
index b3a0a831..5cbdbb05 100644
--- a/orator/exceptions/connectors.py
+++ b/orator/exceptions/connectors.py
@@ -7,7 +7,6 @@ class ConnectorException(Exception):
class UnsupportedDriver(ConnectorException):
-
def __init__(self, driver):
message = 'Driver "%s" is not supported' % driver
@@ -15,7 +14,6 @@ def __init__(self, driver):
class MissingPackage(ConnectorException):
-
def __init__(self, driver, supported_packages):
if not isinstance(supported_packages, list):
supported_packages = [supported_packages]
@@ -24,6 +22,8 @@ def __init__(self, driver, supported_packages):
if len(supported_packages) == 1:
message += '"%s" package' % supported_packages[0]
else:
- message += 'one of the following packages: "%s"' % ('", "'.join(supported_packages))
-
+ message += 'one of the following packages: "%s"' % (
+ '", "'.join(supported_packages)
+ )
+
super(MissingPackage, self).__init__(message)
diff --git a/orator/exceptions/orm.py b/orator/exceptions/orm.py
index e52dbf4b..ef7d8a13 100644
--- a/orator/exceptions/orm.py
+++ b/orator/exceptions/orm.py
@@ -2,11 +2,10 @@
class ModelNotFound(RuntimeError):
-
def __init__(self, model):
self._model = model
- self.message = 'No query results found for model [%s]' % self._model.__name__
+ self.message = "No query results found for model [%s]" % self._model.__name__
def __str__(self):
return self.message
@@ -17,7 +16,6 @@ class MassAssignmentError(RuntimeError):
class RelatedClassNotFound(RuntimeError):
-
def __init__(self, related):
self._related = related
diff --git a/orator/exceptions/query.py b/orator/exceptions/query.py
index 1360b541..bafa1408 100644
--- a/orator/exceptions/query.py
+++ b/orator/exceptions/query.py
@@ -2,7 +2,6 @@
class QueryException(Exception):
-
def __init__(self, sql, bindings, previous):
self.sql = sql
self.bindings = bindings
@@ -10,7 +9,7 @@ def __init__(self, sql, bindings, previous):
self.message = self.format_message(sql, bindings, previous)
def format_message(self, sql, bindings, previous):
- return '%s (SQL: %s (%s))' % (str(previous), sql, bindings)
+ return "%s (SQL: %s (%s))" % (str(previous), sql, bindings)
def __repr__(self):
return self.message
diff --git a/orator/migrations/database_migration_repository.py b/orator/migrations/database_migration_repository.py
index 8a7b1699..fc9e9777 100644
--- a/orator/migrations/database_migration_repository.py
+++ b/orator/migrations/database_migration_repository.py
@@ -4,7 +4,6 @@
class DatabaseMigrationRepository(object):
-
def __init__(self, resolver, table):
"""
:type resolver: orator.database_manager.DatabaseManager
@@ -20,7 +19,7 @@ def get_ran(self):
:rtype: list
"""
- return self.table().lists('migration')
+ return self.table().lists("migration")
def get_last(self):
"""
@@ -28,9 +27,9 @@ def get_last(self):
:rtype: list
"""
- query = self.table().where('batch', self.get_last_batch_number())
+ query = self.table().where("batch", self.get_last_batch_number())
- return query.order_by('migration', 'desc').get()
+ return query.order_by("migration", "desc").get()
def log(self, file, batch):
"""
@@ -39,10 +38,7 @@ def log(self, file, batch):
:type file: str
:type batch: int
"""
- record = {
- 'migration': file,
- 'batch': batch
- }
+ record = {"migration": file, "batch": batch}
self.table().insert(**record)
@@ -52,7 +48,7 @@ def delete(self, migration):
:type migration: dict
"""
- self.table().where('migration', migration['migration']).delete()
+ self.table().where("migration", migration["migration"]).delete()
def get_next_batch_number(self):
"""
@@ -68,7 +64,7 @@ def get_last_batch_number(self):
:rtype: int
"""
- return self.table().max('batch') or 0
+ return self.table().max("batch") or 0
def create_repository(self):
"""
@@ -80,8 +76,8 @@ def create_repository(self):
# The migrations table is responsible for keeping track of which of the
# migrations have actually run for the application. We'll create the
# table to hold the migration file's path as well as the batch ID.
- table.string('migration')
- table.integer('batch')
+ table.string("migration")
+ table.integer("batch")
def repository_exists(self):
"""
diff --git a/orator/migrations/migration_creator.py b/orator/migrations/migration_creator.py
index 2b267acb..b421b46e 100644
--- a/orator/migrations/migration_creator.py
+++ b/orator/migrations/migration_creator.py
@@ -18,7 +18,6 @@ def mkdir_p(path):
class MigrationCreator(object):
-
def create(self, name, path, table=None, create=False):
"""
Create a new migration at the given path.
@@ -38,14 +37,14 @@ def create(self, name, path, table=None, create=False):
if not os.path.exists(os.path.dirname(path)):
mkdir_p(os.path.dirname(path))
- parent = os.path.join(os.path.dirname(path), '__init__.py')
+ parent = os.path.join(os.path.dirname(path), "__init__.py")
if not os.path.exists(parent):
- with open(parent, 'w'):
+ with open(parent, "w"):
pass
stub = self._get_stub(table, create)
- with open(path, 'w') as fh:
+ with open(path, "w") as fh:
fh.write(self._populate_stub(name, stub, table))
return path
@@ -87,10 +86,10 @@ def _populate_stub(self, name, stub, table):
:rtype: str
"""
- stub = stub.replace('DummyClass', self._get_class_name(name))
+ stub = stub.replace("DummyClass", self._get_class_name(name))
if table is not None:
- stub = stub.replace('dummy_table', table)
+ stub = stub.replace("dummy_table", table)
return stub
@@ -98,7 +97,7 @@ def _get_class_name(self, name):
return inflection.camelize(name)
def _get_path(self, name, path):
- return os.path.join(path, self._get_date_prefix() + '_' + name + '.py')
+ return os.path.join(path, self._get_date_prefix() + "_" + name + ".py")
def _get_date_prefix(self):
- return datetime.datetime.utcnow().strftime('%Y_%m_%d_%H%M%S')
+ return datetime.datetime.utcnow().strftime("%Y_%m_%d_%H%M%S")
diff --git a/orator/migrations/migrator.py b/orator/migrations/migrator.py
index 432b4ddb..006205e3 100644
--- a/orator/migrations/migrator.py
+++ b/orator/migrations/migrator.py
@@ -11,7 +11,6 @@
class MigratorHandler(logging.NullHandler):
-
def __init__(self, level=logging.DEBUG):
super(MigratorHandler, self).__init__(level)
@@ -22,7 +21,6 @@ def handle(self, record):
class Migrator(object):
-
def __init__(self, repository, resolver):
"""
:type repository: DatabaseMigrationRepository
@@ -61,7 +59,7 @@ def run_migration_list(self, path, migrations, pretend=False):
:type pretend: bool
"""
if not migrations:
- self._note('Nothing to migrate')
+ self._note("Nothing to migrate")
return
@@ -83,7 +81,7 @@ def _run_up(self, path, migration_file, batch, pretend=False):
migration = self._resolve(path, migration_file)
if pretend:
- return self._pretend_to_run(migration, 'up')
+ return self._pretend_to_run(migration, "up")
if migration.transactional:
with migration.db.transaction():
@@ -93,7 +91,10 @@ def _run_up(self, path, migration_file, batch, pretend=False):
self._repository.log(migration_file, batch)
- self._note(decode('[OK>] Migrated ') + '%s>' % migration_file)
+ self._note(
+ decode("[OK>] Migrated ")
+ + "%s>" % migration_file
+ )
def rollback(self, path, pretend=False):
"""
@@ -112,7 +113,7 @@ def rollback(self, path, pretend=False):
migrations = self._repository.get_last()
if not migrations:
- self._note('Nothing to rollback.')
+ self._note("Nothing to rollback.")
return len(migrations)
@@ -140,10 +141,10 @@ def reset(self, path, pretend=False):
count = len(migrations)
if count == 0:
- self._note('Nothing to rollback.')
+ self._note("Nothing to rollback.")
else:
for migration in migrations:
- self._run_down(path, {'migration': migration}, pretend)
+ self._run_down(path, {"migration": migration}, pretend)
return count
@@ -151,12 +152,12 @@ def _run_down(self, path, migration, pretend=False):
"""
Run "down" a migration instance.
"""
- migration_file = migration['migration']
+ migration_file = migration["migration"]
instance = self._resolve(path, migration_file)
if pretend:
- return self._pretend_to_run(instance, 'down')
+ return self._pretend_to_run(instance, "down")
if instance.transactional:
with instance.db.transaction():
@@ -166,7 +167,10 @@ def _run_down(self, path, migration, pretend=False):
self._repository.delete(migration)
- self._note(decode('[OK>] Rolled back ') + '%s>' % migration_file)
+ self._note(
+ decode("[OK>] Rolled back ")
+ + "%s>" % migration_file
+ )
def _get_migration_files(self, path):
"""
@@ -176,12 +180,12 @@ def _get_migration_files(self, path):
:rtype: list
"""
- files = glob.glob(os.path.join(path, '[0-9]*_*.py'))
+ files = glob.glob(os.path.join(path, "[0-9]*_*.py"))
if not files:
return []
- files = list(map(lambda f: os.path.basename(f).replace('.py', ''), files))
+ files = list(map(lambda f: os.path.basename(f).replace(".py", ""), files))
files = sorted(files)
@@ -197,7 +201,7 @@ def _pretend_to_run(self, migration, method):
:param method: The method to execute
:type method: str
"""
- self._note('')
+ self._note("")
names = []
for query in self._get_queries(migration, method):
name = migration.__class__.__name__
@@ -206,17 +210,13 @@ def _pretend_to_run(self, migration, method):
if isinstance(query, tuple):
query, bindings = query
- query = highlight(
- query,
- SqlLexer(),
- CommandFormatter()
- ).strip()
+ query = highlight(query, SqlLexer(), CommandFormatter()).strip()
if bindings:
query = (query, bindings)
if name not in names:
- self._note('[{}]'.format(name))
+ self._note("[{}]".format(name))
names.append(name)
self._note(query)
@@ -250,19 +250,19 @@ def _resolve(self, path, migration_file):
:rtype: orator.migrations.migration.Migration
"""
- name = '_'.join(migration_file.split('_')[4:])
- migration_file = os.path.join(path, '%s.py' % migration_file)
+ name = "_".join(migration_file.split("_")[4:])
+ migration_file = os.path.join(path, "%s.py" % migration_file)
# Loading parent module
- parent = os.path.join(path, '__init__.py')
+ parent = os.path.join(path, "__init__.py")
if not os.path.exists(parent):
- with open(parent, 'w'):
+ with open(parent, "w"):
pass
- load_module('migrations', parent)
+ load_module("migrations", parent)
# Loading module
- mod = load_module('migrations.%s' % name, migration_file)
+ mod = load_module("migrations.%s" % name, migration_file)
klass = getattr(mod, inflection.camelize(name))
diff --git a/orator/orm/__init__.py b/orator/orm/__init__.py
index b9a222c2..fc37731f 100644
--- a/orator/orm/__init__.py
+++ b/orator/orm/__init__.py
@@ -6,10 +6,18 @@
from .collection import Collection
from .factory import Factory
from .utils import (
- mutator, accessor, column,
- has_one, morph_one,
- belongs_to, morph_to,
- has_many, has_many_through, morph_many,
- belongs_to_many, morph_to_many, morphed_by_many,
- scope
+ mutator,
+ accessor,
+ column,
+ has_one,
+ morph_one,
+ belongs_to,
+ morph_to,
+ has_many,
+ has_many_through,
+ morph_many,
+ belongs_to_many,
+ morph_to_many,
+ morphed_by_many,
+ scope,
)
diff --git a/orator/orm/builder.py b/orator/orm/builder.py
index d86d80aa..72a7441e 100644
--- a/orator/orm/builder.py
+++ b/orator/orm/builder.py
@@ -13,8 +13,19 @@
class Builder(object):
_passthru = [
- 'to_sql', 'lists', 'insert', 'insert_get_id', 'pluck', 'count',
- 'min', 'max', 'avg', 'sum', 'exists', 'get_bindings', 'raw'
+ "to_sql",
+ "lists",
+ "insert",
+ "insert_get_id",
+ "pluck",
+ "count",
+ "min",
+ "max",
+ "avg",
+ "sum",
+ "exists",
+ "get_bindings",
+ "raw",
]
def __init__(self, query):
@@ -97,12 +108,12 @@ def find(self, id, columns=None):
:rtype: orator.Model
"""
if columns is None:
- columns = ['*']
+ columns = ["*"]
if isinstance(id, list):
return self.find_many(id, columns)
- self._query.where(self._model.get_qualified_key_name(), '=', id)
+ self._query.where(self._model.get_qualified_key_name(), "=", id)
return self.first(columns)
@@ -120,7 +131,7 @@ def find_many(self, id, columns=None):
:rtype: orator.Collection
"""
if columns is None:
- columns = ['*']
+ columns = ["*"]
if not id:
return self._model.new_collection()
@@ -165,7 +176,7 @@ def first(self, columns=None):
:rtype: mixed
"""
if columns is None:
- columns = ['*']
+ columns = ["*"]
return self.take(1).get(columns).first()
@@ -232,15 +243,19 @@ def chunk(self, count):
:return: The current chunk
:rtype: list
"""
- page = 1
- results = self.for_page(page, count).get()
+ connection = self._model.get_connection_name()
+ for results in self.apply_scopes().get_query().chunk(count):
+ models = self._model.hydrate(results, connection)
- while not results.is_empty():
- yield results
+ # If we actually found models we will also eager load any relationships that
+ # have been specified as needing to be eager loaded, which will solve the
+ # n+1 query issue for the developers to avoid running a lot of queries.
+ if len(models) > 0:
+ models = self.eager_load_relations(models)
- page += 1
+ collection = self._model.new_collection(models)
- results = self.for_page(page, count).get()
+ yield collection
def lists(self, column, key=None):
"""
@@ -287,7 +302,7 @@ def paginate(self, per_page=None, current_page=None, columns=None):
:return: The paginator
"""
if columns is None:
- columns = ['*']
+ columns = ["*"]
total = self.to_base().get_count_for_pagination()
@@ -313,7 +328,7 @@ def simple_paginate(self, per_page=None, current_page=None, columns=None):
:return: The paginator
"""
if columns is None:
- columns = ['*']
+ columns = ["*"]
page = current_page or Paginator.resolve_current_page()
per_page = per_page or self._model.get_per_page()
@@ -398,7 +413,8 @@ def _add_updated_at_column(self, values):
column = self._model.get_updated_at_column()
- values.update({column: self._model.fresh_timestamp()})
+ if "updated_at" not in values:
+ values.update({column: self._model.fresh_timestamp_string()})
return values
@@ -455,7 +471,7 @@ def eager_load_relations(self, models):
:rtype: list
"""
for name, constraints in self._eager_load.items():
- if name.find('.') == -1:
+ if name.find(".") == -1:
models = self._load_relation(models, name, constraints)
return models
@@ -509,7 +525,7 @@ def _nested_relations(self, relation):
for name, constraints in self._eager_load.items():
if self._is_nested(name, relation):
- nested[name[len(relation + '.'):]] = constraints
+ nested[name[len(relation + ".") :]] = constraints
return nested
@@ -522,11 +538,11 @@ def _is_nested(self, name, relation):
:rtype: bool
"""
- dots = name.find('.')
+ dots = name.find(".")
- return dots and name.startswith(relation + '.')
+ return dots and name.startswith(relation + ".")
- def where(self, column, operator=Null(), value=None, boolean='and'):
+ def where(self, column, operator=Null(), value=None, boolean="and"):
"""
Add a where clause to the query
@@ -568,9 +584,9 @@ def or_where(self, column, operator=None, value=None):
:return: The current Builder instance
:rtype: Builder
"""
- return self.where(column, operator, value, 'or')
+ return self.where(column, operator, value, "or")
- def where_exists(self, query, boolean='and', negate=False):
+ def where_exists(self, query, boolean="and", negate=False):
"""
Add an exists clause to the query.
@@ -601,9 +617,9 @@ def or_where_exists(self, query, negate=False):
:rtype: Builder
"""
- return self.where_exists(query, 'or', negate)
+ return self.where_exists(query, "or", negate)
- def where_not_exists(self, query, boolean='and'):
+ def where_not_exists(self, query, boolean="and"):
"""
Add a where not exists clause to the query.
@@ -627,7 +643,7 @@ def or_where_not_exists(self, query):
"""
return self.or_where_exists(query, True)
- def has(self, relation, operator='>=', count=1, boolean='and', extra=None):
+ def has(self, relation, operator=">=", count=1, boolean="and", extra=None):
"""
Add a relationship count condition to the query.
@@ -648,21 +664,25 @@ def has(self, relation, operator='>=', count=1, boolean='and', extra=None):
:type: Builder
"""
- if relation.find('.') >= 0:
+ if relation.find(".") >= 0:
return self._has_nested(relation, operator, count, boolean, extra)
relation = self._get_has_relation_query(relation)
- query = relation.get_relation_count_query(relation.get_related().new_query(), self)
+ query = relation.get_relation_count_query(
+ relation.get_related().new_query(), self
+ )
# TODO: extra query
if extra:
if callable(extra):
extra(query)
- return self._add_has_where(query.apply_scopes(), relation, operator, count, boolean)
+ return self._add_has_where(
+ query.apply_scopes(), relation, operator, count, boolean
+ )
- def _has_nested(self, relations, operator='>=', count=1, boolean='and', extra=None):
+ def _has_nested(self, relations, operator=">=", count=1, boolean="and", extra=None):
"""
Add nested relationship count conditions to the query.
@@ -683,7 +703,7 @@ def _has_nested(self, relations, operator='>=', count=1, boolean='and', extra=No
:rtype: Builder
"""
- relations = relations.split('.')
+ relations = relations.split(".")
def closure(q):
if len(relations) > 1:
@@ -693,7 +713,7 @@ def closure(q):
return self.where_has(relations.pop(0), closure)
- def doesnt_have(self, relation, boolean='and', extra=None):
+ def doesnt_have(self, relation, boolean="and", extra=None):
"""
Add a relationship count to the query.
@@ -708,9 +728,9 @@ def doesnt_have(self, relation, boolean='and', extra=None):
:rtype: Builder
"""
- return self.has(relation, '<', 1, boolean, extra)
+ return self.has(relation, "<", 1, boolean, extra)
- def where_has(self, relation, extra, operator='>=', count=1):
+ def where_has(self, relation, extra, operator=">=", count=1):
"""
Add a relationship count condition to the query with where clauses.
@@ -728,7 +748,7 @@ def where_has(self, relation, extra, operator='>=', count=1):
:rtype: Builder
"""
- return self.has(relation, operator, count, 'and', extra)
+ return self.has(relation, operator, count, "and", extra)
def where_doesnt_have(self, relation, extra=None):
"""
@@ -742,9 +762,9 @@ def where_doesnt_have(self, relation, extra=None):
:rtype: Builder
"""
- return self.doesnt_have(relation, 'and', extra)
+ return self.doesnt_have(relation, "and", extra)
- def or_has(self, relation, operator='>=', count=1):
+ def or_has(self, relation, operator=">=", count=1):
"""
Add a relationship count condition to the query with an "or".
@@ -759,9 +779,9 @@ def or_has(self, relation, operator='>=', count=1):
:rtype: Builder
"""
- return self.has(relation, operator, count, 'or')
+ return self.has(relation, operator, count, "or")
- def or_where_has(self, relation, extra, operator='>=', count=1):
+ def or_where_has(self, relation, extra, operator=">=", count=1):
"""
Add a relationship count condition to the query with where clauses and an "or".
@@ -779,7 +799,7 @@ def or_where_has(self, relation, extra, operator='>=', count=1):
:rtype: Builder
"""
- return self.has(relation, operator, count, 'or', extra)
+ return self.has(relation, operator, count, "or", extra)
def _add_has_where(self, has_query, relation, operator, count, boolean):
"""
@@ -807,7 +827,9 @@ def _add_has_where(self, has_query, relation, operator, count, boolean):
if isinstance(count, basestring) and count.isdigit():
count = QueryExpression(count)
- return self.where(QueryExpression('(%s)' % has_query.to_sql()), operator, count, boolean)
+ return self.where(
+ QueryExpression("(%s)" % has_query.to_sql()), operator, count, boolean
+ )
def _merge_model_defined_relation_wheres_to_has_query(self, has_query, relation):
"""
@@ -821,11 +843,9 @@ def _merge_model_defined_relation_wheres_to_has_query(self, has_query, relation)
"""
relation_query = relation.get_base_query()
- has_query.merge_wheres(
- relation_query.wheres, relation_query.get_bindings()
- )
+ has_query.merge_wheres(relation_query.wheres, relation_query.get_bindings())
- self._query.add_binding(has_query.get_query().get_bindings(), 'where')
+ self._query.add_binding(has_query.get_query().get_bindings(), "where")
def _get_has_relation_query(self, relation):
"""
@@ -869,8 +889,12 @@ def _parse_with_relations(self, relations):
for relation in relations:
if isinstance(relation, dict):
- name = list(relation.keys())[0]
- constraints = relation[name]
+ for name, constraints in relation.items():
+ results = self._parse_nested_with(name, results)
+
+ results[name] = constraints
+
+ continue
else:
name = relation
constraints = self.__class__(self.get_query().new_query())
@@ -894,10 +918,10 @@ def _parse_nested_with(self, name, results):
"""
progress = []
- for segment in name.split('.'):
+ for segment in name.split("."):
progress.append(segment)
- last = '.'.join(progress)
+ last = ".".join(progress)
if last not in results:
results[last] = self.__class__(self.get_query().new_query())
@@ -920,7 +944,9 @@ def _call_scope(self, scope, *args, **kwargs):
result = getattr(self._model, scope)(self, *args, **kwargs)
if self._should_nest_wheres_for_scope(query, original_where_count):
- self._nest_wheres_for_scope(query, [0, original_where_count, len(query.wheres)])
+ self._nest_wheres_for_scope(
+ query, [0, original_where_count, len(query.wheres)]
+ )
return result or self
@@ -1022,13 +1048,9 @@ def _slice_where_conditions(self, wheres, offset, length):
:rtype: list
"""
where_group = self.get_query().for_nested_where()
- where_group.wheres = wheres[offset:(offset + length)]
+ where_group.wheres = wheres[offset : (offset + length)]
- return {
- 'type': 'nested',
- 'query': where_group,
- 'boolean': 'and'
- }
+ return {"type": "nested", "query": where_group, "boolean": "and"}
def get_query(self):
"""
@@ -1124,12 +1146,14 @@ def get_macro(self, name):
def __dynamic(self, method):
from .utils import scope
- scope_method = 'scope_%s' % method
+ scope_method = "scope_%s" % method
is_scope = False
is_macro = False
# New scope definition check
- if hasattr(self._model, method) and isinstance(getattr(self._model, method), scope):
+ if hasattr(self._model, method) and isinstance(
+ getattr(self._model, method), scope
+ ):
is_scope = True
attribute = getattr(self._model, method)
scope_method = method
@@ -1172,4 +1196,3 @@ def __copy__(self):
new.set_model(self._model)
return new
-
diff --git a/orator/orm/collection.py b/orator/orm/collection.py
index 2024295a..72ac873b 100644
--- a/orator/orm/collection.py
+++ b/orator/orm/collection.py
@@ -4,7 +4,6 @@
class Collection(BaseCollection):
-
def load(self, *relations):
"""
Load a set of relationships onto the collection.
diff --git a/orator/orm/factory.py b/orator/orm/factory.py
index d02f62e3..560914da 100644
--- a/orator/orm/factory.py
+++ b/orator/orm/factory.py
@@ -8,7 +8,6 @@
class Factory(object):
-
def __init__(self, faker=None, resolver=None):
"""
:param faker: A faker generator instance
@@ -56,7 +55,7 @@ def define_as(self, klass, name):
"""
return self.define(klass, name)
- def define(self, klass, name='default'):
+ def define(self, klass, name="default"):
"""
Define a class with a given set of attributes.
@@ -66,6 +65,7 @@ def define(self, klass, name='default'):
:param name: The short name
:type name: str
"""
+
def decorate(func):
@wraps(func)
def wrapped(*args, **kwargs):
@@ -77,7 +77,7 @@ def wrapped(*args, **kwargs):
return decorate
- def register(self, klass, callback, name='default'):
+ def register(self, klass, callback, name="default"):
"""
Register a class with a function.
@@ -189,7 +189,7 @@ def raw_of(self, klass, name, **attributes):
"""
return self.raw(klass, _name=name, **attributes)
- def raw(self, klass, _name='default', **attributes):
+ def raw(self, klass, _name="default", **attributes):
"""
Get the raw attribute dict for a given named model.
@@ -210,7 +210,7 @@ def raw(self, klass, _name='default', **attributes):
return raw
- def of(self, klass, name='default'):
+ def of(self, klass, name="default"):
"""
Create a builder for the given model.
@@ -222,9 +222,11 @@ def of(self, klass, name='default'):
:return: orator.orm.factory_builder.FactoryBuilder
"""
- return FactoryBuilder(klass, name, self._definitions, self._faker, self._resolver)
+ return FactoryBuilder(
+ klass, name, self._definitions, self._faker, self._resolver
+ )
- def build(self, klass, name='default', amount=None):
+ def build(self, klass, name="default", amount=None):
"""
Makes a factory builder with a specified amount.
@@ -242,7 +244,7 @@ def build(self, klass, name='default', amount=None):
if amount is None:
if isinstance(name, int):
amount = name
- name = 'default'
+ name = "default"
else:
amount = 1
@@ -287,7 +289,7 @@ def __setitem__(self, key, value):
def __contains__(self, item):
return item in self._definitions
- def __call__(self, klass, name='default', amount=None):
+ def __call__(self, klass, name="default", amount=None):
"""
Makes a factory builder with a specified amount.
diff --git a/orator/orm/factory_builder.py b/orator/orm/factory_builder.py
index 93e285ec..960cd07a 100644
--- a/orator/orm/factory_builder.py
+++ b/orator/orm/factory_builder.py
@@ -4,7 +4,6 @@
class FactoryBuilder(object):
-
def __init__(self, klass, name, definitions, faker, resolver=None):
"""
:param klass: The class
diff --git a/orator/orm/mixins/soft_deletes.py b/orator/orm/mixins/soft_deletes.py
index 45be01aa..57949e15 100644
--- a/orator/orm/mixins/soft_deletes.py
+++ b/orator/orm/mixins/soft_deletes.py
@@ -35,7 +35,11 @@ def _do_perform_delete_on_model(self):
Perform the actual delete query on this model instance.
"""
if self.__force_deleting__:
- return self.with_trashed().where(self.get_key_name(), self.get_key()).force_delete()
+ return (
+ self.with_trashed()
+ .where(self.get_key_name(), self.get_key())
+ .force_delete()
+ )
return self._run_soft_delete()
@@ -48,15 +52,13 @@ def _run_soft_delete(self):
time = self.fresh_timestamp()
setattr(self, self.get_deleted_at_column(), time)
- query.update({
- self.get_deleted_at_column(): self.from_datetime(time)
- })
+ query.update({self.get_deleted_at_column(): self.from_datetime(time)})
def restore(self):
"""
Restore a soft-deleted model instance.
"""
- if self._fire_model_event('restoring') is False:
+ if self._fire_model_event("restoring") is False:
return False
setattr(self, self.get_deleted_at_column(), None)
@@ -65,7 +67,7 @@ def restore(self):
result = self.save()
- self._fire_model_event('restored')
+ self._fire_model_event("restored")
return result
@@ -99,7 +101,9 @@ def only_trashed(cls):
column = instance.get_qualified_deleted_at_column()
- return instance.new_query_without_scope(SoftDeletingScope()).where_not_null(column)
+ return instance.new_query_without_scope(SoftDeletingScope()).where_not_null(
+ column
+ )
@classmethod
def restoring(cls, callback):
@@ -108,7 +112,7 @@ def restoring(cls, callback):
:type callback: callable
"""
- cls._register_model_event('restoring', callback)
+ cls._register_model_event("restoring", callback)
@classmethod
def restored(cls, callback):
@@ -117,7 +121,7 @@ def restored(cls, callback):
:type callback: callable
"""
- cls._register_model_event('restored', callback)
+ cls._register_model_event("restored", callback)
def get_deleted_at_column(self):
"""
@@ -125,7 +129,7 @@ def get_deleted_at_column(self):
:rtype: str
"""
- return getattr(self, 'DELETED_AT', 'deleted_at')
+ return getattr(self, "DELETED_AT", "deleted_at")
def get_qualified_deleted_at_column(self):
"""
@@ -133,4 +137,4 @@ def get_qualified_deleted_at_column(self):
:rtype: str
"""
- return '%s.%s' % (self.get_table(), self.get_deleted_at_column())
+ return "%s.%s" % (self.get_table(), self.get_deleted_at_column())
diff --git a/orator/orm/model.py b/orator/orm/model.py
index 06b2fedb..f5e62887 100644
--- a/orator/orm/model.py
+++ b/orator/orm/model.py
@@ -5,6 +5,7 @@
import inflection
import inspect
import uuid
+import datetime
from warnings import warn
from six import add_metaclass
from collections import OrderedDict
@@ -14,8 +15,16 @@
from .builder import Builder
from .collection import Collection
from .relations import (
- Relation, HasOne, HasMany, BelongsTo, BelongsToMany, HasManyThrough,
- MorphOne, MorphMany, MorphTo, MorphToMany
+ Relation,
+ HasOne,
+ HasMany,
+ BelongsTo,
+ BelongsToMany,
+ HasManyThrough,
+ MorphOne,
+ MorphMany,
+ MorphTo,
+ MorphToMany,
)
from .relations.wrapper import Wrapper, BelongsToManyWrapper
from .utils import mutator, accessor
@@ -24,7 +33,6 @@
class ModelRegister(dict):
-
def __init__(self, *args, **kwargs):
self.inverse = {}
@@ -67,12 +75,12 @@ class Model(object):
__table__ = None
- __primary_key__ = 'id'
+ __primary_key__ = "id"
__incrementing__ = True
__fillable__ = []
- __guarded__ = ['*']
+ __guarded__ = ["*"]
__unguarded__ = False
__hidden__ = []
@@ -109,10 +117,10 @@ class Model(object):
__attributes__ = {}
- many_methods = ['belongs_to_many', 'morph_to_many', 'morphed_by_many']
+ many_methods = ["belongs_to_many", "morph_to_many", "morphed_by_many"]
- CREATED_AT = 'created_at'
- UPDATED_AT = 'updated_at'
+ CREATED_AT = "created_at"
+ UPDATED_AT = "updated_at"
def __init__(self, _attributes=None, **attributes):
"""
@@ -143,11 +151,11 @@ def _boot_if_not_booted(self):
if not klass._booted.get(klass):
klass._booted[klass] = True
- self._fire_model_event('booting')
+ self._fire_model_event("booting")
klass._boot()
- self._fire_model_event('booted')
+ self._fire_model_event("booted")
@classmethod
def _boot(cls):
@@ -168,7 +176,9 @@ def _boot(cls):
@classmethod
def _boot_columns(cls):
connection = cls.resolve_connection()
- columns = connection.get_schema_manager().list_table_columns(cls.__table__ or inflection.tableize(cls.__name__))
+ columns = connection.get_schema_manager().list_table_columns(
+ cls.__table__ or inflection.tableize(cls.__name__)
+ )
cls.__columns__ = list(columns.keys())
@classmethod
@@ -177,15 +187,15 @@ def _boot_mixins(cls):
Boot the mixins
"""
for mixin in cls.__bases__:
- #if mixin == Model:
+ # if mixin == Model:
# continue
- method = 'boot_%s' % inflection.underscore(mixin.__name__)
+ method = "boot_%s" % inflection.underscore(mixin.__name__)
if hasattr(mixin, method):
getattr(mixin, method)(cls)
@classmethod
- def add_global_scope(cls, scope, implementation = None):
+ def add_global_scope(cls, scope, implementation=None):
"""
Register a new global scope on the model.
@@ -205,7 +215,7 @@ def add_global_scope(cls, scope, implementation = None):
elif isinstance(scope, Scope):
cls._global_scopes[cls][scope.__class__] = scope
else:
- raise Exception('Global scope must be an instance of Scope or a callable')
+ raise Exception("Global scope must be an instance of Scope or a callable")
@classmethod
def has_global_scope(cls, scope):
@@ -565,7 +575,7 @@ def find(cls, id, columns=None):
return instance.new_collection()
if columns is None:
- columns = ['*']
+ columns = ["*"]
return instance.new_query().find(id, columns)
@@ -639,9 +649,9 @@ def with_(cls, *relations):
return instance.new_query().with_(*relations)
- def has_one(self, related,
- foreign_key=None, local_key=None, relation=None,
- _wrapped=True):
+ def has_one(
+ self, related, foreign_key=None, local_key=None, relation=None, _wrapped=True
+ ):
"""
Define a one to one relationship.
@@ -675,15 +685,19 @@ def has_one(self, related,
if not local_key:
local_key = self.get_key_name()
- rel = HasOne(instance.new_query(),
- self,
- '%s.%s' % (instance.get_table(), foreign_key),
- local_key)
+ rel = HasOne(
+ instance.new_query(),
+ self,
+ "%s.%s" % (instance.get_table(), foreign_key),
+ local_key,
+ )
if _wrapped:
- warn('Using has_one method directly is deprecated. '
- 'Use the appropriate decorator instead.',
- category=DeprecationWarning)
+ warn(
+ "Using has_one method directly is deprecated. "
+ "Use the appropriate decorator instead.",
+ category=DeprecationWarning,
+ )
rel = Wrapper(rel)
@@ -691,10 +705,16 @@ def has_one(self, related,
return rel
- def morph_one(self, related, name,
- type_column=None, id_column=None,
- local_key=None, relation=None,
- _wrapped=True):
+ def morph_one(
+ self,
+ related,
+ name,
+ type_column=None,
+ id_column=None,
+ local_key=None,
+ relation=None,
+ _wrapped=True,
+ ):
"""
Define a polymorphic one to one relationship.
@@ -730,14 +750,20 @@ def morph_one(self, related, name,
if not local_key:
local_key = self.get_key_name()
- rel = MorphOne(instance.new_query(), self,
- '%s.%s' % (table, type_column),
- '%s.%s' % (table, id_column), local_key)
+ rel = MorphOne(
+ instance.new_query(),
+ self,
+ "%s.%s" % (table, type_column),
+ "%s.%s" % (table, id_column),
+ local_key,
+ )
if _wrapped:
- warn('Using morph_one method directly is deprecated. '
- 'Use the appropriate decorator instead.',
- category=DeprecationWarning)
+ warn(
+ "Using morph_one method directly is deprecated. "
+ "Use the appropriate decorator instead.",
+ category=DeprecationWarning,
+ )
rel = Wrapper(rel)
@@ -745,9 +771,9 @@ def morph_one(self, related, name,
return rel
- def belongs_to(self, related,
- foreign_key=None, other_key=None, relation=None,
- _wrapped=True):
+ def belongs_to(
+ self, related, foreign_key=None, other_key=None, relation=None, _wrapped=True
+ ):
"""
Define an inverse one to one or many relationship.
@@ -771,7 +797,7 @@ def belongs_to(self, related,
return self._relations[relation]
if foreign_key is None:
- foreign_key = '%s_id' % inflection.underscore(relation)
+ foreign_key = "%s_id" % inflection.underscore(relation)
instance = self._get_related(related, True)
@@ -783,9 +809,11 @@ def belongs_to(self, related,
rel = BelongsTo(query, self, foreign_key, other_key, relation)
if _wrapped:
- warn('Using belongs_to method directly is deprecated. '
- 'Use the appropriate decorator instead.',
- category=DeprecationWarning)
+ warn(
+ "Using belongs_to method directly is deprecated. "
+ "Use the appropriate decorator instead.",
+ category=DeprecationWarning,
+ )
rel = Wrapper(rel)
@@ -793,8 +821,7 @@ def belongs_to(self, related,
return rel
- def morph_to(self, name=None, type_column=None, id_column=None,
- _wrapped=True):
+ def morph_to(self, name=None, type_column=None, id_column=None, _wrapped=True):
"""
Define a polymorphic, inverse one-to-one or many relationship.
@@ -831,14 +858,21 @@ def morph_to(self, name=None, type_column=None, id_column=None,
instance = klass()
instance.set_connection(self.get_connection_name())
- rel = MorphTo(instance.new_query(),
- self, id_column,
- instance.get_key_name(), type_column, name)
+ rel = MorphTo(
+ instance.new_query(),
+ self,
+ id_column,
+ instance.get_key_name(),
+ type_column,
+ name,
+ )
if _wrapped:
- warn('Using morph_to method directly is deprecated. '
- 'Use the appropriate decorator instead.',
- category=DeprecationWarning)
+ warn(
+ "Using morph_to method directly is deprecated. "
+ "Use the appropriate decorator instead.",
+ category=DeprecationWarning,
+ )
rel = Wrapper(rel)
@@ -858,8 +892,9 @@ def get_actual_class_for_morph(self, slug):
if morph_name == slug:
return cls
- def has_many(self, related, foreign_key=None, local_key=None,
- relation=None, _wrapped=True):
+ def has_many(
+ self, related, foreign_key=None, local_key=None, relation=None, _wrapped=True
+ ):
"""
Define a one to many relationship.
@@ -893,15 +928,19 @@ def has_many(self, related, foreign_key=None, local_key=None,
if not local_key:
local_key = self.get_key_name()
- rel = HasMany(instance.new_query(),
- self,
- '%s.%s' % (instance.get_table(), foreign_key),
- local_key)
+ rel = HasMany(
+ instance.new_query(),
+ self,
+ "%s.%s" % (instance.get_table(), foreign_key),
+ local_key,
+ )
if _wrapped:
- warn('Using has_many method directly is deprecated. '
- 'Use the appropriate decorator instead.',
- category=DeprecationWarning)
+ warn(
+ "Using has_many method directly is deprecated. "
+ "Use the appropriate decorator instead.",
+ category=DeprecationWarning,
+ )
rel = Wrapper(rel)
@@ -909,9 +948,15 @@ def has_many(self, related, foreign_key=None, local_key=None,
return rel
- def has_many_through(self, related, through,
- first_key=None, second_key=None, relation=None,
- _wrapped=True):
+ def has_many_through(
+ self,
+ related,
+ through,
+ first_key=None,
+ second_key=None,
+ relation=None,
+ _wrapped=True,
+ ):
"""
Define a has-many-through relationship.
@@ -948,13 +993,20 @@ def has_many_through(self, related, through,
if not second_key:
second_key = through.get_foreign_key()
- rel = HasManyThrough(self._get_related(related)().new_query(),
- self, through, first_key, second_key)
+ rel = HasManyThrough(
+ self._get_related(related)().new_query(),
+ self,
+ through,
+ first_key,
+ second_key,
+ )
if _wrapped:
- warn('Using has_many_through method directly is deprecated. '
- 'Use the appropriate decorator instead.',
- category=DeprecationWarning)
+ warn(
+ "Using has_many_through method directly is deprecated. "
+ "Use the appropriate decorator instead.",
+ category=DeprecationWarning,
+ )
rel = Wrapper(rel)
@@ -962,10 +1014,16 @@ def has_many_through(self, related, through,
return rel
- def morph_many(self, related, name,
- type_column=None, id_column=None,
- local_key=None, relation=None,
- _wrapped=True):
+ def morph_many(
+ self,
+ related,
+ name,
+ type_column=None,
+ id_column=None,
+ local_key=None,
+ relation=None,
+ _wrapped=True,
+ ):
"""
Define a polymorphic one to many relationship.
@@ -1001,14 +1059,20 @@ def morph_many(self, related, name,
if not local_key:
local_key = self.get_key_name()
- rel = MorphMany(instance.new_query(), self,
- '%s.%s' % (table, type_column),
- '%s.%s' % (table, id_column), local_key)
+ rel = MorphMany(
+ instance.new_query(),
+ self,
+ "%s.%s" % (table, type_column),
+ "%s.%s" % (table, id_column),
+ local_key,
+ )
if _wrapped:
- warn('Using morph_many method directly is deprecated. '
- 'Use the appropriate decorator instead.',
- category=DeprecationWarning)
+ warn(
+ "Using morph_many method directly is deprecated. "
+ "Use the appropriate decorator instead.",
+ category=DeprecationWarning,
+ )
rel = Wrapper(rel)
@@ -1016,9 +1080,15 @@ def morph_many(self, related, name,
return rel
- def belongs_to_many(self, related, table=None,
- foreign_key=None, other_key=None,
- relation=None, _wrapped=True):
+ def belongs_to_many(
+ self,
+ related,
+ table=None,
+ foreign_key=None,
+ other_key=None,
+ relation=None,
+ _wrapped=True,
+ ):
"""
Define a many-to-many relationship.
@@ -1060,9 +1130,11 @@ def belongs_to_many(self, related, table=None,
rel = BelongsToMany(query, self, table, foreign_key, other_key, relation)
if _wrapped:
- warn('Using belongs_to_many method directly is deprecated. '
- 'Use the appropriate decorator instead.',
- category=DeprecationWarning)
+ warn(
+ "Using belongs_to_many method directly is deprecated. "
+ "Use the appropriate decorator instead.",
+ category=DeprecationWarning,
+ )
rel = BelongsToManyWrapper(rel)
@@ -1070,10 +1142,17 @@ def belongs_to_many(self, related, table=None,
return rel
- def morph_to_many(self, related, name, table=None,
- foreign_key=None, other_key=None,
- inverse=False, relation=None,
- _wrapped=True):
+ def morph_to_many(
+ self,
+ related,
+ name,
+ table=None,
+ foreign_key=None,
+ other_key=None,
+ inverse=False,
+ relation=None,
+ _wrapped=True,
+ ):
"""
Define a polymorphic many-to-many relationship.
@@ -1106,7 +1185,7 @@ def morph_to_many(self, related, name, table=None,
return self._relations[caller]
if not foreign_key:
- foreign_key = name + '_id'
+ foreign_key = name + "_id"
instance = self._get_related(related, True)
@@ -1118,13 +1197,16 @@ def morph_to_many(self, related, name, table=None,
if not table:
table = inflection.pluralize(name)
- rel = MorphToMany(query, self, name, table,
- foreign_key, other_key, caller, inverse)
+ rel = MorphToMany(
+ query, self, name, table, foreign_key, other_key, caller, inverse
+ )
if _wrapped:
- warn('Using morph_to_many method directly is deprecated. '
- 'Use the appropriate decorator instead.',
- category=DeprecationWarning)
+ warn(
+ "Using morph_to_many method directly is deprecated. "
+ "Use the appropriate decorator instead.",
+ category=DeprecationWarning,
+ )
rel = Wrapper(rel)
@@ -1132,7 +1214,16 @@ def morph_to_many(self, related, name, table=None,
return rel
- def morphed_by_many(self, related, name, table=None, foreign_key=None, other_key=None, relation=None):
+ def morphed_by_many(
+ self,
+ related,
+ name,
+ table=None,
+ foreign_key=None,
+ other_key=None,
+ relation=None,
+ _wrapped=False,
+ ):
"""
Define a polymorphic many-to-many relationship.
@@ -1160,9 +1251,11 @@ def morphed_by_many(self, related, name, table=None, foreign_key=None, other_key
foreign_key = self.get_foreign_key()
if not other_key:
- other_key = name + '_id'
+ other_key = name + "_id"
- return self.morph_to_many(related, name, table, foreign_key, other_key, True, relation)
+ return self.morph_to_many(
+ related, name, table, foreign_key, other_key, True, relation, _wrapped
+ )
def _get_related(self, related, as_instance=False):
"""
@@ -1210,7 +1303,7 @@ def joining_table(self, related):
models = sorted([related, base])
- return '_'.join(models)
+ return "_".join(models)
@classmethod
def destroy(cls, *ids):
@@ -1249,10 +1342,10 @@ def delete(self):
:raises: Exception
"""
if self.__primary_key__ is None:
- raise Exception('No primary key defined on the model.')
+ raise Exception("No primary key defined on the model.")
if self._exists:
- if self._fire_model_event('deleting') is False:
+ if self._fire_model_event("deleting") is False:
return False
self.touch_owners()
@@ -1261,7 +1354,7 @@ def delete(self):
self._exists = False
- self._fire_model_event('deleted')
+ self._fire_model_event("deleted")
return True
@@ -1275,7 +1368,7 @@ def _perform_delete_on_model(self):
"""
Perform the actual delete query on this model instance.
"""
- if hasattr(self, '_do_perform_delete_on_model'):
+ if hasattr(self, "_do_perform_delete_on_model"):
return self._do_perform_delete_on_model()
return self.new_query().where(self.get_key_name(), self.get_key()).delete()
@@ -1287,7 +1380,7 @@ def saving(cls, callback):
:type callback: callable
"""
- cls._register_model_event('saving', callback)
+ cls._register_model_event("saving", callback)
@classmethod
def saved(cls, callback):
@@ -1296,7 +1389,7 @@ def saved(cls, callback):
:type callback: callable
"""
- cls._register_model_event('saved', callback)
+ cls._register_model_event("saved", callback)
@classmethod
def updating(cls, callback):
@@ -1305,7 +1398,7 @@ def updating(cls, callback):
:type callback: callable
"""
- cls._register_model_event('updating', callback)
+ cls._register_model_event("updating", callback)
@classmethod
def updated(cls, callback):
@@ -1314,7 +1407,7 @@ def updated(cls, callback):
:type callback: callable
"""
- cls._register_model_event('updated', callback)
+ cls._register_model_event("updated", callback)
@classmethod
def creating(cls, callback):
@@ -1323,7 +1416,7 @@ def creating(cls, callback):
:type callback: callable
"""
- cls._register_model_event('creating', callback)
+ cls._register_model_event("creating", callback)
@classmethod
def created(cls, callback):
@@ -1332,7 +1425,7 @@ def created(cls, callback):
:type callback: callable
"""
- cls._register_model_event('created', callback)
+ cls._register_model_event("created", callback)
@classmethod
def deleting(cls, callback):
@@ -1341,7 +1434,7 @@ def deleting(cls, callback):
:type callback: callable
"""
- cls._register_model_event('deleting', callback)
+ cls._register_model_event("deleting", callback)
@classmethod
def deleted(cls, callback):
@@ -1350,7 +1443,7 @@ def deleted(cls, callback):
:type callback: callable
"""
- cls._register_model_event('deleted', callback)
+ cls._register_model_event("deleted", callback)
@classmethod
def flush_event_listeners(cls):
@@ -1361,7 +1454,7 @@ def flush_event_listeners(cls):
return
for event in cls.get_observable_events():
- cls.__dispatcher__.forget('%s: %s' % (event, cls.__name__))
+ cls.__dispatcher__.forget("%s: %s" % (event, cls.__name__))
@classmethod
def _register_model_event(cls, event, callback):
@@ -1375,7 +1468,7 @@ def _register_model_event(cls, event, callback):
:type callback: callable
"""
if cls.__dispatcher__:
- cls.__dispatcher__.listen('%s: %s' % (event, cls.__name__), callback)
+ cls.__dispatcher__.listen("%s: %s" % (event, cls.__name__), callback)
@classmethod
def get_observable_events(cls):
@@ -1385,9 +1478,16 @@ def get_observable_events(cls):
:rtype: list
"""
default_events = [
- 'creating', 'created', 'updating', 'updated',
- 'deleting', 'deleted', 'saving', 'saved',
- 'restoring', 'restored'
+ "creating",
+ "created",
+ "updating",
+ "updated",
+ "deleting",
+ "deleted",
+ "saving",
+ "saved",
+ "restoring",
+ "restored",
]
return default_events + cls.__observables__
@@ -1405,7 +1505,7 @@ def _increment(self, column, amount=1):
:return: The new column value
:rtype: int
"""
- return self._increment_or_decrement(column, amount, 'increment')
+ return self._increment_or_decrement(column, amount, "increment")
def _decrement(self, column, amount=1):
"""
@@ -1420,7 +1520,7 @@ def _decrement(self, column, amount=1):
:return: The new column value
:rtype: int
"""
- return self._increment_or_decrement(column, amount, 'decrement')
+ return self._increment_or_decrement(column, amount, "decrement")
def _increment_or_decrement(self, column, amount, method):
"""
@@ -1464,7 +1564,11 @@ def _increment_or_decrement_attribute_value(self, column, amount, method):
:return: None
"""
- setattr(self, column, getattr(self, column) + (amount if method == 'increment' else amount * -1))
+ setattr(
+ self,
+ column,
+ getattr(self, column) + (amount if method == "increment" else amount * -1),
+ )
self.sync_original_attribute(column)
@@ -1517,7 +1621,7 @@ def save(self, options=None):
query = self.new_query()
- if self._fire_model_event('saving') is False:
+ if self._fire_model_event("saving") is False:
return False
if self._exists:
@@ -1534,11 +1638,11 @@ def _finish_save(self, options):
"""
Finish processing on a successful save operation.
"""
- self._fire_model_event('saved')
+ self._fire_model_event("saved")
self.sync_original()
- if options.get('touch', True):
+ if options.get("touch", True):
self.touch_owners()
def _perform_update(self, query, options=None):
@@ -1557,10 +1661,10 @@ def _perform_update(self, query, options=None):
dirty = self.get_dirty()
if len(dirty):
- if self._fire_model_event('updating') is False:
+ if self._fire_model_event("updating") is False:
return False
- if self.__timestamps__ and options.get('timestamps', True):
+ if self.__timestamps__ and options.get("timestamps", True):
self._update_timestamps()
dirty = self.get_dirty()
@@ -1568,7 +1672,7 @@ def _perform_update(self, query, options=None):
if len(dirty):
self._set_keys_for_save_query(query).update(dirty)
- self._fire_model_event('updated')
+ self._fire_model_event("updated")
return True
@@ -1585,10 +1689,10 @@ def _perform_insert(self, query, options=None):
if options is None:
options = {}
- if self._fire_model_event('creating') is False:
+ if self._fire_model_event("creating") is False:
return False
- if self.__timestamps__ and options.get('timestamps', True):
+ if self.__timestamps__ and options.get("timestamps", True):
self._update_timestamps()
attributes = self._attributes
@@ -1600,7 +1704,7 @@ def _perform_insert(self, query, options=None):
self._exists = True
- self._fire_model_event('created')
+ self._fire_model_event("created")
return True
@@ -1627,9 +1731,9 @@ def touch_owners(self):
for relation in self.__touches__:
if hasattr(self, relation):
_relation = getattr(self, relation)
- _relation().touch()
- if _relation is not None:
+ if _relation:
+ _relation.touch()
_relation.touch_owners()
def touches(self, relation):
@@ -1655,7 +1759,7 @@ def _fire_model_event(self, event):
# We will append the names of the class to the event to distinguish it from
# other model events that are fired, allowing us to listen on each model
# event set individually instead of catching event for all the models.
- event = '%s: %s' % (event, self.__class__.__name__)
+ event = "%s: %s" % (event, self.__class__.__name__)
return self.__dispatcher__.fire(event, self)
@@ -1701,10 +1805,16 @@ def _update_timestamps(self):
"""
time = self.fresh_timestamp()
- if not self.is_dirty(self.UPDATED_AT) and self._should_set_timestamp(self.UPDATED_AT):
+ if not self.is_dirty(self.UPDATED_AT) and self._should_set_timestamp(
+ self.UPDATED_AT
+ ):
self.set_updated_at(time)
- if not self._exists and not self.is_dirty(self.CREATED_AT) and self._should_set_timestamp(self.CREATED_AT):
+ if (
+ not self._exists
+ and not self.is_dirty(self.CREATED_AT)
+ and self._should_set_timestamp(self.CREATED_AT)
+ ):
self.set_created_at(time)
def _should_set_timestamp(self, timestamp):
@@ -1763,6 +1873,14 @@ def fresh_timestamp(self):
"""
return pendulum.utcnow()
+ def fresh_timestamp_string(self):
+ """
+ Get a fresh timestamp string for the model.
+
+ :return: str
+ """
+ return self.from_datetime(self.fresh_timestamp())
+
def new_query(self):
"""
Get a new query builder for the model's table
@@ -1795,9 +1913,7 @@ def new_query_without_scopes(self):
:return: A Builder instance
:rtype: Builder
"""
- builder = self.new_orm_builder(
- self._new_base_query_builder()
- )
+ builder = self.new_orm_builder(self._new_base_query_builder())
return builder.set_model(self).with_(*self._with)
@@ -1919,7 +2035,7 @@ def get_qualified_key_name(self):
:rtype: str
"""
- return '%s.%s' % (self.get_table(), self.get_key_name())
+ return "%s.%s" % (self.get_table(), self.get_key_name())
def uses_timestamps(self):
"""
@@ -1934,10 +2050,10 @@ def get_morphs(self, name, type, id):
Get the polymorphic relationship columns.
"""
if not type:
- type = name + '_type'
+ type = name + "_type"
if not id:
- id = name + '_id'
+ id = name + "_id"
return type, id
@@ -1965,7 +2081,7 @@ def get_foreign_key(self):
:rtype: str
"""
- return '%s_id' % inflection.singularize(self.get_table())
+ return "%s_id" % inflection.singularize(self.get_table())
def get_hidden(self):
"""
@@ -2095,7 +2211,7 @@ def is_fillable(self, key):
if self.is_guarded(key):
return False
- return not self.__fillable__ and not key.startswith('_')
+ return not self.__fillable__ and not key.startswith("_")
def is_guarded(self, key):
"""
@@ -2107,7 +2223,7 @@ def is_guarded(self, key):
:return: Whether the attribute is guarded or not
:rtype: bool
"""
- return key in self.__guarded__ or self.__guarded__ == ['*']
+ return key in self.__guarded__ or self.__guarded__ == ["*"]
def totally_guarded(self):
"""
@@ -2115,7 +2231,7 @@ def totally_guarded(self):
:rtype: bool
"""
- return len(self.__fillable__) == 0 and self.__guarded__ == ['*']
+ return len(self.__fillable__) == 0 and self.__guarded__ == ["*"]
def _remove_table_from_key(self, key):
"""
@@ -2126,10 +2242,10 @@ def _remove_table_from_key(self, key):
:rtype: str
"""
- if '.' not in key:
+ if "." not in key:
return key
- return key.split('.')[-1]
+ return key.split(".")[-1]
def get_incrementing(self):
return self.__incrementing__
@@ -2242,9 +2358,9 @@ def relations_to_dict(self):
continue
relation = None
- if hasattr(value, 'serialize'):
+ if hasattr(value, "serialize"):
relation = value.serialize()
- elif hasattr(value, 'to_dict'):
+ elif hasattr(value, "to_dict"):
relation = value.to_dict()
elif value is None:
relation = value
@@ -2272,7 +2388,11 @@ def _get_dictable_items(self, values):
if len(self.__visible__) > 0:
return {x: values[x] for x in values.keys() if x in self.__visible__}
- return {x: values[x] for x in values.keys() if x not in self.__hidden__ and not x.startswith('_')}
+ return {
+ x: values[x]
+ for x in values.keys()
+ if x not in self.__hidden__ and not x.startswith("_")
+ }
def get_attribute(self, key, original=None):
"""
@@ -2337,7 +2457,9 @@ def _get_relationship_from_method(self, method, relations=None):
relations = relations or super(Model, self).__getattribute__(method)
if not isinstance(relations, Relation):
- raise RuntimeError('Relationship method must return an object of type Relation')
+ raise RuntimeError(
+ "Relationship method must return an object of type Relation"
+ )
self._relations[method] = relations
@@ -2352,7 +2474,7 @@ def has_get_mutator(self, key):
:rtype: bool
"""
- return hasattr(self, 'get_%s_attribute' % inflection.underscore(key))
+ return hasattr(self, "get_%s_attribute" % inflection.underscore(key))
def _mutate_attribute_for_dict(self, key):
"""
@@ -2363,7 +2485,7 @@ def _mutate_attribute_for_dict(self, key):
"""
value = getattr(self, key)
- if hasattr(value, 'to_dict'):
+ if hasattr(value, "to_dict"):
return value.to_dict()
if key in self.get_dates():
@@ -2409,7 +2531,7 @@ def _is_json_castable(self, key):
if self._has_cast(key):
type = self._get_cast_type(key)
- return type in ['list', 'dict', 'json', 'object']
+ return type in ["list", "dict", "json", "object"]
return False
@@ -2440,15 +2562,15 @@ def _cast_attribute(self, key, value):
return None
type = self._get_cast_type(key)
- if type in ['int', 'integer']:
+ if type in ["int", "integer"]:
return int(value)
- elif type in ['real', 'float', 'double']:
+ elif type in ["real", "float", "double"]:
return float(value)
- elif type in ['string', 'str']:
+ elif type in ["string", "str"]:
return str(value)
- elif type in ['bool', 'boolean']:
+ elif type in ["bool", "boolean"]:
return bool(value)
- elif type in ['dict', 'list', 'json'] and isinstance(value, basestring):
+ elif type in ["dict", "list", "json"] and isinstance(value, basestring):
return json.loads(value)
else:
return value
@@ -2465,20 +2587,32 @@ def get_dates(self):
def from_datetime(self, value):
"""
- Convert datetime to a datetime object
+ Convert datetime to a storable string.
+
+ :param value: The datetime value
+ :type value: pendulum.Pendulum or datetime.date or datetime.datetime
- :rtype: datetime.datetime
+ :rtype: str
"""
+ date_format = self.get_connection().get_query_grammar().get_date_format()
+
if isinstance(value, pendulum.Pendulum):
- return value
+ return value.format(date_format)
- return pendulum.instance(value)
+ if isinstance(value, datetime.date) and not isinstance(
+ value, (datetime.datetime)
+ ):
+ value = pendulum.date.instance(value)
+
+ return value.format(date_format)
+
+ return pendulum.instance(value).format(date_format)
def as_datetime(self, value):
"""
Return a timestamp as a datetime.
- :rtype: pendulum.Pendulum
+ :rtype: pendulum.Pendulum or pendulum.Date
"""
if isinstance(value, basestring):
return pendulum.parse(value)
@@ -2486,6 +2620,11 @@ def as_datetime(self, value):
if isinstance(value, (int, float)):
return pendulum.from_timestamp(value)
+ if isinstance(value, datetime.date) and not isinstance(
+ value, (datetime.datetime)
+ ):
+ return pendulum.date.instance(value)
+
return pendulum.instance(value)
def get_date_format(self):
@@ -2494,7 +2633,7 @@ def get_date_format(self):
:rtype: str
"""
- return 'iso'
+ return "iso"
def _format_date(self, date):
"""
@@ -2510,7 +2649,7 @@ def _format_date(self, date):
format = self.get_date_format()
- if format == 'iso':
+ if format == "iso":
if isinstance(date, basestring):
return pendulum.parse(date).isoformat()
@@ -2549,10 +2688,12 @@ def replicate(self, except_=None):
except_ = [
self.get_key_name(),
self.get_created_at_column(),
- self.get_updated_at_column()
+ self.get_updated_at_column(),
]
- attributes = {x: self._attributes[x] for x in self._attributes if x not in except_}
+ attributes = {
+ x: self._attributes[x] for x in self._attributes if x not in except_
+ }
instance = self.new_instance(attributes)
@@ -2816,7 +2957,12 @@ def __getattr__(self, item):
return self.get_attribute(item)
def __setattr__(self, key, value):
- if key in ['_attributes', '_exists', '_relations', '_original'] or key.startswith('__'):
+ if key in [
+ "_attributes",
+ "_exists",
+ "_relations",
+ "_original",
+ ] or key.startswith("__"):
return object.__setattr__(self, key, value)
if self._has_set_mutator(key):
@@ -2841,14 +2987,14 @@ def __delattr__(self, item):
def __getstate__(self):
return {
- 'attributes': self._attributes,
- 'relations': self._relations,
- 'exists': self._exists
+ "attributes": self._attributes,
+ "relations": self._relations,
+ "exists": self._exists,
}
def __setstate__(self, state):
self._boot_if_not_booted()
- self.set_raw_attributes(state['attributes'], True)
- self.set_relations(state['relations'])
- self.set_exists(state['exists'])
+ self.set_raw_attributes(state["attributes"], True)
+ self.set_relations(state["relations"])
+ self.set_exists(state["exists"])
diff --git a/orator/orm/relations/belongs_to.py b/orator/orm/relations/belongs_to.py
index 12b7d8de..062e0c5f 100644
--- a/orator/orm/relations/belongs_to.py
+++ b/orator/orm/relations/belongs_to.py
@@ -6,7 +6,6 @@
class BelongsTo(Relation):
-
def __init__(self, query, parent, foreign_key, other_key, relation):
"""
:param query: A Builder instance
@@ -34,6 +33,9 @@ def get_results(self):
"""
Get the results of the relationship.
"""
+ if self._query is None:
+ return None
+
return self._query.first()
def add_constraints(self):
@@ -43,9 +45,15 @@ def add_constraints(self):
:rtype: None
"""
if self._constraints:
- table = self._related.get_table()
+ foreign_key = getattr(self._parent, self._foreign_key, None)
+ if foreign_key is None:
+ self._query = None
+ else:
+ table = self._related.get_table()
- self._query.where('%s.%s' % (table, self._other_key), '=', getattr(self._parent, self._foreign_key))
+ self._query.where(
+ "{}.{}".format(table, self._other_key), "=", foreign_key
+ )
def get_relation_count_query(self, query, parent):
"""
@@ -56,11 +64,15 @@ def get_relation_count_query(self, query, parent):
:rtype: Builder
"""
- query.select(QueryExpression('COUNT(*)'))
+ query.select(QueryExpression("COUNT(*)"))
- other_key = self.wrap('%s.%s' % (query.get_model().get_table(), self._other_key))
+ other_key = self.wrap(
+ "%s.%s" % (query.get_model().get_table(), self._other_key)
+ )
- return query.where(self.get_qualified_foreign_key(), '=', QueryExpression(other_key))
+ return query.where(
+ self.get_qualified_foreign_key(), "=", QueryExpression(other_key)
+ )
def add_eager_constraints(self, models):
"""
@@ -68,7 +80,7 @@ def add_eager_constraints(self, models):
:type models: list
"""
- key = '%s.%s' % (self._related.get_table(), self._other_key)
+ key = "%s.%s" % (self._related.get_table(), self._other_key)
self._query.where_in(key, self._get_eager_model_keys(models))
@@ -142,9 +154,13 @@ def associate(self, model):
:rtype: orator.Model
"""
- self._parent.set_attribute(self._foreign_key, model.get_attribute(self._other_key))
+ self._parent.set_attribute(
+ self._foreign_key, model.get_attribute(self._other_key)
+ )
- return self._parent.set_relation(self._relation, Result(model, self, self._parent))
+ return self._parent.set_relation(
+ self._relation, Result(model, self, self._parent)
+ )
def dissociate(self):
"""
@@ -154,7 +170,9 @@ def dissociate(self):
"""
self._parent.set_attribute(self._foreign_key, None)
- return self._parent.set_relation(self._relation, Result(None, self, self._parent))
+ return self._parent.set_relation(
+ self._relation, Result(None, self, self._parent)
+ )
def update(self, _attributes=None, **attributes):
"""
@@ -176,19 +194,15 @@ def get_foreign_key(self):
return self._foreign_key
def get_qualified_foreign_key(self):
- return '%s.%s' % (self._parent.get_table(), self._foreign_key)
+ return "%s.%s" % (self._parent.get_table(), self._foreign_key)
def get_other_key(self):
return self._other_key
def get_qualified_other_key_name(self):
- return '%s.%s' % (self._related.get_table(), self._other_key)
+ return "%s.%s" % (self._related.get_table(), self._other_key)
def _new_instance(self, model):
return BelongsTo(
- self.new_query(),
- model,
- self._foreign_key,
- self._other_key,
- self._relation
+ self.new_query(), model, self._foreign_key, self._other_key, self._relation
)
diff --git a/orator/orm/relations/belongs_to_many.py b/orator/orm/relations/belongs_to_many.py
index b2273587..8918769b 100644
--- a/orator/orm/relations/belongs_to_many.py
+++ b/orator/orm/relations/belongs_to_many.py
@@ -21,7 +21,9 @@ class BelongsToMany(Relation):
_pivot_columns = []
_pivot_wheres = []
- def __init__(self, query, parent, table, foreign_key, other_key, relation_name=None):
+ def __init__(
+ self, query, parent, table, foreign_key, other_key, relation_name=None
+ ):
"""
:param query: A Builder instance
:type query: Builder
@@ -57,7 +59,7 @@ def get_results(self):
"""
return self.get()
- def where_pivot(self, column, operator=None, value=None, boolean='and'):
+ def where_pivot(self, column, operator=None, value=None, boolean="and"):
"""
Set a where clause for a pivot table column.
@@ -78,7 +80,9 @@ def where_pivot(self, column, operator=None, value=None, boolean='and'):
"""
self._pivot_wheres.append([column, operator, value, boolean])
- return self._query.where('%s.%s' % (self._table, column), operator, value, boolean)
+ return self._query.where(
+ "%s.%s" % (self._table, column), operator, value, boolean
+ )
def or_where_pivot(self, column, operator=None, value=None):
"""
@@ -96,7 +100,7 @@ def or_where_pivot(self, column, operator=None, value=None):
:return: self
:rtype: BelongsToMany
"""
- return self.where_pivot(column, operator, value, 'or')
+ return self.where_pivot(column, operator, value, "or")
def first(self, columns=None):
"""
@@ -136,7 +140,7 @@ def get(self, columns=None):
:rtype: orator.Collection
"""
if columns is None:
- columns = ['*']
+ columns = ["*"]
if self._query.get_query().columns:
columns = []
@@ -161,7 +165,7 @@ def _hydrate_pivot_relation(self, models):
for model in models:
pivot = self.new_existing_pivot(self._clean_pivot_attributes(model))
- model.set_relation('pivot', pivot)
+ model.set_relation("pivot", pivot)
def _clean_pivot_attributes(self, model):
"""
@@ -173,7 +177,7 @@ def _clean_pivot_attributes(self, model):
delete_keys = []
for key, value in model.get_attributes().items():
- if key.find('pivot_') == 0:
+ if key.find("pivot_") == 0:
values[key[6:]] = value
delete_keys.append(key)
@@ -219,16 +223,18 @@ def get_relation_count_query_for_self_join(self, query, parent):
:rtype: orator.orm.Builder
"""
- query.select(QueryExpression('COUNT(*)'))
+ query.select(QueryExpression("COUNT(*)"))
table_prefix = self._query.get_query().get_connection().get_table_prefix()
hash_ = self.get_relation_count_hash()
- query.from_('%s AS %s%s' % (self._table, table_prefix, hash_))
+ query.from_("%s AS %s%s" % (self._table, table_prefix, hash_))
key = self.wrap(self.get_qualified_parent_key_name())
- return query.where('%s.%s' % (hash_, self._foreign_key), '=', QueryExpression(key))
+ return query.where(
+ "%s.%s" % (hash_, self._foreign_key), "=", QueryExpression(key)
+ )
def get_relation_count_hash(self):
"""
@@ -236,7 +242,7 @@ def get_relation_count_hash(self):
:rtype: str
"""
- return 'self_%s' % (hashlib.md5(str(time.time()).encode()).hexdigest())
+ return "self_%s" % (hashlib.md5(str(time.time()).encode()).hexdigest())
def _get_select_columns(self, columns=None):
"""
@@ -247,8 +253,8 @@ def _get_select_columns(self, columns=None):
:rtype: list
"""
- if columns == ['*'] or columns is None:
- columns = ['%s.*' % self._related.get_table()]
+ if columns == ["*"] or columns is None:
+ columns = ["%s.*" % self._related.get_table()]
return columns + self._get_aliased_pivot_columns()
@@ -263,9 +269,9 @@ def _get_aliased_pivot_columns(self):
columns = []
for column in defaults + self._pivot_columns:
- value = '%s.%s AS pivot_%s' % (self._table, column, column)
+ value = "%s.%s AS pivot_%s" % (self._table, column, column)
if value not in columns:
- columns.append('%s.%s AS pivot_%s' % (self._table, column, column))
+ columns.append("%s.%s AS pivot_%s" % (self._table, column, column))
return columns
@@ -295,9 +301,9 @@ def _set_join(self, query=None):
base_table = self._related.get_table()
- key = '%s.%s' % (base_table, self._related.get_key_name())
+ key = "%s.%s" % (base_table, self._related.get_key_name())
- query.join(self._table, key, '=', self.get_other_key())
+ query.join(self._table, key, "=", self.get_other_key())
return self
@@ -310,7 +316,7 @@ def _set_where(self):
"""
foreign = self.get_foreign_key()
- self._query.where(foreign, '=', self._parent.get_key())
+ self._query.where(foreign, "=", self._parent.get_key())
return self
@@ -330,7 +336,9 @@ def init_relation(self, models, relation):
:type relation: str
"""
for model in models:
- model.set_relation(relation, Result(self._related.new_collection(), self, model))
+ model.set_relation(
+ relation, Result(self._related.new_collection(), self, model)
+ )
return models
@@ -348,7 +356,9 @@ def match(self, models, results, relation):
key = model.get_key()
if key in dictionary:
- collection = Result(self._related.new_collection(dictionary[key]), self, model)
+ collection = Result(
+ self._related.new_collection(dictionary[key]), self, model
+ )
else:
collection = Result(self._related.new_collection(), self, model)
@@ -416,7 +426,7 @@ def save(self, model, joining=None, touch=True):
if joining is None:
joining = {}
- model.save({'touch': False})
+ model.save({"touch": False})
self.attach(model.get_key(), joining, touch)
@@ -477,7 +487,9 @@ def first_or_new(self, _attributes=None, **attributes):
return instance
- def first_or_create(self, _attributes=None, _joining=None, _touch=True, **attributes):
+ def first_or_create(
+ self, _attributes=None, _joining=None, _touch=True, **attributes
+ ):
"""
Get the first related model record matching the attributes or create it.
@@ -517,7 +529,7 @@ def update_or_create(self, attributes, values=None, joining=None, touch=True):
instance.fill(**values)
- instance.save({'touch': False})
+ instance.save({"touch": False})
return instance
@@ -535,7 +547,7 @@ def create(self, _attributes=None, _joining=None, _touch=True, **attributes):
instance = self._related.new_instance(attributes)
- instance.save({'touch': False})
+ instance.save({"touch": False})
self.attach(instance.get_key(), _joining, _touch)
@@ -561,11 +573,7 @@ def sync(self, ids, detaching=True):
"""
Sync the intermediate tables with a list of IDs or collection of models
"""
- changes = {
- 'attached': [],
- 'detached': [],
- 'updated': []
- }
+ changes = {"attached": [], "detached": [], "updated": []}
if isinstance(ids, Collection):
ids = ids.model_keys()
@@ -579,11 +587,11 @@ def sync(self, ids, detaching=True):
if detaching and len(detach) > 0:
self.detach(detach)
- changes['detached'] = detach
+ changes["detached"] = detach
changes.update(self._attach_new(records, current, False))
- if len(changes['attached']) or len(changes['updated']):
+ if len(changes["attached"]) or len(changes["updated"]):
self.touch_if_touching()
return changes
@@ -609,18 +617,17 @@ def _attach_new(self, records, current, touch=True):
"""
Attach all of the IDs that aren't in the current dict.
"""
- changes = {
- 'attached': [],
- 'updated': []
- }
+ changes = {"attached": [], "updated": []}
for id, attributes in records.items():
if id not in current:
self.attach(id, attributes, touch)
- changes['attached'].append(id)
- elif len(attributes) > 0 and self.update_existing_pivot(id, attributes, touch):
- changes['updated'].append(id)
+ changes["attached"].append(id)
+ elif len(attributes) > 0 and self.update_existing_pivot(
+ id, attributes, touch
+ ):
+ changes["updated"].append(id)
return changes
@@ -661,8 +668,9 @@ def _create_attach_records(self, ids, attributes):
"""
records = []
- timed = (self._has_pivot_column(self.created_at())
- or self._has_pivot_column(self.updated_at()))
+ timed = self._has_pivot_column(self.created_at()) or self._has_pivot_column(
+ self.updated_at()
+ )
for key, value in enumerate(ids):
records.append(self._attacher(key, value, attributes, timed))
@@ -765,7 +773,9 @@ def _touching_parent(self):
return self.get_related().touches(self._guess_inverse_relation())
def _guess_inverse_relation(self):
- return inflection.camelize(inflection.pluralize(self.get_parent().__class__.__name__))
+ return inflection.camelize(
+ inflection.pluralize(self.get_parent().__class__.__name__)
+ )
def _new_pivot_query(self):
"""
@@ -841,10 +851,10 @@ def get_has_compare_key(self):
return self.get_foreign_key()
def get_foreign_key(self):
- return '%s.%s' % (self._table, self._foreign_key)
+ return "%s.%s" % (self._table, self._foreign_key)
def get_other_key(self):
- return '%s.%s' % (self._table, self._other_key)
+ return "%s.%s" % (self._table, self._other_key)
def get_table(self):
return self._table
@@ -859,7 +869,7 @@ def _new_instance(self, model):
self._table,
self._foreign_key,
self._other_key,
- self._relation_name
+ self._relation_name,
)
relation.with_pivot(*self._pivot_columns)
diff --git a/orator/orm/relations/has_many.py b/orator/orm/relations/has_many.py
index 63e3b65b..dd9e19cd 100644
--- a/orator/orm/relations/has_many.py
+++ b/orator/orm/relations/has_many.py
@@ -5,7 +5,6 @@
class HasMany(HasOneOrMany):
-
def get_results(self):
"""
Get the results of the relationship.
@@ -20,7 +19,9 @@ def init_relation(self, models, relation):
:type relation: str
"""
for model in models:
- model.set_relation(relation, Result(self._related.new_collection(), self, model))
+ model.set_relation(
+ relation, Result(self._related.new_collection(), self, model)
+ )
return models
@@ -35,9 +36,4 @@ def match(self, models, results, relation):
return self.match_many(models, results, relation)
def _new_instance(self, model):
- return HasMany(
- self.new_query(),
- model,
- self._foreign_key,
- self._local_key
- )
+ return HasMany(self.new_query(), model, self._foreign_key, self._local_key)
diff --git a/orator/orm/relations/has_many_through.py b/orator/orm/relations/has_many_through.py
index a1f6eb81..c17db200 100644
--- a/orator/orm/relations/has_many_through.py
+++ b/orator/orm/relations/has_many_through.py
@@ -6,7 +6,6 @@
class HasManyThrough(Relation):
-
def __init__(self, query, far_parent, parent, first_key, second_key):
"""
:param query: A Builder instance
@@ -38,7 +37,11 @@ def add_constraints(self):
self._set_join()
if self._constraints:
- self._query.where('%s.%s' % (parent_table, self._first_key), '=', self._far_parent.get_key())
+ self._query.where(
+ "%s.%s" % (parent_table, self._first_key),
+ "=",
+ self._far_parent.get_key(),
+ )
def get_relation_count_query(self, query, parent):
"""
@@ -53,11 +56,11 @@ def get_relation_count_query(self, query, parent):
self._set_join(query)
- query.select(QueryExpression('COUNT(*)'))
+ query.select(QueryExpression("COUNT(*)"))
- key = self.wrap('%s.%s' % (parent_table, self._first_key))
+ key = self.wrap("%s.%s" % (parent_table, self._first_key))
- return query.where(self.get_has_compare_key(), '=', QueryExpression(key))
+ return query.where(self.get_has_compare_key(), "=", QueryExpression(key))
def _set_join(self, query=None):
"""
@@ -66,9 +69,14 @@ def _set_join(self, query=None):
if not query:
query = self._query
- foreign_key = '%s.%s' % (self._related.get_table(), self._second_key)
+ foreign_key = "%s.%s" % (self._related.get_table(), self._second_key)
- query.join(self._parent.get_table(), self.get_qualified_parent_key_name(), '=', foreign_key)
+ query.join(
+ self._parent.get_table(),
+ self.get_qualified_parent_key_name(),
+ "=",
+ foreign_key,
+ )
def add_eager_constraints(self, models):
"""
@@ -78,7 +86,7 @@ def add_eager_constraints(self, models):
"""
table = self._parent.get_table()
- self._query.where_in('%s.%s' % (table, self._first_key), self.get_keys(models))
+ self._query.where_in("%s.%s" % (table, self._first_key), self.get_keys(models))
def init_relation(self, models, relation):
"""
@@ -88,7 +96,9 @@ def init_relation(self, models, relation):
:type relation: str
"""
for model in models:
- model.set_relation(relation, Result(self._related.new_collection(), self, model))
+ model.set_relation(
+ relation, Result(self._related.new_collection(), self, model)
+ )
return models
@@ -106,7 +116,9 @@ def match(self, models, results, relation):
key = model.get_key()
if key in dictionary:
- value = Result(self._related.new_collection(dictionary[key]), self, model)
+ value = Result(
+ self._related.new_collection(dictionary[key]), self, model
+ )
else:
value = Result(self._related.new_collection(), self, model)
@@ -151,7 +163,7 @@ def get(self, columns=None):
:rtype: orator.Collection
"""
if columns is None:
- columns = ['*']
+ columns = ["*"]
select = self._get_select_columns(columns)
@@ -171,19 +183,15 @@ def _get_select_columns(self, columns=None):
:rtype: list
"""
- if columns == ['*'] or columns is None:
- columns = ['%s.*' % self._related.get_table()]
+ if columns == ["*"] or columns is None:
+ columns = ["%s.*" % self._related.get_table()]
- return columns + ['%s.%s' % (self._parent.get_table(), self._first_key)]
+ return columns + ["%s.%s" % (self._parent.get_table(), self._first_key)]
def get_has_compare_key(self):
return self._far_parent.get_qualified_key_name()
def _new_instance(self, model):
return HasManyThrough(
- self.new_query(),
- model,
- self._parent,
- self._first_key,
- self._second_key
+ self.new_query(), model, self._parent, self._first_key, self._second_key
)
diff --git a/orator/orm/relations/has_one.py b/orator/orm/relations/has_one.py
index 1830040f..a5bf69f7 100644
--- a/orator/orm/relations/has_one.py
+++ b/orator/orm/relations/has_one.py
@@ -5,7 +5,6 @@
class HasOne(HasOneOrMany):
-
def get_results(self):
"""
Get the results of the relationship.
@@ -35,9 +34,4 @@ def match(self, models, results, relation):
return self.match_one(models, results, relation)
def _new_instance(self, model):
- return HasOne(
- self.new_query(),
- model,
- self._foreign_key,
- self._local_key
- )
+ return HasOne(self.new_query(), model, self._foreign_key, self._local_key)
diff --git a/orator/orm/relations/has_one_or_many.py b/orator/orm/relations/has_one_or_many.py
index 716fe7fe..7e356e37 100644
--- a/orator/orm/relations/has_one_or_many.py
+++ b/orator/orm/relations/has_one_or_many.py
@@ -6,7 +6,6 @@
class HasOneOrMany(Relation):
-
def __init__(self, query, parent, foreign_key, local_key):
"""
:type query: orator.orm.Builder
@@ -30,7 +29,7 @@ def add_constraints(self):
Set the base constraints of the relation query
"""
if self._constraints:
- self._query.where(self._foreign_key, '=', self.get_parent_key())
+ self._query.where(self._foreign_key, "=", self.get_parent_key())
def add_eager_constraints(self, models):
"""
@@ -38,7 +37,9 @@ def add_eager_constraints(self, models):
:type models: list
"""
- return self._query.where_in(self._foreign_key, self.get_keys(models, self._local_key))
+ return self._query.where_in(
+ self._foreign_key, self.get_keys(models, self._local_key)
+ )
def match_one(self, models, results, relation):
"""
@@ -55,7 +56,7 @@ def match_one(self, models, results, relation):
:rtype: list
"""
- return self._match_one_or_many(models, results, relation, 'one')
+ return self._match_one_or_many(models, results, relation, "one")
def match_many(self, models, results, relation):
"""
@@ -72,7 +73,7 @@ def match_many(self, models, results, relation):
:rtype: list
"""
- return self._match_one_or_many(models, results, relation, 'many')
+ return self._match_one_or_many(models, results, relation, "many")
def _match_one_or_many(self, models, results, relation, type_):
"""
@@ -98,9 +99,11 @@ def _match_one_or_many(self, models, results, relation, type_):
key = model.get_attribute(self._local_key)
if key in dictionary:
- value = Result(self._get_relation_value(dictionary, key, type_), self, model)
+ value = Result(
+ self._get_relation_value(dictionary, key, type_), self, model
+ )
else:
- if type_ == 'one':
+ if type_ == "one":
value = Result(None, self, model)
else:
value = Result(self._related.new_collection(), self, model)
@@ -119,7 +122,7 @@ def _get_relation_value(self, dictionary, key, type):
"""
value = dictionary[key]
- if type == 'one':
+ if type == "one":
return value[0]
return self._related.new_collection(value)
@@ -171,7 +174,7 @@ def save_many(self, models):
:rtype: list
"""
- return map(self.save, models)
+ return list(map(self.save, models))
def find_or_new(self, id, columns=None):
"""
@@ -186,7 +189,7 @@ def find_or_new(self, id, columns=None):
:rtype: Collection or Model
"""
if columns is None:
- columns = ['*']
+ columns = ["*"]
instance = self._query.find(id, columns)
@@ -315,7 +318,7 @@ def get_foreign_key(self):
return self._foreign_key
def get_plain_foreign_key(self):
- segments = self.get_foreign_key().split('.')
+ segments = self.get_foreign_key().split(".")
return segments[-1]
@@ -323,4 +326,4 @@ def get_parent_key(self):
return self._parent.get_attribute(self._local_key)
def get_qualified_parent_key_name(self):
- return '%s.%s' % (self._parent.get_table(), self._local_key)
+ return "%s.%s" % (self._parent.get_table(), self._local_key)
diff --git a/orator/orm/relations/morph_many.py b/orator/orm/relations/morph_many.py
index bd5fdc70..1450ea6e 100644
--- a/orator/orm/relations/morph_many.py
+++ b/orator/orm/relations/morph_many.py
@@ -5,7 +5,6 @@
class MorphMany(MorphOneOrMany):
-
def get_results(self):
"""
Get the results of the relationship.
@@ -20,7 +19,9 @@ def init_relation(self, models, relation):
:type relation: str
"""
for model in models:
- model.set_relation(relation, Result(self._related.new_collection(), self, model))
+ model.set_relation(
+ relation, Result(self._related.new_collection(), self, model)
+ )
return models
diff --git a/orator/orm/relations/morph_one.py b/orator/orm/relations/morph_one.py
index 96c9e73d..fac1814b 100644
--- a/orator/orm/relations/morph_one.py
+++ b/orator/orm/relations/morph_one.py
@@ -5,7 +5,6 @@
class MorphOne(MorphOneOrMany):
-
def get_results(self):
"""
Get the results of the relationship.
diff --git a/orator/orm/relations/morph_one_or_many.py b/orator/orm/relations/morph_one_or_many.py
index acaee756..dcfba3f8 100644
--- a/orator/orm/relations/morph_one_or_many.py
+++ b/orator/orm/relations/morph_one_or_many.py
@@ -4,7 +4,6 @@
class MorphOneOrMany(HasOneOrMany):
-
def __init__(self, query, parent, morph_type, foreign_key, local_key):
"""
:type query: orator.orm.Builder
@@ -84,7 +83,7 @@ def find_or_new(self, id, columns=None):
:rtype: Collection or Model
"""
if columns is None:
- columns = ['*']
+ columns = ["*"]
instance = self._query.find(id, columns)
@@ -185,7 +184,7 @@ def get_morph_type(self):
return self._morph_type
def get_plain_morph_type(self):
- return self._morph_type.split('.')[-1]
+ return self._morph_type.split(".")[-1]
def get_morph_name(self):
return self._morph_name
@@ -196,5 +195,5 @@ def _new_instance(self, parent):
parent,
self._morph_type,
self._foreign_key,
- self._local_key
+ self._local_key,
)
diff --git a/orator/orm/relations/morph_to.py b/orator/orm/relations/morph_to.py
index 46f4b430..323f1e74 100644
--- a/orator/orm/relations/morph_to.py
+++ b/orator/orm/relations/morph_to.py
@@ -7,7 +7,6 @@
class MorphTo(BelongsTo):
-
def __init__(self, query, parent, foreign_key, other_key, type, relation):
"""
:type query: orator.orm.Builder
@@ -87,7 +86,9 @@ def associate(self, model):
self._parent.set_attribute(self._foreign_key, model.get_key())
self._parent.set_attribute(self._morph_type, model.get_morph_name())
- return self._parent.set_relation(self._relation, Result(model, self, self._parent))
+ return self._parent.set_relation(
+ self._relation, Result(model, self, self._parent)
+ )
def get_eager(self):
"""
@@ -114,11 +115,7 @@ def _match_to_morph_parents(self, type, results):
if result.get_key() in self._dictionary.get(type, []):
for model in self._dictionary[type][result.get_key()]:
model.set_relation(
- self._relation,
- Result(
- result, self, model,
- related=result
- )
+ self._relation, Result(result, self, model, related=result)
)
def _get_results_by_type(self, type):
@@ -151,9 +148,11 @@ def _gather_keys_by_type(self, type):
"""
foreign = self._foreign_key
- keys = BaseCollection.make(list(self._dictionary[type].values()))\
- .map(lambda models: getattr(models[0], foreign))\
+ keys = (
+ BaseCollection.make(list(self._dictionary[type].values()))
+ .map(lambda models: getattr(models[0], foreign))
.unique()
+ )
return keys
@@ -193,5 +192,5 @@ def _new_instance(self, model, related=None):
self._foreign_key,
self._other_key if not related else related.get_key_name(),
self._morph_type,
- self._relation
+ self._relation,
)
diff --git a/orator/orm/relations/morph_to_many.py b/orator/orm/relations/morph_to_many.py
index 599644f0..85f8fd8f 100644
--- a/orator/orm/relations/morph_to_many.py
+++ b/orator/orm/relations/morph_to_many.py
@@ -4,9 +4,17 @@
class MorphToMany(BelongsToMany):
-
- def __init__(self, query, parent, name, table,
- foreign_key, other_key, relation_name=None, inverse=False):
+ def __init__(
+ self,
+ query,
+ parent,
+ name,
+ table,
+ foreign_key,
+ other_key,
+ relation_name=None,
+ inverse=False,
+ ):
"""
:param query: A Builder instance
:type query: elquent.orm.Builder
@@ -30,12 +38,13 @@ def __init__(self, query, parent, name, table,
"""
self._name = name
self._inverse = inverse
- self._morph_type = name + '_type'
- self._morph_name = query.get_model().get_morph_name() if inverse else parent.get_morph_name()
+ self._morph_type = name + "_type"
+ self._morph_name = (
+ query.get_model().get_morph_name() if inverse else parent.get_morph_name()
+ )
super(MorphToMany, self).__init__(
- query, parent, table,
- foreign_key, other_key, relation_name
+ query, parent, table, foreign_key, other_key, relation_name
)
def _set_where(self):
@@ -47,7 +56,7 @@ def _set_where(self):
"""
super(MorphToMany, self)._set_where()
- self._query.where('%s.%s' % (self._table, self._morph_type), self._morph_name)
+ self._query.where("%s.%s" % (self._table, self._morph_type), self._morph_name)
def get_relation_count_query(self, query, parent):
"""
@@ -60,7 +69,7 @@ def get_relation_count_query(self, query, parent):
"""
query = super(MorphToMany, self).get_relation_count_query(query, parent)
- return query.where('%s.%s' % (self._table, self._morph_type), self._morph_name)
+ return query.where("%s.%s" % (self._table, self._morph_type), self._morph_name)
def add_eager_constraints(self, models):
"""
@@ -70,7 +79,7 @@ def add_eager_constraints(self, models):
"""
super(MorphToMany, self).add_eager_constraints(models)
- self._query.where('%s.%s' % (self._table, self._morph_type), self._morph_name)
+ self._query.where("%s.%s" % (self._table, self._morph_type), self._morph_name)
def _create_attach_record(self, id, timed):
"""
@@ -100,9 +109,9 @@ def new_pivot(self, attributes=None, exists=False):
pivot = MorphPivot(self._parent, attributes, self._table, exists)
- pivot.set_pivot_keys(self._foreign_key, self._other_key)\
- .set_morph_type(self._morph_type)\
- .set_morph_name(self._morph_name)
+ pivot.set_pivot_keys(self._foreign_key, self._other_key).set_morph_type(
+ self._morph_type
+ ).set_morph_name(self._morph_name)
return pivot
@@ -121,5 +130,5 @@ def _new_instance(self, model):
self._foreign_key,
self._other_key,
self._relation_name,
- self._inverse
+ self._inverse,
)
diff --git a/orator/orm/relations/relation.py b/orator/orm/relations/relation.py
index 4cfc5f96..7a311b73 100644
--- a/orator/orm/relations/relation.py
+++ b/orator/orm/relations/relation.py
@@ -93,7 +93,8 @@ def raw_update(self, attributes=None):
if attributes is None:
attributes = {}
- return self._query.update(attributes)
+ if self._query is not None:
+ return self._query.update(attributes)
def get_relation_count_query(self, query, parent):
"""
@@ -104,11 +105,11 @@ def get_relation_count_query(self, query, parent):
:rtype: Builder
"""
- query.select(QueryExpression('COUNT(*)'))
+ query.select(QueryExpression("COUNT(*)"))
key = self.wrap(self.get_qualified_parent_key_name())
- return query.where(self.get_has_compare_key(), '=', QueryExpression(key))
+ return query.where(self.get_has_compare_key(), "=", QueryExpression(key))
@classmethod
@contextmanager
@@ -141,7 +142,14 @@ def get_keys(self, models, key=None):
:rtype: list
"""
- return list(set(map(lambda value: value.get_attribute(key) if key else value.get_key(), models)))
+ return list(
+ set(
+ map(
+ lambda value: value.get_attribute(key) if key else value.get_key(),
+ models,
+ )
+ )
+ )
def get_query(self):
return self._query
diff --git a/orator/orm/relations/wrapper.py b/orator/orm/relations/wrapper.py
index ae237446..2949c00e 100644
--- a/orator/orm/relations/wrapper.py
+++ b/orator/orm/relations/wrapper.py
@@ -41,7 +41,6 @@ def __repr__(self):
class BelongsToManyWrapper(Wrapper):
-
def with_timestamps(self):
self._relation.with_timestamps()
diff --git a/orator/orm/scopes/scope.py b/orator/orm/scopes/scope.py
index 9293285a..fae90f84 100644
--- a/orator/orm/scopes/scope.py
+++ b/orator/orm/scopes/scope.py
@@ -2,7 +2,6 @@
class Scope(object):
-
def apply(self, builder, model):
"""
Apply the scope to a given query builder.
diff --git a/orator/orm/scopes/soft_deleting.py b/orator/orm/scopes/soft_deleting.py
index edfca0ca..0e8efb2e 100644
--- a/orator/orm/scopes/soft_deleting.py
+++ b/orator/orm/scopes/soft_deleting.py
@@ -5,7 +5,7 @@
class SoftDeletingScope(Scope):
- _extensions = ['force_delete', 'restore', 'with_trashed', 'only_trashed']
+ _extensions = ["force_delete", "restore", "with_trashed", "only_trashed"]
def apply(self, builder, model):
"""
@@ -29,7 +29,7 @@ def extend(self, builder):
:type builder: orator.orm.builder.Builder
"""
for extension in self._extensions:
- getattr(self, '_add_%s' % extension)(builder)
+ getattr(self, "_add_%s" % extension)(builder)
builder.on_delete(self._on_delete)
@@ -42,9 +42,7 @@ def _on_delete(self, builder):
"""
column = self._get_deleted_at_column(builder)
- return builder.update({
- column: builder.get_model().fresh_timestamp()
- })
+ return builder.update({column: builder.get_model().fresh_timestamp()})
def _get_deleted_at_column(self, builder):
"""
@@ -67,7 +65,7 @@ def _add_force_delete(self, builder):
:param builder: The query builder
:type builder: orator.orm.builder.Builder
"""
- builder.macro('force_delete', self._force_delete)
+ builder.macro("force_delete", self._force_delete)
def _force_delete(self, builder):
"""
@@ -85,7 +83,7 @@ def _add_restore(self, builder):
:param builder: The query builder
:type builder: orator.orm.builder.Builder
"""
- builder.macro('restore', self._restore)
+ builder.macro("restore", self._restore)
def _restore(self, builder):
"""
@@ -96,9 +94,7 @@ def _restore(self, builder):
"""
builder.with_trashed()
- return builder.update({
- builder.get_model().get_deleted_at_column(): None
- })
+ return builder.update({builder.get_model().get_deleted_at_column(): None})
def _add_with_trashed(self, builder):
"""
@@ -107,7 +103,7 @@ def _add_with_trashed(self, builder):
:param builder: The query builder
:type builder: orator.orm.builder.Builder
"""
- builder.macro('with_trashed', self._with_trashed)
+ builder.macro("with_trashed", self._with_trashed)
def _with_trashed(self, builder):
"""
@@ -127,7 +123,7 @@ def _add_only_trashed(self, builder):
:param builder: The query builder
:type builder: orator.orm.builder.Builder
"""
- builder.macro('only_trashed', self._only_trashed)
+ builder.macro("only_trashed", self._only_trashed)
def _only_trashed(self, builder):
"""
diff --git a/orator/orm/utils.py b/orator/orm/utils.py
index 449b4146..2eae9ac2 100644
--- a/orator/orm/utils.py
+++ b/orator/orm/utils.py
@@ -6,14 +6,19 @@
from .builder import Builder
from ..query import QueryBuilder
from .relations import (
- HasOne, HasMany, HasManyThrough,
- BelongsTo, BelongsToMany,
- MorphOne, MorphMany, MorphTo, MorphToMany
+ HasOne,
+ HasMany,
+ HasManyThrough,
+ BelongsTo,
+ BelongsToMany,
+ MorphOne,
+ MorphMany,
+ MorphTo,
+ MorphToMany,
)
class accessor(object):
-
def __init__(self, accessor_, attribute=None):
self.accessor = accessor_
self.mutator_ = None
@@ -48,7 +53,6 @@ def mutator(self, f):
class mutator(object):
-
def __init__(self, mutator_, attribute=None):
self.mutator = mutator_
self.accessor_ = None
@@ -73,7 +77,6 @@ def accessor(self, f):
class column(object):
-
def __init__(self, property_, attribute=None):
self.property = property_
self.mutator_ = None
@@ -231,7 +234,7 @@ def _get(self, instance):
self._foreign_key,
self._local_key,
self._relation,
- _wrapped=False
+ _wrapped=False,
)
@@ -242,9 +245,11 @@ class morph_one(relation):
relation_class = MorphOne
- def __init__(self, name, type_column=None, id_column=None, local_key=None, relation=None):
+ def __init__(
+ self, name, type_column=None, id_column=None, local_key=None, relation=None
+ ):
if isinstance(name, (types.FunctionType, types.MethodType)):
- raise RuntimeError('morph_one relation requires a name')
+ raise RuntimeError("morph_one relation requires a name")
self._name = name
self._type_column = type_column
@@ -255,10 +260,13 @@ def __init__(self, name, type_column=None, id_column=None, local_key=None, relat
def _get(self, instance):
return instance.morph_one(
- self._related, self._name,
- self._type_column, self._id_column,
- self._local_key, self._relation,
- _wrapped=False
+ self._related,
+ self._name,
+ self._type_column,
+ self._id_column,
+ self._local_key,
+ self._relation,
+ _wrapped=False,
)
@@ -287,7 +295,7 @@ def _get(self, instance):
self._foreign_key,
self._other_key,
self._relation,
- _wrapped=False
+ _wrapped=False,
)
def _set(self, relation):
@@ -318,9 +326,7 @@ def __init__(self, name=None, type_column=None, id_column=None):
def _get(self, instance):
return instance.morph_to(
- self._relation,
- self._type_column, self._id_column,
- _wrapped=False
+ self._relation, self._type_column, self._id_column, _wrapped=False
)
@@ -349,7 +355,7 @@ def _get(self, instance):
self._foreign_key,
self._local_key,
self._relation,
- _wrapped=False
+ _wrapped=False,
)
@@ -362,7 +368,9 @@ class has_many_through(relation):
def __init__(self, through, first_key=None, second_key=None, relation=None):
if isinstance(through, (types.FunctionType, types.MethodType)):
- raise RuntimeError('has_many_through relation requires the through parameter')
+ raise RuntimeError(
+ "has_many_through relation requires the through parameter"
+ )
self._through = through
self._first_key = first_key
@@ -377,7 +385,7 @@ def _get(self, instance):
self._first_key,
self._second_key,
self._relation,
- _wrapped=False
+ _wrapped=False,
)
@@ -388,9 +396,11 @@ class morph_many(relation):
relation_class = MorphMany
- def __init__(self, name, type_column=None, id_column=None, local_key=None, relation=None):
+ def __init__(
+ self, name, type_column=None, id_column=None, local_key=None, relation=None
+ ):
if isinstance(name, (types.FunctionType, types.MethodType)):
- raise RuntimeError('morph_many relation requires a name')
+ raise RuntimeError("morph_many relation requires a name")
self._name = name
self._type_column = type_column
@@ -401,10 +411,13 @@ def __init__(self, name, type_column=None, id_column=None, local_key=None, relat
def _get(self, instance):
return instance.morph_many(
- self._related, self._name,
- self._type_column, self._id_column,
- self._local_key, self._relation,
- _wrapped=False
+ self._related,
+ self._name,
+ self._type_column,
+ self._id_column,
+ self._local_key,
+ self._relation,
+ _wrapped=False,
)
@@ -415,8 +428,15 @@ class belongs_to_many(relation):
relation_class = BelongsToMany
- def __init__(self, table=None, foreign_key=None, other_key=None,
- relation=None, with_timestamps=False, with_pivot=None):
+ def __init__(
+ self,
+ table=None,
+ foreign_key=None,
+ other_key=None,
+ relation=None,
+ with_timestamps=False,
+ with_pivot=None,
+ ):
if isinstance(table, (types.FunctionType, types.MethodType)):
func = table
table = None
@@ -439,7 +459,7 @@ def _get(self, instance):
self._foreign_key,
self._other_key,
self._relation,
- _wrapped=False
+ _wrapped=False,
)
if self._timestamps:
@@ -458,9 +478,11 @@ class morph_to_many(relation):
relation_class = MorphToMany
- def __init__(self, name, table=None, foreign_key=None, other_key=None, relation=None):
+ def __init__(
+ self, name, table=None, foreign_key=None, other_key=None, relation=None
+ ):
if isinstance(name, (types.FunctionType, types.MethodType)):
- raise RuntimeError('morph_to_many relation required a name')
+ raise RuntimeError("morph_to_many relation required a name")
self._name = name
self._table = table
@@ -477,7 +499,7 @@ def _get(self, instance):
self._foreign_key,
self._other_key,
relation=self._relation,
- _wrapped=False
+ _wrapped=False,
)
@@ -488,9 +510,11 @@ class morphed_by_many(relation):
relation_class = MorphToMany
- def __init__(self, name, table=None, foreign_key=None, other_key=None, relation=None):
+ def __init__(
+ self, name, table=None, foreign_key=None, other_key=None, relation=None
+ ):
if isinstance(foreign_key, (types.FunctionType, types.MethodType)):
- raise RuntimeError('morphed_by_many relation requires a name')
+ raise RuntimeError("morphed_by_many relation requires a name")
self._name = name
self._table = table
@@ -507,5 +531,5 @@ def _get(self, instance):
self._foreign_key,
self._other_key,
self._relation,
- _wrapped=False
+ _wrapped=False,
)
diff --git a/orator/pagination/length_aware_paginator.py b/orator/pagination/length_aware_paginator.py
index 128154b4..c9a910ed 100644
--- a/orator/pagination/length_aware_paginator.py
+++ b/orator/pagination/length_aware_paginator.py
@@ -9,7 +9,6 @@
class LengthAwarePaginator(BasePaginator):
-
def __init__(self, items, total, per_page, current_page=None, options=None):
"""
Constructor
diff --git a/orator/pagination/paginator.py b/orator/pagination/paginator.py
index f3041c98..3b0f04f7 100644
--- a/orator/pagination/paginator.py
+++ b/orator/pagination/paginator.py
@@ -6,7 +6,6 @@
class Paginator(BasePaginator):
-
def __init__(self, items, per_page, current_page=None, options=None):
"""
Constructor
@@ -59,7 +58,7 @@ def _check_for_more_pages(self):
"""
self._has_more = len(self._items) > self.per_page
- self._items = self._items[0:self.per_page]
+ self._items = self._items[0 : self.per_page]
def has_more_pages(self):
"""
diff --git a/orator/query/builder.py b/orator/query/builder.py
index 60f59c10..6e8aaa6c 100644
--- a/orator/query/builder.py
+++ b/orator/query/builder.py
@@ -2,8 +2,11 @@
import re
import copy
+import datetime
+
from itertools import chain
from collections import OrderedDict
+
from .expression import QueryExpression
from .join_clause import JoinClause
from ..pagination import Paginator, LengthAwarePaginator
@@ -15,12 +18,32 @@
class QueryBuilder(object):
_operators = [
- '=', '<', '>', '<=', '>=', '<>', '!=',
- 'like', 'like binary', 'not like', 'between', 'ilike',
- '&', '|', '^', '<<', '>>',
- 'rlike', 'regexp', 'not regexp',
- '~', '~*', '!~', '!~*', 'similar to',
- 'not similar to',
+ "=",
+ "<",
+ ">",
+ "<=",
+ ">=",
+ "<>",
+ "!=",
+ "like",
+ "like binary",
+ "not like",
+ "between",
+ "ilike",
+ "&",
+ "|",
+ "^",
+ "<<",
+ ">>",
+ "rlike",
+ "regexp",
+ "not regexp",
+ "~",
+ "~*",
+ "!~",
+ "!~*",
+ "similar to",
+ "not similar to",
]
def __init__(self, connection, grammar, processor):
@@ -40,13 +63,13 @@ def __init__(self, connection, grammar, processor):
self._processor = processor
self._connection = connection
self._bindings = OrderedDict()
- for type in ['select', 'join', 'where', 'having', 'order']:
+ for type in ["select", "join", "where", "having", "order"]:
self._bindings[type] = []
self.aggregate_ = None
self.columns = []
self.distinct_ = False
- self.from__ = ''
+ self.from__ = ""
self.joins = []
self.wheres = []
self.groups = []
@@ -75,7 +98,7 @@ def select(self, *columns):
:rtype: QueryBuilder
"""
if not columns:
- columns = ['*']
+ columns = ["*"]
self.columns = list(columns)
@@ -97,7 +120,7 @@ def select_raw(self, expression, bindings=None):
self.add_select(QueryExpression(expression))
if bindings:
- self.add_binding(bindings, 'select')
+ self.add_binding(bindings, "select")
return self
@@ -121,9 +144,11 @@ def select_sub(self, query, as_):
elif isinstance(query, basestring):
bindings = []
else:
- raise ArgumentError('Invalid subselect')
+ raise ArgumentError("Invalid subselect")
- return self.select_raw('(%s) AS %s' % (query, self._grammar.wrap(as_)), bindings)
+ return self.select_raw(
+ "(%s) AS %s" % (query, self._grammar.wrap(as_)), bindings
+ )
def add_select(self, *column):
"""
@@ -167,8 +192,7 @@ def from_(self, table):
return self
- def join(self, table, one=None,
- operator=None, two=None, type='inner', where=False):
+ def join(self, table, one=None, operator=None, two=None, type="inner", where=False):
"""
Add a join clause to the query
@@ -201,13 +225,11 @@ def join(self, table, one=None,
join = JoinClause(table, type)
- self.joins.append(join.on(
- one, operator, two, 'and', where
- ))
+ self.joins.append(join.on(one, operator, two, "and", where))
return self
- def join_where(self, table, one, operator, two, type='inner'):
+ def join_where(self, table, one, operator, two, type="inner"):
"""
Add a "join where" clause to the query
@@ -251,9 +273,9 @@ def left_join(self, table, one=None, operator=None, two=None):
:rtype: QueryBuilder
"""
if isinstance(table, JoinClause):
- table.type = 'left'
+ table.type = "left"
- return self.join(table, one, operator, two, 'left')
+ return self.join(table, one, operator, two, "left")
def left_join_where(self, table, one, operator, two):
"""
@@ -274,7 +296,7 @@ def left_join_where(self, table, one, operator, two):
:return: The current QueryBuilder instance
:rtype: QueryBuilder
"""
- return self.join_where(table, one, operator, two, 'left')
+ return self.join_where(table, one, operator, two, "left")
def right_join(self, table, one=None, operator=None, two=None):
"""
@@ -296,9 +318,9 @@ def right_join(self, table, one=None, operator=None, two=None):
:rtype: QueryBuilder
"""
if isinstance(table, JoinClause):
- table.type = 'right'
+ table.type = "right"
- return self.join(table, one, operator, two, 'right')
+ return self.join(table, one, operator, two, "right")
def right_join_where(self, table, one, operator, two):
"""
@@ -319,9 +341,9 @@ def right_join_where(self, table, one, operator, two):
:return: The current QueryBuilder instance
:rtype: QueryBuilder
"""
- return self.join_where(table, one, operator, two, 'right')
+ return self.join_where(table, one, operator, two, "right")
- def where(self, column, operator=Null(), value=None, boolean='and'):
+ def where(self, column, operator=Null(), value=None, boolean="and"):
"""
Add a where clause to the query
@@ -346,7 +368,7 @@ def where(self, column, operator=Null(), value=None, boolean='and'):
if isinstance(column, dict):
nested = self.new_query()
for key, value in column.items():
- nested.where(key, '=', value)
+ nested.where(key, "=", value)
return self.where_nested(nested, boolean)
@@ -359,89 +381,84 @@ def where(self, column, operator=Null(), value=None, boolean='and'):
if isinstance(condition, list) and len(condition) == 3:
nested.where(condition[0], condition[1], condition[2])
else:
- raise ArgumentError('Invalid conditions in where() clause')
+ raise ArgumentError("Invalid conditions in where() clause")
return self.where_nested(nested, boolean)
if value is None:
if not isinstance(operator, Null):
value = operator
- operator = '='
+ operator = "="
else:
- raise ArgumentError('Value must be provided')
+ raise ArgumentError("Value must be provided")
if operator not in self._operators:
value = operator
- operator = '='
+ operator = "="
if isinstance(value, QueryBuilder):
return self._where_sub(column, operator, value, boolean)
if value is None:
- return self.where_null(column, boolean, operator != '=')
-
- type = 'basic'
-
- self.wheres.append({
- 'type': type,
- 'column': column,
- 'operator': operator,
- 'value': value,
- 'boolean': boolean
- })
+ return self.where_null(column, boolean, operator != "=")
+
+ type = "basic"
+
+ self.wheres.append(
+ {
+ "type": type,
+ "column": column,
+ "operator": operator,
+ "value": value,
+ "boolean": boolean,
+ }
+ )
if not isinstance(value, QueryExpression):
- self.add_binding(value, 'where')
+ self.add_binding(value, "where")
return self
def or_where(self, column, operator=None, value=None):
- return self.where(column, operator, value, 'or')
+ return self.where(column, operator, value, "or")
def _invalid_operator_and_value(self, operator, value):
is_operator = operator in self._operators
- return is_operator and operator != '=' and value is None
+ return is_operator and operator != "=" and value is None
- def where_raw(self, sql, bindings=None, boolean='and'):
- type = 'raw'
+ def where_raw(self, sql, bindings=None, boolean="and"):
+ type = "raw"
- self.wheres.append({
- 'type': type,
- 'sql': sql,
- 'boolean': boolean
- })
+ self.wheres.append({"type": type, "sql": sql, "boolean": boolean})
- self.add_binding(bindings, 'where')
+ self.add_binding(bindings, "where")
return self
def or_where_raw(self, sql, bindings=None):
- return self.where_raw(sql, bindings, 'or')
+ return self.where_raw(sql, bindings, "or")
- def where_between(self, column, values, boolean='and', negate=False):
- type = 'between'
+ def where_between(self, column, values, boolean="and", negate=False):
+ type = "between"
- self.wheres.append({
- 'column': column,
- 'type': type,
- 'boolean': boolean,
- 'not': negate
- })
+ self.wheres.append(
+ {"column": column, "type": type, "boolean": boolean, "not": negate}
+ )
- self.add_binding(values, 'where')
+ self.add_binding(values, "where")
return self
def or_where_between(self, column, values):
- return self.where_between(column, values, 'or')
+ return self.where_between(column, values, "or")
- def where_not_between(self, column, values, boolean='and'):
+ def where_not_between(self, column, values, boolean="and"):
return self.where_between(column, values, boolean, True)
def or_where_not_between(self, column, values):
- return self.where_not_between(column, values, 'or')
+ return self.where_not_between(column, values, "or")
- def where_nested(self, query, boolean='and'):
+ def where_nested(self, query, boolean="and"):
query.from_(self.from__)
return self.add_nested_where_query(query, boolean)
@@ -456,36 +473,34 @@ def for_nested_where(self):
return query.from_(self.from__)
- def add_nested_where_query(self, query, boolean='and'):
+ def add_nested_where_query(self, query, boolean="and"):
if len(query.wheres):
- type = 'nested'
+ type = "nested"
- self.wheres.append({
- 'type': type,
- 'query': query,
- 'boolean': boolean
- })
+ self.wheres.append({"type": type, "query": query, "boolean": boolean})
self.merge_bindings(query)
return self
def _where_sub(self, column, operator, query, boolean):
- type = 'sub'
-
- self.wheres.append({
- 'type': type,
- 'column': column,
- 'operator': operator,
- 'query': query,
- 'boolean': boolean
- })
+ type = "sub"
+
+ self.wheres.append(
+ {
+ "type": type,
+ "column": column,
+ "operator": operator,
+ "query": query,
+ "boolean": boolean,
+ }
+ )
self.merge_bindings(query)
return self
- def where_exists(self, query, boolean='and', negate=False):
+ def where_exists(self, query, boolean="and", negate=False):
"""
Add an exists clause to the query.
@@ -499,15 +514,11 @@ def where_exists(self, query, boolean='and', negate=False):
:rtype: QueryBuilder
"""
if negate:
- type = 'not_exists'
+ type = "not_exists"
else:
- type = 'exists'
+ type = "exists"
- self.wheres.append({
- 'type': type,
- 'query': query,
- 'boolean': boolean
- })
+ self.wheres.append({"type": type, "query": query, "boolean": boolean})
self.merge_bindings(query)
@@ -524,9 +535,9 @@ def or_where_exists(self, query, negate=False):
:rtype: QueryBuilder
"""
- return self.where_exists(query, 'or', negate)
+ return self.where_exists(query, "or", negate)
- def where_not_exists(self, query, boolean='and'):
+ def where_not_exists(self, query, boolean="and"):
"""
Add a where not exists clause to the query.
@@ -550,11 +561,11 @@ def or_where_not_exists(self, query):
"""
return self.or_where_exists(query, True)
- def where_in(self, column, values, boolean='and', negate=False):
+ def where_in(self, column, values, boolean="and", negate=False):
if negate:
- type = 'not_in'
+ type = "not_in"
else:
- type = 'in'
+ type = "in"
if isinstance(values, QueryBuilder):
return self._where_in_sub(column, values, boolean, negate)
@@ -562,25 +573,22 @@ def where_in(self, column, values, boolean='and', negate=False):
if isinstance(values, Collection):
values = values.all()
- self.wheres.append({
- 'type': type,
- 'column': column,
- 'values': values,
- 'boolean': boolean
- })
+ self.wheres.append(
+ {"type": type, "column": column, "values": values, "boolean": boolean}
+ )
- self.add_binding(values, 'where')
+ self.add_binding(values, "where")
return self
def or_where_in(self, column, values):
- return self.where_in(column, values, 'or')
+ return self.where_in(column, values, "or")
- def where_not_in(self, column, values, boolean='and'):
+ def where_not_in(self, column, values, boolean="and"):
return self.where_in(column, values, boolean, True)
def or_where_not_in(self, column, values):
- return self.where_not_in(column, values, 'or')
+ return self.where_not_in(column, values, "or")
def _where_in_sub(self, column, query, boolean, negate=False):
"""
@@ -602,79 +610,74 @@ def _where_in_sub(self, column, query, boolean, negate=False):
:rtype: QueryBuilder
"""
if negate:
- type = 'not_in_sub'
+ type = "not_in_sub"
else:
- type = 'in_sub'
+ type = "in_sub"
- self.wheres.append({
- 'type': type,
- 'column': column,
- 'query': query,
- 'boolean': boolean
- })
+ self.wheres.append(
+ {"type": type, "column": column, "query": query, "boolean": boolean}
+ )
self.merge_bindings(query)
return self
- def where_null(self, column, boolean='and', negate=False):
+ def where_null(self, column, boolean="and", negate=False):
if negate:
- type = 'not_null'
+ type = "not_null"
else:
- type = 'null'
+ type = "null"
- self.wheres.append({
- 'type': type,
- 'column': column,
- 'boolean': boolean
- })
+ self.wheres.append({"type": type, "column": column, "boolean": boolean})
return self
def or_where_null(self, column):
- return self.where_null(column, 'or')
+ return self.where_null(column, "or")
- def where_not_null(self, column, boolean='and'):
+ def where_not_null(self, column, boolean="and"):
return self.where_null(column, boolean, True)
def or_where_not_null(self, column):
- return self.where_not_null(column, 'or')
+ return self.where_not_null(column, "or")
- def where_date(self, column, operator, value, boolean='and'):
- return self._add_date_based_where('date', column, operator, value, boolean)
+ def where_date(self, column, operator, value, boolean="and"):
+ return self._add_date_based_where("date", column, operator, value, boolean)
- def where_day(self, column, operator, value, boolean='and'):
- return self._add_date_based_where('day', column, operator, value, boolean)
+ def where_day(self, column, operator, value, boolean="and"):
+ return self._add_date_based_where("day", column, operator, value, boolean)
- def where_month(self, column, operator, value, boolean='and'):
- return self._add_date_based_where('month', column, operator, value, boolean)
+ def where_month(self, column, operator, value, boolean="and"):
+ return self._add_date_based_where("month", column, operator, value, boolean)
- def where_year(self, column, operator, value, boolean='and'):
- return self._add_date_based_where('year', column, operator, value, boolean)
+ def where_year(self, column, operator, value, boolean="and"):
+ return self._add_date_based_where("year", column, operator, value, boolean)
- def _add_date_based_where(self, type, column, operator, value, boolean='and'):
- self.wheres.append({
- 'type': type,
- 'column': column,
- 'boolean': boolean,
- 'operator': operator,
- 'value': value
- })
+ def _add_date_based_where(self, type, column, operator, value, boolean="and"):
+ self.wheres.append(
+ {
+ "type": type,
+ "column": column,
+ "boolean": boolean,
+ "operator": operator,
+ "value": value,
+ }
+ )
- self.add_binding(value, 'where')
+ self.add_binding(value, "where")
def dynamic_where(self, method):
finder = method[6:]
def dynamic_where(*parameters):
- segments = re.split('_(and|or)_(?=[a-z])', finder, 0, re.I)
+ segments = re.split("_(and|or)_(?=[a-z])", finder, 0, re.I)
- connector = 'and'
+ connector = "and"
index = 0
for segment in segments:
- if segment.lower() != 'and' and segment.lower() != 'or':
+ if segment.lower() != "and" and segment.lower() != "or":
self._add_dynamic(segment, connector, parameters, index)
index += 1
@@ -686,7 +689,7 @@ def dynamic_where(*parameters):
return dynamic_where
def _add_dynamic(self, segment, connector, parameters, index):
- self.where(segment, '=', parameters[index], connector)
+ self.where(segment, "=", parameters[index], connector)
def group_by(self, *columns):
"""
@@ -703,7 +706,7 @@ def group_by(self, *columns):
return self
- def having(self, column, operator=None, value=None, boolean='and'):
+ def having(self, column, operator=None, value=None, boolean="and"):
"""
Add a "having" clause to the query
@@ -722,18 +725,20 @@ def having(self, column, operator=None, value=None, boolean='and'):
:return: The current QueryBuilder instance
:rtype: QueryBuilder
"""
- type = 'basic'
+ type = "basic"
- self.havings.append({
- 'type': type,
- 'column': column,
- 'operator': operator,
- 'value': value,
- 'boolean': boolean
- })
+ self.havings.append(
+ {
+ "type": type,
+ "column": column,
+ "operator": operator,
+ "value": value,
+ "boolean": boolean,
+ }
+ )
if not isinstance(value, QueryExpression):
- self.add_binding(value, 'having')
+ self.add_binding(value, "having")
return self
@@ -753,9 +758,9 @@ def or_having(self, column, operator=None, value=None):
:return: The current QueryBuilder instance
:rtype: QueryBuilder
"""
- return self.having(column, operator, value, 'or')
+ return self.having(column, operator, value, "or")
- def having_raw(self, sql, bindings=None, boolean='and'):
+ def having_raw(self, sql, bindings=None, boolean="and"):
"""
Add a raw having clause to the query
@@ -771,15 +776,11 @@ def having_raw(self, sql, bindings=None, boolean='and'):
:return: The current QueryBuilder instance
:rtype: QueryBuilder
"""
- type = 'raw'
+ type = "raw"
- self.havings.append({
- 'type': type,
- 'sql': sql,
- 'boolean': boolean
- })
+ self.havings.append({"type": type, "sql": sql, "boolean": boolean})
- self.add_binding(bindings, 'having')
+ self.add_binding(bindings, "having")
return self
@@ -796,9 +797,9 @@ def or_having_raw(self, sql, bindings=None):
:return: The current QueryBuilder instance
:rtype: QueryBuilder
"""
- return self.having_raw(sql, bindings, 'or')
+ return self.having_raw(sql, bindings, "or")
- def order_by(self, column, direction='asc'):
+ def order_by(self, column, direction="asc"):
"""
Add a "order by" clause to the query
@@ -812,23 +813,20 @@ def order_by(self, column, direction='asc'):
:rtype: QueryBuilder
"""
if self.unions:
- prop = 'union_orders'
+ prop = "union_orders"
else:
- prop = 'orders'
+ prop = "orders"
- if direction.lower() == 'asc':
- direction = 'asc'
+ if direction.lower() == "asc":
+ direction = "asc"
else:
- direction = 'desc'
+ direction = "desc"
- getattr(self, prop).append({
- 'column': column,
- 'direction': direction
- })
+ getattr(self, prop).append({"column": column, "direction": direction})
return self
- def latest(self, column='created_at'):
+ def latest(self, column="created_at"):
"""
Add an "order by" clause for a timestamp to the query
in descending order
@@ -839,9 +837,9 @@ def latest(self, column='created_at'):
:return: The current QueryBuilder instance
:rtype: QueryBuilder
"""
- return self.order_by(column, 'desc')
+ return self.order_by(column, "desc")
- def oldest(self, column='created_at'):
+ def oldest(self, column="created_at"):
"""
Add an "order by" clause for a timestamp to the query
in ascending order
@@ -852,7 +850,7 @@ def oldest(self, column='created_at'):
:return: The current QueryBuilder instance
:rtype: QueryBuilder
"""
- return self.order_by(column, 'asc')
+ return self.order_by(column, "asc")
def order_by_raw(self, sql, bindings=None):
"""
@@ -870,22 +868,19 @@ def order_by_raw(self, sql, bindings=None):
if bindings is None:
bindings = []
- type = 'raw'
+ type = "raw"
- self.orders.append({
- 'type': type,
- 'sql': sql
- })
+ self.orders.append({"type": type, "sql": sql})
- self.add_binding(bindings, 'order')
+ self.add_binding(bindings, "order")
return self
def offset(self, value):
if self.unions:
- prop = 'union_offset'
+ prop = "union_offset"
else:
- prop = 'offset_'
+ prop = "offset_"
setattr(self, prop, max(0, value))
@@ -896,9 +891,9 @@ def skip(self, value):
def limit(self, value):
if self.unions:
- prop = 'union_limit'
+ prop = "union_limit"
else:
- prop = 'limit_'
+ prop = "limit_"
if value is None or value > 0:
setattr(self, prop, value)
@@ -924,10 +919,7 @@ def union(self, query, all=False):
:return: The query
:rtype: QueryBuilder
"""
- self.unions.append({
- 'query': query,
- 'all': all
- })
+ self.unions.append({"query": query, "all": all})
return self.merge_bindings(query)
@@ -998,9 +990,9 @@ def find(self, id, columns=None):
:rtype: mixed
"""
if not columns:
- columns = ['*']
+ columns = ["*"]
- return self.where('id', '=', id).first(1, columns)
+ return self.where("id", "=", id).first(1, columns)
def pluck(self, column):
"""
@@ -1033,7 +1025,7 @@ def first(self, limit=1, columns=None):
:rtype: mixed
"""
if not columns:
- columns = ['*']
+ columns = ["*"]
return self.take(limit).get(columns).first()
@@ -1048,7 +1040,7 @@ def get(self, columns=None):
:rtype: Collection
"""
if not columns:
- columns = ['*']
+ columns = ["*"]
original = self.columns
@@ -1069,9 +1061,7 @@ def _run_select(self):
:rtype: list
"""
return self._connection.select(
- self.to_sql(),
- self.get_bindings(),
- not self._use_write_connection
+ self.to_sql(), self.get_bindings(), not self._use_write_connection
)
def paginate(self, per_page=15, current_page=None, columns=None):
@@ -1091,7 +1081,7 @@ def paginate(self, per_page=15, current_page=None, columns=None):
:rtype: LengthAwarePaginator
"""
if columns is None:
- columns = ['*']
+ columns = ["*"]
page = current_page or Paginator.resolve_current_page()
@@ -1118,7 +1108,7 @@ def simple_paginate(self, per_page=15, current_page=None, columns=None):
:rtype: Paginator
"""
if columns is None:
- columns = ['*']
+ columns = ["*"]
page = current_page or Paginator.resolve_current_page()
@@ -1136,14 +1126,20 @@ def get_count_for_pagination(self):
return total
def _backup_fields_for_count(self):
- for field in ['orders', 'limit', 'offset']:
- self._backups[field] = getattr(self, field)
+ for field, binding in [("orders", "order"), ("limit", None), ("offset", None)]:
+ self._backups[field] = {}
+ self._backups[field]["query"] = getattr(self, field)
+ if binding is not None:
+ self._backups[field]["binding"] = self.get_raw_bindings()[binding]
+ self.set_bindings([], binding)
setattr(self, field, None)
def _restore_fields_for_count(self):
- for field in ['orders', 'limit', 'offset']:
- setattr(self, field, self._backups[field])
+ for field, binding in [("orders", "order"), ("limit", None), ("offset", None)]:
+ setattr(self, field, self._backups[field]["query"])
+ if binding is not None and self._backups[field]["binding"] is not None:
+ self.add_binding(self._backups[field]["binding"], binding)
self._backups = {}
@@ -1157,15 +1153,10 @@ def chunk(self, count):
:return: The current chunk
:rtype: list
"""
- page = 1
- results = self.for_page(page, count).get()
-
- while not results.is_empty():
- yield results
-
- page += 1
-
- results = self.for_page(page, count).get()
+ for chunk in self._connection.select_many(
+ count, self.to_sql(), self.get_bindings(), not self._use_write_connection
+ ):
+ yield chunk
def lists(self, column, key=None):
"""
@@ -1211,16 +1202,16 @@ def _get_list_select(self, column, key=None):
select = []
for elem in elements:
- dot = elem.find('.')
+ dot = elem.find(".")
if dot >= 0:
- select.append(column[dot + 1:])
+ select.append(column[dot + 1 :])
else:
select.append(elem)
return select
- def implode(self, column, glue=''):
+ def implode(self, column, glue=""):
"""
Concatenate values of a given column as a string.
@@ -1260,10 +1251,13 @@ def count(self, *columns):
:return: The count
:rtype: int
"""
+ if not columns and self.distinct_:
+ columns = self.columns
+
if not columns:
- columns = ['*']
+ columns = ["*"]
- return int(self.aggregate('count', *columns))
+ return int(self.aggregate("count", *columns))
def min(self, column):
"""
@@ -1275,7 +1269,7 @@ def min(self, column):
:return: The min
:rtype: int
"""
- return self.aggregate('min', *[column])
+ return self.aggregate("min", *[column])
def max(self, column):
"""
@@ -1288,9 +1282,9 @@ def max(self, column):
:rtype: int
"""
if not column:
- columns = ['*']
+ columns = ["*"]
- return self.aggregate('max', *[column])
+ return self.aggregate("max", *[column])
def sum(self, column):
"""
@@ -1302,7 +1296,7 @@ def sum(self, column):
:return: The sum
:rtype: int
"""
- return self.aggregate('sum', *[column])
+ return self.aggregate("sum", *[column])
def avg(self, column):
"""
@@ -1315,7 +1309,7 @@ def avg(self, column):
:rtype: int
"""
- return self.aggregate('avg', *[column])
+ return self.aggregate("avg", *[column])
def aggregate(self, func, *columns):
"""
@@ -1331,12 +1325,9 @@ def aggregate(self, func, *columns):
:rtype: mixed
"""
if not columns:
- columns = ['*']
+ columns = ["*"]
- self.aggregate_ = {
- 'function': func,
- 'columns': columns
- }
+ self.aggregate_ = {"function": func, "columns": columns}
previous_columns = self.columns
@@ -1347,7 +1338,7 @@ def aggregate(self, func, *columns):
self.columns = previous_columns
if len(results) > 0:
- return dict((k.lower(), v) for k, v in results[0].items())['aggregate']
+ return dict((k.lower(), v) for k, v in results[0].items())["aggregate"]
def insert(self, _values=None, **values):
"""
@@ -1450,9 +1441,7 @@ def increment(self, column, amount=1, extras=None):
if extras is None:
extras = {}
- columns = {
- column: self.raw('%s + %s' % (wrapped, amount))
- }
+ columns = {column: self.raw("%s + %s" % (wrapped, amount))}
columns.update(extras)
return self.update(**columns)
@@ -1478,9 +1467,7 @@ def decrement(self, column, amount=1, extras=None):
if extras is None:
extras = {}
- columns = {
- column: self.raw('%s - %s' % (wrapped, amount))
- }
+ columns = {column: self.raw("%s - %s" % (wrapped, amount))}
columns.update(extras)
return self.update(**columns)
@@ -1496,7 +1483,7 @@ def delete(self, id=None):
:rtype: int
"""
if id is not None:
- self.where('id', '=', id)
+ self.where("id", "=", id)
sql = self._grammar.compile_delete(self)
@@ -1533,7 +1520,7 @@ def merge_wheres(self, wheres, bindings):
:rtype: None
"""
self.wheres = self.wheres + wheres
- self._bindings['where'] = self._bindings['where'] + bindings
+ self._bindings["where"] = self._bindings["where"] + bindings
def _clean_bindings(self, bindings):
"""
@@ -1560,25 +1547,32 @@ def raw(self, value):
return self._connection.raw(value)
def get_bindings(self):
- return list(chain(*self._bindings.values()))
+ bindings = []
+ for value in chain(*self._bindings.values()):
+ if isinstance(value, datetime.date):
+ value = value.strftime(self._grammar.get_date_format())
+
+ bindings.append(value)
+
+ return bindings
def get_raw_bindings(self):
return self._bindings
- def set_bindings(self, bindings, type='where'):
+ def set_bindings(self, bindings, type="where"):
if type not in self._bindings:
- raise ArgumentError('Invalid binding type: %s' % type)
+ raise ArgumentError("Invalid binding type: %s" % type)
self._bindings[type] = bindings
return self
- def add_binding(self, value, type='where'):
+ def add_binding(self, value, type="where"):
if value is None:
return self
if type not in self._bindings:
- raise ArgumentError('Invalid binding type: %s' % type)
+ raise ArgumentError("Invalid binding type: %s" % type)
if isinstance(value, (list, tuple)):
self._bindings[type] += value
@@ -1606,6 +1600,7 @@ def merge(self, query):
self.groups += query.groups
self.havings += query.havings
self.orders += query.orders
+ self.distinct_ = query.distinct_
if self.columns:
self.columns = Collection(self.columns).unique().all()
@@ -1661,7 +1656,7 @@ def use_write_connection(self):
return self
def __getattr__(self, item):
- if item.startswith('where_'):
+ if item.startswith("where_"):
return self.dynamic_where(item)
raise AttributeError(item)
@@ -1669,8 +1664,15 @@ def __getattr__(self, item):
def __copy__(self):
new = self.__class__(self._connection, self._grammar, self._processor)
- new.__dict__.update(dict((k, copy.deepcopy(v)) for k, v
- in self.__dict__.items()
- if k != '_connection'))
+ new.__dict__.update(
+ dict(
+ (k, copy.deepcopy(v))
+ for k, v in self.__dict__.items()
+ if k != "_connection"
+ )
+ )
return new
+
+ def __deepcopy__(self, memo):
+ return self.__copy__()
diff --git a/orator/query/expression.py b/orator/query/expression.py
index 911ba33f..711de337 100644
--- a/orator/query/expression.py
+++ b/orator/query/expression.py
@@ -2,7 +2,6 @@
class QueryExpression(object):
-
def __init__(self, value):
self._value = value
diff --git a/orator/query/grammars/grammar.py b/orator/query/grammars/grammar.py
index 58e26db9..ffd1d74d 100644
--- a/orator/query/grammars/grammar.py
+++ b/orator/query/grammars/grammar.py
@@ -9,23 +9,23 @@
class QueryGrammar(Grammar):
_select_components = [
- 'aggregate_',
- 'columns',
- 'from__',
- 'joins',
- 'wheres',
- 'groups',
- 'havings',
- 'orders',
- 'limit_',
- 'offset_',
- 'unions',
- 'lock_'
+ "aggregate_",
+ "columns",
+ "from__",
+ "joins",
+ "wheres",
+ "groups",
+ "havings",
+ "orders",
+ "limit_",
+ "offset_",
+ "unions",
+ "lock_",
]
def compile_select(self, query):
if not query.columns:
- query.columns = ['*']
+ query.columns = ["*"]
return self._concatenate(self._compile_components(query)).strip()
@@ -38,20 +38,19 @@ def _compile_components(self, query):
# function for the component which is responsible for making the SQL.
component_value = getattr(query, component)
if component_value is not None:
- method = '_compile_%s' % component.replace('_', '')
+ method = "_compile_%s" % component.replace("_", "")
sql[component] = getattr(self, method)(query, component_value)
return sql
def _compile_aggregate(self, query, aggregate):
- column = self.columnize(aggregate['columns'])
+ column = self.columnize(aggregate["columns"])
- if query.distinct_ and column != '*':
- column = 'DISTINCT %s' % column
+ if query.distinct_ and column != "*":
+ column = "DISTINCT %s" % column
- return 'SELECT %s(%s) AS aggregate' % (aggregate['function'].upper(),
- column)
+ return "SELECT %s(%s) AS aggregate" % (aggregate["function"].upper(), column)
def _compile_columns(self, query, columns):
# If the query is actually performing an aggregating select, we will let that
@@ -61,19 +60,19 @@ def _compile_columns(self, query, columns):
return
if query.distinct_:
- select = 'SELECT DISTINCT '
+ select = "SELECT DISTINCT "
else:
- select = 'SELECT '
+ select = "SELECT "
- return '%s%s' % (select, self.columnize(columns))
+ return "%s%s" % (select, self.columnize(columns))
def _compile_from(self, query, table):
- return 'FROM %s' % self.wrap_table(table)
+ return "FROM %s" % self.wrap_table(table)
def _compile_joins(self, query, joins):
sql = []
- query.set_bindings([], 'join')
+ query.set_bindings([], "join")
for join in joins:
table = self.wrap_table(join.table)
@@ -87,14 +86,14 @@ def _compile_joins(self, query, joins):
clauses.append(self._compile_join_constraints(clause))
for binding in join.bindings:
- query.add_binding(binding, 'join')
+ query.add_binding(binding, "join")
# Once we have constructed the clauses, we'll need to take the boolean connector
# off of the first clause as it obviously will not be required on that clause
# because it leads the rest of the clauses, thus not requiring any boolean.
clauses[0] = self._remove_leading_boolean(clauses[0])
- clauses = ' '.join(clauses)
+ clauses = " ".join(clauses)
type = join.type
@@ -102,212 +101,230 @@ def _compile_joins(self, query, joins):
# build the final join statement SQL for the query and we can then return the
# final clause back to the callers as a single, stringified join statement.
- sql.append('%s JOIN %s ON %s' % (type.upper(), table, clauses))
+ sql.append("%s JOIN %s ON %s" % (type.upper(), table, clauses))
- return ' '.join(sql)
+ return " ".join(sql)
def _compile_join_constraints(self, clause):
- first = self.wrap(clause['first'])
+ first = self.wrap(clause["first"])
- if clause['where']:
+ if clause["where"]:
second = self.get_marker()
else:
- second = self.wrap(clause['second'])
+ second = self.wrap(clause["second"])
- return '%s %s %s %s' % (clause['boolean'].upper(), first,
- clause['operator'], second)
+ return "%s %s %s %s" % (
+ clause["boolean"].upper(),
+ first,
+ clause["operator"],
+ second,
+ )
def _compile_wheres(self, query, _=None):
sql = []
if query.wheres is None:
- return ''
+ return ""
# Each type of where clauses has its own compiler function which is responsible
# for actually creating the where clauses SQL. This helps keep the code nice
# and maintainable since each clause has a very small method that it uses.
for where in query.wheres:
- method = '_where_%s' % where['type']
+ method = "_where_%s" % where["type"]
- sql.append('%s %s' % (where['boolean'].upper(),
- getattr(self, method)(query, where)))
+ sql.append(
+ "%s %s"
+ % (where["boolean"].upper(), getattr(self, method)(query, where))
+ )
# If we actually have some where clauses, we will strip off the first boolean
# operator, which is added by the query builders for convenience so we can
# avoid checking for the first clauses in each of the compilers methods.
if len(sql) > 0:
- sql = ' '.join(sql)
+ sql = " ".join(sql)
- return 'WHERE %s' % re.sub('AND |OR ', '', sql, 1, re.I)
+ return "WHERE %s" % re.sub("AND |OR ", "", sql, 1, re.I)
- return ''
+ return ""
def _where_nested(self, query, where):
- nested = where['query']
+ nested = where["query"]
- return '(%s)' % (self._compile_wheres(nested)[6:])
+ return "(%s)" % (self._compile_wheres(nested)[6:])
def _where_sub(self, query, where):
- select = self.compile_select(where['query'])
+ select = self.compile_select(where["query"])
- return '%s %s (%s)' % (self.wrap(where['column']),
- where['operator'], select)
+ return "%s %s (%s)" % (self.wrap(where["column"]), where["operator"], select)
def _where_basic(self, query, where):
- value = self.parameter(where['value'])
+ value = self.parameter(where["value"])
- return '%s %s %s' % (self.wrap(where['column']),
- where['operator'], value)
+ return "%s %s %s" % (self.wrap(where["column"]), where["operator"], value)
def _where_between(self, query, where):
- if where['not']:
- between = 'NOT BETWEEN'
+ if where["not"]:
+ between = "NOT BETWEEN"
else:
- between = 'BETWEEN'
+ between = "BETWEEN"
- return '%s %s %s AND %s' % (self.wrap(where['column']), between,
- self.get_marker(), self.get_marker())
+ return "%s %s %s AND %s" % (
+ self.wrap(where["column"]),
+ between,
+ self.get_marker(),
+ self.get_marker(),
+ )
def _where_exists(self, query, where):
- return 'EXISTS (%s)' % self.compile_select(where['query'])
+ return "EXISTS (%s)" % self.compile_select(where["query"])
def _where_not_exists(self, query, where):
- return 'NOT EXISTS (%s)' % self.compile_select(where['query'])
+ return "NOT EXISTS (%s)" % self.compile_select(where["query"])
def _where_in(self, query, where):
- if not where['values']:
- return '0 = 1'
+ if not where["values"]:
+ return "0 = 1"
- values = self.parameterize(where['values'])
+ values = self.parameterize(where["values"])
- return '%s IN (%s)' % (self.wrap(where['column']), values)
+ return "%s IN (%s)" % (self.wrap(where["column"]), values)
def _where_not_in(self, query, where):
- if not where['values']:
- return '1 = 1'
+ if not where["values"]:
+ return "1 = 1"
- values = self.parameterize(where['values'])
+ values = self.parameterize(where["values"])
- return '%s NOT IN (%s)' % (self.wrap(where['column']), values)
+ return "%s NOT IN (%s)" % (self.wrap(where["column"]), values)
def _where_in_sub(self, query, where):
- select = self.compile_select(where['query'])
+ select = self.compile_select(where["query"])
- return '%s IN (%s)' % (self.wrap(where['column']), select)
+ return "%s IN (%s)" % (self.wrap(where["column"]), select)
def _where_not_in_sub(self, query, where):
- select = self.compile_select(where['query'])
+ select = self.compile_select(where["query"])
- return '%s NOT IN (%s)' % (self.wrap(where['column']), select)
+ return "%s NOT IN (%s)" % (self.wrap(where["column"]), select)
def _where_null(self, query, where):
- return '%s IS NULL' % self.wrap(where['column'])
+ return "%s IS NULL" % self.wrap(where["column"])
def _where_not_null(self, query, where):
- return '%s IS NOT NULL' % self.wrap(where['column'])
+ return "%s IS NOT NULL" % self.wrap(where["column"])
def _where_date(self, query, where):
- return self._date_based_where('date', query, where)
+ return self._date_based_where("date", query, where)
def _where_day(self, query, where):
- return self._date_based_where('day', query, where)
+ return self._date_based_where("day", query, where)
def _where_month(self, query, where):
- return self._date_based_where('month', query, where)
+ return self._date_based_where("month", query, where)
def _where_year(self, query, where):
- return self._date_based_where('year', query, where)
+ return self._date_based_where("year", query, where)
def _date_based_where(self, type, query, where):
- value = self.parameter(where['value'])
+ value = self.parameter(where["value"])
- return '%s(%s) %s %s' % (type.upper(), self.wrap(where['column']),
- where['operator'], value)
+ return "%s(%s) %s %s" % (
+ type.upper(),
+ self.wrap(where["column"]),
+ where["operator"],
+ value,
+ )
def _where_raw(self, query, where):
- return re.sub('( and | or )',
- lambda m: m.group(1).upper(),
- where['sql'],
- re.I)
+ return re.sub("( and | or )", lambda m: m.group(1).upper(), where["sql"], re.I)
def _compile_groups(self, query, groups):
if not groups:
- return ''
+ return ""
- return 'GROUP BY %s' % self.columnize(groups)
+ return "GROUP BY %s" % self.columnize(groups)
def _compile_havings(self, query, havings):
if not havings:
- return ''
+ return ""
- sql = ' '.join(map(self._compile_having, havings))
+ sql = " ".join(map(self._compile_having, havings))
- return 'HAVING %s' % re.sub('and |or ', '', sql, 1, re.I)
+ return "HAVING %s" % re.sub("and |or ", "", sql, 1, re.I)
def _compile_having(self, having):
# If the having clause is "raw", we can just return the clause straight away
# without doing any more processing on it. Otherwise, we will compile the
# clause into SQL based on the components that make it up from builder.
- if having['type'] == 'raw':
- return '%s %s' % (having['boolean'].upper(), having['sql'])
+ if having["type"] == "raw":
+ return "%s %s" % (having["boolean"].upper(), having["sql"])
return self._compile_basic_having(having)
def _compile_basic_having(self, having):
- column = self.wrap(having['column'])
+ column = self.wrap(having["column"])
- parameter = self.parameter(having['value'])
+ parameter = self.parameter(having["value"])
- return '%s %s %s %s' % (having['boolean'].upper(), column,
- having['operator'], parameter)
+ return "%s %s %s %s" % (
+ having["boolean"].upper(),
+ column,
+ having["operator"],
+ parameter,
+ )
def _compile_orders(self, query, orders):
if not orders:
- return ''
+ return ""
compiled = []
for order in orders:
- if order.get('sql'):
- compiled.append(re.sub('( desc| asc)( |$)',
- lambda m: '%s%s' % (m.group(1).upper(), m.group(2)),
- order['sql'],
- re.I))
+ if order.get("sql"):
+ compiled.append(
+ re.sub(
+ "( desc| asc)( |$)",
+ lambda m: "%s%s" % (m.group(1).upper(), m.group(2)),
+ order["sql"],
+ re.I,
+ )
+ )
else:
- compiled.append('%s %s' % (self.wrap(order['column']),
- order['direction'].upper()))
+ compiled.append(
+ "%s %s" % (self.wrap(order["column"]), order["direction"].upper())
+ )
- return 'ORDER BY %s' % ', '.join(compiled)
+ return "ORDER BY %s" % ", ".join(compiled)
def _compile_limit(self, query, limit):
- return 'LIMIT %s' % int(limit)
+ return "LIMIT %s" % int(limit)
def _compile_offset(self, query, offset):
- return 'OFFSET %s' % int(offset)
+ return "OFFSET %s" % int(offset)
def _compile_unions(self, query, _=None):
- sql = ''
+ sql = ""
for union in query.unions:
sql += self._compile_union(union)
if query.union_orders:
- sql += ' %s' % self._compile_orders(query, query.union_orders)
+ sql += " %s" % self._compile_orders(query, query.union_orders)
if query.union_limit:
- sql += ' %s' % self._compile_limit(query, query.union_limit)
+ sql += " %s" % self._compile_limit(query, query.union_limit)
if query.union_offset:
- sql += ' %s' % self._compile_offset(query, query.union_offset)
+ sql += " %s" % self._compile_offset(query, query.union_offset)
return sql.lstrip()
def _compile_union(self, union):
- if union['all']:
- joiner = ' UNION ALL '
+ if union["all"]:
+ joiner = " UNION ALL "
else:
- joiner = ' UNION '
+ joiner = " UNION "
- return '%s%s' % (joiner, union['query'].to_sql())
+ return "%s%s" % (joiner, union["query"].to_sql())
def compile_insert(self, query, values):
"""
@@ -337,11 +354,11 @@ def compile_insert(self, query, values):
# bindings so we can just go off the first list of values in this array.
parameters = self.parameterize(values[0].values())
- value = ['(%s)' % parameters] * len(values)
+ value = ["(%s)" % parameters] * len(values)
- parameters = ', '.join(value)
+ parameters = ", ".join(value)
- return 'INSERT INTO %s (%s) VALUES %s' % (table, columns, parameters)
+ return "INSERT INTO %s (%s) VALUES %s" % (table, columns, parameters)
def compile_insert_get_id(self, query, values, sequence):
return self.compile_insert(query, values)
@@ -355,24 +372,24 @@ def compile_update(self, query, values):
columns = []
for key, value in values.items():
- columns.append('%s = %s' % (self.wrap(key), self.parameter(value)))
+ columns.append("%s = %s" % (self.wrap(key), self.parameter(value)))
- columns = ', '.join(columns)
+ columns = ", ".join(columns)
# If the query has any "join" clauses, we will setup the joins on the builder
# and compile them so we can attach them to this update, as update queries
# can get join statements to attach to other tables when they're needed.
if query.joins:
- joins = ' %s' % self._compile_joins(query, query.joins)
+ joins = " %s" % self._compile_joins(query, query.joins)
else:
- joins = ''
+ joins = ""
# Of course, update queries may also be constrained by where clauses so we'll
# need to compile the where clauses and attach it to the query so only the
# intended records are updated by the SQL statements we generate to run.
where = self._compile_wheres(query)
- return ('UPDATE %s%s SET %s %s' % (table, joins, columns, where)).strip()
+ return ("UPDATE %s%s SET %s %s" % (table, joins, columns, where)).strip()
def compile_delete(self, query):
table = self.wrap_table(query.from__)
@@ -380,20 +397,18 @@ def compile_delete(self, query):
if isinstance(query.wheres, list):
where = self._compile_wheres(query)
else:
- where = ''
+ where = ""
- return ('DELETE FROM %s %s' % (table, where)).strip()
+ return ("DELETE FROM %s %s" % (table, where)).strip()
def compile_truncate(self, query):
- return {
- 'TRUNCATE %s' % self.wrap_table(query.from__): []
- }
+ return {"TRUNCATE %s" % self.wrap_table(query.from__): []}
def _compile_lock(self, query, value):
if isinstance(value, basestring):
return value
else:
- return ''
+ return ""
def _concatenate(self, segments):
parts = []
@@ -403,9 +418,7 @@ def _concatenate(self, segments):
if value:
parts.append(value)
- return ' '.join(parts)
+ return " ".join(parts)
def _remove_leading_boolean(self, value):
- return re.sub('and | or ', '', value, 1, re.I)
-
-
+ return re.sub("and | or ", "", value, 1, re.I)
diff --git a/orator/query/grammars/mysql_grammar.py b/orator/query/grammars/mysql_grammar.py
index 39139039..909c5993 100644
--- a/orator/query/grammars/mysql_grammar.py
+++ b/orator/query/grammars/mysql_grammar.py
@@ -7,20 +7,20 @@
class MySQLQueryGrammar(QueryGrammar):
_select_components = [
- 'aggregate_',
- 'columns',
- 'from__',
- 'joins',
- 'wheres',
- 'groups',
- 'havings',
- 'orders',
- 'limit_',
- 'offset_',
- 'lock_'
+ "aggregate_",
+ "columns",
+ "from__",
+ "joins",
+ "wheres",
+ "groups",
+ "havings",
+ "orders",
+ "limit_",
+ "offset_",
+ "lock_",
]
- marker = '%s'
+ marker = "%s"
def compile_select(self, query):
"""
@@ -35,7 +35,7 @@ def compile_select(self, query):
sql = super(MySQLQueryGrammar, self).compile_select(query)
if query.unions:
- sql = '(%s) %s' % (sql, self._compile_unions(query))
+ sql = "(%s) %s" % (sql, self._compile_unions(query))
return sql
@@ -49,12 +49,12 @@ def _compile_union(self, union):
:return: The compiled union statement
:rtype: str
"""
- if union['all']:
- joiner = ' UNION ALL '
+ if union["all"]:
+ joiner = " UNION ALL "
else:
- joiner = ' UNION '
+ joiner = " UNION "
- return '%s(%s)' % (joiner, union['query'].to_sql())
+ return "%s(%s)" % (joiner, union["query"].to_sql())
def _compile_lock(self, query, value):
"""
@@ -73,9 +73,9 @@ def _compile_lock(self, query, value):
return value
if value is True:
- return 'FOR UPDATE'
+ return "FOR UPDATE"
elif value is False:
- return 'LOCK IN SHARE MODE'
+ return "LOCK IN SHARE MODE"
def compile_update(self, query, values):
"""
@@ -93,10 +93,10 @@ def compile_update(self, query, values):
sql = super(MySQLQueryGrammar, self).compile_update(query, values)
if query.orders:
- sql += ' %s' % self._compile_orders(query, query.orders)
+ sql += " %s" % self._compile_orders(query, query.orders)
if query.limit_:
- sql += ' %s' % self._compile_limit(query, query.limit_)
+ sql += " %s" % self._compile_limit(query, query.limit_)
return sql.rstrip()
@@ -115,22 +115,22 @@ def compile_delete(self, query):
if isinstance(query.wheres, list):
wheres = self._compile_wheres(query)
else:
- wheres = ''
+ wheres = ""
if query.joins:
- joins = ' %s' % self._compile_joins(query, query.joins)
+ joins = " %s" % self._compile_joins(query, query.joins)
- sql = 'DELETE %s FROM %s%s %s' % (table, table, joins, wheres)
+ sql = "DELETE %s FROM %s%s %s" % (table, table, joins, wheres)
else:
- sql = 'DELETE FROM %s %s' % (table, wheres)
+ sql = "DELETE FROM %s %s" % (table, wheres)
sql = sql.strip()
if query.orders:
- sql += ' %s' % self._compile_orders(query, query.orders)
+ sql += " %s" % self._compile_orders(query, query.orders)
if query.limit_:
- sql += ' %s' % self._compile_limit(query, query.limit_)
+ sql += " %s" % self._compile_limit(query, query.limit_)
return sql
@@ -144,7 +144,7 @@ def _wrap_value(self, value):
:return: The wrapped value
:rtype: str
"""
- if value == '*':
+ if value == "*":
return value
- return '`%s`' % value.replace('`', '``')
+ return "`%s`" % value.replace("`", "``")
diff --git a/orator/query/grammars/postgres_grammar.py b/orator/query/grammars/postgres_grammar.py
index cd33f9d4..90dc8d2a 100644
--- a/orator/query/grammars/postgres_grammar.py
+++ b/orator/query/grammars/postgres_grammar.py
@@ -7,12 +7,25 @@
class PostgresQueryGrammar(QueryGrammar):
_operators = [
- '=', '<', '>', '<=', '>=', '<>', '!=',
- 'like', 'not like', 'between', 'ilike',
- '&', '|', '#', '<<', '>>'
+ "=",
+ "<",
+ ">",
+ "<=",
+ ">=",
+ "<>",
+ "!=",
+ "like",
+ "not like",
+ "between",
+ "ilike",
+ "&",
+ "|",
+ "#",
+ "<<",
+ ">>",
]
- marker = '%s'
+ marker = "%s"
def _compile_lock(self, query, value):
"""
@@ -31,9 +44,9 @@ def _compile_lock(self, query, value):
return value
if value:
- return 'FOR UPDATE'
+ return "FOR UPDATE"
- return 'FOR SHARE'
+ return "FOR SHARE"
def compile_update(self, query, values):
"""
@@ -56,7 +69,7 @@ def compile_update(self, query, values):
where = self._compile_update_wheres(query)
- return ('UPDATE %s SET %s%s %s' % (table, columns, from_, where)).strip()
+ return ("UPDATE %s SET %s%s %s" % (table, columns, from_, where)).strip()
def _compile_update_columns(self, values):
"""
@@ -71,9 +84,9 @@ def _compile_update_columns(self, values):
columns = []
for key, value in values.items():
- columns.append('%s = %s' % (self.wrap(key), self.parameter(value)))
+ columns.append("%s = %s" % (self.wrap(key), self.parameter(value)))
- return ', '.join(columns)
+ return ", ".join(columns)
def _compile_update_from(self, query):
"""
@@ -86,7 +99,7 @@ def _compile_update_from(self, query):
:rtype: str
"""
if not query.joins:
- return ''
+ return ""
froms = []
@@ -94,9 +107,9 @@ def _compile_update_from(self, query):
froms.append(self.wrap_table(join.table))
if len(froms):
- return ' FROM %s' % ', '.join(froms)
+ return " FROM %s" % ", ".join(froms)
- return ''
+ return ""
def _compile_update_wheres(self, query):
"""
@@ -116,9 +129,9 @@ def _compile_update_wheres(self, query):
join_where = self._compile_update_join_wheres(query)
if not base_where.strip():
- return 'WHERE %s' % self._remove_leading_boolean(join_where)
+ return "WHERE %s" % self._remove_leading_boolean(join_where)
- return '%s %s' % (base_where, join_where)
+ return "%s %s" % (base_where, join_where)
def _compile_update_join_wheres(self, query):
"""
@@ -136,7 +149,7 @@ def _compile_update_join_wheres(self, query):
for clause in join.clauses:
join_wheres.append(self._compile_join_constraints(clause))
- return ' '.join(join_wheres)
+ return " ".join(join_wheres)
def compile_insert_get_id(self, query, values, sequence=None):
"""
@@ -155,10 +168,12 @@ def compile_insert_get_id(self, query, values, sequence=None):
:rtype: str
"""
if sequence is None:
- sequence = 'id'
+ sequence = "id"
- return '%s RETURNING %s'\
- % (self.compile_insert(query, values), self.wrap(sequence))
+ return "%s RETURNING %s" % (
+ self.compile_insert(query, values),
+ self.wrap(sequence),
+ )
def compile_truncate(self, query):
"""
@@ -170,6 +185,4 @@ def compile_truncate(self, query):
:return: The compiled statement
:rtype: str
"""
- return {
- 'TRUNCATE %s RESTART IDENTITY' % self.wrap_table(query.from__): {}
- }
+ return {"TRUNCATE %s RESTART IDENTITY" % self.wrap_table(query.from__): {}}
diff --git a/orator/query/grammars/sqlite_grammar.py b/orator/query/grammars/sqlite_grammar.py
index fb9399be..f2bd623d 100644
--- a/orator/query/grammars/sqlite_grammar.py
+++ b/orator/query/grammars/sqlite_grammar.py
@@ -6,9 +6,21 @@
class SQLiteQueryGrammar(QueryGrammar):
_operators = [
- '=', '<', '>', '<=', '>=', '<>', '!=',
- 'like', 'not like', 'between', 'ilike',
- '&', '|', '<<', '>>',
+ "=",
+ "<",
+ ">",
+ "<=",
+ ">=",
+ "<>",
+ "!=",
+ "like",
+ "not like",
+ "between",
+ "ilike",
+ "&",
+ "|",
+ "<<",
+ ">>",
]
def compile_insert(self, query, values):
@@ -41,12 +53,15 @@ def compile_insert(self, query, values):
# unions joining them together. So we'll build out this list of columns and
# then join them all together with select unions to complete the queries.
for column in values[0].keys():
- columns.append('%s AS %s' % (self.get_marker(), self.wrap(column)))
+ columns.append("%s AS %s" % (self.get_marker(), self.wrap(column)))
- columns = [', '.join(columns)] * len(values)
+ columns = [", ".join(columns)] * len(values)
- return 'INSERT INTO %s (%s) SELECT %s'\
- % (table, names, ' UNION ALL SELECT '.join(columns))
+ return "INSERT INTO %s (%s) SELECT %s" % (
+ table,
+ names,
+ " UNION ALL SELECT ".join(columns),
+ )
def compile_truncate(self, query):
"""
@@ -59,13 +74,29 @@ def compile_truncate(self, query):
:rtype: str
"""
sql = {
- 'DELETE FROM sqlite_sequence WHERE name = %s' % self.get_marker(): [query.from__]
+ "DELETE FROM sqlite_sequence WHERE name = %s"
+ % self.get_marker(): [query.from__]
}
- sql['DELETE FROM %s' % self.wrap_table(query.from__)] = []
+ sql["DELETE FROM %s" % self.wrap_table(query.from__)] = []
return sql
+ def _where_date(self, query, where):
+ """
+ Compile a "where date" clause
+
+ :param query: A QueryBuilder instance
+ :type query: QueryBuilder
+
+ :param where: The condition
+ :type where: dict
+
+ :return: The compiled clause
+ :rtype: str
+ """
+ return self._date_based_where("%Y-%m-%d", query, where)
+
def _where_day(self, query, where):
"""
Compile a "where day" clause
@@ -79,7 +110,7 @@ def _where_day(self, query, where):
:return: The compiled clause
:rtype: str
"""
- return self._date_based_where('%d', query, where)
+ return self._date_based_where("%d", query, where)
def _where_month(self, query, where):
"""
@@ -94,7 +125,7 @@ def _where_month(self, query, where):
:return: The compiled clause
:rtype: str
"""
- return self._date_based_where('%m', query, where)
+ return self._date_based_where("%m", query, where)
def _where_year(self, query, where):
"""
@@ -109,7 +140,7 @@ def _where_year(self, query, where):
:return: The compiled clause
:rtype: str
"""
- return self._date_based_where('%Y', query, where)
+ return self._date_based_where("%Y", query, where)
def _date_based_where(self, type, query, where):
"""
@@ -127,9 +158,12 @@ def _date_based_where(self, type, query, where):
:return: The compiled clause
:rtype: str
"""
- value = str(where['value']).zfill(2)
+ value = str(where["value"]).zfill(2)
value = self.parameter(value)
- return 'strftime(\'%s\', %s) %s %s'\
- % (type, self.wrap(where['column']),
- where['operator'], value)
+ return "strftime('%s', %s) %s %s" % (
+ type,
+ self.wrap(where["column"]),
+ where["operator"],
+ value,
+ )
diff --git a/orator/query/join_clause.py b/orator/query/join_clause.py
index adccc305..f3c2be0a 100644
--- a/orator/query/join_clause.py
+++ b/orator/query/join_clause.py
@@ -4,22 +4,23 @@
class JoinClause(object):
-
- def __init__(self, table, type='inner'):
+ def __init__(self, table, type="inner"):
self.type = type
self.table = table
self.clauses = []
self.bindings = []
- def on(self, first, operator, second, boolean='and', where=False):
- self.clauses.append({
- 'first': first,
- 'operator': operator,
- 'second': second,
- 'boolean': boolean,
- 'where': where
- })
+ def on(self, first, operator, second, boolean="and", where=False):
+ self.clauses.append(
+ {
+ "first": first,
+ "operator": operator,
+ "second": second,
+ "boolean": boolean,
+ "where": where,
+ }
+ )
if where:
self.bindings.append(second)
@@ -27,22 +28,22 @@ def on(self, first, operator, second, boolean='and', where=False):
return self
def or_on(self, first, operator, second):
- return self.on(first, operator, second, 'or')
+ return self.on(first, operator, second, "or")
- def where(self, first, operator, second, boolean='and'):
+ def where(self, first, operator, second, boolean="and"):
return self.on(first, operator, second, boolean, True)
def or_where(self, first, operator, second):
- return self.where(first, operator, second, 'or')
+ return self.where(first, operator, second, "or")
- def where_null(self, column, boolean='and'):
- return self.on(column, 'IS', QueryExpression('NULL'), boolean, False)
+ def where_null(self, column, boolean="and"):
+ return self.on(column, "IS", QueryExpression("NULL"), boolean, False)
def or_where_null(self, column):
- return self.where_null(column, 'or')
+ return self.where_null(column, "or")
- def where_not_null(self, column, boolean='and'):
- return self.on(column, 'IS', QueryExpression('NOT NULL'), boolean, False)
+ def where_not_null(self, column, boolean="and"):
+ return self.on(column, "IS", QueryExpression("NOT NULL"), boolean, False)
def or_where_not_null(self, column):
- return self.where_not_null(column, 'or')
+ return self.where_not_null(column, "or")
diff --git a/orator/query/processors/mysql_processor.py b/orator/query/processors/mysql_processor.py
index 289df386..3a7d5596 100644
--- a/orator/query/processors/mysql_processor.py
+++ b/orator/query/processors/mysql_processor.py
@@ -4,7 +4,6 @@
class MySQLQueryProcessor(QueryProcessor):
-
def process_insert_get_id(self, query, sql, values, sequence=None):
"""
Process an "insert get ID" query.
@@ -29,18 +28,18 @@ def process_insert_get_id(self, query, sql, values, sequence=None):
query.get_connection().insert(sql, values)
cursor = query.get_connection().get_cursor()
- if hasattr(cursor, 'lastrowid'):
+ if hasattr(cursor, "lastrowid"):
id = cursor.lastrowid
else:
- id = query.get_connection().statement('SELECT LAST_INSERT_ID()')
+ id = query.get_connection().statement("SELECT LAST_INSERT_ID()")
else:
query.get_connection().insert(sql, values)
cursor = query.get_connection().get_cursor()
- if hasattr(cursor, 'lastrowid'):
+ if hasattr(cursor, "lastrowid"):
id = cursor.lastrowid
else:
- id = query.get_connection().statement('SELECT LAST_INSERT_ID()')
+ id = query.get_connection().statement("SELECT LAST_INSERT_ID()")
if isinstance(id, int):
return id
@@ -60,4 +59,4 @@ def process_column_listing(self, results):
:return: The processed results
:return: list
"""
- return list(map(lambda x: x['column_name'], results))
+ return list(map(lambda x: x["column_name"], results))
diff --git a/orator/query/processors/postgres_processor.py b/orator/query/processors/postgres_processor.py
index 2b60d340..a5b88bb1 100644
--- a/orator/query/processors/postgres_processor.py
+++ b/orator/query/processors/postgres_processor.py
@@ -4,7 +4,6 @@
class PostgresQueryProcessor(QueryProcessor):
-
def process_insert_get_id(self, query, sql, values, sequence=None):
"""
Process an "insert get ID" query.
@@ -46,4 +45,4 @@ def process_column_listing(self, results):
:return: The processed results
:return: list
"""
- return list(map(lambda x: x['column_name'], results))
+ return list(map(lambda x: x["column_name"], results))
diff --git a/orator/query/processors/processor.py b/orator/query/processors/processor.py
index 08812dd8..61909456 100644
--- a/orator/query/processors/processor.py
+++ b/orator/query/processors/processor.py
@@ -2,7 +2,6 @@
class QueryProcessor(object):
-
def process_select(self, query, results):
"""
Process the results of a "select" query
diff --git a/orator/query/processors/sqlite_processor.py b/orator/query/processors/sqlite_processor.py
index 77cbf2a8..55e6b945 100644
--- a/orator/query/processors/sqlite_processor.py
+++ b/orator/query/processors/sqlite_processor.py
@@ -4,7 +4,6 @@
class SQLiteQueryProcessor(QueryProcessor):
-
def process_column_listing(self, results):
"""
Process the results of a column listing query
@@ -15,4 +14,4 @@ def process_column_listing(self, results):
:return: The processed results
:return: list
"""
- return list(map(lambda x: x['name'], results))
+ return list(map(lambda x: x["name"], results))
diff --git a/orator/schema/blueprint.py b/orator/schema/blueprint.py
index ea69a704..d32aadd6 100644
--- a/orator/schema/blueprint.py
+++ b/orator/schema/blueprint.py
@@ -4,7 +4,6 @@
class Blueprint(object):
-
def __init__(self, table):
"""
:param table: The table to operate on
@@ -47,7 +46,7 @@ def to_sql(self, connection, grammar):
statements = []
for command in self._commands:
- method = 'compile_%s' % command.name
+ method = "compile_%s" % command.name
if hasattr(grammar, method):
sql = getattr(grammar, method)(self, command, connection)
@@ -64,10 +63,10 @@ def _add_implied_commands(self):
Add the commands that are implied by the blueprint.
"""
if len(self.get_added_columns()) and not self._creating():
- self._commands.insert(0, self._create_command('add'))
+ self._commands.insert(0, self._create_command("add"))
if len(self.get_changed_columns()) and not self._creating():
- self._commands.insert(0, self._create_command('change'))
+ self._commands.insert(0, self._create_command("change"))
return self._add_fluent_indexes()
@@ -76,7 +75,7 @@ def _add_fluent_indexes(self):
Add the index commands fluently specified on columns:
"""
for column in self._columns:
- for index in ['primary', 'unique', 'index']:
+ for index in ["primary", "unique", "index"]:
column_index = column.get(index)
if column_index is True:
@@ -95,7 +94,7 @@ def _creating(self):
:rtype: bool
"""
for command in self._commands:
- if command.name == 'create':
+ if command.name == "create":
return True
return False
@@ -106,7 +105,7 @@ def create(self):
:rtype: Fluent
"""
- return self._add_command('create')
+ return self._add_command("create")
def drop(self):
"""
@@ -114,7 +113,7 @@ def drop(self):
:rtype: Fluent
"""
- self._add_command('drop')
+ self._add_command("drop")
return self
@@ -124,7 +123,7 @@ def drop_if_exists(self):
:rtype: Fluent
"""
- return self._add_command('drop_if_exists')
+ return self._add_command("drop_if_exists")
def drop_column(self, *columns):
"""
@@ -137,7 +136,7 @@ def drop_column(self, *columns):
"""
columns = list(columns)
- return self._add_command('drop_column', columns=columns)
+ return self._add_command("drop_column", columns=columns)
def rename_column(self, from_, to):
"""
@@ -150,7 +149,7 @@ def rename_column(self, from_, to):
:rtype: Fluent
"""
- return self._add_command('rename_column', **{'from_': from_, 'to': to})
+ return self._add_command("rename_column", **{"from_": from_, "to": to})
def drop_primary(self, index=None):
"""
@@ -161,7 +160,7 @@ def drop_primary(self, index=None):
:rtype: dict
"""
- return self._drop_index_command('drop_primary', 'primary', index)
+ return self._drop_index_command("drop_primary", "primary", index)
def drop_unique(self, index):
"""
@@ -172,7 +171,7 @@ def drop_unique(self, index):
:rtype: Fluent
"""
- return self._drop_index_command('drop_unique', 'unique', index)
+ return self._drop_index_command("drop_unique", "unique", index)
def drop_index(self, index):
"""
@@ -183,7 +182,7 @@ def drop_index(self, index):
:rtype: Fluent
"""
- return self._drop_index_command('drop_index', 'index', index)
+ return self._drop_index_command("drop_index", "index", index)
def drop_foreign(self, index):
"""
@@ -194,7 +193,7 @@ def drop_foreign(self, index):
:rtype: dict
"""
- return self._drop_index_command('drop_foreign', 'foreign', index)
+ return self._drop_index_command("drop_foreign", "foreign", index)
def drop_timestamps(self):
"""
@@ -202,7 +201,7 @@ def drop_timestamps(self):
:rtype: Fluent
"""
- return self.drop_column('created_at', 'updated_at')
+ return self.drop_column("created_at", "updated_at")
def drop_soft_deletes(self):
"""
@@ -210,7 +209,7 @@ def drop_soft_deletes(self):
:rtype: Fluent
"""
- return self.drop_column('deleted_at')
+ return self.drop_column("deleted_at")
def rename(self, to):
"""
@@ -221,7 +220,7 @@ def rename(self, to):
:rtype: Fluent
"""
- return self._add_command('rename', to=to)
+ return self._add_command("rename", to=to)
def primary(self, columns, name=None):
"""
@@ -235,7 +234,7 @@ def primary(self, columns, name=None):
:rtype: Fluent
"""
- return self._index_command('primary', columns, name)
+ return self._index_command("primary", columns, name)
def unique(self, columns, name=None):
"""
@@ -249,7 +248,7 @@ def unique(self, columns, name=None):
:rtype: Fluent
"""
- return self._index_command('unique', columns, name)
+ return self._index_command("unique", columns, name)
def index(self, columns, name=None):
"""
@@ -263,7 +262,7 @@ def index(self, columns, name=None):
:rtype: Fluent
"""
- return self._index_command('index', columns, name)
+ return self._index_command("index", columns, name)
def foreign(self, columns, name=None):
"""
@@ -277,7 +276,7 @@ def foreign(self, columns, name=None):
:rtype: Fluent
"""
- return self._index_command('foreign', columns, name)
+ return self._index_command("foreign", columns, name)
def increments(self, column):
"""
@@ -310,7 +309,7 @@ def char(self, column, length=255):
:rtype: Fluent
"""
- return self._add_column('char', column, length=length)
+ return self._add_column("char", column, length=length)
def string(self, column, length=255):
"""
@@ -321,7 +320,7 @@ def string(self, column, length=255):
:rtype: Fluent
"""
- return self._add_column('string', column, length=length)
+ return self._add_column("string", column, length=length)
def text(self, column):
"""
@@ -332,7 +331,7 @@ def text(self, column):
:rtype: Fluent
"""
- return self._add_column('text', column)
+ return self._add_column("text", column)
def medium_text(self, column):
"""
@@ -343,7 +342,7 @@ def medium_text(self, column):
:rtype: Fluent
"""
- return self._add_column('medium_text', column)
+ return self._add_column("medium_text", column)
def long_text(self, column):
"""
@@ -354,7 +353,7 @@ def long_text(self, column):
:rtype: Fluent
"""
- return self._add_column('long_text', column)
+ return self._add_column("long_text", column)
def integer(self, column, auto_increment=False, unsigned=False):
"""
@@ -369,9 +368,9 @@ def integer(self, column, auto_increment=False, unsigned=False):
:rtype: Fluent
"""
- return self._add_column('integer', column,
- auto_increment=auto_increment,
- unsigned=unsigned)
+ return self._add_column(
+ "integer", column, auto_increment=auto_increment, unsigned=unsigned
+ )
def big_integer(self, column, auto_increment=False, unsigned=False):
"""
@@ -386,9 +385,9 @@ def big_integer(self, column, auto_increment=False, unsigned=False):
:rtype: Fluent
"""
- return self._add_column('big_integer', column,
- auto_increment=auto_increment,
- unsigned=unsigned)
+ return self._add_column(
+ "big_integer", column, auto_increment=auto_increment, unsigned=unsigned
+ )
def medium_integer(self, column, auto_increment=False, unsigned=False):
"""
@@ -403,9 +402,9 @@ def medium_integer(self, column, auto_increment=False, unsigned=False):
:rtype: Fluent
"""
- return self._add_column('medium_integer', column,
- auto_increment=auto_increment,
- unsigned=unsigned)
+ return self._add_column(
+ "medium_integer", column, auto_increment=auto_increment, unsigned=unsigned
+ )
def tiny_integer(self, column, auto_increment=False, unsigned=False):
"""
@@ -420,9 +419,9 @@ def tiny_integer(self, column, auto_increment=False, unsigned=False):
:rtype: Fluent
"""
- return self._add_column('tiny_integer', column,
- auto_increment=auto_increment,
- unsigned=unsigned)
+ return self._add_column(
+ "tiny_integer", column, auto_increment=auto_increment, unsigned=unsigned
+ )
def small_integer(self, column, auto_increment=False, unsigned=False):
"""
@@ -437,13 +436,13 @@ def small_integer(self, column, auto_increment=False, unsigned=False):
:rtype: Fluent
"""
- return self._add_column('small_integer', column,
- auto_increment=auto_increment,
- unsigned=unsigned)
+ return self._add_column(
+ "small_integer", column, auto_increment=auto_increment, unsigned=unsigned
+ )
def unsigned_integer(self, column, auto_increment=False):
"""
- Create a new unisgned integer column on the table.
+ Create a new unsigned integer column on the table.
:param column: The column
:type column: str
@@ -480,7 +479,7 @@ def float(self, column, total=8, places=2):
:rtype: Fluent
"""
- return self._add_column('float', column, total=total, places=places)
+ return self._add_column("float", column, total=total, places=places)
def double(self, column, total=None, places=None):
"""
@@ -495,7 +494,7 @@ def double(self, column, total=None, places=None):
:rtype: Fluent
"""
- return self._add_column('double', column, total=total, places=places)
+ return self._add_column("double", column, total=total, places=places)
def decimal(self, column, total=8, places=2):
"""
@@ -510,7 +509,7 @@ def decimal(self, column, total=8, places=2):
:rtype: Fluent
"""
- return self._add_column('decimal', column, total=total, places=places)
+ return self._add_column("decimal", column, total=total, places=places)
def boolean(self, column):
"""
@@ -521,8 +520,8 @@ def boolean(self, column):
:rtype: Fluent
"""
- return self._add_column('boolean', column)
-
+ return self._add_column("boolean", column)
+
def enum(self, column, allowed):
"""
Create a new enum column on the table.
@@ -534,7 +533,7 @@ def enum(self, column, allowed):
:rtype: Fluent
"""
- return self._add_column('enum', column, allowed=allowed)
+ return self._add_column("enum", column, allowed=allowed)
def json(self, column):
"""
@@ -545,7 +544,7 @@ def json(self, column):
:rtype: Fluent
"""
- return self._add_column('json', column)
+ return self._add_column("json", column)
def date(self, column):
"""
@@ -556,7 +555,7 @@ def date(self, column):
:rtype: Fluent
"""
- return self._add_column('date', column)
+ return self._add_column("date", column)
def datetime(self, column):
"""
@@ -567,7 +566,7 @@ def datetime(self, column):
:rtype: Fluent
"""
- return self._add_column('datetime', column)
+ return self._add_column("datetime", column)
def time(self, column):
"""
@@ -578,7 +577,7 @@ def time(self, column):
:rtype: Fluent
"""
- return self._add_column('time', column)
+ return self._add_column("time", column)
def timestamp(self, column):
"""
@@ -589,7 +588,7 @@ def timestamp(self, column):
:rtype: Fluent
"""
- return self._add_column('timestamp', column)
+ return self._add_column("timestamp", column)
def nullable_timestamps(self):
"""
@@ -597,8 +596,8 @@ def nullable_timestamps(self):
:rtype: Fluent
"""
- self.timestamp('created_at').nullable()
- self.timestamp('updated_at').nullable()
+ self.timestamp("created_at").nullable()
+ self.timestamp("updated_at").nullable()
def timestamps(self, use_current=True):
"""
@@ -607,11 +606,11 @@ def timestamps(self, use_current=True):
:rtype: Fluent
"""
if use_current:
- self.timestamp('created_at').use_current()
- self.timestamp('updated_at').use_current()
+ self.timestamp("created_at").use_current()
+ self.timestamp("updated_at").use_current()
else:
- self.timestamp('created_at')
- self.timestamp('updated_at')
+ self.timestamp("created_at")
+ self.timestamp("updated_at")
def soft_deletes(self):
"""
@@ -619,7 +618,7 @@ def soft_deletes(self):
:rtype: Fluent
"""
- return self.timestamp('deleted_at').nullable()
+ return self.timestamp("deleted_at").nullable()
def binary(self, column):
"""
@@ -630,7 +629,7 @@ def binary(self, column):
:rtype: Fluent
"""
- return self._add_column('binary', column)
+ return self._add_column("binary", column)
def morphs(self, name, index_name=None):
"""
@@ -640,9 +639,9 @@ def morphs(self, name, index_name=None):
:type index_name: str
"""
- self.unsigned_integer('%s_id' % name)
- self.string('%s_type' % name)
- self.index(['%s_id' % name, '%s_type' % name], index_name)
+ self.unsigned_integer("%s_id" % name)
+ self.string("%s_type" % name)
+ self.index(["%s_id" % name, "%s_type" % name], index_name)
def _drop_index_command(self, command, type, index):
"""
@@ -695,9 +694,13 @@ def _create_index_name(self, type, columns):
if not isinstance(columns, list):
columns = [columns]
- index = '%s_%s_%s' % (self._table, '_'.join([str(column) for column in columns]), type)
+ index = "%s_%s_%s" % (
+ self._table,
+ "_".join([str(column) for column in columns]),
+ type,
+ )
- return index.lower().replace('-', '_').replace('.', '_')
+ return index.lower().replace("-", "_").replace(".", "_")
def _add_column(self, type, name, **parameters):
"""
@@ -714,10 +717,7 @@ def _add_column(self, type, name, **parameters):
:rtype: Fluent
"""
- parameters.update({
- 'type': type,
- 'name': name
- })
+ parameters.update({"type": type, "name": name})
column = Fluent(**parameters)
self._columns.append(column)
@@ -766,7 +766,7 @@ def _create_command(self, name, **parameters):
:rtype: Fluent
"""
- parameters.update({'name': name})
+ parameters.update({"name": name})
return Fluent(**parameters)
@@ -780,7 +780,7 @@ def get_commands(self):
return self._commands
def get_added_columns(self):
- return list(filter(lambda column: not column.get('change'), self._columns))
+ return list(filter(lambda column: not column.get("change"), self._columns))
def get_changed_columns(self):
- return list(filter(lambda column: column.get('change'), self._columns))
+ return list(filter(lambda column: column.get("change"), self._columns))
diff --git a/orator/schema/builder.py b/orator/schema/builder.py
index 4935b3ef..9e4fb466 100644
--- a/orator/schema/builder.py
+++ b/orator/schema/builder.py
@@ -5,7 +5,6 @@
class SchemaBuilder(object):
-
def __init__(self, connection):
"""
:param connection: The schema connection
diff --git a/orator/schema/grammars/grammar.py b/orator/schema/grammars/grammar.py
index d9d13ca5..71c759fd 100644
--- a/orator/schema/grammars/grammar.py
+++ b/orator/schema/grammars/grammar.py
@@ -10,7 +10,6 @@
class SchemaGrammar(Grammar):
-
def __init__(self, connection):
super(SchemaGrammar, self).__init__(marker=connection.get_marker())
@@ -93,19 +92,21 @@ def compile_foreign(self, blueprint, command, _):
columns = self.columnize(command.columns)
- on_columns = self.columnize(command.references
- if isinstance(command.references, list)
- else [command.references])
+ on_columns = self.columnize(
+ command.references
+ if isinstance(command.references, list)
+ else [command.references]
+ )
- sql = 'ALTER TABLE %s ADD CONSTRAINT %s ' % (table, command.index)
+ sql = "ALTER TABLE %s ADD CONSTRAINT %s " % (table, command.index)
- sql += 'FOREIGN KEY (%s) REFERENCES %s (%s)' % (columns, on, on_columns)
+ sql += "FOREIGN KEY (%s) REFERENCES %s (%s)" % (columns, on, on_columns)
- if command.get('on_delete'):
- sql += ' ON DELETE %s' % command.on_delete
+ if command.get("on_delete"):
+ sql += " ON DELETE %s" % command.on_delete
- if command.get('on_update'):
- sql += ' ON UPDATE %s' % command.on_update
+ if command.get("on_update"):
+ sql += " ON UPDATE %s" % command.on_update
return sql
@@ -121,7 +122,7 @@ def _get_columns(self, blueprint):
columns = []
for column in blueprint.get_added_columns():
- sql = self.wrap(column) + ' ' + self._get_type(column)
+ sql = self.wrap(column) + " " + self._get_type(column)
columns.append(self._add_modifiers(sql, blueprint, column))
@@ -132,7 +133,7 @@ def _add_modifiers(self, sql, blueprint, column):
Add the column modifiers to the deifinition
"""
for modifier in self._modifiers:
- method = '_modify_%s' % modifier
+ method = "_modify_%s" % modifier
if hasattr(self, method):
sql += getattr(self, method)(blueprint, column)
@@ -163,13 +164,13 @@ def _get_type(self, column):
:rtype sql
"""
- return getattr(self, '_type_%s' % column.type)(column)
+ return getattr(self, "_type_%s" % column.type)(column)
def prefix_list(self, prefix, values):
"""
Add a prefix to a list of values.
"""
- return list(map(lambda value: prefix + ' ' + value, values))
+ return list(map(lambda value: prefix + " " + value, values))
def wrap_table(self, table):
if isinstance(table, Blueprint):
@@ -245,9 +246,13 @@ def _get_changed_diff(self, blueprint, schema):
:rtype: orator.dbal.TableDiff
"""
- table = schema.list_table_details(self.get_table_prefix() + blueprint.get_table())
+ table = schema.list_table_details(
+ self.get_table_prefix() + blueprint.get_table()
+ )
- return Comparator().diff_table(table, self._get_table_with_column_changes(blueprint, table))
+ return Comparator().diff_table(
+ table, self._get_table_with_column_changes(blueprint, table)
+ )
def _get_table_with_column_changes(self, blueprint, table):
"""
@@ -269,7 +274,7 @@ def _get_table_with_column_changes(self, blueprint, table):
option = self._map_fluent_option(key)
if option is not None:
- method = 'set_%s' % option
+ method = "set_%s" % option
if hasattr(column, method):
getattr(column, method)(self._map_fluent_value(option, value))
@@ -293,13 +298,13 @@ def _get_column_change_options(self, fluent):
Get the column change options.
"""
options = {
- 'name': fluent.name,
- 'type': self._get_dbal_column_type(fluent.type),
- 'default': fluent.get('default')
+ "name": fluent.name,
+ "type": self._get_dbal_column_type(fluent.type),
+ "default": fluent.get("default"),
}
- if fluent.type in ['string']:
- options['length'] = fluent.length
+ if fluent.type in ["string"]:
+ options["length"] = fluent.length
return options
@@ -314,29 +319,29 @@ def _get_dbal_column_type(self, type_):
"""
type_ = type_.lower()
- if type_ == 'big_integer':
- type_ = 'bigint'
- elif type == 'small_integer':
- type_ = 'smallint'
- elif type_ in ['medium_text', 'long_text']:
- type_ = 'text'
+ if type_ == "big_integer":
+ type_ = "bigint"
+ elif type == "small_integer":
+ type_ = "smallint"
+ elif type_ in ["medium_text", "long_text"]:
+ type_ = "text"
return type_
def _map_fluent_option(self, attribute):
- if attribute in ['type', 'name']:
+ if attribute in ["type", "name"]:
return
- elif attribute == 'nullable':
- return 'notnull'
- elif attribute == 'total':
- return 'precision'
- elif attribute == 'places':
- return 'scale'
+ elif attribute == "nullable":
+ return "notnull"
+ elif attribute == "total":
+ return "precision"
+ elif attribute == "places":
+ return "scale"
else:
return
def _map_fluent_value(self, option, value):
- if option == 'notnull':
+ if option == "notnull":
return not value
return value
diff --git a/orator/schema/grammars/mysql_grammar.py b/orator/schema/grammars/mysql_grammar.py
index da21286a..af00ee38 100644
--- a/orator/schema/grammars/mysql_grammar.py
+++ b/orator/schema/grammars/mysql_grammar.py
@@ -9,14 +9,25 @@
class MySQLSchemaGrammar(SchemaGrammar):
_modifiers = [
- 'unsigned', 'charset', 'collate', 'nullable',
- 'default', 'increment', 'comment', 'after'
+ "unsigned",
+ "charset",
+ "collate",
+ "nullable",
+ "default",
+ "increment",
+ "comment",
+ "after",
]
- _serials = ['big_integer', 'integer',
- 'medium_integer', 'small_integer', 'tiny_integer']
+ _serials = [
+ "big_integer",
+ "integer",
+ "medium_integer",
+ "small_integer",
+ "tiny_integer",
+ ]
- marker = '%s'
+ marker = "%s"
def compile_table_exists(self):
"""
@@ -24,32 +35,36 @@ def compile_table_exists(self):
:rtype: str
"""
- return 'SELECT * ' \
- 'FROM information_schema.tables ' \
- 'WHERE table_schema = %(marker)s ' \
- 'AND table_name = %(marker)s' % {'marker': self.get_marker()}
+ return (
+ "SELECT * "
+ "FROM information_schema.tables "
+ "WHERE table_schema = %(marker)s "
+ "AND table_name = %(marker)s" % {"marker": self.get_marker()}
+ )
def compile_column_exists(self):
"""
Compile the query to determine the list of columns.
"""
- return 'SELECT column_name ' \
- 'FROM information_schema.columns ' \
- 'WHERE table_schema = %(marker)s AND table_name = %(marker)s' \
- % {'marker': self.get_marker()}
+ return (
+ "SELECT column_name "
+ "FROM information_schema.columns "
+ "WHERE table_schema = %(marker)s AND table_name = %(marker)s"
+ % {"marker": self.get_marker()}
+ )
def compile_create(self, blueprint, command, connection):
"""
Compile a create table command.
"""
- columns = ', '.join(self._get_columns(blueprint))
+ columns = ", ".join(self._get_columns(blueprint))
- sql = 'CREATE TABLE %s (%s)' % (self.wrap_table(blueprint), columns)
+ sql = "CREATE TABLE %s (%s)" % (self.wrap_table(blueprint), columns)
sql = self._compile_create_encoding(sql, connection, blueprint)
if blueprint.engine:
- sql += ' ENGINE = %s' % blueprint.engine
+ sql += " ENGINE = %s" % blueprint.engine
return sql
@@ -63,77 +78,76 @@ def _compile_create_encoding(self, sql, connection, blueprint):
:rtype: str
"""
- charset = blueprint.charset or connection.get_config('charset')
+ charset = blueprint.charset or connection.get_config("charset")
if charset:
- sql += ' DEFAULT CHARACTER SET %s' % charset
+ sql += " DEFAULT CHARACTER SET %s" % charset
- collation = blueprint.collation or connection.get_config('collation')
+ collation = blueprint.collation or connection.get_config("collation")
if collation:
- sql += ' COLLATE %s' % collation
+ sql += " COLLATE %s" % collation
return sql
def compile_add(self, blueprint, command, _):
table = self.wrap_table(blueprint)
- columns = self.prefix_list('ADD', self._get_columns(blueprint))
+ columns = self.prefix_list("ADD", self._get_columns(blueprint))
- return 'ALTER TABLE %s %s' % (table, ', '.join(columns))
+ return "ALTER TABLE %s %s" % (table, ", ".join(columns))
def compile_primary(self, blueprint, command, _):
command.name = None
- return self._compile_key(blueprint, command, 'PRIMARY KEY')
+ return self._compile_key(blueprint, command, "PRIMARY KEY")
def compile_unique(self, blueprint, command, _):
- return self._compile_key(blueprint, command, 'UNIQUE')
+ return self._compile_key(blueprint, command, "UNIQUE")
def compile_index(self, blueprint, command, _):
- return self._compile_key(blueprint, command, 'INDEX')
+ return self._compile_key(blueprint, command, "INDEX")
def _compile_key(self, blueprint, command, type):
columns = self.columnize(command.columns)
table = self.wrap_table(blueprint)
- return 'ALTER TABLE %s ADD %s %s(%s)' % (table, type, command.index, columns)
+ return "ALTER TABLE %s ADD %s %s(%s)" % (table, type, command.index, columns)
def compile_drop(self, blueprint, command, _):
- return 'DROP TABLE %s' % self.wrap_table(blueprint)
+ return "DROP TABLE %s" % self.wrap_table(blueprint)
def compile_drop_if_exists(self, blueprint, command, _):
- return 'DROP TABLE IF EXISTS %s' % self.wrap_table(blueprint)
+ return "DROP TABLE IF EXISTS %s" % self.wrap_table(blueprint)
def compile_drop_column(self, blueprint, command, connection):
- columns = self.prefix_list('DROP', self.wrap_list(command.columns))
+ columns = self.prefix_list("DROP", self.wrap_list(command.columns))
table = self.wrap_table(blueprint)
- return 'ALTER TABLE %s %s' % (table, ', '.join(columns))
+ return "ALTER TABLE %s %s" % (table, ", ".join(columns))
def compile_drop_primary(self, blueprint, command, _):
- return 'ALTER TABLE %s DROP PRIMARY KEY'\
- % self.wrap_table(blueprint)
+ return "ALTER TABLE %s DROP PRIMARY KEY" % self.wrap_table(blueprint)
def compile_drop_unique(self, blueprint, command, _):
table = self.wrap_table(blueprint)
- return 'ALTER TABLE %s DROP INDEX %s' % (table, command.index)
+ return "ALTER TABLE %s DROP INDEX %s" % (table, command.index)
def compile_drop_index(self, blueprint, command, _):
table = self.wrap_table(blueprint)
- return 'ALTER TABLE %s DROP INDEX %s' % (table, command.index)
+ return "ALTER TABLE %s DROP INDEX %s" % (table, command.index)
def compile_drop_foreign(self, blueprint, command, _):
table = self.wrap_table(blueprint)
- return 'ALTER TABLE %s DROP FOREIGN KEY %s' % (table, command.index)
+ return "ALTER TABLE %s DROP FOREIGN KEY %s" % (table, command.index)
def compile_rename(self, blueprint, command, _):
from_ = self.wrap_table(blueprint)
- return 'RENAME TABLE %s TO %s' % (from_, self.wrap_table(command.to))
+ return "RENAME TABLE %s TO %s" % (from_, self.wrap_table(command.to))
def _type_char(self, column):
return "CHAR(%s)" % column.length
@@ -142,121 +156,131 @@ def _type_string(self, column):
return "VARCHAR(%s)" % column.length
def _type_text(self, column):
- return 'TEXT'
+ return "TEXT"
def _type_medium_text(self, column):
- return 'MEDIUMTEXT'
+ return "MEDIUMTEXT"
def _type_long_text(self, column):
- return 'LONGTEXT'
+ return "LONGTEXT"
def _type_integer(self, column):
- return 'INT'
+ return "INT"
def _type_big_integer(self, column):
- return 'BIGINT'
+ return "BIGINT"
def _type_medium_integer(self, column):
- return 'MEDIUMINT'
+ return "MEDIUMINT"
def _type_tiny_integer(self, column):
- return 'TINYINT'
+ return "TINYINT"
def _type_small_integer(self, column):
- return 'SMALLINT'
+ return "SMALLINT"
def _type_float(self, column):
return self._type_double(column)
def _type_double(self, column):
if column.total and column.places:
- return 'DOUBLE(%s, %s)' % (column.total, column.places)
+ return "DOUBLE(%s, %s)" % (column.total, column.places)
- return 'DOUBLE'
+ return "DOUBLE"
def _type_decimal(self, column):
- return 'DECIMAL(%s, %s)' % (column.total, column.places)
+ return "DECIMAL(%s, %s)" % (column.total, column.places)
def _type_boolean(self, column):
- return 'TINYINT(1)'
+ return "TINYINT(1)"
def _type_enum(self, column):
- return 'ENUM(\'%s\')' % '\', \''.join(column.allowed)
+ return "ENUM('%s')" % "', '".join(column.allowed)
def _type_json(self, column):
if self.platform().has_native_json_type():
- return 'JSON'
+ return "JSON"
- return 'TEXT'
+ return "TEXT"
def _type_date(self, column):
- return 'DATE'
+ return "DATE"
def _type_datetime(self, column):
- return 'DATETIME'
+ return "DATETIME"
def _type_time(self, column):
- return 'TIME'
+ return "TIME"
def _type_timestamp(self, column):
- if column.use_current:
- if self.platform_version() >= (5, 6):
- return 'TIMESTAMP DEFAULT CURRENT_TIMESTAMP'
+ platform_version = self.platform_version(3)
+ column_type = "TIMESTAMP"
+
+ if platform_version >= (5, 6, 0):
+ if platform_version >= (5, 6, 4):
+ # Versions 5.6.4+ support fractional seconds
+ column_type = "TIMESTAMP(6)"
+ current = "CURRENT_TIMESTAMP(6)"
else:
- return 'TIMESTAMP DEFAULT 0'
+ current = "CURRENT_TIMESTAMP"
+ else:
+ current = "0"
+
+ if column.use_current:
+ return "{} DEFAULT {}".format(column_type, current)
- return 'TIMESTAMP'
+ return column_type
def _type_binary(self, column):
- return 'BLOB'
+ return "BLOB"
def _modify_unsigned(self, blueprint, column):
- if column.get('unsigned', False):
- return ' UNSIGNED'
+ if column.get("unsigned", False):
+ return " UNSIGNED"
- return ''
+ return ""
def _modify_charset(self, blueprint, column):
- if column.get('charset'):
- return ' CHARACTER SET ' + column.charset
+ if column.get("charset"):
+ return " CHARACTER SET " + column.charset
- return ''
+ return ""
def _modify_collate(self, blueprint, column):
- if column.get('collation'):
- return ' COLLATE ' + column.collation
+ if column.get("collation"):
+ return " COLLATE " + column.collation
- return ''
+ return ""
def _modify_nullable(self, blueprint, column):
- if column.get('nullable'):
- return ' NULL'
+ if column.get("nullable"):
+ return " NULL"
- return ' NOT NULL'
+ return " NOT NULL"
def _modify_default(self, blueprint, column):
- if column.get('default') is not None:
- return ' DEFAULT %s' % self._get_default_value(column.default)
+ if column.get("default") is not None:
+ return " DEFAULT %s" % self._get_default_value(column.default)
- return ''
+ return ""
def _modify_increment(self, blueprint, column):
if column.type in self._serials and column.auto_increment:
- return ' AUTO_INCREMENT PRIMARY KEY'
+ return " AUTO_INCREMENT PRIMARY KEY"
- return ''
+ return ""
def _modify_after(self, blueprint, column):
- if column.get('after') is not None:
- return ' AFTER ' + self.wrap(column.after)
+ if column.get("after") is not None:
+ return " AFTER " + self.wrap(column.after)
- return ''
+ return ""
def _modify_comment(self, blueprint, column):
- if column.get('comment') is not None:
+ if column.get("comment") is not None:
return ' COMMENT "%s"' % column.comment
- return ''
+ return ""
def _get_column_change_options(self, fluent):
"""
@@ -264,15 +288,15 @@ def _get_column_change_options(self, fluent):
"""
options = super(MySQLSchemaGrammar, self)._get_column_change_options(fluent)
- if fluent.type == 'enum':
- options['extra'] = {
- 'definition': '(\'{}\')'.format('\',\''.join(fluent.allowed))
+ if fluent.type == "enum":
+ options["extra"] = {
+ "definition": "('{}')".format("','".join(fluent.allowed))
}
return options
def _wrap_value(self, value):
- if value == '*':
+ if value == "*":
return value
- return '`%s`' % value.replace('`', '``')
+ return "`%s`" % value.replace("`", "``")
diff --git a/orator/schema/grammars/postgres_grammar.py b/orator/schema/grammars/postgres_grammar.py
index 1495eb1c..e1086f8b 100644
--- a/orator/schema/grammars/postgres_grammar.py
+++ b/orator/schema/grammars/postgres_grammar.py
@@ -8,12 +8,17 @@
class PostgresSchemaGrammar(SchemaGrammar):
- _modifiers = ['increment', 'nullable', 'default']
+ _modifiers = ["increment", "nullable", "default"]
- _serials = ['big_integer', 'integer',
- 'medium_integer', 'small_integer', 'tiny_integer']
+ _serials = [
+ "big_integer",
+ "integer",
+ "medium_integer",
+ "small_integer",
+ "tiny_integer",
+ ]
- marker = '%s'
+ marker = "%s"
def compile_rename_column(self, blueprint, command, connection):
"""
@@ -34,8 +39,11 @@ def compile_rename_column(self, blueprint, command, connection):
column = self.wrap(command.from_)
- return 'ALTER TABLE %s RENAME COLUMN %s TO %s'\
- % (table, column, self.wrap(command.to))
+ return "ALTER TABLE %s RENAME COLUMN %s TO %s" % (
+ table,
+ column,
+ self.wrap(command.to),
+ )
def compile_table_exists(self):
"""
@@ -43,91 +51,101 @@ def compile_table_exists(self):
:rtype: str
"""
- return 'SELECT * ' \
- 'FROM information_schema.tables ' \
- 'WHERE table_name = \'%(marker)s\'' \
- % {'marker': self.get_marker()}
+ return (
+ "SELECT * "
+ "FROM information_schema.tables "
+ "WHERE table_name = %(marker)s" % {"marker": self.get_marker()}
+ )
def compile_column_exists(self, table):
"""
Compile the query to determine the list of columns.
"""
- return 'SELECT column_name ' \
- 'FROM information_schema.columns ' \
- 'WHERE table_name = \'%s\'' % table
+ return (
+ "SELECT column_name "
+ "FROM information_schema.columns "
+ "WHERE table_name = '%s'" % table
+ )
def compile_create(self, blueprint, command, _):
"""
Compile a create table command.
"""
- columns = ', '.join(self._get_columns(blueprint))
+ columns = ", ".join(self._get_columns(blueprint))
- return 'CREATE TABLE %s (%s)' % (self.wrap_table(blueprint), columns)
+ return "CREATE TABLE %s (%s)" % (self.wrap_table(blueprint), columns)
def compile_add(self, blueprint, command, _):
table = self.wrap_table(blueprint)
- columns = self.prefix_list('ADD COLUMN', self._get_columns(blueprint))
+ columns = self.prefix_list("ADD COLUMN", self._get_columns(blueprint))
- return 'ALTER TABLE %s %s' % (table, ', '.join(columns))
+ return "ALTER TABLE %s %s" % (table, ", ".join(columns))
def compile_primary(self, blueprint, command, _):
columns = self.columnize(command.columns)
- return 'ALTER TABLE %s ADD PRIMARY KEY (%s)'\
- % (self.wrap_table(blueprint), columns)
+ return "ALTER TABLE %s ADD PRIMARY KEY (%s)" % (
+ self.wrap_table(blueprint),
+ columns,
+ )
def compile_unique(self, blueprint, command, _):
columns = self.columnize(command.columns)
table = self.wrap_table(blueprint)
- return 'ALTER TABLE %s ADD CONSTRAINT %s UNIQUE (%s)'\
- % (table, command.index, columns)
+ return "ALTER TABLE %s ADD CONSTRAINT %s UNIQUE (%s)" % (
+ table,
+ command.index,
+ columns,
+ )
def compile_index(self, blueprint, command, _):
columns = self.columnize(command.columns)
table = self.wrap_table(blueprint)
- return 'CREATE INDEX %s ON %s (%s)' % (command.index, table, columns)
+ return "CREATE INDEX %s ON %s (%s)" % (command.index, table, columns)
def compile_drop(self, blueprint, command, _):
- return 'DROP TABLE %s' % self.wrap_table(blueprint)
+ return "DROP TABLE %s" % self.wrap_table(blueprint)
def compile_drop_if_exists(self, blueprint, command, _):
- return 'DROP TABLE IF EXISTS %s' % self.wrap_table(blueprint)
+ return "DROP TABLE IF EXISTS %s" % self.wrap_table(blueprint)
def compile_drop_column(self, blueprint, command, connection):
- columns = self.prefix_list('DROP COLUMN', self.wrap_list(command.columns))
+ columns = self.prefix_list("DROP COLUMN", self.wrap_list(command.columns))
table = self.wrap_table(blueprint)
- return 'ALTER TABLE %s %s' % (table, ', '.join(columns))
+ return "ALTER TABLE %s %s" % (table, ", ".join(columns))
def compile_drop_primary(self, blueprint, command, _):
table = blueprint.get_table()
- return 'ALTER TABLE %s DROP CONSTRAINT %s_pkey'\
- % (self.wrap_table(blueprint), table)
+ return "ALTER TABLE %s DROP CONSTRAINT %s_pkey" % (
+ self.wrap_table(blueprint),
+ table,
+ )
def compile_drop_unique(self, blueprint, command, _):
table = self.wrap_table(blueprint)
- return 'ALTER TABLE %s DROP CONSTRAINT %s' % (table, command.index)
+ return "ALTER TABLE %s DROP CONSTRAINT %s" % (table, command.index)
def compile_drop_index(self, blueprint, command, _):
- return 'DROP INDEX %s' % command.index
+ return "DROP INDEX %s" % command.index
def compile_drop_foreign(self, blueprint, command, _):
table = self.wrap_table(blueprint)
- return 'ALTER TABLE %s DROP CONSTRAINT %s' % (table, command.index)
+ return "ALTER TABLE %s DROP CONSTRAINT %s" % (table, command.index)
def compile_rename(self, blueprint, command, _):
from_ = self.wrap_table(blueprint)
- return 'ALTER TABLE %s RENAME TO %s' % (from_, self.wrap_table(command.to))
+ return "ALTER TABLE %s RENAME TO %s" % (from_, self.wrap_table(command.to))
def _type_char(self, column):
return "CHAR(%s)" % column.length
@@ -136,84 +154,84 @@ def _type_string(self, column):
return "VARCHAR(%s)" % column.length
def _type_text(self, column):
- return 'TEXT'
+ return "TEXT"
def _type_medium_text(self, column):
- return 'TEXT'
+ return "TEXT"
def _type_long_text(self, column):
- return 'TEXT'
+ return "TEXT"
def _type_integer(self, column):
- return 'SERIAL' if column.auto_increment else 'INTEGER'
+ return "SERIAL" if column.auto_increment else "INTEGER"
def _type_big_integer(self, column):
- return 'BIGSERIAL' if column.auto_increment else 'BIGINT'
+ return "BIGSERIAL" if column.auto_increment else "BIGINT"
def _type_medium_integer(self, column):
- return 'SERIAL' if column.auto_increment else 'INTEGER'
+ return "SERIAL" if column.auto_increment else "INTEGER"
def _type_tiny_integer(self, column):
- return 'SMALLSERIAL' if column.auto_increment else 'SMALLINT'
+ return "SMALLSERIAL" if column.auto_increment else "SMALLINT"
def _type_small_integer(self, column):
- return 'SMALLSERIAL' if column.auto_increment else 'SMALLINT'
+ return "SMALLSERIAL" if column.auto_increment else "SMALLINT"
def _type_float(self, column):
return self._type_double(column)
def _type_double(self, column):
- return 'DOUBLE PRECISION'
+ return "DOUBLE PRECISION"
def _type_decimal(self, column):
- return 'DECIMAL(%s, %s)' % (column.total, column.places)
+ return "DECIMAL(%s, %s)" % (column.total, column.places)
def _type_boolean(self, column):
- return 'BOOLEAN'
+ return "BOOLEAN"
def _type_enum(self, column):
allowed = list(map(lambda a: "'%s'" % a, column.allowed))
- return 'VARCHAR(255) CHECK ("%s" IN (%s))' % (column.name, ', '.join(allowed))
+ return 'VARCHAR(255) CHECK ("%s" IN (%s))' % (column.name, ", ".join(allowed))
def _type_json(self, column):
- return 'JSON'
+ return "JSON"
def _type_date(self, column):
- return 'DATE'
+ return "DATE"
def _type_datetime(self, column):
- return 'TIMESTAMP(0) WITHOUT TIME ZONE'
+ return "TIMESTAMP(6) WITHOUT TIME ZONE"
def _type_time(self, column):
- return 'TIME(0) WITHOUT TIME ZONE'
+ return "TIME(6) WITHOUT TIME ZONE"
def _type_timestamp(self, column):
if column.use_current:
- return 'TIMESTAMP(0) WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP(0)'
+ return "TIMESTAMP(6) WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP(6)"
- return 'TIMESTAMP(0) WITHOUT TIME ZONE'
+ return "TIMESTAMP(6) WITHOUT TIME ZONE"
def _type_binary(self, column):
- return 'BYTEA'
+ return "BYTEA"
def _modify_nullable(self, blueprint, column):
- if column.get('nullable'):
- return ' NULL'
+ if column.get("nullable"):
+ return " NULL"
- return ' NOT NULL'
+ return " NOT NULL"
def _modify_default(self, blueprint, column):
- if column.get('default') is not None:
- return ' DEFAULT %s' % self._get_default_value(column.default)
+ if column.get("default") is not None:
+ return " DEFAULT %s" % self._get_default_value(column.default)
- return ''
+ return ""
def _modify_increment(self, blueprint, column):
if column.type in self._serials and column.auto_increment:
- return ' PRIMARY KEY'
+ return " PRIMARY KEY"
- return ''
+ return ""
def _get_dbal_column_type(self, type_):
"""
@@ -226,7 +244,7 @@ def _get_dbal_column_type(self, type_):
"""
type_ = type_.lower()
- if type_ == 'enum':
- return 'string'
+ if type_ == "enum":
+ return "string"
return super(PostgresSchemaGrammar, self)._get_dbal_column_type(type_)
diff --git a/orator/schema/grammars/sqlite_grammar.py b/orator/schema/grammars/sqlite_grammar.py
index 8aeb4971..98780c31 100644
--- a/orator/schema/grammars/sqlite_grammar.py
+++ b/orator/schema/grammars/sqlite_grammar.py
@@ -8,9 +8,9 @@
class SQLiteSchemaGrammar(SchemaGrammar):
- _modifiers = ['nullable', 'default', 'increment']
+ _modifiers = ["nullable", "default", "increment"]
- _serials = ['big_integer', 'integer']
+ _serials = ["big_integer", "integer"]
def compile_rename_column(self, blueprint, command, connection):
"""
@@ -29,16 +29,18 @@ def compile_rename_column(self, blueprint, command, connection):
"""
sql = []
# If foreign keys are on, we disable them
- foreign_keys = self._connection.select('PRAGMA foreign_keys')
+ foreign_keys = self._connection.select("PRAGMA foreign_keys")
if foreign_keys:
foreign_keys = bool(foreign_keys[0])
if foreign_keys:
- sql.append('PRAGMA foreign_keys = OFF')
+ sql.append("PRAGMA foreign_keys = OFF")
- sql += super(SQLiteSchemaGrammar, self).compile_rename_column(blueprint, command, connection)
+ sql += super(SQLiteSchemaGrammar, self).compile_rename_column(
+ blueprint, command, connection
+ )
if foreign_keys:
- sql.append('PRAGMA foreign_keys = ON')
+ sql.append("PRAGMA foreign_keys = ON")
return sql
@@ -59,16 +61,18 @@ def compile_change(self, blueprint, command, connection):
"""
sql = []
# If foreign keys are on, we disable them
- foreign_keys = self._connection.select('PRAGMA foreign_keys')
+ foreign_keys = self._connection.select("PRAGMA foreign_keys")
if foreign_keys:
foreign_keys = bool(foreign_keys[0])
if foreign_keys:
- sql.append('PRAGMA foreign_keys = OFF')
+ sql.append("PRAGMA foreign_keys = OFF")
- sql += super(SQLiteSchemaGrammar, self).compile_change(blueprint, command, connection)
+ sql += super(SQLiteSchemaGrammar, self).compile_change(
+ blueprint, command, connection
+ )
if foreign_keys:
- sql.append('PRAGMA foreign_keys = ON')
+ sql.append("PRAGMA foreign_keys = ON")
return sql
@@ -78,41 +82,44 @@ def compile_table_exists(self):
:rtype: str
"""
- return "SELECT * FROM sqlite_master WHERE type = 'table' AND name = %(marker)s" % {'marker': self.get_marker()}
+ return (
+ "SELECT * FROM sqlite_master WHERE type = 'table' AND name = %(marker)s"
+ % {"marker": self.get_marker()}
+ )
def compile_column_exists(self, table):
"""
Compile the query to determine the list of columns.
"""
- return 'PRAGMA table_info(%s)' % table.replace('.', '__')
+ return "PRAGMA table_info(%s)" % table.replace(".", "__")
def compile_create(self, blueprint, command, _):
"""
Compile a create table command.
"""
- columns = ', '.join(self._get_columns(blueprint))
+ columns = ", ".join(self._get_columns(blueprint))
- sql = 'CREATE TABLE %s (%s' % (self.wrap_table(blueprint), columns)
+ sql = "CREATE TABLE %s (%s" % (self.wrap_table(blueprint), columns)
sql += self._add_foreign_keys(blueprint)
sql += self._add_primary_keys(blueprint)
- return sql + ')'
+ return sql + ")"
def _add_foreign_keys(self, blueprint):
- sql = ''
+ sql = ""
- foreigns = self._get_commands_by_name(blueprint, 'foreign')
+ foreigns = self._get_commands_by_name(blueprint, "foreign")
for foreign in foreigns:
sql += self._get_foreign_key(foreign)
- if foreign.get('on_delete'):
- sql += ' ON DELETE %s' % foreign.on_delete
+ if foreign.get("on_delete"):
+ sql += " ON DELETE %s" % foreign.on_delete
- if foreign.get('on_update'):
- sql += ' ON UPDATE %s' % foreign.on_delete
+ if foreign.get("on_update"):
+ sql += " ON UPDATE %s" % foreign.on_delete
return sql
@@ -127,27 +134,27 @@ def _get_foreign_key(self, foreign):
on_columns = self.columnize(references)
- return ', FOREIGN KEY(%s) REFERENCES %s(%s)' % (columns, on, on_columns)
+ return ", FOREIGN KEY(%s) REFERENCES %s(%s)" % (columns, on, on_columns)
def _add_primary_keys(self, blueprint):
- primary = self._get_command_by_name(blueprint, 'primary')
+ primary = self._get_command_by_name(blueprint, "primary")
if primary:
columns = self.columnize(primary.columns)
- return ', PRIMARY KEY (%s)' % columns
+ return ", PRIMARY KEY (%s)" % columns
- return ''
+ return ""
def compile_add(self, blueprint, command, _):
table = self.wrap_table(blueprint)
- columns = self.prefix_list('ADD COLUMN', self._get_columns(blueprint))
+ columns = self.prefix_list("ADD COLUMN", self._get_columns(blueprint))
statements = []
for column in columns:
- statements.append('ALTER TABLE %s %s' % (table, column))
+ statements.append("ALTER TABLE %s %s" % (table, column))
return statements
@@ -156,23 +163,23 @@ def compile_unique(self, blueprint, command, _):
table = self.wrap_table(blueprint)
- return 'CREATE UNIQUE INDEX %s ON %s (%s)' % (command.index, table, columns)
+ return "CREATE UNIQUE INDEX %s ON %s (%s)" % (command.index, table, columns)
def compile_index(self, blueprint, command, _):
columns = self.columnize(command.columns)
table = self.wrap_table(blueprint)
- return 'CREATE INDEX %s ON %s (%s)' % (command.index, table, columns)
+ return "CREATE INDEX %s ON %s (%s)" % (command.index, table, columns)
def compile_foreign(self, blueprint, command, _):
pass
def compile_drop(self, blueprint, command, _):
- return 'DROP TABLE %s' % self.wrap_table(blueprint)
+ return "DROP TABLE %s" % self.wrap_table(blueprint)
def compile_drop_if_exists(self, blueprint, command, _):
- return 'DROP TABLE IF EXISTS %s' % self.wrap_table(blueprint)
+ return "DROP TABLE IF EXISTS %s" % self.wrap_table(blueprint)
def compile_drop_column(self, blueprint, command, connection):
schema = connection.get_schema_manager()
@@ -187,99 +194,99 @@ def compile_drop_column(self, blueprint, command, connection):
return schema.get_database_platform().get_alter_table_sql(table_diff)
def compile_drop_unique(self, blueprint, command, _):
- return 'DROP INDEX %s' % command.index
+ return "DROP INDEX %s" % command.index
def compile_drop_index(self, blueprint, command, _):
- return 'DROP INDEX %s' % command.index
+ return "DROP INDEX %s" % command.index
def compile_rename(self, blueprint, command, _):
from_ = self.wrap_table(blueprint)
- return 'ALTER TABLE %s RENAME TO %s' % (from_, self.wrap_table(command.to))
+ return "ALTER TABLE %s RENAME TO %s" % (from_, self.wrap_table(command.to))
def _type_char(self, column):
- return 'VARCHAR'
+ return "VARCHAR"
def _type_string(self, column):
- return 'VARCHAR'
+ return "VARCHAR"
def _type_text(self, column):
- return 'TEXT'
+ return "TEXT"
def _type_medium_text(self, column):
- return 'TEXT'
+ return "TEXT"
def _type_long_text(self, column):
- return 'TEXT'
+ return "TEXT"
def _type_integer(self, column):
- return 'INTEGER'
+ return "INTEGER"
def _type_big_integer(self, column):
- return 'INTEGER'
+ return "INTEGER"
def _type_medium_integer(self, column):
- return 'INTEGER'
+ return "INTEGER"
def _type_tiny_integer(self, column):
- return 'TINYINT'
+ return "TINYINT"
def _type_small_integer(self, column):
- return 'INTEGER'
+ return "INTEGER"
def _type_float(self, column):
- return 'FLOAT'
+ return "FLOAT"
def _type_double(self, column):
- return 'FLOAT'
+ return "FLOAT"
def _type_decimal(self, column):
- return 'NUMERIC'
+ return "NUMERIC"
def _type_boolean(self, column):
- return 'TINYINT'
+ return "TINYINT"
def _type_enum(self, column):
- return 'VARCHAR'
+ return "VARCHAR"
def _type_json(self, column):
- return 'TEXT'
+ return "TEXT"
def _type_date(self, column):
- return 'DATE'
+ return "DATE"
def _type_datetime(self, column):
- return 'DATETIME'
+ return "DATETIME"
def _type_time(self, column):
- return 'TIME'
+ return "TIME"
def _type_timestamp(self, column):
if column.use_current:
- return 'DATETIME DEFAULT CURRENT_TIMESTAMP'
+ return "DATETIME DEFAULT CURRENT_TIMESTAMP"
- return 'DATETIME'
+ return "DATETIME"
def _type_binary(self, column):
- return 'BLOB'
+ return "BLOB"
def _modify_nullable(self, blueprint, column):
- if column.get('nullable'):
- return ' NULL'
+ if column.get("nullable"):
+ return " NULL"
- return ' NOT NULL'
+ return " NOT NULL"
def _modify_default(self, blueprint, column):
- if column.get('default') is not None:
- return ' DEFAULT %s' % self._get_default_value(column.default)
+ if column.get("default") is not None:
+ return " DEFAULT %s" % self._get_default_value(column.default)
- return ''
+ return ""
def _modify_increment(self, blueprint, column):
if column.type in self._serials and column.auto_increment:
- return ' PRIMARY KEY AUTOINCREMENT'
+ return " PRIMARY KEY AUTOINCREMENT"
- return ''
+ return ""
def _get_dbal_column_type(self, type_):
"""
@@ -292,7 +299,7 @@ def _get_dbal_column_type(self, type_):
"""
type_ = type_.lower()
- if type_ == 'enum':
- return 'string'
+ if type_ == "enum":
+ return "string"
return super(SQLiteSchemaGrammar, self)._get_dbal_column_type(type_)
diff --git a/orator/schema/mysql_builder.py b/orator/schema/mysql_builder.py
index 7d54d34f..62d95e5b 100644
--- a/orator/schema/mysql_builder.py
+++ b/orator/schema/mysql_builder.py
@@ -4,7 +4,6 @@
class MySQLSchemaBuilder(SchemaBuilder):
-
def has_table(self, table):
"""
Determine if the given table exists.
@@ -33,6 +32,12 @@ def get_column_listing(self, table):
database = self._connection.get_database_name()
table = self._connection.get_table_prefix() + table
- results = self._connection.select(sql, [database, table])
+ results = []
+ for result in self._connection.select(sql, [database, table]):
+ new_result = {}
+ for key, value in result.items():
+ new_result[key.lower()] = value
+
+ results.append(new_result)
return self._connection.get_post_processor().process_column_listing(results)
diff --git a/orator/schema/schema.py b/orator/schema/schema.py
index 98b4b67f..e95c29ec 100644
--- a/orator/schema/schema.py
+++ b/orator/schema/schema.py
@@ -2,7 +2,6 @@
class Schema(object):
-
def __init__(self, manager):
"""
:param manager: The database manager
diff --git a/orator/seeds/seeder.py b/orator/seeds/seeder.py
index 978b4c67..911e6f64 100644
--- a/orator/seeds/seeder.py
+++ b/orator/seeds/seeder.py
@@ -32,7 +32,7 @@ def call(self, klass):
self._resolve(klass).run()
if self._command:
- self._command.line('Seeded: %s>' % klass.__name__)
+ self._command.line("Seeded: %s>" % klass.__name__)
def _resolve(self, klass):
"""
diff --git a/orator/support/fluent.py b/orator/support/fluent.py
index f01e4ba8..d0dc388a 100644
--- a/orator/support/fluent.py
+++ b/orator/support/fluent.py
@@ -29,7 +29,6 @@ def __set_value(self, value):
class Fluent(object):
-
def __init__(self, **attributes):
self._attributes = {}
@@ -78,7 +77,7 @@ def __getattr__(self, item):
return Dynamic(self._attributes.get(item), item, self)
def __setattr__(self, key, value):
- if key == '_attributes':
+ if key == "_attributes":
super(Fluent, self).__setattr__(key, value)
try:
diff --git a/orator/support/grammar.py b/orator/support/grammar.py
index 7d9dc9ee..93ac7bd9 100644
--- a/orator/support/grammar.py
+++ b/orator/support/grammar.py
@@ -5,10 +5,10 @@
class Grammar(object):
- marker = '?'
+ marker = "?"
def __init__(self, marker=None):
- self._table_prefix = ''
+ self._table_prefix = ""
if marker:
self.marker = marker
@@ -30,18 +30,17 @@ def wrap(self, value, prefix_alias=False):
# to separate out the pieces so we can wrap each of the segments
# of the expression on it own, and then joins them
# both back together with the "as" connector.
- if value.lower().find(' as ') >= 0:
- segments = value.split(' ')
+ if value.lower().find(" as ") >= 0:
+ segments = value.split(" ")
if prefix_alias:
segments[2] = self._table_prefix + segments[2]
- return '%s AS %s' % (self.wrap(segments[0]),
- self._wrap_value(segments[2]))
+ return "%s AS %s" % (self.wrap(segments[0]), self._wrap_value(segments[2]))
wrapped = []
- segments = value.split('.')
+ segments = value.split(".")
# If the value is not an aliased table expression, we'll just wrap it like
# normal, so if there is more than one segment, we will wrap the first
@@ -52,19 +51,19 @@ def wrap(self, value, prefix_alias=False):
else:
wrapped.append(self._wrap_value(segment))
- return '.'.join(wrapped)
+ return ".".join(wrapped)
def _wrap_value(self, value):
- if value == '*':
+ if value == "*":
return value
return '"%s"' % value.replace('"', '""')
def columnize(self, columns):
- return ', '.join(map(self.wrap, columns))
+ return ", ".join(map(self.wrap, columns))
def parameterize(self, values):
- return ', '.join(map(self.parameter, values))
+ return ", ".join(map(self.parameter, values))
def parameter(self, value):
if self.is_expression(value):
@@ -79,7 +78,7 @@ def is_expression(self, value):
return isinstance(value, QueryExpression)
def get_date_format(self):
- return 'Y-m-d H:i:s'
+ return "%Y-%m-%d %H:%M:%S.%f"
def get_table_prefix(self):
return self._table_prefix
diff --git a/orator/utils/__init__.py b/orator/utils/__init__.py
index ac9d385e..e5a70c3f 100644
--- a/orator/utils/__init__.py
+++ b/orator/utils/__init__.py
@@ -21,10 +21,12 @@
from urlparse import parse_qsl
def load_module(module, path):
- with open(path, 'rb') as fh:
+ with open(path, "rb") as fh:
mod = imp.load_source(module, path, fh)
return mod
+
+
else:
long = int
unicode = str
@@ -32,21 +34,19 @@ def load_module(module, path):
from functools import reduce
- from urllib.parse import (quote_plus, unquote_plus,
- parse_qsl, quote, unquote)
+ from urllib.parse import quote_plus, unquote_plus, parse_qsl, quote, unquote
if PY33:
from importlib import machinery
def load_module(module, path):
- return machinery.SourceFileLoader(
- module, path
- ).load_module(module)
+ return machinery.SourceFileLoader(module, path).load_module(module)
+
else:
import imp
def load_module(module, path):
- with open(path, 'rb') as fh:
+ with open(path, "rb") as fh:
mod = imp.load_source(module, path, fh)
return mod
@@ -56,7 +56,6 @@ def load_module(module, path):
class Null(object):
-
def __bool__(self):
return False
@@ -65,9 +64,9 @@ def __eq__(self, other):
def deprecated(func):
- '''This is a decorator which can be used to mark functions
+ """This is a decorator which can be used to mark functions
as deprecated. It will result in a warning being emitted
- when the function is used.'''
+ when the function is used."""
@functools.wraps(func)
def new_func(*args, **kwargs):
@@ -80,7 +79,7 @@ def new_func(*args, **kwargs):
"Call to deprecated function {}.".format(func.__name__),
category=DeprecationWarning,
filename=func_code.co_filename,
- lineno=func_code.co_firstlineno + 1
+ lineno=func_code.co_firstlineno + 1,
)
return func(*args, **kwargs)
@@ -96,7 +95,7 @@ def decode(string, encodings=None):
return string
if encodings is None:
- encodings = ['utf-8', 'latin1', 'ascii']
+ encodings = ["utf-8", "latin1", "ascii"]
for encoding in encodings:
try:
@@ -104,7 +103,7 @@ def decode(string, encodings=None):
except (UnicodeEncodeError, UnicodeDecodeError):
pass
- return string.decode(encodings[0], errors='ignore')
+ return string.decode(encodings[0], errors="ignore")
def encode(string, encodings=None):
@@ -115,7 +114,7 @@ def encode(string, encodings=None):
return string
if encodings is None:
- encodings = ['utf-8', 'latin1', 'ascii']
+ encodings = ["utf-8", "latin1", "ascii"]
for encoding in encodings:
try:
@@ -123,4 +122,4 @@ def encode(string, encodings=None):
except (UnicodeEncodeError, UnicodeDecodeError):
pass
- return string.encode(encodings[0], errors='ignore')
+ return string.encode(encodings[0], errors="ignore")
diff --git a/orator/utils/command_formatter.py b/orator/utils/command_formatter.py
index f2b28b7c..24f7cb77 100644
--- a/orator/utils/command_formatter.py
+++ b/orator/utils/command_formatter.py
@@ -1,41 +1,48 @@
# -*- coding: utf-8 -*-
from pygments.formatter import Formatter
-from pygments.token import Keyword, Name, Comment, String, Error, \
- Number, Operator, Generic, Token, Whitespace
+from pygments.token import (
+ Keyword,
+ Name,
+ Comment,
+ String,
+ Error,
+ Number,
+ Operator,
+ Generic,
+ Token,
+ Whitespace,
+)
from pygments.util import get_choice_opt
COMMAND_COLORS = {
- Token: ('', ''),
-
- Whitespace: ('fg=white', 'fg=black;options=bold'),
- Comment: ('fg=white', 'fg=black;options=bold'),
- Comment.Preproc: ('fg=cyan', 'fg=cyan;options=bold'),
- Keyword: ('fg=blue', 'fg=blue;options=bold'),
- Keyword.Type: ('fg=cyan', 'fg=cyan;options=bold'),
- Operator.Word: ('fg=magenta', 'fg=magenta;options=bold'),
- Name.Builtin: ('fg=cyan', 'fg=cyan;options=bold'),
- Name.Function: ('fg=green', 'fg=green;option=bold'),
- Name.Namespace: ('fg=cyan;options=underline', 'fg=cyan;options=bold,underline'),
- Name.Class: ('fg=green;options=underline', 'fg=green;options=bold,underline'),
- Name.Exception: ('fg=cyan', 'fg=cyan;options=bold'),
- Name.Decorator: ('fg=black;options=bold', 'fg=white'),
- Name.Variable: ('fg=red', 'fg=red;options=bold'),
- Name.Constant: ('fg=red', 'fg=red;options=bold'),
- Name.Attribute: ('fg=cyan', 'fg=cyan;options=bold'),
- Name.Tag: ('fg=blue;options=bold', 'fg=blue;options=bold'),
- String: ('fg=yellow', 'fg=yellow'),
- Number: ('fg=blue', 'fg=blue;options=bold'),
-
- Generic.Deleted: ('fg=red;options=bold', 'fg=red;options=bold'),
- Generic.Inserted: ('fg=green', 'fg=green;options=bold'),
- Generic.Heading: ('options=bold', 'option=bold'),
- Generic.Subheading: ('fg=magenta;options=bold', 'fg=magenta;options=bold'),
- Generic.Prompt: ('options=bold', 'options=bold'),
- Generic.Error: ('fg=red;options=bold', 'fg=red;options=bold'),
-
- Error: ('fg=red;options=bold,underline', 'fg=red;options=bold,underline'),
+ Token: ("", ""),
+ Whitespace: ("fg=white", "fg=black;options=bold"),
+ Comment: ("fg=white", "fg=black;options=bold"),
+ Comment.Preproc: ("fg=cyan", "fg=cyan;options=bold"),
+ Keyword: ("fg=blue", "fg=blue;options=bold"),
+ Keyword.Type: ("fg=cyan", "fg=cyan;options=bold"),
+ Operator.Word: ("fg=magenta", "fg=magenta;options=bold"),
+ Name.Builtin: ("fg=cyan", "fg=cyan;options=bold"),
+ Name.Function: ("fg=green", "fg=green;option=bold"),
+ Name.Namespace: ("fg=cyan;options=underline", "fg=cyan;options=bold,underline"),
+ Name.Class: ("fg=green;options=underline", "fg=green;options=bold,underline"),
+ Name.Exception: ("fg=cyan", "fg=cyan;options=bold"),
+ Name.Decorator: ("fg=black;options=bold", "fg=white"),
+ Name.Variable: ("fg=red", "fg=red;options=bold"),
+ Name.Constant: ("fg=red", "fg=red;options=bold"),
+ Name.Attribute: ("fg=cyan", "fg=cyan;options=bold"),
+ Name.Tag: ("fg=blue;options=bold", "fg=blue;options=bold"),
+ String: ("fg=yellow", "fg=yellow"),
+ Number: ("fg=blue", "fg=blue;options=bold"),
+ Generic.Deleted: ("fg=red;options=bold", "fg=red;options=bold"),
+ Generic.Inserted: ("fg=green", "fg=green;options=bold"),
+ Generic.Heading: ("options=bold", "option=bold"),
+ Generic.Subheading: ("fg=magenta;options=bold", "fg=magenta;options=bold"),
+ Generic.Prompt: ("options=bold", "options=bold"),
+ Generic.Error: ("fg=red;options=bold", "fg=red;options=bold"),
+ Error: ("fg=red;options=bold,underline", "fg=red;options=bold,underline"),
}
@@ -62,16 +69,17 @@ class CommandFormatter(Formatter):
Set to ``True`` to have line numbers on the terminal output as well
(default: ``False`` = no line numbers).
"""
- name = 'Command'
- aliases = ['command']
+ name = "Command"
+ aliases = ["command"]
filenames = []
def __init__(self, **options):
Formatter.__init__(self, **options)
- self.darkbg = get_choice_opt(options, 'bg',
- ['light', 'dark'], 'light') == 'dark'
- self.colorscheme = options.get('colorscheme', None) or COMMAND_COLORS
- self.linenos = options.get('linenos', False)
+ self.darkbg = (
+ get_choice_opt(options, "bg", ["light", "dark"], "light") == "dark"
+ )
+ self.colorscheme = options.get("colorscheme", None) or COMMAND_COLORS
+ self.linenos = options.get("linenos", False)
self._lineno = 0
def format(self, tokensource, outfile):
@@ -79,7 +87,7 @@ def format(self, tokensource, outfile):
def _write_lineno(self, outfile):
self._lineno += 1
- outfile.write("%s%04d: " % (self._lineno != 1 and '\n' or '', self._lineno))
+ outfile.write("%s%04d: " % (self._lineno != 1 and "\n" or "", self._lineno))
def _get_color(self, ttype):
# self.colorscheme is a dict containing usually generic types, so we
@@ -100,14 +108,14 @@ def format_unencoded(self, tokensource, outfile):
for line in value.splitlines(True):
if color:
- outfile.write('<%s>%s>' % (color, line.rstrip('\n')))
+ outfile.write("<%s>%s>" % (color, line.rstrip("\n")))
else:
- outfile.write(line.rstrip('\n'))
- if line.endswith('\n'):
+ outfile.write(line.rstrip("\n"))
+ if line.endswith("\n"):
if self.linenos:
self._write_lineno(outfile)
else:
- outfile.write('\n')
+ outfile.write("\n")
if self.linenos:
outfile.write("\n")
diff --git a/orator/utils/helpers.py b/orator/utils/helpers.py
index 27b11b69..1b54f44a 100644
--- a/orator/utils/helpers.py
+++ b/orator/utils/helpers.py
@@ -2,6 +2,7 @@
import os
import errno
+import datetime
def value(val):
@@ -19,3 +20,18 @@ def mkdir_p(path, mode=0o777):
pass
else:
raise
+
+
+def serialize(value):
+ if isinstance(value, datetime.datetime):
+ if hasattr(value, "to_json"):
+ value = value.to_json()
+ else:
+ value = value.isoformat()
+ elif isinstance(value, list):
+ value = list(map(serialize, value))
+ elif isinstance(value, dict):
+ for k, v in value.items():
+ value[k] = serialize(v)
+
+ return value
diff --git a/orator/utils/qmarker.py b/orator/utils/qmarker.py
index 5ee784b5..d6b2a19e 100644
--- a/orator/utils/qmarker.py
+++ b/orator/utils/qmarker.py
@@ -5,21 +5,22 @@
class Qmarker(object):
- RE_QMARK = re.compile(r'\?\?|\?|%')
+ RE_QMARK = re.compile(r"\?\?|\?|%")
@classmethod
def qmark(cls, query):
"""
Convert a "qmark" query into "format" style.
"""
+
def sub_sequence(m):
s = m.group(0)
- if s == '??':
- return '?'
- if s == '%':
- return '%%'
+ if s == "??":
+ return "?"
+ if s == "%":
+ return "%%"
else:
- return '%s'
+ return "%s"
return cls.RE_QMARK.sub(sub_sequence, query)
diff --git a/orator/utils/url.py b/orator/utils/url.py
index dc04fb10..19bd7f79 100644
--- a/orator/utils/url.py
+++ b/orator/utils/url.py
@@ -44,8 +44,16 @@ class URL(object):
"""
- def __init__(self, drivername, username=None, password=None,
- host=None, port=None, database=None, query=None):
+ def __init__(
+ self,
+ drivername,
+ username=None,
+ password=None,
+ host=None,
+ port=None,
+ database=None,
+ query=None,
+ ):
self.drivername = drivername
self.username = username
self.password = password
@@ -62,22 +70,21 @@ def __to_string__(self, hide_password=True):
if self.username is not None:
s += _rfc_1738_quote(self.username)
if self.password is not None:
- s += ':' + ('***' if hide_password
- else _rfc_1738_quote(self.password))
+ s += ":" + ("***" if hide_password else _rfc_1738_quote(self.password))
s += "@"
if self.host is not None:
- if ':' in self.host:
+ if ":" in self.host:
s += "[%s]" % self.host
else:
s += self.host
if self.port is not None:
- s += ':' + str(self.port)
+ s += ":" + str(self.port)
if self.database is not None:
- s += '/' + self.database
+ s += "/" + self.database
if self.query:
keys = list(self.query)
keys.sort()
- s += '?' + "&".join("%s=%s" % (k, self.query[k]) for k in keys)
+ s += "?" + "&".join("%s=%s" % (k, self.query[k]) for k in keys)
return s
def __str__(self):
@@ -90,43 +97,46 @@ def __hash__(self):
return hash(str(self))
def __eq__(self, other):
- return \
- isinstance(other, URL) and \
- self.drivername == other.drivername and \
- self.username == other.username and \
- self.password == other.password and \
- self.host == other.host and \
- self.database == other.database and \
- self.query == other.query
+ return (
+ isinstance(other, URL)
+ and self.drivername == other.drivername
+ and self.username == other.username
+ and self.password == other.password
+ and self.host == other.host
+ and self.database == other.database
+ and self.query == other.query
+ )
def get_backend_name(self):
- if '+' not in self.drivername:
+ if "+" not in self.drivername:
return self.drivername
else:
- return self.drivername.split('+')[0]
+ return self.drivername.split("+")[0]
def get_driver_name(self):
- if '+' not in self.drivername:
+ if "+" not in self.drivername:
return self.get_dialect().driver
else:
- return self.drivername.split('+')[1]
+ return self.drivername.split("+")[1]
def get_dialect(self):
"""Return the SQLAlchemy database dialect class corresponding
to this URL's driver name.
"""
- if '+' not in self.drivername:
+ if "+" not in self.drivername:
name = self.drivername
else:
- name = self.drivername.replace('+', '.')
+ name = self.drivername.replace("+", ".")
cls = registry.load(name)
# check for legacy dialects that
# would return a module with 'dialect' as the
# actual class
- if hasattr(cls, 'dialect') and \
- isinstance(cls.dialect, type) and \
- issubclass(cls.dialect, Dialect):
+ if (
+ hasattr(cls, "dialect")
+ and isinstance(cls.dialect, type)
+ and issubclass(cls.dialect, Dialect)
+ ):
return cls.dialect
else:
return cls
@@ -146,7 +156,7 @@ def translate_connect_args(self, names=[], **kw):
"""
translated = {}
- attribute_names = ['host', 'database', 'username', 'password', 'port']
+ attribute_names = ["host", "database", "username", "password", "port"]
for sname in attribute_names:
if names:
name = names.pop(0)
@@ -173,7 +183,8 @@ def make_url(name_or_url):
def _parse_rfc1738_args(name):
- pattern = re.compile(r'''
+ pattern = re.compile(
+ r"""
(?P[\w\+]+)://
(?:
(?P[^:/]*)
@@ -187,40 +198,40 @@ def _parse_rfc1738_args(name):
(?::(?P[^/]*))?
)?
(?:/(?P.*))?
- ''', re.X)
+ """,
+ re.X,
+ )
m = pattern.match(name)
if m is not None:
components = m.groupdict()
- if components['database'] is not None:
- tokens = components['database'].split('?', 2)
- components['database'] = tokens[0]
- query = (
- len(tokens) > 1 and dict(parse_qsl(tokens[1]))) or None
+ if components["database"] is not None:
+ tokens = components["database"].split("?", 2)
+ components["database"] = tokens[0]
+ query = (len(tokens) > 1 and dict(parse_qsl(tokens[1]))) or None
if PY2 and query is not None:
- query = dict((k.encode('ascii'), query[k]) for k in query)
+ query = dict((k.encode("ascii"), query[k]) for k in query)
else:
query = None
- components['query'] = query
+ components["query"] = query
- if components['username'] is not None:
- components['username'] = _rfc_1738_unquote(components['username'])
+ if components["username"] is not None:
+ components["username"] = _rfc_1738_unquote(components["username"])
- if components['password'] is not None:
- components['password'] = _rfc_1738_unquote(components['password'])
+ if components["password"] is not None:
+ components["password"] = _rfc_1738_unquote(components["password"])
- ipv4host = components.pop('ipv4host')
- ipv6host = components.pop('ipv6host')
- components['host'] = ipv4host or ipv6host
- name = components.pop('name')
+ ipv4host = components.pop("ipv4host")
+ ipv6host = components.pop("ipv6host")
+ components["host"] = ipv4host or ipv6host
+ name = components.pop("name")
return URL(name, **components)
else:
- raise ArgumentError(
- "Could not parse rfc1738 URL from string '%s'" % name)
+ raise ArgumentError("Could not parse rfc1738 URL from string '%s'" % name)
def _rfc_1738_quote(text):
- return re.sub(r'[:@/]', lambda m: "%%%X" % ord(m.group(0)), text)
+ return re.sub(r"[:@/]", lambda m: "%%%X" % ord(m.group(0)), text)
def _rfc_1738_unquote(text):
@@ -228,12 +239,10 @@ def _rfc_1738_unquote(text):
def _parse_keyvalue_args(name):
- m = re.match(r'(\w+)://(.*)', name)
+ m = re.match(r"(\w+)://(.*)", name)
if m is not None:
(name, args) = m.group(1, 2)
opts = dict(parse_qsl(args))
return URL(name, *opts)
else:
return None
-
-
diff --git a/orator/version.py b/orator/version.py
deleted file mode 100644
index 06e67340..00000000
--- a/orator/version.py
+++ /dev/null
@@ -1,3 +0,0 @@
-# -*- coding: utf-8 -*-
-
-VERSION = '0.9.2'
diff --git a/pyproject.toml b/pyproject.toml
new file mode 100644
index 00000000..c067e646
--- /dev/null
+++ b/pyproject.toml
@@ -0,0 +1,62 @@
+[tool.poetry]
+name = "orator"
+version = "0.9.9"
+description = "The Orator ORM provides a simple yet beautiful ActiveRecord implementation."
+
+license = "MIT"
+
+authors = [
+ "Sébastien Eustace "
+]
+
+readme = 'README.rst'
+
+repository = "https://github.com/sdispater/orator"
+homepage = "https://orator-orm.com/"
+
+keywords = ['database', 'orm']
+
+
+[tool.poetry.dependencies]
+python = "~2.7 || ^3.5"
+backpack = "^0.1"
+blinker = "^1.4"
+cleo = "^0.6"
+inflection = "^0.3"
+Faker = "^0.8"
+lazy-object-proxy = "^1.2"
+pendulum = "^1.4"
+pyaml = "^16.12"
+pyyaml = "^5.1"
+Pygments = "^2.2"
+simplejson = "^3.10"
+six = "^1.10"
+wrapt = "^1.10"
+
+# Extras
+psycopg2 = { version = "^2.7", optional = true }
+PyMySQL = { version = "^0.7", optional = true }
+mysqlclient = { version = "^1.3", optional = true }
+
+
+[tool.poetry.dev-dependencies]
+flexmock = "0.9.7"
+pytest = "^3.5"
+pytest-mock = "^1.6"
+tox = "^3.5"
+pre-commit = "^1.11"
+
+
+[tool.poetry.extras]
+mysql = ["mysqlclient"]
+mysql-python = ["PyMySQL"]
+pgsql = ["psycopg2"]
+
+
+[tool.poetry.scripts]
+orator = 'orator.commands.application:application.run'
+
+
+[build-system]
+requires = ["poetry>=0.12a3"]
+build-backend = "poetry.masonry.api"
diff --git a/requirements.txt b/requirements.txt
deleted file mode 100644
index dee82e7b..00000000
--- a/requirements.txt
+++ /dev/null
@@ -1,12 +0,0 @@
-simplejson
-pendulum
-backpack
-inflection
-six
-cleo>=0.4.1
-blinker
-lazy-object-proxy
-fake-factory
-wrapt
-pyaml
-pygments
diff --git a/setup.cfg b/setup.cfg
deleted file mode 100644
index 2a9acf13..00000000
--- a/setup.cfg
+++ /dev/null
@@ -1,2 +0,0 @@
-[bdist_wheel]
-universal = 1
diff --git a/setup.py b/setup.py
deleted file mode 100644
index 3f67d89c..00000000
--- a/setup.py
+++ /dev/null
@@ -1,51 +0,0 @@
-# -*- coding: utf-8 -*-
-
-import os
-from setuptools import find_packages
-from distutils.core import setup
-
-
-here = os.path.abspath(os.path.dirname(__file__))
-
-def get_version():
- with open(os.path.join(here, 'orator/version.py')) as f:
- variables = {}
- exec(f.read(), variables)
-
- version = variables.get('VERSION')
- if version:
- return version
-
- raise RuntimeError('No version info found.')
-
-__version__ = get_version()
-
-with open(os.path.join(here, 'requirements.txt')) as f:
- requirements = f.readlines()
-
-setup_kwargs = dict(
- name='orator',
- license='MIT',
- version=__version__,
- description='The Orator ORM provides a simple yet beautiful ActiveRecord implementation.',
- long_description=open('README.rst').read(),
- entry_points={
- 'console_scripts': ['orator=orator.commands.application:application.run'],
- },
- author='Sébastien Eustace',
- author_email='sebastien.eustace@gmail.com',
- url='https://github.com/sdispater/orator',
- download_url='https://github.com/sdispater/orator/archive/%s.tar.gz' % __version__,
- packages=find_packages(exclude=['tests']),
- install_requires=requirements,
- tests_require=['pytest', 'mock', 'flexmock==0.9.7', 'mysqlclient', 'psycopg2'],
- test_suite='nose.collector',
- classifiers=[
- 'Intended Audience :: Developers',
- 'Operating System :: OS Independent',
- 'Programming Language :: Python',
- 'Topic :: Software Development :: Libraries :: Python Modules',
- ],
-)
-
-setup(**setup_kwargs)
diff --git a/tests-requirements.txt b/tests-requirements.txt
deleted file mode 100644
index 843db42e..00000000
--- a/tests-requirements.txt
+++ /dev/null
@@ -1,5 +0,0 @@
--r requirements.txt
-pytest
-pytest-mock
-flexmock==0.9.7
-psycopg2
diff --git a/tests/__init__.py b/tests/__init__.py
index f59d6aae..fbdad4b4 100644
--- a/tests/__init__.py
+++ b/tests/__init__.py
@@ -15,34 +15,32 @@
class OratorTestCase(TestCase):
-
def tearDown(self):
- if hasattr(self, 'local_database'):
+ if hasattr(self, "local_database"):
os.remove(self.local_database)
def init_database(self):
- self.local_database = '/tmp/orator_test_database.db'
+ self.local_database = "/tmp/orator_test_database.db"
if os.path.exists(self.local_database):
os.remove(self.local_database)
- self.manager = DatabaseManager({
- 'default': 'sqlite',
- 'sqlite': {
- 'driver': 'sqlite',
- 'database': self.local_database
+ self.manager = DatabaseManager(
+ {
+ "default": "sqlite",
+ "sqlite": {"driver": "sqlite", "database": self.local_database},
}
- })
+ )
with self.manager.transaction():
try:
self.manager.statement(
- 'CREATE TABLE `users` ('
- 'id INTEGER PRIMARY KEY NOT NULL, '
- 'name CHAR(50) NOT NULL, '
- 'created_at DATETIME DEFAULT CURRENT_TIMESTAMP, '
- 'updated_at DATETIME DEFAULT CURRENT_TIMESTAMP'
- ')'
+ "CREATE TABLE `users` ("
+ "id INTEGER PRIMARY KEY NOT NULL, "
+ "name CHAR(50) NOT NULL, "
+ "created_at DATETIME DEFAULT CURRENT_TIMESTAMP, "
+ "updated_at DATETIME DEFAULT CURRENT_TIMESTAMP"
+ ")"
)
except Exception:
pass
diff --git a/tests/commands/__init__.py b/tests/commands/__init__.py
index d6cd4d88..18658674 100644
--- a/tests/commands/__init__.py
+++ b/tests/commands/__init__.py
@@ -6,11 +6,10 @@
class OratorCommandTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
- def run_command(self, command, options=None, input_stream=None):
+ def run_command(self, command, options=None):
"""
Run the command.
@@ -20,18 +19,12 @@ def run_command(self, command, options=None, input_stream=None):
if options is None:
options = []
- options = [('command', command.get_name())] + options
+ options = [("command", command.get_name())] + options
application = Application()
application.add(command)
- if input_stream:
- dialog = command.get_helper('question')
- dialog.__class__.input_stream = input_stream
-
command_tester = CommandTester(command)
command_tester.execute(options)
return command_tester
-
-
diff --git a/tests/commands/migrations/__init__.py b/tests/commands/migrations/__init__.py
index 633f8661..40a96afc 100644
--- a/tests/commands/migrations/__init__.py
+++ b/tests/commands/migrations/__init__.py
@@ -1,2 +1 @@
# -*- coding: utf-8 -*-
-
diff --git a/tests/commands/migrations/test_install_command.py b/tests/commands/migrations/test_install_command.py
index fc3f4fc1..4a8a676f 100644
--- a/tests/commands/migrations/test_install_command.py
+++ b/tests/commands/migrations/test_install_command.py
@@ -7,13 +7,12 @@
class MigrateInstallCommandTestCase(OratorCommandTestCase):
-
def test_execute_calls_repository_to_install(self):
repo_mock = flexmock(DatabaseMigrationRepository)
- repo_mock.should_receive('set_source').once().with_args('foo')
- repo_mock.should_receive('create_repository').once()
+ repo_mock.should_receive("set_source").once().with_args("foo")
+ repo_mock.should_receive("create_repository").once()
command = flexmock(InstallCommand())
- command.should_receive('_get_config').and_return({})
+ command.should_receive("_get_config").and_return({})
- self.run_command(command, [('--database', 'foo')])
+ self.run_command(command, [("--database", "foo")])
diff --git a/tests/commands/migrations/test_make_command.py b/tests/commands/migrations/test_make_command.py
index 4b3181bd..b72269eb 100644
--- a/tests/commands/migrations/test_make_command.py
+++ b/tests/commands/migrations/test_make_command.py
@@ -8,33 +8,39 @@
class MigrateMakeCommandTestCase(OratorCommandTestCase):
-
def test_basic_create_gives_creator_proper_arguments(self):
creator_mock = flexmock(MigrationCreator)
- creator_mock.should_receive('create').once()\
- .with_args('create_foo', os.path.join(os.getcwd(), 'migrations'), None, False).and_return('foo')
+ creator_mock.should_receive("create").once().with_args(
+ "create_foo", os.path.join(os.getcwd(), "migrations"), None, False
+ ).and_return("foo")
command = flexmock(MigrateMakeCommand())
- command.should_receive('_get_config').and_return({})
+ command.should_receive("_get_config").and_return({})
- self.run_command(command, [('name', 'create_foo')])
+ self.run_command(command, [("name", "create_foo")])
def test_basic_create_gives_creator_proper_arguments_when_table_is_set(self):
creator_mock = flexmock(MigrationCreator)
- creator_mock.should_receive('create').once()\
- .with_args('create_foo', os.path.join(os.getcwd(), 'migrations'), 'users', False).and_return('foo')
+ creator_mock.should_receive("create").once().with_args(
+ "create_foo", os.path.join(os.getcwd(), "migrations"), "users", False
+ ).and_return("foo")
command = flexmock(MigrateMakeCommand())
- command.should_receive('_get_config').and_return({})
+ command.should_receive("_get_config").and_return({})
- self.run_command(command, [('name', 'create_foo'), ('--table', 'users')])
+ self.run_command(command, [("name", "create_foo"), ("--table", "users")])
- def test_basic_create_gives_creator_proper_arguments_when_table_is_set_with_create(self):
+ def test_basic_create_gives_creator_proper_arguments_when_table_is_set_with_create(
+ self
+ ):
creator_mock = flexmock(MigrationCreator)
- creator_mock.should_receive('create').once()\
- .with_args('create_foo', os.path.join(os.getcwd(), 'migrations'), 'users', True).and_return('foo')
+ creator_mock.should_receive("create").once().with_args(
+ "create_foo", os.path.join(os.getcwd(), "migrations"), "users", True
+ ).and_return("foo")
command = flexmock(MigrateMakeCommand())
- command.should_receive('_get_config').and_return({})
+ command.should_receive("_get_config").and_return({})
- self.run_command(command, [('name', 'create_foo'), ('--table', 'users'), '--create'])
+ self.run_command(
+ command, [("name", "create_foo"), ("--table", "users"), "--create"]
+ )
diff --git a/tests/commands/migrations/test_migrate_command.py b/tests/commands/migrations/test_migrate_command.py
index a88cc1a3..4a69058a 100644
--- a/tests/commands/migrations/test_migrate_command.py
+++ b/tests/commands/migrations/test_migrate_command.py
@@ -1,9 +1,7 @@
# -*- coding: utf-8 -*-
import os
-from io import BytesIO
from flexmock import flexmock
-from cleo import Output
from orator.migrations import Migrator
from orator.commands.migrations import MigrateCommand
from orator import DatabaseManager
@@ -11,74 +9,97 @@
class MigrateCommandTestCase(OratorCommandTestCase):
-
def test_basic_migrations_call_migrator_with_proper_arguments(self):
resolver = flexmock(DatabaseManager)
- resolver.should_receive('connection').and_return(None)
+ resolver.should_receive("connection").and_return(None)
migrator_mock = flexmock(Migrator)
- migrator_mock.should_receive('set_connection').once().with_args(None)
- migrator_mock.should_receive('run').once().with_args(os.path.join(os.getcwd(), 'migrations'), False)
- migrator_mock.should_receive('get_notes').and_return([])
- migrator_mock.should_receive('repository_exists').once().and_return(True)
+ migrator_mock.should_receive("set_connection").once().with_args(None)
+ migrator_mock.should_receive("run").once().with_args(
+ os.path.join(os.getcwd(), "migrations"), False
+ )
+ migrator_mock.should_receive("get_notes").and_return([])
+ migrator_mock.should_receive("repository_exists").once().and_return(True)
command = flexmock(MigrateCommand())
- command.should_receive('_get_config').and_return({})
+ command.should_receive("_get_config").and_return({})
+ command.should_receive("confirm").and_return(True)
- self.run_command(command, input_stream=self.get_input_stream('y\n'))
+ self.run_command(command)
def test_migration_repository_create_when_necessary(self):
resolver = flexmock(DatabaseManager)
- resolver.should_receive('connection').and_return(None)
+ resolver.should_receive("connection").and_return(None)
migrator_mock = flexmock(Migrator)
- migrator_mock.should_receive('set_connection').once().with_args(None)
- migrator_mock.should_receive('run').once().with_args(os.path.join(os.getcwd(), 'migrations'), False)
- migrator_mock.should_receive('get_notes').and_return([])
- migrator_mock.should_receive('repository_exists').once().and_return(False)
+ migrator_mock.should_receive("set_connection").once().with_args(None)
+ migrator_mock.should_receive("run").once().with_args(
+ os.path.join(os.getcwd(), "migrations"), False
+ )
+ migrator_mock.should_receive("get_notes").and_return([])
+ migrator_mock.should_receive("repository_exists").once().and_return(False)
command = flexmock(MigrateCommand())
- command.should_receive('_get_config').and_return({})
- command.should_receive('call').once()\
- .with_args('migrate:install', [('--config', None)])
+ command.should_receive("_get_config").and_return({})
+ command.should_receive("confirm").and_return(True)
+ command.should_receive("call").once().with_args(
+ "migrate:install", [("--config", None)]
+ )
- self.run_command(command, input_stream=self.get_input_stream('y\n'))
+ self.run_command(command)
def test_migration_can_be_pretended(self):
resolver = flexmock(DatabaseManager)
- resolver.should_receive('connection').and_return(None)
+ resolver.should_receive("connection").and_return(None)
migrator_mock = flexmock(Migrator)
- migrator_mock.should_receive('set_connection').once().with_args(None)
- migrator_mock.should_receive('run').once().with_args(os.path.join(os.getcwd(), 'migrations'), True)
- migrator_mock.should_receive('get_notes').and_return([])
- migrator_mock.should_receive('repository_exists').once().and_return(True)
+ migrator_mock.should_receive("set_connection").once().with_args(None)
+ migrator_mock.should_receive("run").once().with_args(
+ os.path.join(os.getcwd(), "migrations"), True
+ )
+ migrator_mock.should_receive("get_notes").and_return([])
+ migrator_mock.should_receive("repository_exists").once().and_return(True)
command = flexmock(MigrateCommand())
- command.should_receive('_get_config').and_return({})
+ command.should_receive("_get_config").and_return({})
+ command.should_receive("confirm").and_return(True)
- self.run_command(command, [('--pretend', True)], input_stream=self.get_input_stream('y\n'))
+ self.run_command(command, [("--pretend", True)])
def test_migration_database_can_be_set(self):
resolver = flexmock(DatabaseManager)
- resolver.should_receive('connection').and_return(None)
+ resolver.should_receive("connection").and_return(None)
migrator_mock = flexmock(Migrator)
- migrator_mock.should_receive('set_connection').once().with_args('foo')
- migrator_mock.should_receive('run').once().with_args(os.path.join(os.getcwd(), 'migrations'), False)
- migrator_mock.should_receive('get_notes').and_return([])
- migrator_mock.should_receive('repository_exists').once().and_return(False)
+ migrator_mock.should_receive("set_connection").once().with_args("foo")
+ migrator_mock.should_receive("run").once().with_args(
+ os.path.join(os.getcwd(), "migrations"), False
+ )
+ migrator_mock.should_receive("get_notes").and_return([])
+ migrator_mock.should_receive("repository_exists").once().and_return(False)
command = flexmock(MigrateCommand())
- command.should_receive('_get_config').and_return({})
- command.should_receive('call').once()\
- .with_args('migrate:install', [('--database', 'foo'), ('--config', None)])
+ command.should_receive("_get_config").and_return({})
+ command.should_receive("confirm").and_return(True)
+ command.should_receive("call").once().with_args(
+ "migrate:install", [("--database", "foo"), ("--config", None)]
+ )
+
+ self.run_command(command, [("--database", "foo")])
+
+ def test_migration_can_be_forced(self):
+ resolver = flexmock(DatabaseManager)
+ resolver.should_receive("connection").and_return(None)
- self.run_command(command, [('--database', 'foo')], input_stream=self.get_input_stream('y\n'))
+ migrator_mock = flexmock(Migrator)
+ migrator_mock.should_receive("set_connection").once().with_args(None)
+ migrator_mock.should_receive("run").once().with_args(
+ os.path.join(os.getcwd(), "migrations"), False
+ )
+ migrator_mock.should_receive("get_notes").and_return([])
+ migrator_mock.should_receive("repository_exists").once().and_return(True)
- def get_input_stream(self, input_):
- stream = BytesIO()
- stream.write(input_.encode())
- stream.seek(0)
+ command = flexmock(MigrateCommand())
+ command.should_receive("_get_config").and_return({})
- return stream
+ self.run_command(command, [("--force", True)])
diff --git a/tests/commands/migrations/test_refresh_command.py b/tests/commands/migrations/test_refresh_command.py
new file mode 100644
index 00000000..657141e8
--- /dev/null
+++ b/tests/commands/migrations/test_refresh_command.py
@@ -0,0 +1,39 @@
+# -*- coding: utf-8 -*-
+
+import os
+from flexmock import flexmock
+from orator.migrations import Migrator
+from orator.commands.migrations import RefreshCommand
+from orator import DatabaseManager
+from .. import OratorCommandTestCase
+
+
+class RefreshCommandTestCase(OratorCommandTestCase):
+ def test_refresh_runs_the_seeder_when_seed_option_set(self):
+ resolver = flexmock(DatabaseManager)
+ resolver.should_receive("connection").and_return(None)
+
+ command = flexmock(RefreshCommand())
+ command.should_receive("_get_config").and_return({})
+ command.should_receive("confirm").and_return(True)
+ command.should_receive("call").with_args("migrate:reset", object).and_return(
+ True
+ )
+ command.should_receive("call").with_args("migrate", object).and_return(True)
+ command.should_receive("_run_seeder")
+
+ self.run_command(command, [("--seed")])
+
+ def test_refresh_does_not_run_the_seeder_when_seed_option_absent(self):
+ resolver = flexmock(DatabaseManager)
+ resolver.should_receive("connection").and_return(None)
+
+ command = flexmock(RefreshCommand())
+ command.should_receive("_get_config").and_return({})
+ command.should_receive("confirm").and_return(True)
+ command.should_receive("call").with_args("migrate:reset", object).and_return(
+ True
+ )
+ command.should_receive("call").with_args("migrate", object).and_return(True)
+
+ self.run_command(command, [])
diff --git a/tests/commands/migrations/test_reset_command.py b/tests/commands/migrations/test_reset_command.py
index 67cc5216..32581772 100644
--- a/tests/commands/migrations/test_reset_command.py
+++ b/tests/commands/migrations/test_reset_command.py
@@ -1,68 +1,77 @@
# -*- coding: utf-8 -*-
import os
-from io import BytesIO
from flexmock import flexmock
from orator.migrations import Migrator
from orator.commands.migrations import ResetCommand
from orator import DatabaseManager
-from orator.connections import Connection
from .. import OratorCommandTestCase
class ResetCommandTestCase(OratorCommandTestCase):
-
def test_basic_migrations_call_migrator_with_proper_arguments(self):
resolver = flexmock(DatabaseManager)
- resolver.should_receive('connection').and_return(None)
+ resolver.should_receive("connection").and_return(None)
migrator_mock = flexmock(Migrator)
- migrator_mock.should_receive('set_connection').once().with_args(None)
- migrator_mock.should_receive('reset').once()\
- .with_args(os.path.join(os.getcwd(), 'migrations'), False)\
- .and_return(2)
- migrator_mock.should_receive('get_notes').and_return([])
+ migrator_mock.should_receive("set_connection").once().with_args(None)
+ migrator_mock.should_receive("reset").once().with_args(
+ os.path.join(os.getcwd(), "migrations"), False
+ ).and_return(2)
+ migrator_mock.should_receive("get_notes").and_return([])
command = flexmock(ResetCommand())
- command.should_receive('_get_config').and_return({})
+ command.should_receive("_get_config").and_return({})
+ command.should_receive("confirm").and_return(True)
- self.run_command(command, input_stream=self.get_input_stream('y\n'))
+ self.run_command(command)
def test_migration_can_be_pretended(self):
resolver = flexmock(DatabaseManager)
- resolver.should_receive('connection').and_return(None)
+ resolver.should_receive("connection").and_return(None)
migrator_mock = flexmock(Migrator)
- migrator_mock.should_receive('set_connection').once().with_args(None)
- migrator_mock.should_receive('reset').once()\
- .with_args(os.path.join(os.getcwd(), 'migrations'), True)\
- .and_return(2)
- migrator_mock.should_receive('get_notes').and_return([])
+ migrator_mock.should_receive("set_connection").once().with_args(None)
+ migrator_mock.should_receive("reset").once().with_args(
+ os.path.join(os.getcwd(), "migrations"), True
+ ).and_return(2)
+ migrator_mock.should_receive("get_notes").and_return([])
command = flexmock(ResetCommand())
- command.should_receive('_get_config').and_return({})
+ command.should_receive("_get_config").and_return({})
+ command.should_receive("confirm").and_return(True)
- self.run_command(command, [('--pretend', True)], input_stream=self.get_input_stream('y\n'))
+ self.run_command(command, [("--pretend", True)])
def test_migration_database_can_be_set(self):
resolver = flexmock(DatabaseManager)
- resolver.should_receive('connection').and_return(None)
+ resolver.should_receive("connection").and_return(None)
migrator_mock = flexmock(Migrator)
- migrator_mock.should_receive('set_connection').once().with_args('foo')
- migrator_mock.should_receive('reset').once()\
- .with_args(os.path.join(os.getcwd(), 'migrations'), False)\
- .and_return(2)
- migrator_mock.should_receive('get_notes').and_return([])
+ migrator_mock.should_receive("set_connection").once().with_args("foo")
+ migrator_mock.should_receive("reset").once().with_args(
+ os.path.join(os.getcwd(), "migrations"), False
+ ).and_return(2)
+ migrator_mock.should_receive("get_notes").and_return([])
command = flexmock(ResetCommand())
- command.should_receive('_get_config').and_return({})
+ command.should_receive("_get_config").and_return({})
+ command.should_receive("confirm").and_return(True)
+
+ self.run_command(command, [("--database", "foo")])
+
+ def test_migration_can_be_forced(self):
+ resolver = flexmock(DatabaseManager)
+ resolver.should_receive("connection").and_return(None)
- self.run_command(command, [('--database', 'foo')], input_stream=self.get_input_stream('y\n'))
+ migrator_mock = flexmock(Migrator)
+ migrator_mock.should_receive("set_connection").once().with_args(None)
+ migrator_mock.should_receive("reset").once().with_args(
+ os.path.join(os.getcwd(), "migrations"), False
+ ).and_return(2)
+ migrator_mock.should_receive("get_notes").and_return([])
- def get_input_stream(self, input_):
- stream = BytesIO()
- stream.write(input_.encode())
- stream.seek(0)
+ command = flexmock(ResetCommand())
+ command.should_receive("_get_config").and_return({})
- return stream
+ self.run_command(command, [("--force", True)])
diff --git a/tests/commands/migrations/test_rollback_command.py b/tests/commands/migrations/test_rollback_command.py
index 39973f26..9de77bba 100644
--- a/tests/commands/migrations/test_rollback_command.py
+++ b/tests/commands/migrations/test_rollback_command.py
@@ -1,9 +1,7 @@
# -*- coding: utf-8 -*-
import os
-from io import BytesIO
from flexmock import flexmock
-from cleo import Output
from orator.migrations import Migrator
from orator.commands.migrations import RollbackCommand
from orator import DatabaseManager
@@ -11,52 +9,69 @@
class RollbackCommandTestCase(OratorCommandTestCase):
-
def test_basic_migrations_call_migrator_with_proper_arguments(self):
resolver = flexmock(DatabaseManager)
- resolver.should_receive('connection').and_return(None)
+ resolver.should_receive("connection").and_return(None)
migrator_mock = flexmock(Migrator)
- migrator_mock.should_receive('set_connection').once().with_args(None)
- migrator_mock.should_receive('rollback').once().with_args(os.path.join(os.getcwd(), 'migrations'), False)
- migrator_mock.should_receive('get_notes').and_return([])
+ migrator_mock.should_receive("set_connection").once().with_args(None)
+ migrator_mock.should_receive("rollback").once().with_args(
+ os.path.join(os.getcwd(), "migrations"), False
+ )
+ migrator_mock.should_receive("get_notes").and_return([])
command = flexmock(RollbackCommand())
- command.should_receive('_get_config').and_return({})
+ command.should_receive("_get_config").and_return({})
+ command.should_receive("confirm").and_return(True)
- self.run_command(command, input_stream=self.get_input_stream('y\n'))
+ self.run_command(command)
def test_migration_can_be_pretended(self):
resolver = flexmock(DatabaseManager)
- resolver.should_receive('connection').and_return(None)
+ resolver.should_receive("connection").and_return(None)
migrator_mock = flexmock(Migrator)
- migrator_mock.should_receive('set_connection').once().with_args(None)
- migrator_mock.should_receive('rollback').once().with_args(os.path.join(os.getcwd(), 'migrations'), True)
- migrator_mock.should_receive('get_notes').and_return([])
+ migrator_mock.should_receive("set_connection").once().with_args(None)
+ migrator_mock.should_receive("rollback").once().with_args(
+ os.path.join(os.getcwd(), "migrations"), True
+ )
+ migrator_mock.should_receive("get_notes").and_return([])
command = flexmock(RollbackCommand())
- command.should_receive('_get_config').and_return({})
+ command.should_receive("_get_config").and_return({})
+ command.should_receive("confirm").and_return(True)
- self.run_command(command, [('--pretend', True)], input_stream=self.get_input_stream('y\n'))
+ self.run_command(command, [("--pretend", True)])
def test_migration_database_can_be_set(self):
resolver = flexmock(DatabaseManager)
- resolver.should_receive('connection').and_return(None)
+ resolver.should_receive("connection").and_return(None)
migrator_mock = flexmock(Migrator)
- migrator_mock.should_receive('set_connection').once().with_args('foo')
- migrator_mock.should_receive('rollback').once().with_args(os.path.join(os.getcwd(), 'migrations'), False)
- migrator_mock.should_receive('get_notes').and_return([])
+ migrator_mock.should_receive("set_connection").once().with_args("foo")
+ migrator_mock.should_receive("rollback").once().with_args(
+ os.path.join(os.getcwd(), "migrations"), False
+ )
+ migrator_mock.should_receive("get_notes").and_return([])
command = flexmock(RollbackCommand())
- command.should_receive('_get_config').and_return({})
+ command.should_receive("_get_config").and_return({})
+ command.should_receive("confirm").and_return(True)
+
+ self.run_command(command, [("--database", "foo")])
+
+ def test_migration_can_be_forced(self):
+ resolver = flexmock(DatabaseManager)
+ resolver.should_receive("connection").and_return(None)
- self.run_command(command, [('--database', 'foo')], input_stream=self.get_input_stream('y\n'))
+ migrator_mock = flexmock(Migrator)
+ migrator_mock.should_receive("set_connection").once().with_args(None)
+ migrator_mock.should_receive("rollback").once().with_args(
+ os.path.join(os.getcwd(), "migrations"), False
+ )
+ migrator_mock.should_receive("get_notes").and_return([])
- def get_input_stream(self, input_):
- stream = BytesIO()
- stream.write(input_.encode())
- stream.seek(0)
+ command = flexmock(RollbackCommand())
+ command.should_receive("_get_config").and_return({})
- return stream
+ self.run_command(command, [("--force", True)])
diff --git a/tests/connections/__init__.py b/tests/connections/__init__.py
index 633f8661..40a96afc 100644
--- a/tests/connections/__init__.py
+++ b/tests/connections/__init__.py
@@ -1,2 +1 @@
# -*- coding: utf-8 -*-
-
diff --git a/tests/connections/test_connection.py b/tests/connections/test_connection.py
index 557a74fd..870ef8fe 100644
--- a/tests/connections/test_connection.py
+++ b/tests/connections/test_connection.py
@@ -1,6 +1,7 @@
# -*- coding: utf-8 -*-
import threading
+
try:
from Queue import Queue
except ImportError:
@@ -17,24 +18,23 @@
class ConnectionTestCase(OratorTestCase):
-
def test_table_returns_query_builder(self):
- connection = Connection(None, 'database')
- builder = connection.table('users')
+ connection = Connection(None, "database")
+ builder = connection.table("users")
self.assertIsInstance(builder, QueryBuilder)
- self.assertEqual('users', builder.from__)
+ self.assertEqual("users", builder.from__)
self.assertEqual(connection.get_query_grammar(), builder.get_grammar())
def test_transaction(self):
- connection = Connection(None, 'database')
+ connection = Connection(None, "database")
connection.begin_transaction = mock.MagicMock(unsafe=True)
connection.commit = mock.MagicMock(unsafe=True)
connection.rollback = mock.MagicMock(unsafe=True)
connection.insert = mock.MagicMock(return_value=1)
with connection.transaction():
- connection.table('users').insert({'name': 'foo'})
+ connection.table("users").insert({"name": "foo"})
connection.begin_transaction.assert_called_once()
connection.commit.assert_called_once()
@@ -46,29 +46,33 @@ def test_transaction(self):
try:
with connection.transaction():
- connection.table('users').insert({'name': 'foo'})
- raise Exception('foo')
+ connection.table("users").insert({"name": "foo"})
+ raise Exception("foo")
except Exception as e:
- self.assertEqual('foo', str(e))
+ self.assertEqual("foo", str(e))
connection.begin_transaction.assert_called_once()
connection.rollback.assert_called_once()
self.assertFalse(connection.commit.called)
def test_try_again_if_caused_by_lost_connection_is_called(self):
- connection = flexmock(Connection(None, 'database'))
+ connection = flexmock(Connection(None, "database"))
cursor = flexmock()
- connection.should_receive('_try_again_if_caused_by_lost_connection').once()
- connection.should_receive('_get_cursor_for_select').and_return(cursor)
- connection.should_receive('reconnect')
- cursor.should_receive('execute').and_raise(Exception('error'))
+ connection.should_receive("_try_again_if_caused_by_lost_connection").once()
+ connection.should_receive("_get_cursor_for_select").and_return(cursor)
+ connection.should_receive("reconnect")
+ cursor.should_receive("execute").and_raise(Exception("error"))
connection.select('SELECT * FROM "users"')
+ def test_lost_connection_returns_true_with_capitalized_error(self):
+ connection = Connection(None, "database")
+ self.assertTrue(connection._caused_by_lost_connection("Lost Connection"))
+
def test_prefix_set_to_none(self):
- connection = Connection(None, 'database', None)
+ connection = Connection(None, "database", None)
self.assertIsNotNone(connection.get_table_prefix())
- self.assertEqual('', connection.get_table_prefix())
+ self.assertEqual("", connection.get_table_prefix())
class ConnectionThreadLocalTest(OratorTestCase):
@@ -80,13 +84,15 @@ def test_create_thread_local(self):
def create_user_thread(low, hi):
for _ in range(low, hi):
- User.create(name='u%d' % i)
+ User.create(name="u%d" % i)
User.get_connection_resolver().disconnect()
threads = []
for i in range(self.threads):
- threads.append(threading.Thread(target=create_user_thread, args=(i*10, i * 10 + 10)))
+ threads.append(
+ threading.Thread(target=create_user_thread, args=(i * 10, i * 10 + 10))
+ )
[t.start() for t in threads]
[t.join() for t in threads]
@@ -105,7 +111,9 @@ def reader_thread(q, num):
threads = []
for i in range(self.threads):
- threads.append(threading.Thread(target=reader_thread, args=(data_queue, 20)))
+ threads.append(
+ threading.Thread(target=reader_thread, args=(data_queue, 20))
+ )
[t.start() for t in threads]
[t.join() for t in threads]
diff --git a/tests/connections/test_mysql_connection.py b/tests/connections/test_mysql_connection.py
index c17566e2..2c9a2f9f 100644
--- a/tests/connections/test_mysql_connection.py
+++ b/tests/connections/test_mysql_connection.py
@@ -6,18 +6,17 @@
class MySQLConnectionTestCase(OratorTestCase):
-
def test_marker_is_properly_set(self):
- connection = MySQLConnection(None, 'database', '', {'use_qmark': True})
+ connection = MySQLConnection(None, "database", "", {"use_qmark": True})
- self.assertEqual('?', connection.get_marker())
+ self.assertEqual("?", connection.get_marker())
def test_marker_default(self):
- connection = MySQLConnection(None, 'database', '', {})
+ connection = MySQLConnection(None, "database", "", {})
self.assertIsNone(connection.get_marker())
def test_marker_use_qmark_false(self):
- connection = MySQLConnection(None, 'database', '', {'use_qmark': False})
+ connection = MySQLConnection(None, "database", "", {"use_qmark": False})
self.assertIsNone(connection.get_marker())
diff --git a/tests/connections/test_postgres_connection.py b/tests/connections/test_postgres_connection.py
index 5225223a..48d62239 100644
--- a/tests/connections/test_postgres_connection.py
+++ b/tests/connections/test_postgres_connection.py
@@ -6,18 +6,17 @@
class PostgresConnectionTestCase(OratorTestCase):
-
def test_marker_is_properly_set(self):
- connection = PostgresConnection(None, 'database', '', {'use_qmark': True})
+ connection = PostgresConnection(None, "database", "", {"use_qmark": True})
- self.assertEqual('?', connection.get_marker())
+ self.assertEqual("?", connection.get_marker())
def test_marker_default(self):
- connection = PostgresConnection(None, 'database', '', {})
+ connection = PostgresConnection(None, "database", "", {})
self.assertIsNone(connection.get_marker())
def test_marker_use_qmark_false(self):
- connection = PostgresConnection(None, 'database', '', {'use_qmark': False})
+ connection = PostgresConnection(None, "database", "", {"use_qmark": False})
self.assertIsNone(connection.get_marker())
diff --git a/tests/integrations/__init__.py b/tests/integrations/__init__.py
index 8a09c847..a59a7f33 100644
--- a/tests/integrations/__init__.py
+++ b/tests/integrations/__init__.py
@@ -1,18 +1,55 @@
# -*- coding: utf-8 -*-
+import os
+import json
+import logging
import pendulum
import simplejson as json
-from datetime import datetime, timedelta
+
+from datetime import datetime, timedelta, date
from backpack import collect
from orator import Model, Collection, DatabaseManager
-from orator.orm import morph_to, has_one, has_many, belongs_to_many, morph_many, belongs_to, scope, accessor
+from orator.orm import (
+ morph_to,
+ has_one,
+ has_many,
+ belongs_to_many,
+ morph_many,
+ belongs_to,
+ scope,
+ accessor,
+)
from orator.orm.relations import BelongsToMany
from orator.exceptions.orm import ModelNotFound
-from orator.exceptions.query import QueryException
-class IntegrationTestCase(object):
+logger = logging.getLogger("orator.connection.queries")
+logger.setLevel(logging.DEBUG)
+
+
+class LoggedQueriesFormatter(logging.Formatter):
+ def __init__(self, fmt=None, datefmt=None, style="%"):
+ super(LoggedQueriesFormatter, self).__init__()
+
+ self.logged_queries = []
+
+ def format(self, record):
+ self.logged_queries.append(record.query)
+
+ return super(LoggedQueriesFormatter, self).format(record)
+
+ def reset(self):
+ self.logged_queries = []
+
+
+formatter = LoggedQueriesFormatter()
+handler = logging.StreamHandler(open(os.devnull, "w"))
+
+handler.setFormatter(formatter)
+logger.addHandler(handler)
+
+class IntegrationTestCase(object):
@classmethod
def setUpClass(cls):
Model.set_connection_resolver(cls.get_connection_resolver())
@@ -26,12 +63,12 @@ def get_connection_resolver(cls):
# Adding another connection to test connection switching
config = cls.get_manager_config()
- config['test'] = {
- 'driver': 'sqlite',
- 'database': ':memory:'
- }
+ config["test"] = {"driver": "sqlite", "database": ":memory:"}
+
+ db = DatabaseManager(config)
+ db.connection().enable_query_log()
- return DatabaseManager(config)
+ return db
@classmethod
def tearDownClass(cls):
@@ -43,134 +80,174 @@ def marker(self):
def setUp(self):
self.migrate()
- self.migrate('test')
+ self.migrate("test")
+
+ formatter.reset()
def tearDown(self):
self.revert()
- self.revert('test')
+ self.revert("test")
def test_basic_model_retrieval(self):
- OratorTestUser.create(email='john@doe.com')
- model = OratorTestUser.where('email', 'john@doe.com').first()
- self.assertEqual('john@doe.com', model.email)
+ OratorTestUser.create(email="john@doe.com")
+ model = OratorTestUser.where("email", "john@doe.com").first()
+ self.assertEqual("john@doe.com", model.email)
def test_basic_model_collection_retrieval(self):
- OratorTestUser.create(id=1, email='john@doe.com')
- OratorTestUser.create(id=2, email='jane@doe.com')
+ OratorTestUser.create(id=1, email="john@doe.com")
+ OratorTestUser.create(id=2, email="jane@doe.com")
- models = OratorTestUser.oldest('id').get()
+ models = OratorTestUser.oldest("id").get()
self.assertEqual(2, len(models))
self.assertIsInstance(models, Collection)
self.assertIsInstance(models[0], OratorTestUser)
self.assertIsInstance(models[1], OratorTestUser)
- self.assertEqual('john@doe.com', models[0].email)
- self.assertEqual('jane@doe.com', models[1].email)
+ self.assertEqual("john@doe.com", models[0].email)
+ self.assertEqual("jane@doe.com", models[1].email)
def test_lists_retrieval(self):
- OratorTestUser.create(id=1, email='john@doe.com')
- OratorTestUser.create(id=2, email='jane@doe.com')
+ OratorTestUser.create(id=1, email="john@doe.com")
+ OratorTestUser.create(id=2, email="jane@doe.com")
- simple = OratorTestUser.oldest('id').lists('email')
- keyed = OratorTestUser.oldest('id').lists('email', 'id')
+ simple = OratorTestUser.oldest("id").lists("email")
+ keyed = OratorTestUser.oldest("id").lists("email", "id")
- self.assertEqual(['john@doe.com', 'jane@doe.com'], simple)
- self.assertEqual({1: 'john@doe.com', 2: 'jane@doe.com'}, keyed)
+ self.assertEqual(["john@doe.com", "jane@doe.com"], simple)
+ self.assertEqual({1: "john@doe.com", 2: "jane@doe.com"}, keyed)
def test_find_or_fail(self):
- OratorTestUser.create(id=1, email='john@doe.com')
- OratorTestUser.create(id=2, email='jane@doe.com')
+ OratorTestUser.create(id=1, email="john@doe.com")
+ OratorTestUser.create(id=2, email="jane@doe.com")
single = OratorTestUser.find_or_fail(1)
multiple = OratorTestUser.find_or_fail([1, 2])
self.assertIsInstance(single, OratorTestUser)
- self.assertEqual('john@doe.com', single.email)
+ self.assertEqual("john@doe.com", single.email)
self.assertIsInstance(multiple, Collection)
self.assertIsInstance(multiple[0], OratorTestUser)
self.assertIsInstance(multiple[1], OratorTestUser)
def test_find_or_fail_with_single_id_raises_model_not_found_exception(self):
- self.assertRaises(
- ModelNotFound,
- OratorTestUser.find_or_fail,
- 1
- )
+ self.assertRaises(ModelNotFound, OratorTestUser.find_or_fail, 1)
def test_find_or_fail_with_multiple_ids_raises_model_not_found_exception(self):
- self.assertRaises(
- ModelNotFound,
- OratorTestUser.find_or_fail,
- [1, 2]
- )
+ self.assertRaises(ModelNotFound, OratorTestUser.find_or_fail, [1, 2])
def test_one_to_one_relationship(self):
- user = OratorTestUser.create(email='john@doe.com')
- user.post().create(name='First Post')
+ user = OratorTestUser.create(email="john@doe.com")
+ user.post().create(name="First Post")
post = user.post
user = post.user
- self.assertEqual('john@doe.com', user.email)
- self.assertEqual('First Post', post.name)
+ self.assertEqual("john@doe.com", user.email)
+ self.assertEqual("First Post", post.name)
def test_one_to_many_relationship(self):
- user = OratorTestUser.create(email='john@doe.com')
- user.posts().create(name='First Post')
- user.posts().create(name='Second Post')
+ user = OratorTestUser.create(email="john@doe.com")
+ user.posts().create(name="First Post")
+ user.posts().create(name="Second Post")
posts = user.posts
- post2 = user.posts().where('name', 'Second Post').first()
+ post2 = user.posts().where("name", "Second Post").first()
self.assertEqual(2, len(posts))
self.assertIsInstance(posts[0], OratorTestPost)
self.assertIsInstance(posts[1], OratorTestPost)
self.assertIsInstance(post2, OratorTestPost)
- self.assertEqual('Second Post', post2.name)
+ self.assertEqual("Second Post", post2.name)
self.assertIsInstance(post2.user, OratorTestUser)
- self.assertEqual('john@doe.com', post2.user.email)
+ self.assertEqual("john@doe.com", post2.user.email)
def test_basic_model_hydrate(self):
- OratorTestUser.create(id=1, email='john@doe.com')
- OratorTestUser.create(id=2, email='jane@doe.com')
+ OratorTestUser.create(id=1, email="john@doe.com")
+ OratorTestUser.create(id=2, email="jane@doe.com")
models = OratorTestUser.hydrate_raw(
- 'SELECT * FROM test_users WHERE email = %s' % self.marker,
- ['jane@doe.com'],
- self.connection().get_name()
+ "SELECT * FROM test_users WHERE email = %s" % self.marker,
+ ["jane@doe.com"],
+ self.connection().get_name(),
)
self.assertIsInstance(models, Collection)
self.assertIsInstance(models[0], OratorTestUser)
- self.assertEqual('jane@doe.com', models[0].email)
+ self.assertEqual("jane@doe.com", models[0].email)
self.assertEqual(self.connection().get_name(), models[0].get_connection_name())
self.assertEqual(1, len(models))
def test_has_on_self_referencing_belongs_to_many_relationship(self):
- user = OratorTestUser.create(email='john@doe.com')
- friend = user.friends().create(email='jane@doe.com')
+ user = OratorTestUser.create(email="john@doe.com")
+ friend = user.friends().create(email="jane@doe.com")
- results = OratorTestUser.has('friends').get()
+ results = OratorTestUser.has("friends").get()
self.assertEqual(1, len(results))
- self.assertEqual('john@doe.com', results.first().email)
+ self.assertEqual("john@doe.com", results.first().email)
def test_basic_has_many_eager_loading(self):
- user = OratorTestUser.create(id=1, email='john@doe.com')
- user.posts().create(name='First Post')
- user = OratorTestUser.with_('posts').where('email', 'john@doe.com').first()
+ user = OratorTestUser.create(id=1, email="john@doe.com")
+ post = user.posts().create(name="First Post")
+ comment = post.comments().create(body="Text")
+ comment2 = post.comments().create(body="Text 2")
+ comment.children().save(comment2)
+ user = (
+ OratorTestUser.with_("posts.comments.children.parent")
+ .where("email", "john@doe.com")
+ .first()
+ )
+
+ self.assertEqual("First Post", user.posts.first().name)
+ self.assertEqual("Text", user.posts.first().comments.first().body)
+ self.assertEqual(
+ "Text 2", user.posts.first().comments.first().children.first().body
+ )
+ self.assertEqual(
+ "Text", user.posts.first().comments.first().children.first().parent.body
+ )
- self.assertEqual('First Post', user.posts.first().name)
+ queries = formatter.logged_queries
+ self.assertEqual(10, len(queries))
- post = OratorTestPost.with_('user').where('name', 'First Post').get()
- self.assertEqual('john@doe.com', post.first().user.email)
+ formatter.reset()
+
+ post = OratorTestPost.with_("user").where("name", "First Post").get()
+ self.assertEqual("john@doe.com", post.first().user.email)
+
+ comment = (
+ OratorTestComment.with_("parent.post.user").where("body", "Text 2").first()
+ )
+ self.assertEqual("Text", comment.parent.body)
+ self.assertEqual("First Post", comment.parent.post.name)
+ self.assertEqual("john@doe.com", comment.parent.post.user.email)
+
+ queries = formatter.logged_queries
+ self.assertEqual(6, len(queries))
+
+ def test_all_eager_loaded_transitive_relations_must_be_present(self):
+ user = OratorTestUser.create(id=1, email="john@doe.com")
+ post = user.posts().create(name="First Post")
+ comment = post.comments().create(body="Text")
+ comment2 = post.comments().create(body="Text 2")
+ comment.children().save(comment2)
+ post = (
+ OratorTestPost.with_("user", "user.posts", "user.post")
+ .where("id", post.id)
+ .first()
+ )
+
+ data = post.serialize()
+ assert "user" in data
+ assert "posts" in data["user"]
+ assert "post" in data["user"]
def test_basic_morph_many_relationship(self):
- user = OratorTestUser.create(email='john@doe.com')
- user.photos().create(name='Avatar 1')
- user.photos().create(name='Avatar 2')
- post = user.posts().create(name='First Post')
- post.photos().create(name='Hero 1')
- post.photos().create(name='Hero 2')
+ user = OratorTestUser.create(email="john@doe.com")
+ user.photos().create(name="Avatar 1")
+ user.photos().create(name="Avatar 2")
+ post = user.posts().create(name="First Post")
+ post.photos().create(name="Hero 1")
+ post.photos().create(name="Hero 2")
self.assertIsInstance(user.photos, Collection)
self.assertIsInstance(user.photos[0], OratorTestPhoto)
@@ -179,157 +256,217 @@ def test_basic_morph_many_relationship(self):
self.assertIsInstance(post.photos[0], OratorTestPhoto)
self.assertEqual(2, len(user.photos))
self.assertEqual(2, len(post.photos))
- self.assertEqual('Avatar 1', user.photos[0].name)
- self.assertEqual('Avatar 2', user.photos[1].name)
- self.assertEqual('Hero 1', post.photos[0].name)
- self.assertEqual('Hero 2', post.photos[1].name)
+ self.assertEqual("Avatar 1", user.photos[0].name)
+ self.assertEqual("Avatar 2", user.photos[1].name)
+ self.assertEqual("Hero 1", post.photos[0].name)
+ self.assertEqual("Hero 2", post.photos[1].name)
- photos = OratorTestPhoto.order_by('name').get()
+ photos = OratorTestPhoto.order_by("name").get()
self.assertIsInstance(photos, Collection)
self.assertEqual(4, len(photos))
self.assertIsInstance(photos[0].imageable, OratorTestUser)
self.assertIsInstance(photos[2].imageable, OratorTestPost)
- self.assertEqual('john@doe.com', photos[1].imageable.email)
- self.assertEqual('First Post', photos[3].imageable.name)
+ self.assertEqual("john@doe.com", photos[1].imageable.email)
+ self.assertEqual("First Post", photos[3].imageable.name)
def test_multi_insert_with_different_values(self):
date = pendulum.utcnow()._datetime
- result = OratorTestPost.insert([
- {
- 'user_id': 1, 'name': 'Post', 'created_at': date, 'updated_at': date
- }, {
- 'user_id': 2, 'name': 'Post', 'created_at': date, 'updated_at': date
- }
- ])
+ user1 = OratorTestUser.create(email="john@doe.com")
+ user2 = OratorTestUser.create(email="jane@doe.com")
+ result = OratorTestPost.insert(
+ [
+ {
+ "user_id": user1.id,
+ "name": "Post",
+ "created_at": date,
+ "updated_at": date,
+ },
+ {
+ "user_id": user2.id,
+ "name": "Post",
+ "created_at": date,
+ "updated_at": date,
+ },
+ ]
+ )
self.assertTrue(result)
self.assertEqual(2, OratorTestPost.count())
def test_multi_insert_with_same_values(self):
date = pendulum.utcnow()._datetime
- result = OratorTestPost.insert([
- {
- 'user_id': 1, 'name': 'Post', 'created_at': date, 'updated_at': date
- }, {
- 'user_id': 1, 'name': 'Post', 'created_at': date, 'updated_at': date
- }
- ])
+ user1 = OratorTestUser.create(email="john@doe.com")
+ result = OratorTestPost.insert(
+ [
+ {
+ "user_id": user1.id,
+ "name": "Post",
+ "created_at": date,
+ "updated_at": date,
+ },
+ {
+ "user_id": user1.id,
+ "name": "Post",
+ "created_at": date,
+ "updated_at": date,
+ },
+ ]
+ )
self.assertTrue(result)
self.assertEqual(2, OratorTestPost.count())
def test_belongs_to_many_further_query(self):
- user = OratorTestUser.create(id=1, email='john@doe.com')
- friend = OratorTestUser.create(id=2, email='jane@doe.com')
- another_friend = OratorTestUser.create(id=3, email='another@doe.com')
+ user = OratorTestUser.create(id=1, email="john@doe.com")
+ friend = OratorTestUser.create(id=2, email="jane@doe.com")
+ another_friend = OratorTestUser.create(id=3, email="another@doe.com")
user.friends().attach(friend)
user.friends().attach(another_friend)
- related_friend = OratorTestUser.with_('friends').find(1).friends().where('test_users.id', 3).first()
+ related_friend = (
+ OratorTestUser.with_("friends")
+ .find(1)
+ .friends()
+ .where("test_users.id", 3)
+ .first()
+ )
self.assertEqual(3, related_friend.id)
- self.assertEqual('another@doe.com', related_friend.email)
- self.assertIn('pivot', related_friend.to_dict())
+ self.assertEqual("another@doe.com", related_friend.email)
+ self.assertIn("pivot", related_friend.to_dict())
self.assertEqual(1, related_friend.pivot.user_id)
self.assertEqual(3, related_friend.pivot.friend_id)
- self.assertTrue(hasattr(related_friend.pivot, 'id'))
+ self.assertTrue(hasattr(related_friend.pivot, "is_close"))
+
+ self.assertIsInstance(user.friends().with_pivot("is_close"), BelongsToMany)
- self.assertIsInstance(user.friends().with_pivot('id'), BelongsToMany)
+ self.assertEqual(2, user.friends().get().count())
+ user.friends().sync([friend.id])
+ self.assertEqual(1, user.friends().get().count())
def test_belongs_to_morph_many_eagerload(self):
- user = OratorTestUser.create(id=1, email='john@doe.com')
- user.photos().create(name='Avatar 1')
- user.photos().create(name='Avatar 2')
- post = user.posts().create(name='First Post')
- post.photos().create(name='Hero 1')
- post.photos().create(name='Hero 2')
-
- posts = OratorTestPost.with_('user', 'photos').get()
+ user = OratorTestUser.create(id=1, email="john@doe.com")
+ user.photos().create(name="Avatar 1")
+ user.photos().create(name="Avatar 2")
+ post = user.posts().create(name="First Post")
+ post.photos().create(name="Hero 1")
+ post.photos().create(name="Hero 2")
+
+ posts = OratorTestPost.with_("user", "photos").get()
self.assertIsInstance(posts[0].user, OratorTestUser)
self.assertEqual(user.id, posts[0].user().first().id)
self.assertIsInstance(posts[0].photos, Collection)
- self.assertEqual(posts[0].photos().where('name', 'Hero 2').first().name, 'Hero 2')
+ self.assertEqual(
+ posts[0].photos().where("name", "Hero 2").first().name, "Hero 2"
+ )
+
+ def test_belongs_to_associate(self):
+ user = OratorTestUser.create(id=1, email="john@doe.com")
+ post = OratorTestPost(name="Test Post")
+
+ post.user().associate(user)
+ post.save()
+
+ self.assertEqual(1, post.user.id)
+
+ def test_belongs_to_associate_new_instances(self):
+ user = OratorTestUser.create(email="john@doe.com")
+ post = user.posts().create(name="First Post")
+ comment1 = OratorTestComment.create(body="test1", post_id=post.id)
+
+ self.assertEqual(comment1.parent, None)
+
+ comment2 = OratorTestComment.create(body="test2", post_id=post.id)
+ comment2.parent().associate(comment1)
+
+ self.assertEqual(comment2.parent.id, comment1.id)
def test_has_many_eagerload(self):
- user = OratorTestUser.create(id=1, email='john@doe.com')
- post1 = user.posts().create(name='First Post')
- post2 = user.posts().create(name='Second Post')
+ user = OratorTestUser.create(id=1, email="john@doe.com")
+ post1 = user.posts().create(name="First Post")
+ post2 = user.posts().create(name="Second Post")
- user = OratorTestUser.with_('posts').first()
+ user = OratorTestUser.with_("posts").first()
self.assertIsInstance(user.posts, Collection)
- self.assertEqual(user.posts().where('name', 'Second Post').first().id, post2.id)
+ self.assertEqual(user.posts().where("name", "Second Post").first().id, post2.id)
def test_relationships_properties_accept_builder(self):
- user = OratorTestUser.create(id=1, email='john@doe.com')
- post1 = user.posts().create(name='First Post')
- post2 = user.posts().create(name='Second Post')
+ user = OratorTestUser.create(id=1, email="john@doe.com")
+ post1 = user.posts().create(name="First Post")
+ post2 = user.posts().create(name="Second Post")
- user = OratorTestUser.with_('posts').first()
- columns = ', '.join(self.connection().get_query_grammar().wrap_list(['id', 'name', 'user_id']))
+ user = OratorTestUser.with_("posts").first()
+ columns = ", ".join(
+ self.connection().get_query_grammar().wrap_list(["id", "name", "user_id"])
+ )
self.assertEqual(
- 'SELECT %(columns)s FROM %(table)s WHERE %(table)s.%(user_id)s = %(marker)s ORDER BY %(name)s DESC'
+ "SELECT %(columns)s FROM %(table)s WHERE %(table)s.%(user_id)s = %(marker)s ORDER BY %(name)s DESC"
% {
- 'columns': columns,
- 'marker': self.marker,
- 'table': self.grammar().wrap('test_posts'),
- 'user_id': self.grammar().wrap('user_id'),
- 'name': self.grammar().wrap('name')
+ "columns": columns,
+ "marker": self.marker,
+ "table": self.grammar().wrap("test_posts"),
+ "user_id": self.grammar().wrap("user_id"),
+ "name": self.grammar().wrap("name"),
},
- user.post().to_sql()
+ user.post().to_sql(),
)
user = OratorTestUser.first()
self.assertEqual(
- 'SELECT %(columns)s FROM %(table)s WHERE %(table)s.%(user_id)s = %(marker)s ORDER BY %(name)s DESC'
+ "SELECT %(columns)s FROM %(table)s WHERE %(table)s.%(user_id)s = %(marker)s ORDER BY %(name)s DESC"
% {
- 'columns': columns,
- 'marker': self.marker,
- 'table': self.grammar().wrap('test_posts'),
- 'user_id': self.grammar().wrap('user_id'),
- 'name': self.grammar().wrap('name')
+ "columns": columns,
+ "marker": self.marker,
+ "table": self.grammar().wrap("test_posts"),
+ "user_id": self.grammar().wrap("user_id"),
+ "name": self.grammar().wrap("name"),
},
- user.post().to_sql()
+ user.post().to_sql(),
)
def test_morph_to_eagerload(self):
- user = OratorTestUser.create(id=1, email='john@doe.com')
- user.photos().create(name='Avatar 1')
- user.photos().create(name='Avatar 2')
- post = user.posts().create(name='First Post')
- post.photos().create(name='Hero 1')
- post.photos().create(name='Hero 2')
-
- photo = OratorTestPhoto.with_('imageable').where('name', 'Hero 2').first()
+ user = OratorTestUser.create(id=1, email="john@doe.com")
+ user.photos().create(name="Avatar 1")
+ user.photos().create(name="Avatar 2")
+ post = user.posts().create(name="First Post")
+ post.photos().create(name="Hero 1")
+ post.photos().create(name="Hero 2")
+
+ photo = OratorTestPhoto.with_("imageable").where("name", "Hero 2").first()
self.assertIsInstance(photo.imageable, OratorTestPost)
self.assertEqual(post.id, photo.imageable.id)
- self.assertEqual(post.id, photo.imageable().where('name', 'First Post').first().id)
+ self.assertEqual(
+ post.id, photo.imageable().where("name", "First Post").first().id
+ )
def test_json_type(self):
- user = OratorTestUser.create(id=1, email='john@doe.com')
- photo = user.photos().create(name='Avatar 1', metadata={'foo': 'bar'})
+ user = OratorTestUser.create(id=1, email="john@doe.com")
+ photo = user.photos().create(name="Avatar 1", metadata={"foo": "bar"})
photo = OratorTestPhoto.find(photo.id)
- self.assertEqual('bar', photo.metadata['foo'])
+ self.assertEqual("bar", photo.metadata["foo"])
def test_local_scopes(self):
- yesterday = created_at=datetime.utcnow() - timedelta(days=1)
- john = OratorTestUser.create(id=1, email='john@doe.com', created_at=yesterday, updated_at=yesterday)
- jane = OratorTestUser.create(id=2, email='jane@doe.com')
+ yesterday = datetime.utcnow() - timedelta(days=1)
+ john = OratorTestUser.create(
+ id=1, email="john@doe.com", created_at=yesterday, updated_at=yesterday
+ )
+ jane = OratorTestUser.create(id=2, email="jane@doe.com")
result = OratorTestUser.older_than(minutes=30).get()
self.assertEqual(1, len(result))
- self.assertEqual('john@doe.com', result.first().email)
+ self.assertEqual("john@doe.com", result.first().email)
- result = OratorTestUser.where_not_null('id').older_than(minutes=30).get()
+ result = OratorTestUser.where_not_null("id").older_than(minutes=30).get()
self.assertEqual(1, len(result))
- self.assertEqual('john@doe.com', result.first().email)
+ self.assertEqual("john@doe.com", result.first().email)
def test_repr_relations(self):
- user = OratorTestUser.create(id=1, email='john@doe.com')
- photo = user.photos().create(name='Avatar 1', metadata={'foo': 'bar'})
+ user = OratorTestUser.create(id=1, email="john@doe.com")
+ photo = user.photos().create(name="Avatar 1", metadata={"foo": "bar"})
repr(OratorTestUser.first().photos)
- repr(OratorTestUser.with_('photos').first().photos)
+ repr(OratorTestUser.with_("photos").first().photos)
def test_reconnection(self):
db = Model.get_connection_resolver()
@@ -340,65 +477,83 @@ def test_reconnection(self):
db.disconnect()
def test_raw_query(self):
- user = OratorTestUser.create(id=1, email='john@doe.com')
- photo = user.photos().create(name='Avatar 1', metadata={'foo': 'bar'})
+ user = OratorTestUser.create(id=1, email="john@doe.com")
+ photo = user.photos().create(name="Avatar 1", metadata={"foo": "bar"})
- user = self.connection().table('test_users')\
- .where_raw('test_users.email = %s' % self.get_marker(), 'john@doe.com')\
+ user = (
+ self.connection()
+ .table("test_users")
+ .where_raw("test_users.email = %s" % self.get_marker(), "john@doe.com")
.first()
+ )
- self.assertEqual(1, user['id'])
+ self.assertEqual(1, user["id"])
photos = self.connection().select(
- 'SELECT * FROM test_photos WHERE imageable_id = %(marker)s and imageable_type = %(marker)s'
+ "SELECT * FROM test_photos WHERE imageable_id = %(marker)s and imageable_type = %(marker)s"
% {"marker": self.get_marker()},
- [str(user['id']), 'test_users']
+ [str(user["id"]), "test_users"],
)
- self.assertEqual('Avatar 1', photos[0]['name'])
+ self.assertEqual("Avatar 1", photos[0]["name"])
def test_pivot(self):
- user = OratorTestUser.create(id=1, email='john@doe.com')
- friend = OratorTestUser.create(id=2, email='jane@doe.com')
- another_friend = OratorTestUser.create(id=3, email='another@doe.com')
+ user = OratorTestUser.create(id=1, email="john@doe.com")
+ friend = OratorTestUser.create(id=2, email="jane@doe.com")
+ another_friend = OratorTestUser.create(id=3, email="another@doe.com")
user.friends().attach(friend)
user.friends().attach(another_friend)
- user.friends().update_existing_pivot(friend.id, {'is_close': True})
- self.assertTrue(user.friends().where('test_users.email', 'jane@doe.com').first().pivot.is_close)
+ user.friends().update_existing_pivot(friend.id, {"is_close": True})
+ self.assertTrue(
+ user.friends()
+ .where("test_users.email", "jane@doe.com")
+ .first()
+ .pivot.is_close
+ )
def test_serialization(self):
- user = OratorTestUser.create(id=1, email='john@doe.com')
- photo = user.photos().create(name='Avatar 1', metadata={'foo': 'bar'})
+ user = OratorTestUser.create(id=1, email="john@doe.com")
+ photo = user.photos().create(name="Avatar 1", metadata={"foo": "bar"})
serialized_user = OratorTestUser.first().serialize()
serialized_photo = OratorTestPhoto.first().serialize()
- self.assertEqual(1, serialized_user['id'])
- self.assertEqual('john@doe.com', serialized_user['email'])
- self.assertEqual('Avatar 1', serialized_photo['name'])
- self.assertEqual('bar', serialized_photo['metadata']['foo'])
- self.assertEqual('Avatar 1', json.loads(OratorTestPhoto.first().to_json())['name'])
+ self.assertEqual(1, serialized_user["id"])
+ self.assertEqual("john@doe.com", serialized_user["email"])
+ self.assertEqual("Avatar 1", serialized_photo["name"])
+ self.assertEqual("bar", serialized_photo["metadata"]["foo"])
+ self.assertEqual(
+ "Avatar 1", json.loads(OratorTestPhoto.first().to_json())["name"]
+ )
def test_query_builder_results_attribute_retrieval(self):
- user = OratorTestUser.create(id=1, email='john@doe.com')
- users = self.connection().table('test_users').get()
+ user = OratorTestUser.create(id=1, email="john@doe.com")
+ users = self.connection().table("test_users").get()
- self.assertEqual('john@doe.com', users[0].email)
- self.assertEqual('john@doe.com', users[0]['email'])
+ self.assertEqual("john@doe.com", users[0].email)
+ self.assertEqual("john@doe.com", users[0]["email"])
self.assertEqual(1, users[0].id)
- self.assertEqual(1, users[0]['id'])
+ self.assertEqual(1, users[0]["id"])
+
+ def test_query_builder_results_serialization(self):
+ OratorTestUser.create(id=1, email="john@doe.com")
+ users = self.connection().table("test_users").get()
+
+ serialized = json.loads(users.to_json())[0]
+ self.assertEqual(1, serialized["id"])
+ self.assertEqual("john@doe.com", serialized["email"])
def test_connection_switching(self):
- OratorTestUser.create(id=1, email='john@doe.com')
+ OratorTestUser.create(id=1, email="john@doe.com")
- self.assertIsNone(OratorTestUser.on('test').first())
+ self.assertIsNone(OratorTestUser.on("test").first())
self.assertIsNotNone(OratorTestUser.first())
- OratorTestUser.on('test').insert(id=1, email='jane@doe.com')
- user = OratorTestUser.on('test').first()
+ OratorTestUser.on("test").insert(id=1, email="jane@doe.com")
+ user = OratorTestUser.on("test").first()
connection = user.get_connection()
- post = user.posts().create(name='Test')
+ post = user.posts().create(name="Test")
self.assertEqual(connection, post.get_connection())
def test_columns_listing(self):
@@ -408,11 +563,101 @@ def test_columns_listing(self):
.all()
)
- self.assertEqual(['created_at', 'email', 'id', 'updated_at'], column_names)
+ self.assertEqual(["created_at", "email", "id", "updated_at"], column_names)
def test_has_column(self):
+ self.assertTrue(self.schema().has_column(OratorTestUser().get_table(), "email"))
+
+ def test_table_exists(self):
+ self.assertTrue(self.schema().has_table(OratorTestUser().get_table()))
+
+ def test_transaction(self):
+ count = self.connection().table("test_users").count()
+
+ with self.connection().transaction():
+ OratorTestUser.create(id=1, email="jane@doe.com")
+ self.connection().rollback()
+
+ self.assertEqual(count, self.connection().table("test_users").count())
+
+ def test_date(self):
+ user = OratorTestUser.create(id=1, email="john@doe.com")
+ photo1 = user.photos().create(name="Photo 1", taken_on=pendulum.date.today())
+ photo2 = user.photos().create(name="Photo 2")
+
+ self.assertIsInstance(OratorTestPhoto.find(photo1.id).taken_on, date)
+ self.assertIsNone(OratorTestPhoto.find(photo2.id).taken_on)
+
+ def test_chunk_update_builder(self):
+ for i in range(20):
+ self.connection().table("test_users").insert(
+ id=i + 1, email="john{}@doe.com".format(i)
+ )
+
+ count = 0
+ for users in (
+ self.connection().table("test_users").where("id", "<", 50).chunk(10)
+ ):
+ for user in users:
+ count += 1
+
+ if count == 10:
+ self.connection().table("test_users").where("id", user.id).update(
+ id=60
+ )
+
+ self.assertEqual(count, 20)
+
+ def test_chunk_update_model(self):
+ for i in range(20):
+ OratorTestUser.create(id=i + 1, email="john{}@doe.com".format(i))
+
+ count = 0
+ for users in OratorTestUser.where("id", "<", 50).chunk(10):
+ for user in users:
+ count += 1
+
+ if count == 10:
+ OratorTestUser.where("id", user.id).update(id=60)
+
+ self.assertEqual(count, 20)
+
+ def test_timestamp_with_timezone(self):
+ now = pendulum.utcnow()
+ user = OratorTestUser.create(email="john@doe.com", created_at=now)
+ fresh_user = OratorTestUser.find(user.id)
+
+ self.assertEqual(user.created_at, fresh_user.created_at)
+ self.assertEqual(now, fresh_user.created_at)
+
+ def test_touches(self):
+ user = OratorTestUser.create(email="john@doe.com")
+ post = user.posts().create(name="Post")
+ comment1 = post.comments().create(body="Comment 1")
+ comment2 = post.comments().create(body="Comment 2")
+ comment3 = post.comments().create(body="Comment 3")
+ comment4 = comment3.children().create(body="Comment 4", post_id=post.id)
+
+ comment1_updated_at = comment1.updated_at
+ comment2_updated_at = comment2.updated_at
+ comment3_updated_at = comment3.updated_at
+ comment4_updated_at = comment4.updated_at
+
+ comment4.body = "Comment 4 updated"
+ comment4.save()
+
+ self.assertTrue(comment4.updated_at > comment4_updated_at)
+ self.assertEqual(
+ comment4.updated_at, OratorTestComment.find(comment4.id).updated_at
+ )
self.assertTrue(
- self.schema().has_column(OratorTestUser().get_table(), 'email')
+ comment3_updated_at < OratorTestComment.find(comment3.id).updated_at
+ )
+ self.assertEqual(
+ comment1_updated_at, OratorTestComment.find(comment1.id).updated_at
+ )
+ self.assertEqual(
+ comment2_updated_at, OratorTestComment.find(comment2.id).updated_at
)
def grammar(self):
@@ -425,93 +670,146 @@ def schema(self, connection=None):
return self.connection(connection).get_schema_builder()
def migrate(self, connection=None):
- self.schema(connection).drop_if_exists('test_users')
- self.schema(connection).drop_if_exists('test_friends')
- self.schema(connection).drop_if_exists('test_posts')
- self.schema(connection).drop_if_exists('test_photos')
-
- with self.schema(connection).create('test_users') as table:
- table.increments('id')
- table.string('email').unique()
+ self.schema(connection).drop_if_exists("test_users")
+ self.schema(connection).drop_if_exists("test_friends")
+ self.schema(connection).drop_if_exists("test_posts")
+ self.schema(connection).drop_if_exists("test_photos")
+
+ with self.schema(connection).create("test_users") as table:
+ table.increments("id")
+ table.string("email").unique()
+ table.timestamps(use_current=True)
+
+ with self.schema(connection).create("test_friends") as table:
+ table.increments("id")
+ table.integer("user_id").unsigned()
+ table.integer("friend_id").unsigned()
+ table.boolean("is_close").default(False)
+
+ table.foreign("user_id").references("id").on("test_users").on_delete(
+ "cascade"
+ )
+ table.foreign("friend_id").references("id").on("test_users").on_delete(
+ "cascade"
+ )
+
+ with self.schema(connection).create("test_posts") as table:
+ table.increments("id")
+ table.integer("user_id").unsigned()
+ table.string("name")
table.timestamps(use_current=True)
- with self.schema(connection).create('test_friends') as table:
- table.increments('id')
- table.integer('user_id')
- table.integer('friend_id')
- table.boolean('is_close').default(False)
+ table.foreign("user_id").references("id").on("test_users").on_delete(
+ "cascade"
+ )
- with self.schema(connection).create('test_posts') as table:
- table.increments('id')
- table.integer('user_id')
- table.string('name')
+ with self.schema(connection).create("test_comments") as table:
+ table.increments("id")
+ table.integer("post_id").unsigned()
+ table.integer("parent_id").unsigned().nullable()
+ table.text("body")
table.timestamps(use_current=True)
- with self.schema(connection).create('test_photos') as table:
- table.increments('id')
- table.morphs('imageable')
- table.string('name')
- table.json('metadata').nullable()
+ table.foreign("post_id").references("id").on("test_posts").on_delete(
+ "cascade"
+ )
+ table.foreign("parent_id").references("id").on("test_comments").on_delete(
+ "cascade"
+ )
+
+ with self.schema(connection).create("test_photos") as table:
+ table.increments("id")
+ table.morphs("imageable")
+ table.string("name")
+ table.json("metadata").nullable()
+ table.date("taken_on").nullable()
table.timestamps(use_current=True)
def revert(self, connection=None):
- self.schema(connection).drop_if_exists('test_users')
- self.schema(connection).drop_if_exists('test_friends')
- self.schema(connection).drop_if_exists('test_posts')
- self.schema(connection).drop_if_exists('test_photos')
+ self.schema(connection).drop_if_exists("test_photos")
+ self.schema(connection).drop_if_exists("test_comments")
+ self.schema(connection).drop_if_exists("test_posts")
+ self.schema(connection).drop_if_exists("test_friends")
+ self.schema(connection).drop_if_exists("test_users")
def get_marker(self):
- return '?'
+ return "?"
class OratorTestUser(Model):
- __table__ = 'test_users'
+ __table__ = "test_users"
__guarded__ = []
- @belongs_to_many('test_friends', 'user_id', 'friend_id', with_pivot=['id', 'is_close'])
+ @belongs_to_many("test_friends", "user_id", "friend_id", with_pivot=["is_close"])
def friends(self):
return OratorTestUser
- @has_many('user_id')
+ @has_many("user_id")
def posts(self):
- return 'test_posts'
+ return "test_posts"
- @has_one('user_id')
+ @has_one("user_id")
def post(self):
- return OratorTestPost.select('id', 'name', 'name', 'user_id').order_by('name', 'desc')
+ return OratorTestPost.select("id", "name", "name", "user_id").order_by(
+ "name", "desc"
+ )
- @morph_many('imageable')
+ @morph_many("imageable")
def photos(self):
- return OratorTestPhoto.order_by('name')
+ return OratorTestPhoto.order_by("name")
@scope
def older_than(self, query, **kwargs):
- query.where('updated_at', '<', (pendulum.utcnow() - timedelta(**kwargs))._datetime)
+ query.where("updated_at", "<", pendulum.utcnow().subtract(**kwargs))
class OratorTestPost(Model):
- __table__ = 'test_posts'
+ __table__ = "test_posts"
__guarded__ = []
- @belongs_to('user_id')
+ @belongs_to("user_id")
def user(self):
return OratorTestUser
- @morph_many('imageable')
+ @has_many("post_id")
+ def comments(self):
+ return OratorTestComment
+
+ @morph_many("imageable")
def photos(self):
- return OratorTestPhoto.order_by('name')
+ return OratorTestPhoto.order_by("name")
+
+
+class OratorTestComment(Model):
+
+ __touches__ = ["parent"]
+
+ __table__ = "test_comments"
+ __guarded__ = []
+
+ @belongs_to("post_id")
+ def post(self):
+ return OratorTestPost
+
+ @belongs_to("parent_id")
+ def parent(self):
+ return OratorTestComment
+
+ @has_many("parent_id")
+ def children(self):
+ return OratorTestComment
class OratorTestPhoto(Model):
- __table__ = 'test_photos'
+ __table__ = "test_photos"
__guarded__ = []
- __casts__ = {
- 'metadata': 'json'
- }
+ __casts__ = {"metadata": "json"}
+
+ __dates__ = ["taken_on"]
@morph_to
def imageable(self):
@@ -519,4 +817,4 @@ def imageable(self):
@accessor
def created_at(self):
- return pendulum.instance(self._attributes['created_at']).to('Europe/Paris')
+ return pendulum.instance(self._attributes["created_at"]).in_tz("Europe/Paris")
diff --git a/tests/integrations/test_mysql.py b/tests/integrations/test_mysql.py
index 81bf8f5e..de945367 100644
--- a/tests/integrations/test_mysql.py
+++ b/tests/integrations/test_mysql.py
@@ -7,29 +7,28 @@
class MySQLIntegrationTestCase(IntegrationTestCase, OratorTestCase):
-
@classmethod
def get_manager_config(cls):
- ci = os.environ.get('CI', False)
+ ci = os.environ.get("CI", False)
if ci:
- database = 'orator_test'
- user = 'root'
- password = ''
+ database = "orator_test"
+ user = "root"
+ password = ""
else:
- database = 'orator_test'
- user = 'orator'
- password = 'orator'
+ database = "orator_test"
+ user = "orator"
+ password = "orator"
return {
- 'default': 'mysql',
- 'mysql': {
- 'driver': 'mysql',
- 'database': database,
- 'user': user,
- 'password': password
- }
+ "default": "mysql",
+ "mysql": {
+ "driver": "mysql",
+ "database": database,
+ "user": user,
+ "password": password,
+ },
}
def get_marker(self):
- return '%s'
+ return "%s"
diff --git a/tests/integrations/test_mysql_qmark.py b/tests/integrations/test_mysql_qmark.py
index 78553848..497a793b 100644
--- a/tests/integrations/test_mysql_qmark.py
+++ b/tests/integrations/test_mysql_qmark.py
@@ -7,30 +7,29 @@
class MySQLQmarkIntegrationTestCase(IntegrationTestCase, OratorTestCase):
-
@classmethod
def get_manager_config(cls):
- ci = os.environ.get('CI', False)
+ ci = os.environ.get("CI", False)
if ci:
- database = 'orator_test'
- user = 'root'
- password = ''
+ database = "orator_test"
+ user = "root"
+ password = ""
else:
- database = 'orator_test'
- user = 'orator'
- password = 'orator'
+ database = "orator_test"
+ user = "orator"
+ password = "orator"
return {
- 'default': 'mysql',
- 'mysql': {
- 'driver': 'mysql',
- 'database': database,
- 'user': user,
- 'password': password,
- 'use_qmark': True
- }
+ "default": "mysql",
+ "mysql": {
+ "driver": "mysql",
+ "database": database,
+ "user": user,
+ "password": password,
+ "use_qmark": True,
+ },
}
def get_marker(self):
- return '?'
+ return "?"
diff --git a/tests/integrations/test_postgres.py b/tests/integrations/test_postgres.py
index 03886acf..f4c1c886 100644
--- a/tests/integrations/test_postgres.py
+++ b/tests/integrations/test_postgres.py
@@ -7,29 +7,28 @@
class PostgresIntegrationTestCase(IntegrationTestCase, OratorTestCase):
-
@classmethod
def get_manager_config(cls):
- ci = os.environ.get('CI', False)
+ ci = os.environ.get("CI", False)
if ci:
- database = 'orator_test'
- user = 'postgres'
+ database = "orator_test"
+ user = "postgres"
password = None
else:
- database = 'orator_test'
- user = 'orator'
- password = 'orator'
+ database = "orator_test"
+ user = "orator"
+ password = "orator"
return {
- 'default': 'postgres',
- 'postgres': {
- 'driver': 'pgsql',
- 'database': database,
- 'user': user,
- 'password': password
- }
+ "default": "postgres",
+ "postgres": {
+ "driver": "pgsql",
+ "database": database,
+ "user": user,
+ "password": password,
+ },
}
def get_marker(self):
- return '%s'
+ return "%s"
diff --git a/tests/integrations/test_postgres_qmark.py b/tests/integrations/test_postgres_qmark.py
index 5f912322..25b038e2 100644
--- a/tests/integrations/test_postgres_qmark.py
+++ b/tests/integrations/test_postgres_qmark.py
@@ -7,30 +7,29 @@
class PostgresQmarkIntegrationTestCase(IntegrationTestCase, OratorTestCase):
-
@classmethod
def get_manager_config(cls):
- ci = os.environ.get('CI', False)
+ ci = os.environ.get("CI", False)
if ci:
- database = 'orator_test'
- user = 'postgres'
+ database = "orator_test"
+ user = "postgres"
password = None
else:
- database = 'orator_test'
- user = 'orator'
- password = 'orator'
+ database = "orator_test"
+ user = "orator"
+ password = "orator"
return {
- 'default': 'postgres',
- 'postgres': {
- 'driver': 'pgsql',
- 'database': database,
- 'user': user,
- 'password': password,
- 'use_qmark': True
- }
+ "default": "postgres",
+ "postgres": {
+ "driver": "pgsql",
+ "database": database,
+ "user": user,
+ "password": password,
+ "use_qmark": True,
+ },
}
def get_marker(self):
- return '?'
+ return "?"
diff --git a/tests/integrations/test_sqlite.py b/tests/integrations/test_sqlite.py
index a74615e6..116929d8 100644
--- a/tests/integrations/test_sqlite.py
+++ b/tests/integrations/test_sqlite.py
@@ -5,13 +5,9 @@
class SQLiteIntegrationTestCase(IntegrationTestCase, OratorTestCase):
-
@classmethod
def get_manager_config(cls):
return {
- 'default': 'sqlite',
- 'sqlite': {
- 'driver': 'sqlite',
- 'database': ':memory:'
- }
+ "default": "sqlite",
+ "sqlite": {"driver": "sqlite", "database": ":memory:"},
}
diff --git a/tests/migrations/__init__.py b/tests/migrations/__init__.py
index 633f8661..40a96afc 100644
--- a/tests/migrations/__init__.py
+++ b/tests/migrations/__init__.py
@@ -1,2 +1 @@
# -*- coding: utf-8 -*-
-
diff --git a/tests/migrations/test_database_migration_repository.py b/tests/migrations/test_database_migration_repository.py
index c7a6fa39..6c7faca6 100644
--- a/tests/migrations/test_database_migration_repository.py
+++ b/tests/migrations/test_database_migration_repository.py
@@ -10,7 +10,6 @@
class DatabaseMigrationRepositoryTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
@@ -18,50 +17,70 @@ def test_get_ran_migrations_list_migrations_by_package(self):
repo = self.get_repository()
connection = flexmock(Connection(None))
query = flexmock(QueryBuilder(connection, None, None))
- repo.get_connection_resolver().should_receive('connection').with_args(None).and_return(connection)
- repo.get_connection().should_receive('table').once().with_args('migrations').and_return(query)
- query.should_receive('lists').once().with_args('migration').and_return('bar')
+ repo.get_connection_resolver().should_receive("connection").with_args(
+ None
+ ).and_return(connection)
+ repo.get_connection().should_receive("table").once().with_args(
+ "migrations"
+ ).and_return(query)
+ query.should_receive("lists").once().with_args("migration").and_return("bar")
- self.assertEqual('bar', repo.get_ran())
+ self.assertEqual("bar", repo.get_ran())
def test_get_last_migrations(self):
resolver_mock = flexmock(DatabaseManager)
- resolver_mock.should_receive('connection').and_return(None)
+ resolver_mock.should_receive("connection").and_return(None)
resolver = flexmock(resolver_mock({}))
- repo = flexmock(DatabaseMigrationRepository(resolver, 'migrations'))
+ repo = flexmock(DatabaseMigrationRepository(resolver, "migrations"))
connection = flexmock(Connection(None))
query = flexmock(QueryBuilder(connection, None, None))
- repo.should_receive('get_last_batch_number').and_return(1)
- repo.get_connection_resolver().should_receive('connection').with_args(None).and_return(connection)
- repo.get_connection().should_receive('table').once().with_args('migrations').and_return(query)
- query.should_receive('where').once().with_args('batch', 1).and_return(query)
- query.should_receive('order_by').once().with_args('migration', 'desc').and_return(query)
- query.should_receive('get').once().and_return('foo')
-
- self.assertEqual('foo', repo.get_last())
+ repo.should_receive("get_last_batch_number").and_return(1)
+ repo.get_connection_resolver().should_receive("connection").with_args(
+ None
+ ).and_return(connection)
+ repo.get_connection().should_receive("table").once().with_args(
+ "migrations"
+ ).and_return(query)
+ query.should_receive("where").once().with_args("batch", 1).and_return(query)
+ query.should_receive("order_by").once().with_args(
+ "migration", "desc"
+ ).and_return(query)
+ query.should_receive("get").once().and_return("foo")
+
+ self.assertEqual("foo", repo.get_last())
def test_log_inserts_record_into_migration_table(self):
repo = self.get_repository()
connection = flexmock(Connection(None))
query = flexmock(QueryBuilder(connection, None, None))
- repo.get_connection_resolver().should_receive('connection').with_args(None).and_return(connection)
- repo.get_connection().should_receive('table').once().with_args('migrations').and_return(query)
- query.should_receive('insert').once().with_args(migration='bar', batch=1)
+ repo.get_connection_resolver().should_receive("connection").with_args(
+ None
+ ).and_return(connection)
+ repo.get_connection().should_receive("table").once().with_args(
+ "migrations"
+ ).and_return(query)
+ query.should_receive("insert").once().with_args(migration="bar", batch=1)
- repo.log('bar', 1)
+ repo.log("bar", 1)
def test_delete_removes_migration_from_table(self):
repo = self.get_repository()
connection = flexmock(Connection(None))
query = flexmock(QueryBuilder(connection, None, None))
- repo.get_connection_resolver().should_receive('connection').with_args(None).and_return(connection)
- repo.get_connection().should_receive('table').once().with_args('migrations').and_return(query)
- query.should_receive('where').once().with_args('migration', 'foo').and_return(query)
- query.should_receive('delete').once()
+ repo.get_connection_resolver().should_receive("connection").with_args(
+ None
+ ).and_return(connection)
+ repo.get_connection().should_receive("table").once().with_args(
+ "migrations"
+ ).and_return(query)
+ query.should_receive("where").once().with_args("migration", "foo").and_return(
+ query
+ )
+ query.should_receive("delete").once()
class Migration(object):
- migration = 'foo'
+ migration = "foo"
def __getitem__(self, item):
return self.migration
@@ -70,10 +89,10 @@ def __getitem__(self, item):
def test_get_next_batch_number_returns_last_batch_number_plus_one(self):
resolver_mock = flexmock(DatabaseManager)
- resolver_mock.should_receive('connection').and_return(None)
+ resolver_mock.should_receive("connection").and_return(None)
resolver = flexmock(resolver_mock({}))
- repo = flexmock(DatabaseMigrationRepository(resolver, 'migrations'))
- repo.should_receive('get_last_batch_number').and_return(1)
+ repo = flexmock(DatabaseMigrationRepository(resolver, "migrations"))
+ repo.should_receive("get_last_batch_number").and_return(1)
self.assertEqual(2, repo.get_next_batch_number())
@@ -81,13 +100,17 @@ def test_get_last_batch_number_returns_max_batch(self):
repo = self.get_repository()
connection = flexmock(Connection(None))
query = flexmock(QueryBuilder(connection, None, None))
- repo.get_connection_resolver().should_receive('connection').with_args(None).and_return(connection)
- repo.get_connection().should_receive('table').once().with_args('migrations').and_return(query)
- query.should_receive('max').and_return(1)
+ repo.get_connection_resolver().should_receive("connection").with_args(
+ None
+ ).and_return(connection)
+ repo.get_connection().should_receive("table").once().with_args(
+ "migrations"
+ ).and_return(query)
+ query.should_receive("max").and_return(1)
self.assertEqual(1, repo.get_last_batch_number())
def get_repository(self):
resolver = flexmock(DatabaseManager)
- resolver.should_receive('connection').and_return(None)
- return DatabaseMigrationRepository(flexmock(resolver({})), 'migrations')
+ resolver.should_receive("connection").and_return(None)
+ return DatabaseMigrationRepository(flexmock(resolver({})), "migrations")
diff --git a/tests/migrations/test_migration_creator.py b/tests/migrations/test_migration_creator.py
index 9386b841..c9eebdf3 100644
--- a/tests/migrations/test_migration_creator.py
+++ b/tests/migrations/test_migration_creator.py
@@ -9,55 +9,64 @@
class MigrationCreatorTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
def test_basic_create_method_stores_migration_file(self):
- expected = os.path.join(tempfile.gettempdir(), 'foo_create_bar.py')
+ expected = os.path.join(tempfile.gettempdir(), "foo_create_bar.py")
if os.path.exists(expected):
os.remove(expected)
creator = self.get_creator()
- creator.should_receive('_get_date_prefix').and_return('foo')
- creator.create('create_bar', tempfile.gettempdir())
+ creator.should_receive("_get_date_prefix").and_return("foo")
+ creator.create("create_bar", tempfile.gettempdir())
self.assertTrue(os.path.exists(expected))
with open(expected) as fh:
content = fh.read()
- self.assertEqual(content, BLANK_STUB.replace('DummyClass', 'CreateBar'))
+ self.assertEqual(content, BLANK_STUB.replace("DummyClass", "CreateBar"))
os.remove(expected)
def test_table_update_migration_stores_migration_file(self):
- expected = os.path.join(tempfile.gettempdir(), 'foo_create_bar.py')
+ expected = os.path.join(tempfile.gettempdir(), "foo_create_bar.py")
if os.path.exists(expected):
os.remove(expected)
creator = self.get_creator()
- creator.should_receive('_get_date_prefix').and_return('foo')
- creator.create('create_bar', tempfile.gettempdir(), 'baz')
+ creator.should_receive("_get_date_prefix").and_return("foo")
+ creator.create("create_bar", tempfile.gettempdir(), "baz")
self.assertTrue(os.path.exists(expected))
with open(expected) as fh:
content = fh.read()
- self.assertEqual(content, UPDATE_STUB.replace('DummyClass', 'CreateBar').replace('dummy_table', 'baz'))
+ self.assertEqual(
+ content,
+ UPDATE_STUB.replace("DummyClass", "CreateBar").replace(
+ "dummy_table", "baz"
+ ),
+ )
os.remove(expected)
def test_table_create_migration_stores_migration_file(self):
- expected = os.path.join(tempfile.gettempdir(), 'foo_create_bar.py')
+ expected = os.path.join(tempfile.gettempdir(), "foo_create_bar.py")
if os.path.exists(expected):
os.remove(expected)
creator = self.get_creator()
- creator.should_receive('_get_date_prefix').and_return('foo')
- creator.create('create_bar', tempfile.gettempdir(), 'baz', True)
+ creator.should_receive("_get_date_prefix").and_return("foo")
+ creator.create("create_bar", tempfile.gettempdir(), "baz", True)
self.assertTrue(os.path.exists(expected))
with open(expected) as fh:
content = fh.read()
- self.assertEqual(content, CREATE_STUB.replace('DummyClass', 'CreateBar').replace('dummy_table', 'baz'))
+ self.assertEqual(
+ content,
+ CREATE_STUB.replace("DummyClass", "CreateBar").replace(
+ "dummy_table", "baz"
+ ),
+ )
os.remove(expected)
diff --git a/tests/migrations/test_migrator.py b/tests/migrations/test_migrator.py
index e7d7dfe0..c62c3d2d 100644
--- a/tests/migrations/test_migrator.py
+++ b/tests/migrations/test_migrator.py
@@ -12,7 +12,6 @@
class MigratorTestCase(OratorTestCase):
-
def setUp(self):
if PY3K:
self.orig = inspect.getargspec
@@ -26,271 +25,278 @@ def tearDown(self):
def test_migrations_are_run_up_when_outstanding_migrations_exist(self):
resolver_mock = flexmock(DatabaseManager)
- resolver_mock.should_receive('connection').and_return({})
+ resolver_mock.should_receive("connection").and_return({})
resolver = flexmock(DatabaseManager({}))
connection = flexmock()
- connection.should_receive('transaction').twice().and_return(connection)
- resolver.should_receive('connection').and_return(connection)
+ connection.should_receive("transaction").twice().and_return(connection)
+ resolver.should_receive("connection").and_return(connection)
migrator = flexmock(
Migrator(
- flexmock(
- DatabaseMigrationRepository(
- resolver,
- 'migrations'
- )
- ),
- resolver
+ flexmock(DatabaseMigrationRepository(resolver, "migrations")), resolver
)
)
g = flexmock(glob)
- g.should_receive('glob').with_args(os.path.join(os.getcwd(), '[0-9]*_*.py')).and_return([
- os.path.join(os.getcwd(), '2_bar.py'),
- os.path.join(os.getcwd(), '1_foo.py'),
- os.path.join(os.getcwd(), '3_baz.py')
- ])
-
- migrator.get_repository().should_receive('get_ran').once().and_return(['1_foo'])
- migrator.get_repository().should_receive('get_next_batch_number').once().and_return(1)
- migrator.get_repository().should_receive('log').once().with_args('2_bar', 1)
- migrator.get_repository().should_receive('log').once().with_args('3_baz', 1)
+ g.should_receive("glob").with_args(
+ os.path.join(os.getcwd(), "[0-9]*_*.py")
+ ).and_return(
+ [
+ os.path.join(os.getcwd(), "2_bar.py"),
+ os.path.join(os.getcwd(), "1_foo.py"),
+ os.path.join(os.getcwd(), "3_baz.py"),
+ ]
+ )
+
+ migrator.get_repository().should_receive("get_ran").once().and_return(["1_foo"])
+ migrator.get_repository().should_receive(
+ "get_next_batch_number"
+ ).once().and_return(1)
+ migrator.get_repository().should_receive("log").once().with_args("2_bar", 1)
+ migrator.get_repository().should_receive("log").once().with_args("3_baz", 1)
bar_mock = flexmock(MigrationStub())
bar_mock.set_connection(connection)
- bar_mock.should_receive('up').once()
+ bar_mock.should_receive("up").once()
baz_mock = flexmock(MigrationStub())
baz_mock.set_connection(connection)
- baz_mock.should_receive('up').once()
- migrator.should_receive('_resolve').with_args(os.getcwd(), '2_bar').once().and_return(bar_mock)
- migrator.should_receive('_resolve').with_args(os.getcwd(), '3_baz').once().and_return(baz_mock)
+ baz_mock.should_receive("up").once()
+ migrator.should_receive("_resolve").with_args(
+ os.getcwd(), "2_bar"
+ ).once().and_return(bar_mock)
+ migrator.should_receive("_resolve").with_args(
+ os.getcwd(), "3_baz"
+ ).once().and_return(baz_mock)
migrator.run(os.getcwd())
def test_migrations_are_run_up_directly_if_transactional_is_false(self):
resolver_mock = flexmock(DatabaseManager)
- resolver_mock.should_receive('connection').and_return({})
+ resolver_mock.should_receive("connection").and_return({})
resolver = flexmock(DatabaseManager({}))
connection = flexmock()
- connection.should_receive('transaction').never()
- resolver.should_receive('connection').and_return(connection)
+ connection.should_receive("transaction").never()
+ resolver.should_receive("connection").and_return(connection)
migrator = flexmock(
Migrator(
- flexmock(
- DatabaseMigrationRepository(
- resolver,
- 'migrations'
- )
- ),
- resolver
+ flexmock(DatabaseMigrationRepository(resolver, "migrations")), resolver
)
)
g = flexmock(glob)
- g.should_receive('glob').with_args(os.path.join(os.getcwd(), '[0-9]*_*.py')).and_return([
- os.path.join(os.getcwd(), '2_bar.py'),
- os.path.join(os.getcwd(), '1_foo.py'),
- os.path.join(os.getcwd(), '3_baz.py')
- ])
-
- migrator.get_repository().should_receive('get_ran').once().and_return(['1_foo'])
- migrator.get_repository().should_receive('get_next_batch_number').once().and_return(1)
- migrator.get_repository().should_receive('log').once().with_args('2_bar', 1)
- migrator.get_repository().should_receive('log').once().with_args('3_baz', 1)
+ g.should_receive("glob").with_args(
+ os.path.join(os.getcwd(), "[0-9]*_*.py")
+ ).and_return(
+ [
+ os.path.join(os.getcwd(), "2_bar.py"),
+ os.path.join(os.getcwd(), "1_foo.py"),
+ os.path.join(os.getcwd(), "3_baz.py"),
+ ]
+ )
+
+ migrator.get_repository().should_receive("get_ran").once().and_return(["1_foo"])
+ migrator.get_repository().should_receive(
+ "get_next_batch_number"
+ ).once().and_return(1)
+ migrator.get_repository().should_receive("log").once().with_args("2_bar", 1)
+ migrator.get_repository().should_receive("log").once().with_args("3_baz", 1)
bar_mock = flexmock(MigrationStub())
bar_mock.transactional = False
bar_mock.set_connection(connection)
- bar_mock.should_receive('up').once()
+ bar_mock.should_receive("up").once()
baz_mock = flexmock(MigrationStub())
baz_mock.transactional = False
baz_mock.set_connection(connection)
- baz_mock.should_receive('up').once()
- migrator.should_receive('_resolve').with_args(os.getcwd(), '2_bar').once().and_return(bar_mock)
- migrator.should_receive('_resolve').with_args(os.getcwd(), '3_baz').once().and_return(baz_mock)
+ baz_mock.should_receive("up").once()
+ migrator.should_receive("_resolve").with_args(
+ os.getcwd(), "2_bar"
+ ).once().and_return(bar_mock)
+ migrator.should_receive("_resolve").with_args(
+ os.getcwd(), "3_baz"
+ ).once().and_return(baz_mock)
migrator.run(os.getcwd())
def test_up_migration_can_be_pretended(self):
resolver_mock = flexmock(DatabaseManager)
- resolver_mock.should_receive('connection').and_return({})
+ resolver_mock.should_receive("connection").and_return({})
resolver = flexmock(DatabaseManager({}))
connection = flexmock(Connection(None))
- connection.should_receive('get_logged_queries').twice().and_return([])
- resolver.should_receive('connection').with_args(None).and_return(connection)
+ connection.should_receive("get_logged_queries").twice().and_return([])
+ resolver.should_receive("connection").with_args(None).and_return(connection)
migrator = flexmock(
Migrator(
- flexmock(
- DatabaseMigrationRepository(
- resolver,
- 'migrations'
- )
- ),
- resolver
+ flexmock(DatabaseMigrationRepository(resolver, "migrations")), resolver
)
)
g = flexmock(glob)
- g.should_receive('glob').with_args(os.path.join(os.getcwd(), '[0-9]*_*.py')).and_return([
- os.path.join(os.getcwd(), '2_bar.py'),
- os.path.join(os.getcwd(), '1_foo.py'),
- os.path.join(os.getcwd(), '3_baz.py')
- ])
-
- migrator.get_repository().should_receive('get_ran').once().and_return(['1_foo'])
- migrator.get_repository().should_receive('get_next_batch_number').once().and_return(1)
+ g.should_receive("glob").with_args(
+ os.path.join(os.getcwd(), "[0-9]*_*.py")
+ ).and_return(
+ [
+ os.path.join(os.getcwd(), "2_bar.py"),
+ os.path.join(os.getcwd(), "1_foo.py"),
+ os.path.join(os.getcwd(), "3_baz.py"),
+ ]
+ )
+
+ migrator.get_repository().should_receive("get_ran").once().and_return(["1_foo"])
+ migrator.get_repository().should_receive(
+ "get_next_batch_number"
+ ).once().and_return(1)
bar_mock = flexmock(MigrationStub())
- bar_mock.should_receive('get_connection').once().and_return(connection)
- bar_mock.should_receive('up').once()
+ bar_mock.should_receive("get_connection").once().and_return(connection)
+ bar_mock.should_receive("up").once()
baz_mock = flexmock(MigrationStub())
- baz_mock.should_receive('get_connection').once().and_return(connection)
- baz_mock.should_receive('up').once()
- migrator.should_receive('_resolve').with_args(os.getcwd(), '2_bar').once().and_return(bar_mock)
- migrator.should_receive('_resolve').with_args(os.getcwd(), '3_baz').once().and_return(baz_mock)
+ baz_mock.should_receive("get_connection").once().and_return(connection)
+ baz_mock.should_receive("up").once()
+ migrator.should_receive("_resolve").with_args(
+ os.getcwd(), "2_bar"
+ ).once().and_return(bar_mock)
+ migrator.should_receive("_resolve").with_args(
+ os.getcwd(), "3_baz"
+ ).once().and_return(baz_mock)
migrator.run(os.getcwd(), True)
def test_nothing_is_done_when_no_migrations_outstanding(self):
resolver_mock = flexmock(DatabaseManager)
- resolver_mock.should_receive('connection').and_return(None)
+ resolver_mock.should_receive("connection").and_return(None)
resolver = flexmock(DatabaseManager({}))
migrator = flexmock(
Migrator(
- flexmock(
- DatabaseMigrationRepository(
- resolver,
- 'migrations'
- )
- ),
- resolver
+ flexmock(DatabaseMigrationRepository(resolver, "migrations")), resolver
)
)
g = flexmock(glob)
- g.should_receive('glob').with_args(os.path.join(os.getcwd(), '[0-9]*_*.py')).and_return([
- os.path.join(os.getcwd(), '1_foo.py')
- ])
+ g.should_receive("glob").with_args(
+ os.path.join(os.getcwd(), "[0-9]*_*.py")
+ ).and_return([os.path.join(os.getcwd(), "1_foo.py")])
- migrator.get_repository().should_receive('get_ran').once().and_return(['1_foo'])
+ migrator.get_repository().should_receive("get_ran").once().and_return(["1_foo"])
migrator.run(os.getcwd())
def test_last_batch_of_migrations_can_be_rolled_back(self):
resolver_mock = flexmock(DatabaseManager)
- resolver_mock.should_receive('connection').and_return({})
+ resolver_mock.should_receive("connection").and_return({})
resolver = flexmock(DatabaseManager({}))
connection = flexmock()
- connection.should_receive('transaction').twice().and_return(connection)
- resolver.should_receive('connection').and_return(connection)
+ connection.should_receive("transaction").twice().and_return(connection)
+ resolver.should_receive("connection").and_return(connection)
migrator = flexmock(
Migrator(
- flexmock(
- DatabaseMigrationRepository(
- resolver,
- 'migrations'
- )
- ),
- resolver
+ flexmock(DatabaseMigrationRepository(resolver, "migrations")), resolver
)
)
- foo_migration = MigrationStub('foo')
- bar_migration = MigrationStub('bar')
- migrator.get_repository().should_receive('get_last').once().and_return([
- foo_migration,
- bar_migration
- ])
+ foo_migration = MigrationStub("foo")
+ bar_migration = MigrationStub("bar")
+ migrator.get_repository().should_receive("get_last").once().and_return(
+ [foo_migration, bar_migration]
+ )
bar_mock = flexmock(MigrationStub())
bar_mock.set_connection(connection)
- bar_mock.should_receive('down').once()
+ bar_mock.should_receive("down").once()
foo_mock = flexmock(MigrationStub())
foo_mock.set_connection(connection)
- foo_mock.should_receive('down').once()
- migrator.should_receive('_resolve').with_args(os.getcwd(), 'bar').once().and_return(bar_mock)
- migrator.should_receive('_resolve').with_args(os.getcwd(), 'foo').once().and_return(foo_mock)
-
- migrator.get_repository().should_receive('delete').once().with_args(bar_migration)
- migrator.get_repository().should_receive('delete').once().with_args(foo_migration)
+ foo_mock.should_receive("down").once()
+ migrator.should_receive("_resolve").with_args(
+ os.getcwd(), "bar"
+ ).once().and_return(bar_mock)
+ migrator.should_receive("_resolve").with_args(
+ os.getcwd(), "foo"
+ ).once().and_return(foo_mock)
+
+ migrator.get_repository().should_receive("delete").once().with_args(
+ bar_migration
+ )
+ migrator.get_repository().should_receive("delete").once().with_args(
+ foo_migration
+ )
migrator.rollback(os.getcwd())
- def test_last_batch_of_migrations_can_be_rolled_back_directly_if_transactional_is_false(self):
+ def test_last_batch_of_migrations_can_be_rolled_back_directly_if_transactional_is_false(
+ self
+ ):
resolver_mock = flexmock(DatabaseManager)
- resolver_mock.should_receive('connection').and_return({})
+ resolver_mock.should_receive("connection").and_return({})
resolver = flexmock(DatabaseManager({}))
connection = flexmock()
- connection.should_receive('transaction').never()
- resolver.should_receive('connection').and_return(connection)
+ connection.should_receive("transaction").never()
+ resolver.should_receive("connection").and_return(connection)
migrator = flexmock(
Migrator(
- flexmock(
- DatabaseMigrationRepository(
- resolver,
- 'migrations'
- )
- ),
- resolver
+ flexmock(DatabaseMigrationRepository(resolver, "migrations")), resolver
)
)
- foo_migration = MigrationStub('foo')
- bar_migration = MigrationStub('bar')
- migrator.get_repository().should_receive('get_last').once().and_return([
- foo_migration,
- bar_migration
- ])
+ foo_migration = MigrationStub("foo")
+ bar_migration = MigrationStub("bar")
+ migrator.get_repository().should_receive("get_last").once().and_return(
+ [foo_migration, bar_migration]
+ )
bar_mock = flexmock(MigrationStub())
bar_mock.transactional = False
bar_mock.set_connection(connection)
- bar_mock.should_receive('down').once()
+ bar_mock.should_receive("down").once()
foo_mock = flexmock(MigrationStub())
foo_mock.transactional = False
foo_mock.set_connection(connection)
- foo_mock.should_receive('down').once()
- migrator.should_receive('_resolve').with_args(os.getcwd(), 'bar').once().and_return(bar_mock)
- migrator.should_receive('_resolve').with_args(os.getcwd(), 'foo').once().and_return(foo_mock)
-
- migrator.get_repository().should_receive('delete').once().with_args(bar_migration)
- migrator.get_repository().should_receive('delete').once().with_args(foo_migration)
+ foo_mock.should_receive("down").once()
+ migrator.should_receive("_resolve").with_args(
+ os.getcwd(), "bar"
+ ).once().and_return(bar_mock)
+ migrator.should_receive("_resolve").with_args(
+ os.getcwd(), "foo"
+ ).once().and_return(foo_mock)
+
+ migrator.get_repository().should_receive("delete").once().with_args(
+ bar_migration
+ )
+ migrator.get_repository().should_receive("delete").once().with_args(
+ foo_migration
+ )
migrator.rollback(os.getcwd())
def test_rollback_migration_can_be_pretended(self):
resolver_mock = flexmock(DatabaseManager)
- resolver_mock.should_receive('connection').and_return({})
+ resolver_mock.should_receive("connection").and_return({})
resolver = flexmock(DatabaseManager({}))
connection = flexmock(Connection(None))
- connection.should_receive('get_logged_queries').twice().and_return([])
- resolver.should_receive('connection').with_args(None).and_return(connection)
+ connection.should_receive("get_logged_queries").twice().and_return([])
+ resolver.should_receive("connection").with_args(None).and_return(connection)
migrator = flexmock(
Migrator(
- flexmock(
- DatabaseMigrationRepository(
- resolver,
- 'migrations'
- )
- ),
- resolver
+ flexmock(DatabaseMigrationRepository(resolver, "migrations")), resolver
)
)
- foo_migration = flexmock(MigrationStub('foo'))
- foo_migration.should_receive('get_connection').and_return(connection)
- bar_migration = flexmock(MigrationStub('bar'))
- bar_migration.should_receive('get_connection').and_return(connection)
- migrator.get_repository().should_receive('get_last').once().and_return([
- foo_migration,
- bar_migration
- ])
+ foo_migration = flexmock(MigrationStub("foo"))
+ foo_migration.should_receive("get_connection").and_return(connection)
+ bar_migration = flexmock(MigrationStub("bar"))
+ bar_migration.should_receive("get_connection").and_return(connection)
+ migrator.get_repository().should_receive("get_last").once().and_return(
+ [foo_migration, bar_migration]
+ )
- migrator.should_receive('_resolve').with_args(os.getcwd(), 'bar').once().and_return(bar_migration)
- migrator.should_receive('_resolve').with_args(os.getcwd(), 'foo').once().and_return(foo_migration)
+ migrator.should_receive("_resolve").with_args(
+ os.getcwd(), "bar"
+ ).once().and_return(bar_migration)
+ migrator.should_receive("_resolve").with_args(
+ os.getcwd(), "foo"
+ ).once().and_return(foo_migration)
migrator.rollback(os.getcwd(), True)
@@ -301,27 +307,20 @@ def test_rollback_migration_can_be_pretended(self):
def test_nothing_is_rolled_back_when_nothing_in_repository(self):
resolver = flexmock(DatabaseManager)
- resolver.should_receive('connection').and_return(None)
+ resolver.should_receive("connection").and_return(None)
migrator = flexmock(
Migrator(
- flexmock(
- DatabaseMigrationRepository(
- resolver,
- 'migrations'
- )
- ),
- resolver
+ flexmock(DatabaseMigrationRepository(resolver, "migrations")), resolver
)
)
- migrator.get_repository().should_receive('get_last').once().and_return([])
+ migrator.get_repository().should_receive("get_last").once().and_return([])
migrator.rollback(os.getcwd())
class MigrationStub(Migration):
-
def __init__(self, migration=None):
self.migration = migration
self.upped = False
diff --git a/tests/orm/__init__.py b/tests/orm/__init__.py
index 633f8661..40a96afc 100644
--- a/tests/orm/__init__.py
+++ b/tests/orm/__init__.py
@@ -1,2 +1 @@
# -*- coding: utf-8 -*-
-
diff --git a/tests/orm/mixins/__init__.py b/tests/orm/mixins/__init__.py
index 633f8661..40a96afc 100644
--- a/tests/orm/mixins/__init__.py
+++ b/tests/orm/mixins/__init__.py
@@ -1,2 +1 @@
# -*- coding: utf-8 -*-
-
diff --git a/tests/orm/mixins/test_soft_deletes.py b/tests/orm/mixins/test_soft_deletes.py
index 768bead2..03c037a0 100644
--- a/tests/orm/mixins/test_soft_deletes.py
+++ b/tests/orm/mixins/test_soft_deletes.py
@@ -13,7 +13,6 @@
class SoftDeletesTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
@@ -23,9 +22,9 @@ def test_delete_sets_soft_deleted_column(self):
builder = flexmock(Builder)
query_builder = flexmock(QueryBuilder(None, None, None))
query = Builder(query_builder)
- model.should_receive('new_query').and_return(query)
- builder.should_receive('where').once().with_args('id', 1).and_return(query)
- builder.should_receive('update').once().with_args({'deleted_at': t})
+ model.should_receive("new_query").and_return(query)
+ builder.should_receive("where").once().with_args("id", 1).and_return(query)
+ builder.should_receive("update").once().with_args({"deleted_at": t})
model.delete()
self.assertIsInstance(model.deleted_at, datetime.datetime)
@@ -33,7 +32,7 @@ def test_delete_sets_soft_deleted_column(self):
def test_restore(self):
model = flexmock(SoftDeleteModelStub())
model.set_exists(True)
- model.should_receive('save').once()
+ model.should_receive("save").once()
model.restore()
@@ -41,12 +40,11 @@ def test_restore(self):
class SoftDeleteModelStub(SoftDeletes, Model):
-
def get_key(self):
return 1
def get_key_name(self):
- return 'id'
+ return "id"
def from_datetime(self, value):
return t
diff --git a/tests/orm/mixins/test_soft_deletes_integration.py b/tests/orm/mixins/test_soft_deletes_integration.py
index bbcf5636..6b806bba 100644
--- a/tests/orm/mixins/test_soft_deletes_integration.py
+++ b/tests/orm/mixins/test_soft_deletes_integration.py
@@ -9,12 +9,7 @@
class SoftDeletesIntegrationTestCase(OratorTestCase):
- databases = {
- 'test': {
- 'driver': 'sqlite',
- 'database': ':memory:'
- }
- }
+ databases = {"test": {"driver": "sqlite", "database": ":memory:"}}
def setUp(self):
self.db = DatabaseManager(self.databases)
@@ -24,30 +19,30 @@ def setUp(self):
self.create_schema()
def create_schema(self):
- with self.schema().create('users') as table:
- table.increments('id')
- table.string('email').unique()
+ with self.schema().create("users") as table:
+ table.increments("id")
+ table.string("email").unique()
table.timestamps()
table.soft_deletes()
- with self.schema().create('posts') as table:
- table.increments('id')
- table.string('title')
- table.integer('user_id')
+ with self.schema().create("posts") as table:
+ table.increments("id")
+ table.string("title")
+ table.integer("user_id")
table.timestamps()
table.soft_deletes()
- with self.schema().create('comments') as table:
- table.increments('id')
- table.string('body')
- table.integer('post_id')
+ with self.schema().create("comments") as table:
+ table.increments("id")
+ table.string("body")
+ table.integer("post_id")
table.timestamps()
table.soft_deletes()
def tearDown(self):
- self.schema().drop('users')
- self.schema().drop('posts')
- self.schema().drop('comments')
+ self.schema().drop("users")
+ self.schema().drop("posts")
+ self.schema().drop("comments")
Model.unset_connection_resolver()
@@ -79,7 +74,7 @@ def test_soft_deletes_are_not_retrieved_from_builder_helpers(self):
self.assertEqual(1, count)
query = SoftDeletesTestUser.query()
- self.assertEqual(1, len(query.lists('email')))
+ self.assertEqual(1, len(query.lists("email")))
Paginator.current_page_resolver(lambda: 1)
query = SoftDeletesTestUser.query()
@@ -89,15 +84,20 @@ def test_soft_deletes_are_not_retrieved_from_builder_helpers(self):
query = SoftDeletesTestUser.query()
self.assertEqual(1, len(query.simple_paginate(2).items))
- self.assertEqual(0, SoftDeletesTestUser.where('email', 'john@doe.com').increment('id'))
- self.assertEqual(0, SoftDeletesTestUser.where('email', 'john@doe.com').decrement('id'))
-
+ self.assertEqual(
+ 0, SoftDeletesTestUser.where("email", "john@doe.com").increment("id")
+ )
+ self.assertEqual(
+ 0, SoftDeletesTestUser.where("email", "john@doe.com").decrement("id")
+ )
def test_with_trashed_returns_all_records(self):
self.create_users()
self.assertEqual(2, SoftDeletesTestUser.with_trashed().get().count())
- self.assertIsInstance(SoftDeletesTestUser.with_trashed().find(1), SoftDeletesTestUser)
+ self.assertIsInstance(
+ SoftDeletesTestUser.with_trashed().find(1), SoftDeletesTestUser
+ )
def test_delete_sets_deleted_column(self):
self.create_users()
@@ -142,63 +142,83 @@ def test_first_or_new_ignores_soft_deletes(self):
self.create_users()
john = SoftDeletesTestUser.first_or_new(id=1)
- self.assertEqual('john@doe.com', john.email)
+ self.assertEqual("john@doe.com", john.email)
def test_where_has_with_deleted_relationship(self):
self.create_users()
- jane = SoftDeletesTestUser.where('email', 'jane@doe.com').first()
- post = jane.posts().create(title='First Title')
+ jane = SoftDeletesTestUser.where("email", "jane@doe.com").first()
+ post = jane.posts().create(title="First Title")
- users = SoftDeletesTestUser.where('email', 'john@doe.com').has('posts').get()
+ users = SoftDeletesTestUser.where("email", "john@doe.com").has("posts").get()
self.assertEqual(0, len(users))
- users = SoftDeletesTestUser.where('email', 'jane@doe.com').has('posts').get()
+ users = SoftDeletesTestUser.where("email", "jane@doe.com").has("posts").get()
self.assertEqual(1, len(users))
- users = SoftDeletesTestUser.where('email', 'doesnt@exist.com').or_has('posts').get()
+ users = (
+ SoftDeletesTestUser.where("email", "doesnt@exist.com").or_has("posts").get()
+ )
self.assertEqual(1, len(users))
- users = SoftDeletesTestUser.where_has('posts', lambda q: q.where('title', 'First Title')).get()
+ users = SoftDeletesTestUser.where_has(
+ "posts", lambda q: q.where("title", "First Title")
+ ).get()
self.assertEqual(1, len(users))
- users = SoftDeletesTestUser.where_has('posts', lambda q: q.where('title', 'Another Title')).get()
+ users = SoftDeletesTestUser.where_has(
+ "posts", lambda q: q.where("title", "Another Title")
+ ).get()
self.assertEqual(0, len(users))
- users = SoftDeletesTestUser.where('email', 'doesnt@exist.com')\
- .or_where_has('posts', lambda q: q.where('title', 'First Title'))\
+ users = (
+ SoftDeletesTestUser.where("email", "doesnt@exist.com")
+ .or_where_has("posts", lambda q: q.where("title", "First Title"))
.get()
+ )
self.assertEqual(1, len(users))
# With post delete
post.delete()
- users = SoftDeletesTestUser.has('posts').get()
+ users = SoftDeletesTestUser.has("posts").get()
self.assertEqual(0, len(users))
def test_where_has_with_nested_deleted_relationship(self):
self.create_users()
- jane = SoftDeletesTestUser.where('email', 'jane@doe.com').first()
- post = jane.posts().create(title='First Title')
- comment = post.comments().create(body='Comment Body')
+ jane = SoftDeletesTestUser.where("email", "jane@doe.com").first()
+ post = jane.posts().create(title="First Title")
+ comment = post.comments().create(body="Comment Body")
comment.delete()
- users = SoftDeletesTestUser.has('posts.comments').get()
+ users = SoftDeletesTestUser.has("posts.comments").get()
self.assertEqual(0, len(users))
- users = SoftDeletesTestUser.doesnt_have('posts.comments').get()
+ users = SoftDeletesTestUser.doesnt_have("posts.comments").get()
self.assertEqual(1, len(users))
def test_or_where_with_soft_deletes_constraint(self):
self.create_users()
- users = SoftDeletesTestUser.where('email', 'john@doe.com').or_where('email', 'jane@doe.com')
+ users = SoftDeletesTestUser.where("email", "john@doe.com").or_where(
+ "email", "jane@doe.com"
+ )
+ self.assertEqual(1, len(users.get()))
+ self.assertEqual(["jane@doe.com"], users.order_by("id").lists("email"))
+
+ def test_where_exists_on_soft_delete_model(self):
+ self.create_users()
+
+ users = SoftDeletesTestUser.where_exists(
+ SoftDeletesTestUser.where("email", "jane@doe.com")
+ )
+
self.assertEqual(1, len(users.get()))
- self.assertEqual(['jane@doe.com'], users.order_by('id').lists('email'))
+ self.assertEqual(["jane@doe.com"], users.order_by("id").lists("email"))
def create_users(self):
- john = SoftDeletesTestUser.create(email='john@doe.com')
- jane = SoftDeletesTestUser.create(email='jane@doe.com')
+ john = SoftDeletesTestUser.create(email="john@doe.com")
+ jane = SoftDeletesTestUser.create(email="jane@doe.com")
john.delete()
@@ -211,9 +231,9 @@ def schema(self):
class SoftDeletesTestUser(SoftDeletes, Model):
- __table__ = 'users'
+ __table__ = "users"
- __dates__ = ['deleted_at']
+ __dates__ = ["deleted_at"]
__guarded__ = []
@@ -224,9 +244,9 @@ def posts(self):
class SoftDeletesTestPost(SoftDeletes, Model):
- __table__ = 'posts'
+ __table__ = "posts"
- __dates__ = ['deleted_at']
+ __dates__ = ["deleted_at"]
__guarded__ = []
@@ -237,8 +257,8 @@ def comments(self):
class SoftDeletesTestComment(SoftDeletes, Model):
- __table__ = 'comments'
+ __table__ = "comments"
- __dates__ = ['deleted_at']
+ __dates__ = ["deleted_at"]
__guarded__ = []
diff --git a/tests/orm/models.py b/tests/orm/models.py
index 668ce86e..75b600d5 100644
--- a/tests/orm/models.py
+++ b/tests/orm/models.py
@@ -5,4 +5,4 @@
class User(Model):
- __fillable__ = ['name']
+ __fillable__ = ["name"]
diff --git a/tests/orm/relations/__init__.py b/tests/orm/relations/__init__.py
index 633f8661..40a96afc 100644
--- a/tests/orm/relations/__init__.py
+++ b/tests/orm/relations/__init__.py
@@ -1,2 +1 @@
# -*- coding: utf-8 -*-
-
diff --git a/tests/orm/relations/test_belongs_to.py b/tests/orm/relations/test_belongs_to.py
index bba99566..e6325046 100644
--- a/tests/orm/relations/test_belongs_to.py
+++ b/tests/orm/relations/test_belongs_to.py
@@ -15,37 +15,37 @@
class OrmBelongsToTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
def test_update_retrieve_model_and_updates(self):
relation = self._get_relation()
mock = flexmock(Model())
- mock.should_receive('fill').once().with_args({'foo': 'bar'}).and_return(mock)
- mock.should_receive('save').once().and_return(True)
- relation.get_query().should_receive('first').once().and_return(mock)
+ mock.should_receive("fill").once().with_args({"foo": "bar"}).and_return(mock)
+ mock.should_receive("save").once().and_return(True)
+ relation.get_query().should_receive("first").once().and_return(mock)
- self.assertTrue(relation.update({'foo': 'bar'}))
+ self.assertTrue(relation.update({"foo": "bar"}))
def test_relation_is_properly_initialized(self):
relation = self._get_relation()
model = flexmock(Model())
- model.should_receive('set_relation').once().with_args('foo', None)
- models = relation.init_relation([model], 'foo')
+ model.should_receive("set_relation").once().with_args("foo", None)
+ models = relation.init_relation([model], "foo")
self.assertEqual([model], models)
def test_eager_constraints_are_properly_added(self):
relation = self._get_relation()
- relation.get_query().get_query().should_receive('where_in').once()\
- .with_args('relation.id', ['foreign.value', 'foreign.value.two'])
+ relation.get_query().get_query().should_receive("where_in").once().with_args(
+ "relation.id", ["foreign.value", "foreign.value.two"]
+ )
model1 = OrmBelongsToModelStub()
model2 = OrmBelongsToModelStub()
- model2.foreign_key = 'foreign.value'
+ model2.foreign_key = "foreign.value"
model3 = AnotherOrmBelongsToModelStub()
- model3.foreign_key = 'foreign.value.two'
+ model3.foreign_key = "foreign.value.two"
models = [model1, model2, model3]
relation.add_eager_constraints(models)
@@ -54,63 +54,60 @@ def test_models_are_properly_matched_to_parents(self):
relation = self._get_relation()
result1 = flexmock()
- result1.should_receive('get_attribute').with_args('id').and_return(1)
+ result1.should_receive("get_attribute").with_args("id").and_return(1)
result2 = flexmock()
- result2.should_receive('get_attribute').with_args('id').and_return(2)
+ result2.should_receive("get_attribute").with_args("id").and_return(2)
model1 = OrmBelongsToModelStub()
model1.foreign_key = 1
model2 = OrmBelongsToModelStub()
model2.foreign_key = 2
- models = relation.match([model1, model2], Collection([result1, result2]), 'foo')
+ models = relation.match([model1, model2], Collection([result1, result2]), "foo")
- self.assertEqual(1, models[0].foo.get_attribute('id'))
- self.assertEqual(2, models[1].foo.get_attribute('id'))
+ self.assertEqual(1, models[0].foo.get_attribute("id"))
+ self.assertEqual(2, models[1].foo.get_attribute("id"))
def test_associate_sets_foreign_key_on_model(self):
parent = Model()
- parent.foreign_key = 'foreign.value'
- parent.get_attribute = mock.MagicMock(return_value='foreign.value')
+ parent.foreign_key = "foreign.value"
+ parent.get_attribute = mock.MagicMock(return_value="foreign.value")
parent.set_attribute = mock.MagicMock()
parent.set_relation = mock.MagicMock()
relation = self._get_relation(parent)
associate = flexmock(Model())
- associate.should_receive('get_attribute').once().with_args('id').and_return(1)
+ associate.should_receive("get_attribute").once().with_args("id").and_return(1)
relation.associate(associate)
- parent.get_attribute.assert_has_calls([
- mock.call('foreign_key'),
- mock.call('foreign_key')
- ])
- parent.set_attribute.assert_has_calls([
- mock.call('foreign_key', 1)
- ])
- parent.set_relation.assert_called_once_with('relation', associate)
+ parent.get_attribute.assert_has_calls(
+ [mock.call("foreign_key"), mock.call("foreign_key")]
+ )
+ parent.set_attribute.assert_has_calls([mock.call("foreign_key", 1)])
+ parent.set_relation.assert_called_once_with("relation", associate)
def _get_relation(self, parent=None):
flexmock(Builder)
query = flexmock(QueryBuilder(None, QueryGrammar(), None))
builder = Builder(query)
- builder.should_receive('where').with_args('relation.id', '=', 'foreign.value')
+ builder.should_receive("where").with_args("relation.id", "=", "foreign.value")
related = flexmock(Model())
- related.should_receive('new_query').and_return(builder)
- related.should_receive('get_key_name').and_return('id')
- related.should_receive('get_table').and_return('relation')
- builder.should_receive('get_model').and_return(related)
+ related.should_receive("new_query").and_return(builder)
+ related.should_receive("get_key_name").and_return("id")
+ related.should_receive("get_table").and_return("relation")
+ builder.should_receive("get_model").and_return(related)
if parent is None:
parent = OrmBelongsToModelStub()
- parent.foreign_key = 'foreign.value'
+ parent.foreign_key = "foreign.value"
- return BelongsTo(builder, parent, 'foreign_key', 'id', 'relation')
+ return BelongsTo(builder, parent, "foreign_key", "id", "relation")
class OrmBelongsToModelStub(Model):
- foreign_key = 'foreign.value'
+ foreign_key = "foreign.value"
class AnotherOrmBelongsToModelStub(Model):
- foreign_key = 'foreign.value.two'
+ foreign_key = "foreign.value.two"
diff --git a/tests/orm/relations/test_belongs_to_many.py b/tests/orm/relations/test_belongs_to_many.py
index 0bb0dd01..ef3a8588 100644
--- a/tests/orm/relations/test_belongs_to_many.py
+++ b/tests/orm/relations/test_belongs_to_many.py
@@ -18,76 +18,99 @@
class OrmBelongsToTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
def test_models_are_properly_hydrated(self):
model1 = OrmBelongsToManyModelStub()
- model1.fill(name='john', pivot_user_id=1, pivot_role_id=2)
+ model1.fill(name="john", pivot_user_id=1, pivot_role_id=2)
model2 = OrmBelongsToManyModelStub()
- model2.fill(name='jane', pivot_user_id=3, pivot_role_id=4)
+ model2.fill(name="jane", pivot_user_id=3, pivot_role_id=4)
models = [model1, model2]
- base_builder = flexmock(Builder(QueryBuilder(MockConnection().prepare_mock(),
- QueryGrammar(), QueryProcessor())))
+ base_builder = flexmock(
+ Builder(
+ QueryBuilder(
+ MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()
+ )
+ )
+ )
relation = self._get_relation()
- relation.get_parent().should_receive('get_connection_name').and_return('foo.connection')
- relation.get_query().get_query().should_receive('add_select').once()\
- .with_args(*['roles.*', 'user_role.user_id AS pivot_user_id', 'user_role.role_id AS pivot_role_id'])\
- .and_return(relation.get_query())
- relation.get_query().should_receive('get_models').once().and_return(models)
- relation.get_query().should_receive('eager_load_relations').once().with_args(models).and_return(models)
- relation.get_related().should_receive('new_collection').replace_with(lambda l: Collection(l))
- relation.get_query().should_receive('get_query').once().and_return(base_builder)
+ relation.get_parent().should_receive("get_connection_name").and_return(
+ "foo.connection"
+ )
+ relation.get_query().get_query().should_receive("add_select").once().with_args(
+ *[
+ "roles.*",
+ "user_role.user_id AS pivot_user_id",
+ "user_role.role_id AS pivot_role_id",
+ ]
+ ).and_return(relation.get_query())
+ relation.get_query().should_receive("get_models").once().and_return(models)
+ relation.get_query().should_receive("eager_load_relations").once().with_args(
+ models
+ ).and_return(models)
+ relation.get_related().should_receive("new_collection").replace_with(
+ lambda l: Collection(l)
+ )
+ relation.get_query().should_receive("get_query").once().and_return(base_builder)
results = relation.get()
self.assertIsInstance(results, Collection)
# Make sure the foreign keys were set on the pivot models
- self.assertEqual('user_id', results[0].pivot.get_foreign_key())
- self.assertEqual('role_id', results[0].pivot.get_other_key())
+ self.assertEqual("user_id", results[0].pivot.get_foreign_key())
+ self.assertEqual("role_id", results[0].pivot.get_other_key())
- self.assertEqual('john', results[0].name)
+ self.assertEqual("john", results[0].name)
self.assertEqual(1, results[0].pivot.user_id)
self.assertEqual(2, results[0].pivot.role_id)
- self.assertEqual('foo.connection', results[0].pivot.get_connection_name())
+ self.assertEqual("foo.connection", results[0].pivot.get_connection_name())
- self.assertEqual('jane', results[1].name)
+ self.assertEqual("jane", results[1].name)
self.assertEqual(3, results[1].pivot.user_id)
self.assertEqual(4, results[1].pivot.role_id)
- self.assertEqual('foo.connection', results[1].pivot.get_connection_name())
+ self.assertEqual("foo.connection", results[1].pivot.get_connection_name())
- self.assertEqual('user_role', results[0].pivot.get_table())
+ self.assertEqual("user_role", results[0].pivot.get_table())
self.assertTrue(results[0].pivot.exists)
def test_timestamps_can_be_retrieved_properly(self):
model1 = OrmBelongsToManyModelStub()
- model1.fill(name='john', pivot_user_id=1, pivot_role_id=2)
+ model1.fill(name="john", pivot_user_id=1, pivot_role_id=2)
model2 = OrmBelongsToManyModelStub()
- model2.fill(name='jane', pivot_user_id=3, pivot_role_id=4)
+ model2.fill(name="jane", pivot_user_id=3, pivot_role_id=4)
models = [model1, model2]
- base_builder = flexmock(Builder(QueryBuilder(MockConnection().prepare_mock(),
- QueryGrammar(), QueryProcessor())))
+ base_builder = flexmock(
+ Builder(
+ QueryBuilder(
+ MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()
+ )
+ )
+ )
relation = self._get_relation().with_timestamps()
- relation.get_parent().should_receive('get_connection_name').and_return('foo.connection')
- relation.get_query().get_query().should_receive('add_select').once()\
- .with_args(
- 'roles.*',
- 'user_role.user_id AS pivot_user_id',
- 'user_role.role_id AS pivot_role_id',
- 'user_role.created_at AS pivot_created_at',
- 'user_role.updated_at AS pivot_updated_at'
- )\
- .and_return(relation.get_query())
- relation.get_query().should_receive('get_models').once().and_return(models)
- relation.get_query().should_receive('eager_load_relations').once().with_args(models).and_return(models)
- relation.get_related().should_receive('new_collection').replace_with(lambda l: Collection(l))
- relation.get_query().should_receive('get_query').once().and_return(base_builder)
+ relation.get_parent().should_receive("get_connection_name").and_return(
+ "foo.connection"
+ )
+ relation.get_query().get_query().should_receive("add_select").once().with_args(
+ "roles.*",
+ "user_role.user_id AS pivot_user_id",
+ "user_role.role_id AS pivot_role_id",
+ "user_role.created_at AS pivot_created_at",
+ "user_role.updated_at AS pivot_updated_at",
+ ).and_return(relation.get_query())
+ relation.get_query().should_receive("get_models").once().and_return(models)
+ relation.get_query().should_receive("eager_load_relations").once().with_args(
+ models
+ ).and_return(models)
+ relation.get_related().should_receive("new_collection").replace_with(
+ lambda l: Collection(l)
+ )
+ relation.get_query().should_receive("get_query").once().and_return(base_builder)
results = relation.get()
@@ -110,8 +133,12 @@ def test_models_are_properly_matched_to_parents(self):
model3 = OrmBelongsToManyModelStub()
model3.id = 3
- relation.get_related().should_receive('new_collection').replace_with(lambda l=None: Collection(l))
- models = relation.match([model1, model2, model3], Collection([result1, result2, result3]), 'foo')
+ relation.get_related().should_receive("new_collection").replace_with(
+ lambda l=None: Collection(l)
+ )
+ models = relation.match(
+ [model1, model2, model3], Collection([result1, result2, result3]), "foo"
+ )
self.assertEqual(1, models[0].foo[0].pivot.user_id)
self.assertEqual(1, len(models[0].foo))
@@ -122,16 +149,20 @@ def test_models_are_properly_matched_to_parents(self):
def test_relation_is_properly_initialized(self):
relation = self._get_relation()
- relation.get_related().should_receive('new_collection').replace_with(lambda l=None: Collection(l or []))
+ relation.get_related().should_receive("new_collection").replace_with(
+ lambda l=None: Collection(l or [])
+ )
model = flexmock(Model())
- model.should_receive('set_relation').once().with_args('foo', Collection)
- models = relation.init_relation([model], 'foo')
+ model.should_receive("set_relation").once().with_args("foo", Collection)
+ models = relation.init_relation([model], "foo")
self.assertEqual([model], models)
def test_eager_constraints_are_properly_added(self):
relation = self._get_relation()
- relation.get_query().get_query().should_receive('where_in').once().with_args('user_role.user_id', [1, 2])
+ relation.get_query().get_query().should_receive("where_in").once().with_args(
+ "user_role.user_id", [1, 2]
+ )
model1 = OrmBelongsToManyModelStub()
model1.id = 1
model2 = OrmBelongsToManyModelStub()
@@ -143,102 +174,106 @@ def test_attach_inserts_pivot_table_record(self):
flexmock(BelongsToMany, touch_if_touching=lambda: True)
relation = self._get_relation()
query = flexmock()
- query.should_receive('from_').once().with_args('user_role').and_return(query)
- query.should_receive('insert').once().with_args([{'user_id': 1, 'role_id': 2, 'foo': 'bar'}]).and_return(True)
+ query.should_receive("from_").once().with_args("user_role").and_return(query)
+ query.should_receive("insert").once().with_args(
+ [{"user_id": 1, "role_id": 2, "foo": "bar"}]
+ ).and_return(True)
mock_query_builder = flexmock()
- relation.get_query().should_receive('get_query').and_return(mock_query_builder)
- mock_query_builder.should_receive('new_query').once().and_return(query)
- relation.should_receive('touch_if_touching').once()
+ relation.get_query().should_receive("get_query").and_return(mock_query_builder)
+ mock_query_builder.should_receive("new_query").once().and_return(query)
+ relation.should_receive("touch_if_touching").once()
- relation.attach(2, {'foo': 'bar'})
+ relation.attach(2, {"foo": "bar"})
def test_attach_multiple_inserts_pivot_table_record(self):
flexmock(BelongsToMany, touch_if_touching=lambda: True)
relation = self._get_relation()
query = flexmock()
- query.should_receive('from_').once().with_args('user_role').and_return(query)
- query.should_receive('insert').once().with_args(
+ query.should_receive("from_").once().with_args("user_role").and_return(query)
+ query.should_receive("insert").once().with_args(
[
- {'user_id': 1, 'role_id': 2, 'foo': 'bar'},
- {'user_id': 1, 'role_id': 3, 'bar': 'baz', 'foo': 'bar'}
+ {"user_id": 1, "role_id": 2, "foo": "bar"},
+ {"user_id": 1, "role_id": 3, "bar": "baz", "foo": "bar"},
]
).and_return(True)
mock_query_builder = flexmock()
- relation.get_query().should_receive('get_query').and_return(mock_query_builder)
- mock_query_builder.should_receive('new_query').once().and_return(query)
- relation.should_receive('touch_if_touching').once()
+ relation.get_query().should_receive("get_query").and_return(mock_query_builder)
+ mock_query_builder.should_receive("new_query").once().and_return(query)
+ relation.should_receive("touch_if_touching").once()
- relation.attach([2, {3: {'bar': 'baz'}}], {'foo': 'bar'})
+ relation.attach([2, {3: {"bar": "baz"}}], {"foo": "bar"})
def test_attach_inserts_pivot_table_records_with_timestamps_when_ncessary(self):
flexmock(BelongsToMany, touch_if_touching=lambda: True)
relation = self._get_relation().with_timestamps()
query = flexmock()
- query.should_receive('from_').once().with_args('user_role').and_return(query)
+ query.should_receive("from_").once().with_args("user_role").and_return(query)
now = pendulum.now()
- query.should_receive('insert').once().with_args(
+ query.should_receive("insert").once().with_args(
[
- {'user_id': 1, 'role_id': 2, 'foo': 'bar', 'created_at': now, 'updated_at': now}
+ {
+ "user_id": 1,
+ "role_id": 2,
+ "foo": "bar",
+ "created_at": now,
+ "updated_at": now,
+ }
]
).and_return(True)
mock_query_builder = flexmock()
- relation.get_query().should_receive('get_query').and_return(mock_query_builder)
- mock_query_builder.should_receive('new_query').once().and_return(query)
- relation.get_parent().should_receive('fresh_timestamp').once().and_return(now)
- relation.should_receive('touch_if_touching').once()
+ relation.get_query().should_receive("get_query").and_return(mock_query_builder)
+ mock_query_builder.should_receive("new_query").once().and_return(query)
+ relation.get_parent().should_receive("fresh_timestamp").once().and_return(now)
+ relation.should_receive("touch_if_touching").once()
- relation.attach(2, {'foo': 'bar'})
+ relation.attach(2, {"foo": "bar"})
def test_attach_inserts_pivot_table_records_with_a_created_at_timestamp(self):
flexmock(BelongsToMany, touch_if_touching=lambda: True)
- relation = self._get_relation().with_pivot('created_at')
+ relation = self._get_relation().with_pivot("created_at")
query = flexmock()
- query.should_receive('from_').once().with_args('user_role').and_return(query)
+ query.should_receive("from_").once().with_args("user_role").and_return(query)
now = pendulum.now()
- query.should_receive('insert').once().with_args(
- [
- {'user_id': 1, 'role_id': 2, 'foo': 'bar', 'created_at': now}
- ]
+ query.should_receive("insert").once().with_args(
+ [{"user_id": 1, "role_id": 2, "foo": "bar", "created_at": now}]
).and_return(True)
mock_query_builder = flexmock()
- relation.get_query().should_receive('get_query').and_return(mock_query_builder)
- mock_query_builder.should_receive('new_query').once().and_return(query)
- relation.get_parent().should_receive('fresh_timestamp').once().and_return(now)
- relation.should_receive('touch_if_touching').once()
+ relation.get_query().should_receive("get_query").and_return(mock_query_builder)
+ mock_query_builder.should_receive("new_query").once().and_return(query)
+ relation.get_parent().should_receive("fresh_timestamp").once().and_return(now)
+ relation.should_receive("touch_if_touching").once()
- relation.attach(2, {'foo': 'bar'})
+ relation.attach(2, {"foo": "bar"})
def test_attach_inserts_pivot_table_records_with_an_updated_at_timestamp(self):
flexmock(BelongsToMany, touch_if_touching=lambda: True)
- relation = self._get_relation().with_pivot('updated_at')
+ relation = self._get_relation().with_pivot("updated_at")
query = flexmock()
- query.should_receive('from_').once().with_args('user_role').and_return(query)
+ query.should_receive("from_").once().with_args("user_role").and_return(query)
now = pendulum.now()
- query.should_receive('insert').once().with_args(
- [
- {'user_id': 1, 'role_id': 2, 'foo': 'bar', 'updated_at': now}
- ]
+ query.should_receive("insert").once().with_args(
+ [{"user_id": 1, "role_id": 2, "foo": "bar", "updated_at": now}]
).and_return(True)
mock_query_builder = flexmock()
- relation.get_query().should_receive('get_query').and_return(mock_query_builder)
- mock_query_builder.should_receive('new_query').once().and_return(query)
- relation.get_parent().should_receive('fresh_timestamp').once().and_return(now)
- relation.should_receive('touch_if_touching').once()
+ relation.get_query().should_receive("get_query").and_return(mock_query_builder)
+ mock_query_builder.should_receive("new_query").once().and_return(query)
+ relation.get_parent().should_receive("fresh_timestamp").once().and_return(now)
+ relation.should_receive("touch_if_touching").once()
- relation.attach(2, {'foo': 'bar'})
+ relation.attach(2, {"foo": "bar"})
def test_detach_remove_pivot_table_record(self):
flexmock(BelongsToMany, touch_if_touching=lambda: True)
relation = self._get_relation()
query = flexmock()
- query.should_receive('from_').once().with_args('user_role').and_return(query)
- query.should_receive('where').once().with_args('user_id', 1).and_return(query)
- query.should_receive('where_in').once().with_args('role_id', [1, 2, 3])
- query.should_receive('delete').once().and_return(True)
+ query.should_receive("from_").once().with_args("user_role").and_return(query)
+ query.should_receive("where").once().with_args("user_id", 1).and_return(query)
+ query.should_receive("where_in").once().with_args("role_id", [1, 2, 3])
+ query.should_receive("delete").once().and_return(True)
mock_query_builder = flexmock()
- relation.get_query().should_receive('get_query').and_return(mock_query_builder)
- mock_query_builder.should_receive('new_query').once().and_return(query)
- relation.should_receive('touch_if_touching').once()
+ relation.get_query().should_receive("get_query").and_return(mock_query_builder)
+ mock_query_builder.should_receive("new_query").once().and_return(query)
+ relation.should_receive("touch_if_touching").once()
self.assertTrue(relation.detach([1, 2, 3]))
@@ -246,14 +281,14 @@ def test_detach_with_single_id_remove_pivot_table_record(self):
flexmock(BelongsToMany, touch_if_touching=lambda: True)
relation = self._get_relation()
query = flexmock()
- query.should_receive('from_').once().with_args('user_role').and_return(query)
- query.should_receive('where').once().with_args('user_id', 1).and_return(query)
- query.should_receive('where_in').once().with_args('role_id', [1])
- query.should_receive('delete').once().and_return(True)
+ query.should_receive("from_").once().with_args("user_role").and_return(query)
+ query.should_receive("where").once().with_args("user_id", 1).and_return(query)
+ query.should_receive("where_in").once().with_args("role_id", [1])
+ query.should_receive("delete").once().and_return(True)
mock_query_builder = flexmock()
- relation.get_query().should_receive('get_query').and_return(mock_query_builder)
- mock_query_builder.should_receive('new_query').once().and_return(query)
- relation.should_receive('touch_if_touching').once()
+ relation.get_query().should_receive("get_query").and_return(mock_query_builder)
+ mock_query_builder.should_receive("new_query").once().and_return(query)
+ relation.should_receive("touch_if_touching").once()
self.assertTrue(relation.detach(1))
@@ -261,14 +296,14 @@ def test_detach_clears_all_records_when_no_ids(self):
flexmock(BelongsToMany, touch_if_touching=lambda: True)
relation = self._get_relation()
query = flexmock()
- query.should_receive('from_').once().with_args('user_role').and_return(query)
- query.should_receive('where').once().with_args('user_id', 1).and_return(query)
- query.should_receive('where_in').never()
- query.should_receive('delete').once().and_return(True)
+ query.should_receive("from_").once().with_args("user_role").and_return(query)
+ query.should_receive("where").once().with_args("user_id", 1).and_return(query)
+ query.should_receive("where_in").never()
+ query.should_receive("delete").once().and_return(True)
mock_query_builder = flexmock()
- relation.get_query().should_receive('get_query').and_return(mock_query_builder)
- mock_query_builder.should_receive('new_query').once().and_return(query)
- relation.should_receive('touch_if_touching').once()
+ relation.get_query().should_receive("get_query").and_return(mock_query_builder)
+ mock_query_builder.should_receive("new_query").once().and_return(query)
+ relation.should_receive("touch_if_touching").once()
self.assertTrue(relation.detach())
@@ -276,201 +311,240 @@ def test_create_creates_new_model_and_insert_attachment_record(self):
flexmock(BelongsToMany, attach=lambda: True)
relation = self._get_relation()
model = flexmock()
- relation.get_related().should_receive('new_instance').once().and_return(model).with_args({'foo': 'bar'})
- model.should_receive('save').once()
- model.should_receive('get_key').and_return('foo')
- relation.should_receive('attach').once().with_args('foo', {'bar': 'baz'}, True)
+ relation.get_related().should_receive("new_instance").once().and_return(
+ model
+ ).with_args({"foo": "bar"})
+ model.should_receive("save").once()
+ model.should_receive("get_key").and_return("foo")
+ relation.should_receive("attach").once().with_args("foo", {"bar": "baz"}, True)
- self.assertEqual(model, relation.create({'foo': 'bar'}, {'bar': 'baz'}))
+ self.assertEqual(model, relation.create({"foo": "bar"}, {"bar": "baz"}))
def test_find_or_new_finds_model(self):
flexmock(BelongsToMany)
relation = self._get_relation()
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('find').with_args('foo', None).and_return(model)
- relation.get_related().should_receive('new_instance').never()
+ model.foo = "bar"
+ relation.get_query().should_receive("find").with_args("foo", None).and_return(
+ model
+ )
+ relation.get_related().should_receive("new_instance").never()
- self.assertEqual('bar', relation.find_or_new('foo').foo)
+ self.assertEqual("bar", relation.find_or_new("foo").foo)
def test_find_or_new_returns_new_model(self):
flexmock(BelongsToMany)
relation = self._get_relation()
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('find').with_args('foo', None).and_return(None)
- relation.get_related().should_receive('new_instance').once().and_return(model)
+ model.foo = "bar"
+ relation.get_query().should_receive("find").with_args("foo", None).and_return(
+ None
+ )
+ relation.get_related().should_receive("new_instance").once().and_return(model)
- self.assertEqual('bar', relation.find_or_new('foo').foo)
+ self.assertEqual("bar", relation.find_or_new("foo").foo)
def test_first_or_new_finds_first_model(self):
flexmock(BelongsToMany)
relation = self._get_relation()
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('where').with_args({'foo': 'bar'}).and_return(relation.get_query())
- relation.get_query().should_receive('first').once().and_return(model)
- relation.get_related().should_receive('new_instance').never()
+ model.foo = "bar"
+ relation.get_query().should_receive("where").with_args(
+ {"foo": "bar"}
+ ).and_return(relation.get_query())
+ relation.get_query().should_receive("first").once().and_return(model)
+ relation.get_related().should_receive("new_instance").never()
- self.assertEqual('bar', relation.first_or_new({'foo': 'bar'}).foo)
+ self.assertEqual("bar", relation.first_or_new({"foo": "bar"}).foo)
def test_first_or_new_returns_new_model(self):
flexmock(BelongsToMany)
relation = self._get_relation()
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('where').with_args({'foo': 'bar'}).and_return(relation.get_query())
- relation.get_query().should_receive('first').once().and_return(None)
- relation.get_related().should_receive('new_instance').once().and_return(model)
+ model.foo = "bar"
+ relation.get_query().should_receive("where").with_args(
+ {"foo": "bar"}
+ ).and_return(relation.get_query())
+ relation.get_query().should_receive("first").once().and_return(None)
+ relation.get_related().should_receive("new_instance").once().and_return(model)
- self.assertEqual('bar', relation.first_or_new({'foo': 'bar'}).foo)
+ self.assertEqual("bar", relation.first_or_new({"foo": "bar"}).foo)
def test_first_or_create_finds_first_model(self):
flexmock(BelongsToMany)
relation = self._get_relation()
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('where').with_args({'foo': 'bar'}).and_return(relation.get_query())
- relation.get_query().should_receive('first').once().and_return(model)
- relation.should_receive('create').never()
+ model.foo = "bar"
+ relation.get_query().should_receive("where").with_args(
+ {"foo": "bar"}
+ ).and_return(relation.get_query())
+ relation.get_query().should_receive("first").once().and_return(model)
+ relation.should_receive("create").never()
- self.assertEqual('bar', relation.first_or_create({'foo': 'bar'}).foo)
+ self.assertEqual("bar", relation.first_or_create({"foo": "bar"}).foo)
def test_first_or_create_returns_new_model(self):
flexmock(BelongsToMany)
relation = self._get_relation()
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('where').with_args({'foo': 'bar'}).and_return(relation.get_query())
- relation.get_query().should_receive('first').once().and_return(None)
- relation.should_receive('create').once().with_args({'foo': 'bar'}, {}, True).and_return(model)
+ model.foo = "bar"
+ relation.get_query().should_receive("where").with_args(
+ {"foo": "bar"}
+ ).and_return(relation.get_query())
+ relation.get_query().should_receive("first").once().and_return(None)
+ relation.should_receive("create").once().with_args(
+ {"foo": "bar"}, {}, True
+ ).and_return(model)
- self.assertEqual('bar', relation.first_or_create({'foo': 'bar'}).foo)
+ self.assertEqual("bar", relation.first_or_create({"foo": "bar"}).foo)
def test_update_or_create_finds_first_mode_and_updates(self):
flexmock(BelongsToMany)
relation = self._get_relation()
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('where').with_args({'foo': 'bar'}).and_return(relation.get_query())
- relation.get_query().should_receive('first').once().and_return(model)
- model.should_receive('fill').once()
- model.should_receive('save').once()
- relation.should_receive('create').never()
+ model.foo = "bar"
+ relation.get_query().should_receive("where").with_args(
+ {"foo": "bar"}
+ ).and_return(relation.get_query())
+ relation.get_query().should_receive("first").once().and_return(model)
+ model.should_receive("fill").once()
+ model.should_receive("save").once()
+ relation.should_receive("create").never()
- self.assertEqual('bar', relation.update_or_create({'foo': 'bar'}).foo)
+ self.assertEqual("bar", relation.update_or_create({"foo": "bar"}).foo)
def test_update_or_create_returns_new_model(self):
flexmock(BelongsToMany)
relation = self._get_relation()
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('where').with_args({'foo': 'bar'}).and_return(relation.get_query())
- relation.get_query().should_receive('first').once().and_return(None)
- relation.should_receive('create').once().with_args({'bar': 'baz'}, None, True).and_return(model)
+ model.foo = "bar"
+ relation.get_query().should_receive("where").with_args(
+ {"foo": "bar"}
+ ).and_return(relation.get_query())
+ relation.get_query().should_receive("first").once().and_return(None)
+ relation.should_receive("create").once().with_args(
+ {"bar": "baz"}, None, True
+ ).and_return(model)
- self.assertEqual('bar', relation.update_or_create({'foo': 'bar'}, {'bar': 'baz'}).foo)
+ self.assertEqual(
+ "bar", relation.update_or_create({"foo": "bar"}, {"bar": "baz"}).foo
+ )
def test_sync_syncs_intermediate_table_with_given_list(self):
- for list_ in [[2, 3, 4], ['2', '3', '4']]:
+ for list_ in [[2, 3, 4], ["2", "3", "4"]]:
flexmock(BelongsToMany)
relation = self._get_relation()
query = flexmock()
- query.should_receive('from_').once().with_args('user_role').and_return(query)
- query.should_receive('where').once().with_args('user_id', 1).and_return(query)
+ query.should_receive("from_").once().with_args("user_role").and_return(
+ query
+ )
+ query.should_receive("where").once().with_args("user_id", 1).and_return(
+ query
+ )
mock_query_builder = flexmock()
- relation.get_query().should_receive('get_query').and_return(mock_query_builder)
- mock_query_builder.should_receive('new_query').once().and_return(query)
- query.should_receive('lists').once().with_args('role_id').and_return(Collection([1, list_[0], list_[1]]))
- relation.should_receive('attach').once().with_args(list_[2], {}, False)
- relation.should_receive('detach').once().with_args([1])
- relation.get_related().should_receive('touches').and_return(False)
- relation.get_parent().should_receive('touches').and_return(False)
+ relation.get_query().should_receive("get_query").and_return(
+ mock_query_builder
+ )
+ mock_query_builder.should_receive("new_query").once().and_return(query)
+ query.should_receive("lists").once().with_args("role_id").and_return(
+ Collection([1, list_[0], list_[1]])
+ )
+ relation.should_receive("attach").once().with_args(list_[2], {}, False)
+ relation.should_receive("detach").once().with_args([1])
+ relation.get_related().should_receive("touches").and_return(False)
+ relation.get_parent().should_receive("touches").and_return(False)
self.assertEqual(
- {
- 'attached': [list_[2]],
- 'detached': [1],
- 'updated': []
- },
- relation.sync(list_)
+ {"attached": [list_[2]], "detached": [1], "updated": []},
+ relation.sync(list_),
)
def test_sync_syncs_intermediate_table_with_given_list_and_attributes(self):
flexmock(BelongsToMany)
relation = self._get_relation()
query = flexmock()
- query.should_receive('from_').once().with_args('user_role').and_return(query)
- query.should_receive('where').once().with_args('user_id', 1).and_return(query)
+ query.should_receive("from_").once().with_args("user_role").and_return(query)
+ query.should_receive("where").once().with_args("user_id", 1).and_return(query)
mock_query_builder = flexmock()
- relation.get_query().should_receive('get_query').and_return(mock_query_builder)
- mock_query_builder.should_receive('new_query').once().and_return(query)
- query.should_receive('lists').once().with_args('role_id').and_return(Collection([1, 2, 3]))
- relation.should_receive('attach').once().with_args(4, {'foo': 'bar'}, False)
- relation.should_receive('update_existing_pivot').once().with_args(3, {'bar': 'baz'}, False).and_return(True)
- relation.should_receive('detach').once().with_args([1])
- relation.should_receive('touch_if_touching').once()
- relation.get_related().should_receive('touches').and_return(False)
- relation.get_parent().should_receive('touches').and_return(False)
+ relation.get_query().should_receive("get_query").and_return(mock_query_builder)
+ mock_query_builder.should_receive("new_query").once().and_return(query)
+ query.should_receive("lists").once().with_args("role_id").and_return(
+ Collection([1, 2, 3])
+ )
+ relation.should_receive("attach").once().with_args(4, {"foo": "bar"}, False)
+ relation.should_receive("update_existing_pivot").once().with_args(
+ 3, {"bar": "baz"}, False
+ ).and_return(True)
+ relation.should_receive("detach").once().with_args([1])
+ relation.should_receive("touch_if_touching").once()
+ relation.get_related().should_receive("touches").and_return(False)
+ relation.get_parent().should_receive("touches").and_return(False)
self.assertEqual(
- {
- 'attached': [4],
- 'detached': [1],
- 'updated': [3]
- },
- relation.sync([2, {3: {'bar': 'baz'}}, {4: {'foo': 'bar'}}], )
+ {"attached": [4], "detached": [1], "updated": [3]},
+ relation.sync([2, {3: {"bar": "baz"}}, {4: {"foo": "bar"}}]),
)
def test_sync_does_not_return_values_that_were_not_updated(self):
flexmock(BelongsToMany)
relation = self._get_relation()
query = flexmock()
- query.should_receive('from_').once().with_args('user_role').and_return(query)
- query.should_receive('where').once().with_args('user_id', 1).and_return(query)
+ query.should_receive("from_").once().with_args("user_role").and_return(query)
+ query.should_receive("where").once().with_args("user_id", 1).and_return(query)
mock_query_builder = flexmock()
- relation.get_query().should_receive('get_query').and_return(mock_query_builder)
- mock_query_builder.should_receive('new_query').once().and_return(query)
- query.should_receive('lists').once().with_args('role_id').and_return(Collection([1, 2, 3]))
- relation.should_receive('attach').once().with_args(4, {'foo': 'bar'}, False)
- relation.should_receive('update_existing_pivot').once().with_args(3, {'bar': 'baz'}, False).and_return(False)
- relation.should_receive('detach').once().with_args([1])
- relation.should_receive('touch_if_touching').once()
- relation.get_related().should_receive('touches').and_return(False)
- relation.get_parent().should_receive('touches').and_return(False)
+ relation.get_query().should_receive("get_query").and_return(mock_query_builder)
+ mock_query_builder.should_receive("new_query").once().and_return(query)
+ query.should_receive("lists").once().with_args("role_id").and_return(
+ Collection([1, 2, 3])
+ )
+ relation.should_receive("attach").once().with_args(4, {"foo": "bar"}, False)
+ relation.should_receive("update_existing_pivot").once().with_args(
+ 3, {"bar": "baz"}, False
+ ).and_return(False)
+ relation.should_receive("detach").once().with_args([1])
+ relation.should_receive("touch_if_touching").once()
+ relation.get_related().should_receive("touches").and_return(False)
+ relation.get_parent().should_receive("touches").and_return(False)
self.assertEqual(
- {
- 'attached': [4],
- 'detached': [1],
- 'updated': []
- },
- relation.sync([2, {3: {'bar': 'baz'}}, {4: {'foo': 'bar'}}], )
+ {"attached": [4], "detached": [1], "updated": []},
+ relation.sync([2, {3: {"bar": "baz"}}, {4: {"foo": "bar"}}]),
)
def test_touch_method_syncs_timestamps(self):
relation = self._get_relation()
- relation.get_related().should_receive('get_updated_at_column').and_return('updated_at')
+ relation.get_related().should_receive("get_updated_at_column").and_return(
+ "updated_at"
+ )
now = pendulum.now()
- relation.get_related().should_receive('fresh_timestamp').and_return(now)
- relation.get_related().should_receive('get_qualified_key_name').and_return('table.id')
- relation.get_query().get_query().should_receive('select').once().with_args('table.id')\
- .and_return(relation.get_query().get_query())
- relation.get_query().should_receive('lists').once().and_return(Collection([1, 2, 3]))
+ relation.get_related().should_receive("fresh_timestamp").and_return(now)
+ relation.get_related().should_receive("get_qualified_key_name").and_return(
+ "table.id"
+ )
+ relation.get_query().get_query().should_receive("select").once().with_args(
+ "table.id"
+ ).and_return(relation.get_query().get_query())
+ relation.get_query().should_receive("lists").once().and_return(
+ Collection([1, 2, 3])
+ )
query = flexmock()
- relation.get_related().should_receive('new_query').once().and_return(query)
- query.should_receive('where_in').once().with_args('id', [1, 2, 3]).and_return(query)
- query.should_receive('update').once().with_args({'updated_at': now})
+ relation.get_related().should_receive("new_query").once().and_return(query)
+ query.should_receive("where_in").once().with_args("id", [1, 2, 3]).and_return(
+ query
+ )
+ query.should_receive("update").once().with_args({"updated_at": now})
relation.touch()
def test_touch_if_touching(self):
flexmock(BelongsToMany)
relation = self._get_relation()
- relation.should_receive('_touching_parent').once().and_return(True)
- relation.get_parent().should_receive('touch').once()
- relation.get_parent().should_receive('touches').once().with_args('relation_name').and_return(True)
- relation.should_receive('touch').once()
+ relation.should_receive("_touching_parent").once().and_return(True)
+ relation.get_parent().should_receive("touch").once()
+ relation.get_parent().should_receive("touches").once().with_args(
+ "relation_name"
+ ).and_return(True)
+ relation.should_receive("touch").once()
relation.touch_if_touching()
@@ -478,16 +552,20 @@ def test_sync_method_converts_collection_to_list_of_keys(self):
flexmock(BelongsToMany)
relation = self._get_relation()
query = flexmock()
- query.should_receive('from_').once().with_args('user_role').and_return(query)
- query.should_receive('where').once().with_args('user_id', 1).and_return(query)
+ query.should_receive("from_").once().with_args("user_role").and_return(query)
+ query.should_receive("where").once().with_args("user_id", 1).and_return(query)
mock_query_builder = flexmock()
- relation.get_query().should_receive('get_query').and_return(mock_query_builder)
- mock_query_builder.should_receive('new_query').once().and_return(query)
- query.should_receive('lists').once().with_args('role_id').and_return(Collection([1, 2, 3]))
+ relation.get_query().should_receive("get_query").and_return(mock_query_builder)
+ mock_query_builder.should_receive("new_query").once().and_return(query)
+ query.should_receive("lists").once().with_args("role_id").and_return(
+ Collection([1, 2, 3])
+ )
collection = flexmock(Collection())
- collection.should_receive('model_keys').once().and_return([1, 2, 3])
- relation.should_receive('_format_sync_list').with_args([1, 2, 3]).and_return({1: {}, 2: {}, 3: {}})
+ collection.should_receive("model_keys").once().and_return([1, 2, 3])
+ relation.should_receive("_format_sync_list").with_args([1, 2, 3]).and_return(
+ {1: {}, 2: {}, 3: {}}
+ )
relation.sync(collection)
@@ -495,54 +573,70 @@ def test_where_pivot_params_used_for_new_queries(self):
flexmock(BelongsToMany)
relation = self._get_relation()
- relation.get_query().should_receive('where').once().and_return(relation)
+ relation.get_query().should_receive("where").once().and_return(relation)
query = flexmock()
mock_query_builder = flexmock()
- relation.get_query().should_receive('get_query').and_return(mock_query_builder)
- mock_query_builder.should_receive('new_query').once().and_return(query)
+ relation.get_query().should_receive("get_query").and_return(mock_query_builder)
+ mock_query_builder.should_receive("new_query").once().and_return(query)
- query.should_receive('from_').once().with_args('user_role').and_return(query)
+ query.should_receive("from_").once().with_args("user_role").and_return(query)
- query.should_receive('where').once().with_args('user_id', 1).and_return(query)
+ query.should_receive("where").once().with_args("user_id", 1).and_return(query)
- query.should_receive('where').once().with_args('foo', '=', 'bar', 'and').and_return(query)
+ query.should_receive("where").once().with_args(
+ "foo", "=", "bar", "and"
+ ).and_return(query)
- query.should_receive('lists').once().with_args('role_id').and_return(Collection([1, 2, 3]))
- relation.should_receive('_format_sync_list').with_args([1, 2, 3]).and_return({1: {}, 2: {}, 3: {}})
+ query.should_receive("lists").once().with_args("role_id").and_return(
+ Collection([1, 2, 3])
+ )
+ relation.should_receive("_format_sync_list").with_args([1, 2, 3]).and_return(
+ {1: {}, 2: {}, 3: {}}
+ )
- relation = relation.where_pivot('foo', '=', 'bar')
+ relation = relation.where_pivot("foo", "=", "bar")
relation.sync([1, 2, 3])
def _get_relation(self):
builder, parent = self._get_relation_arguments()[:2]
- return BelongsToMany(builder, parent, 'user_role', 'user_id', 'role_id', 'relation_name')
+ return BelongsToMany(
+ builder, parent, "user_role", "user_id", "role_id", "relation_name"
+ )
def _get_relation_arguments(self):
- flexmock(Model).should_receive('_boot_columns').and_return(['name'])
+ flexmock(Model).should_receive("_boot_columns").and_return(["name"])
parent = flexmock(Model())
- parent.should_receive('get_key').and_return(1)
- parent.should_receive('get_created_at_column').and_return('created_at')
- parent.should_receive('get_updated_at_column').and_return('updated_at')
+ parent.should_receive("get_key").and_return(1)
+ parent.should_receive("get_created_at_column").and_return("created_at")
+ parent.should_receive("get_updated_at_column").and_return("updated_at")
- query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()))
+ query = flexmock(
+ QueryBuilder(
+ MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()
+ )
+ )
flexmock(Builder)
builder = Builder(query)
- builder.should_receive('get_query').and_return(query)
+ builder.should_receive("get_query").and_return(query)
related = flexmock(Model())
builder.set_model(related)
- builder.should_receive('get_model').and_return(related)
+ builder.should_receive("get_model").and_return(related)
- related.should_receive('new_query').and_return(builder)
- related.should_receive('get_key_name').and_return('id')
- related.should_receive('get_table').and_return('roles')
- related.should_receive('new_pivot').replace_with(lambda *args: Pivot(*args))
+ related.should_receive("new_query").and_return(builder)
+ related.should_receive("get_key_name").and_return("id")
+ related.should_receive("get_table").and_return("roles")
+ related.should_receive("new_pivot").replace_with(lambda *args: Pivot(*args))
- builder.get_query().should_receive('join').at_least().once().with_args('user_role', 'roles.id', '=', 'user_role.role_id')
- builder.should_receive('where').at_least().once().with_args('user_role.user_id', '=', 1)
+ builder.get_query().should_receive("join").at_least().once().with_args(
+ "user_role", "roles.id", "=", "user_role.role_id"
+ )
+ builder.should_receive("where").at_least().once().with_args(
+ "user_role.user_id", "=", 1
+ )
- return builder, parent, 'user_role', 'user_id', 'role_id', 'relation_id'
+ return builder, parent, "user_role", "user_id", "role_id", "relation_id"
class OrmBelongsToManyModelStub(Model):
diff --git a/tests/orm/relations/test_decorators.py b/tests/orm/relations/test_decorators.py
index 20f62829..4e33c917 100644
--- a/tests/orm/relations/test_decorators.py
+++ b/tests/orm/relations/test_decorators.py
@@ -2,14 +2,20 @@
from ... import OratorTestCase
from orator import Model as BaseModel
-from orator.orm import morph_to, has_one, has_many, belongs_to_many, morph_many, belongs_to
+from orator.orm import (
+ morph_to,
+ has_one,
+ has_many,
+ belongs_to_many,
+ morph_many,
+ belongs_to,
+)
from orator.orm.model import ModelRegister
from orator.connections import SQLiteConnection
from orator.connectors.sqlite_connector import SQLiteConnector
class DecoratorsTestCase(OratorTestCase):
-
@classmethod
def setUpClass(cls):
Model.set_connection_resolver(DatabaseIntegrationConnectionResolver())
@@ -19,41 +25,41 @@ def tearDownClass(cls):
Model.unset_connection_resolver()
def setUp(self):
- with self.schema().create('test_users') as table:
- table.increments('id')
- table.string('email').unique()
+ with self.schema().create("test_users") as table:
+ table.increments("id")
+ table.string("email").unique()
table.timestamps()
- with self.schema().create('test_friends') as table:
- table.increments('id')
- table.integer('user_id')
- table.integer('friend_id')
+ with self.schema().create("test_friends") as table:
+ table.increments("id")
+ table.integer("user_id")
+ table.integer("friend_id")
- with self.schema().create('test_posts') as table:
- table.increments('id')
- table.integer('user_id')
- table.string('name')
+ with self.schema().create("test_posts") as table:
+ table.increments("id")
+ table.integer("user_id")
+ table.string("name")
table.timestamps()
table.soft_deletes()
- with self.schema().create('test_photos') as table:
- table.increments('id')
- table.morphs('imageable')
- table.string('name')
+ with self.schema().create("test_photos") as table:
+ table.increments("id")
+ table.morphs("imageable")
+ table.string("name")
table.timestamps()
def tearDown(self):
- self.schema().drop('test_users')
- self.schema().drop('test_friends')
- self.schema().drop('test_posts')
- self.schema().drop('test_photos')
+ self.schema().drop("test_users")
+ self.schema().drop("test_friends")
+ self.schema().drop("test_posts")
+ self.schema().drop("test_photos")
def test_extra_queries_are_properly_set_on_relations(self):
self.create()
# With eager loading
- user = OratorTestUser.with_('friends', 'posts', 'post', 'photos').find(1)
- post = OratorTestPost.with_('user', 'photos').find(1)
+ user = OratorTestUser.with_("friends", "posts", "post", "photos").find(1)
+ post = OratorTestPost.with_("user", "photos").find(1)
self.assertEqual(1, len(user.friends))
self.assertEqual(2, len(user.posts))
self.assertIsInstance(user.post, OratorTestPost)
@@ -63,27 +69,27 @@ def test_extra_queries_are_properly_set_on_relations(self):
self.assertEqual(
'SELECT * FROM "test_users" INNER JOIN "test_friends" ON "test_users"."id" = "test_friends"."friend_id" '
'WHERE "test_friends"."user_id" = ? ORDER BY "friend_id" ASC',
- user.friends().to_sql()
+ user.friends().to_sql(),
)
self.assertEqual(
'SELECT * FROM "test_posts" WHERE "deleted_at" IS NULL AND "test_posts"."user_id" = ?',
- user.posts().to_sql()
+ user.posts().to_sql(),
)
self.assertEqual(
'SELECT * FROM "test_posts" WHERE "test_posts"."user_id" = ? ORDER BY "name" DESC',
- user.post().to_sql()
+ user.post().to_sql(),
)
self.assertEqual(
'SELECT * FROM "test_photos" WHERE "name" IS NOT NULL AND "test_photos"."imageable_id" = ? AND "test_photos"."imageable_type" = ?',
- user.photos().to_sql()
+ user.photos().to_sql(),
)
self.assertEqual(
'SELECT * FROM "test_users" WHERE "test_users"."id" = ? ORDER BY "id" ASC',
- post.user().to_sql()
+ post.user().to_sql(),
)
self.assertEqual(
'SELECT * FROM "test_photos" WHERE "test_photos"."imageable_id" = ? AND "test_photos"."imageable_type" = ?',
- post.photos().to_sql()
+ post.photos().to_sql(),
)
# Without eager loading
@@ -98,45 +104,48 @@ def test_extra_queries_are_properly_set_on_relations(self):
self.assertEqual(
'SELECT * FROM "test_users" INNER JOIN "test_friends" ON "test_users"."id" = "test_friends"."friend_id" '
'WHERE "test_friends"."user_id" = ? ORDER BY "friend_id" ASC',
- user.friends().to_sql()
+ user.friends().to_sql(),
)
self.assertEqual(
'SELECT * FROM "test_posts" WHERE "deleted_at" IS NULL AND "test_posts"."user_id" = ?',
- user.posts().to_sql()
+ user.posts().to_sql(),
)
self.assertEqual(
'SELECT * FROM "test_posts" WHERE "test_posts"."user_id" = ? ORDER BY "name" DESC',
- user.post().to_sql()
+ user.post().to_sql(),
)
self.assertEqual(
'SELECT * FROM "test_photos" WHERE "name" IS NOT NULL AND "test_photos"."imageable_id" = ? AND "test_photos"."imageable_type" = ?',
- user.photos().to_sql()
+ user.photos().to_sql(),
)
self.assertEqual(
'SELECT * FROM "test_users" WHERE "test_users"."id" = ? ORDER BY "id" ASC',
- post.user().to_sql()
+ post.user().to_sql(),
)
self.assertEqual(
'SELECT * FROM "test_photos" WHERE "test_photos"."imageable_id" = ? AND "test_photos"."imageable_type" = ?',
- post.photos().to_sql()
+ post.photos().to_sql(),
)
+ self.assertEqual(
+ 'SELECT DISTINCT * FROM "test_posts" WHERE "deleted_at" IS NULL AND "test_posts"."user_id" = ? ORDER BY "user_id" ASC',
+ user.posts().order_by("user_id").distinct().to_sql(),
+ )
def create(self):
- user = OratorTestUser.create(id=1, email='john@doe.com')
- friend = OratorTestUser.create(id=2, email='jane@doe.com')
+ user = OratorTestUser.create(id=1, email="john@doe.com")
+ friend = OratorTestUser.create(id=2, email="jane@doe.com")
user.friends().attach(friend)
- post1 = user.posts().create(name='First Post')
- post2 = user.posts().create(name='Second Post')
+ post1 = user.posts().create(name="First Post")
+ post2 = user.posts().create(name="Second Post")
- user.photos().create(name='Avatar 1')
- user.photos().create(name='Avatar 2')
- user.photos().create(name='Avatar 3')
-
- post1.photos().create(name='Hero 1')
- post1.photos().create(name='Hero 2')
+ user.photos().create(name="Avatar 1")
+ user.photos().create(name="Avatar 2")
+ user.photos().create(name="Avatar 3")
+ post1.photos().create(name="Hero 1")
+ post1.photos().create(name="Hero 2")
def connection(self):
return Model.get_connection_resolver().connection()
@@ -152,43 +161,43 @@ class Model(BaseModel):
class OratorTestUser(Model):
- __table__ = 'test_users'
+ __table__ = "test_users"
__guarded__ = []
- @belongs_to_many('test_friends', 'user_id', 'friend_id', with_pivot=['id'])
+ @belongs_to_many("test_friends", "user_id", "friend_id", with_pivot=["id"])
def friends(self):
- return OratorTestUser.order_by('friend_id')
+ return OratorTestUser.order_by("friend_id")
- @has_many('user_id')
+ @has_many("user_id")
def posts(self):
- return OratorTestPost.where_null('deleted_at')
+ return OratorTestPost.where_null("deleted_at")
- @has_one('user_id')
+ @has_one("user_id")
def post(self):
- return OratorTestPost.order_by('name', 'desc')
+ return OratorTestPost.order_by("name", "desc")
- @morph_many('imageable')
+ @morph_many("imageable")
def photos(self):
- return OratorTestPhoto.where_not_null('name')
+ return OratorTestPhoto.where_not_null("name")
class OratorTestPost(Model):
- __table__ = 'test_posts'
+ __table__ = "test_posts"
__guarded__ = []
- @belongs_to('user_id')
+ @belongs_to("user_id")
def user(self):
- return OratorTestUser.order_by('id')
+ return OratorTestUser.order_by("id")
- @morph_many('imageable')
+ @morph_many("imageable")
def photos(self):
- return 'test_photos'
+ return "test_photos"
class OratorTestPhoto(Model):
- __table__ = 'test_photos'
+ __table__ = "test_photos"
__guarded__ = []
@morph_to
@@ -204,12 +213,14 @@ def connection(self, name=None):
if self._connection:
return self._connection
- self._connection = SQLiteConnection(SQLiteConnector().connect({'database': ':memory:'}))
+ self._connection = SQLiteConnection(
+ SQLiteConnector().connect({"database": ":memory:"})
+ )
return self._connection
def get_default_connection(self):
- return 'default'
+ return "default"
def set_default_connection(self, name):
pass
diff --git a/tests/orm/relations/test_has_many.py b/tests/orm/relations/test_has_many.py
index e25b3879..16a1e0ea 100644
--- a/tests/orm/relations/test_has_many.py
+++ b/tests/orm/relations/test_has_many.py
@@ -15,131 +15,168 @@
class OrmHasManyTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
def test_create_properly_creates_new_model(self):
relation = self._get_relation()
created = flexmock(Model(), save=lambda: True, set_attribute=lambda: None)
- created.should_receive('save').once().and_return(True)
- relation.get_related().should_receive('new_instance').once().with_args({'name': 'john'}).and_return(created)
- created.should_receive('set_attribute').with_args('foreign_key', 1)
+ created.should_receive("save").once().and_return(True)
+ relation.get_related().should_receive("new_instance").once().with_args(
+ {"name": "john"}
+ ).and_return(created)
+ created.should_receive("set_attribute").with_args("foreign_key", 1)
- self.assertEqual(created, relation.create(name='john'))
+ self.assertEqual(created, relation.create(name="john"))
def test_find_or_new_finds_model(self):
relation = self._get_relation()
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('find').once().with_args('foo', ['*']).and_return(model)
- model.should_receive('set_attribute').never()
+ model.foo = "bar"
+ relation.get_query().should_receive("find").once().with_args(
+ "foo", ["*"]
+ ).and_return(model)
+ model.should_receive("set_attribute").never()
- self.assertEqual('bar', relation.find_or_new('foo').foo)
+ self.assertEqual("bar", relation.find_or_new("foo").foo)
def test_find_or_new_returns_new_model_with_foreign_key_set(self):
relation = self._get_relation()
- relation.get_query().should_receive('find').once().with_args('foo', ['*']).and_return(None)
+ relation.get_query().should_receive("find").once().with_args(
+ "foo", ["*"]
+ ).and_return(None)
model = flexmock()
- model.foo = 'bar'
- relation.get_related().should_receive('new_instance').once().with_args().and_return(model)
- model.should_receive('set_attribute').once().with_args('foreign_key', 1)
+ model.foo = "bar"
+ relation.get_related().should_receive(
+ "new_instance"
+ ).once().with_args().and_return(model)
+ model.should_receive("set_attribute").once().with_args("foreign_key", 1)
- self.assertEqual('bar', relation.find_or_new('foo').foo)
+ self.assertEqual("bar", relation.find_or_new("foo").foo)
def test_first_or_new_finds_first_model(self):
relation = self._get_relation()
- relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query())
+ relation.get_query().should_receive("where").once().with_args(
+ {"foo": "bar"}
+ ).and_return(relation.get_query())
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('first').once().with_args().and_return(model)
- model.should_receive('set_attribute').never()
+ model.foo = "bar"
+ relation.get_query().should_receive("first").once().with_args().and_return(
+ model
+ )
+ model.should_receive("set_attribute").never()
- self.assertEqual('bar', relation.first_or_new(foo='bar').foo)
+ self.assertEqual("bar", relation.first_or_new(foo="bar").foo)
def test_first_or_new_returns_new_model_with_foreign_key_set(self):
relation = self._get_relation()
- relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query())
- relation.get_query().should_receive('first').once().with_args().and_return(None)
+ relation.get_query().should_receive("where").once().with_args(
+ {"foo": "bar"}
+ ).and_return(relation.get_query())
+ relation.get_query().should_receive("first").once().with_args().and_return(None)
model = flexmock()
- model.foo = 'bar'
- relation.get_related().should_receive('new_instance').once().with_args().and_return(model)
- model.should_receive('set_attribute').once().with_args('foreign_key', 1)
+ model.foo = "bar"
+ relation.get_related().should_receive(
+ "new_instance"
+ ).once().with_args().and_return(model)
+ model.should_receive("set_attribute").once().with_args("foreign_key", 1)
- self.assertEqual('bar', relation.first_or_new(foo='bar').foo)
+ self.assertEqual("bar", relation.first_or_new(foo="bar").foo)
def test_first_or_create_finds_first_model(self):
relation = self._get_relation()
- relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query())
+ relation.get_query().should_receive("where").once().with_args(
+ {"foo": "bar"}
+ ).and_return(relation.get_query())
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('first').once().with_args().and_return(model)
- model.should_receive('set_attribute').never()
+ model.foo = "bar"
+ relation.get_query().should_receive("first").once().with_args().and_return(
+ model
+ )
+ model.should_receive("set_attribute").never()
- self.assertEqual('bar', relation.first_or_create(foo='bar').foo)
+ self.assertEqual("bar", relation.first_or_create(foo="bar").foo)
def test_first_or_create_returns_new_model_with_foreign_key_set(self):
relation = self._get_relation()
- relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query())
- relation.get_query().should_receive('first').once().with_args().and_return(None)
+ relation.get_query().should_receive("where").once().with_args(
+ {"foo": "bar"}
+ ).and_return(relation.get_query())
+ relation.get_query().should_receive("first").once().with_args().and_return(None)
model = flexmock()
- model.foo = 'bar'
- relation.get_related().should_receive('new_instance').once().with_args({'foo': 'bar'}).and_return(model)
- model.should_receive('save').once().and_return(True)
- model.should_receive('set_attribute').once().with_args('foreign_key', 1)
+ model.foo = "bar"
+ relation.get_related().should_receive("new_instance").once().with_args(
+ {"foo": "bar"}
+ ).and_return(model)
+ model.should_receive("save").once().and_return(True)
+ model.should_receive("set_attribute").once().with_args("foreign_key", 1)
- self.assertEqual('bar', relation.first_or_create(foo='bar').foo)
+ self.assertEqual("bar", relation.first_or_create(foo="bar").foo)
def test_update_or_create_finds_first_model_and_updates(self):
relation = self._get_relation()
- relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query())
+ relation.get_query().should_receive("where").once().with_args(
+ {"foo": "bar"}
+ ).and_return(relation.get_query())
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('first').once().with_args().and_return(model)
- relation.get_related().should_receive('new_instance').never()
- model.should_receive('fill').once().with_args({'foo': 'baz'})
- model.should_receive('save').once()
-
- self.assertEqual('bar', relation.update_or_create({'foo': 'bar'}, {'foo': 'baz'}).foo)
+ model.foo = "bar"
+ relation.get_query().should_receive("first").once().with_args().and_return(
+ model
+ )
+ relation.get_related().should_receive("new_instance").never()
+ model.should_receive("fill").once().with_args({"foo": "baz"})
+ model.should_receive("save").once()
+
+ self.assertEqual(
+ "bar", relation.update_or_create({"foo": "bar"}, {"foo": "baz"}).foo
+ )
def test_update_or_create_creates_new_model_with_foreign_key_set(self):
relation = self._get_relation()
- relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query())
+ relation.get_query().should_receive("where").once().with_args(
+ {"foo": "bar"}
+ ).and_return(relation.get_query())
- relation.get_query().should_receive('first').once().with_args().and_return(None)
+ relation.get_query().should_receive("first").once().with_args().and_return(None)
model = flexmock()
- model.foo = 'bar'
- relation.get_related().should_receive('new_instance').once().and_return(model)
- model.should_receive('fill').once().with_args({'foo': 'baz'})
- model.should_receive('save').once()
- model.should_receive('set_attribute').once().with_args('foreign_key', 1)
+ model.foo = "bar"
+ relation.get_related().should_receive("new_instance").once().and_return(model)
+ model.should_receive("fill").once().with_args({"foo": "baz"})
+ model.should_receive("save").once()
+ model.should_receive("set_attribute").once().with_args("foreign_key", 1)
- self.assertEqual('bar', relation.update_or_create({'foo': 'bar'}, {'foo': 'baz'}).foo)
+ self.assertEqual(
+ "bar", relation.update_or_create({"foo": "bar"}, {"foo": "baz"}).foo
+ )
def test_update_updates_models_with_timestamps(self):
relation = self._get_relation()
- relation.get_related().should_receive('uses_timestamps').once().and_return(True)
+ relation.get_related().should_receive("uses_timestamps").once().and_return(True)
now = pendulum.now()
- relation.get_related().should_receive('fresh_timestamp').once().and_return(now)
- relation.get_query().should_receive('update').once().with_args({'foo': 'bar', 'updated_at': now}).and_return('results')
+ relation.get_related().should_receive("fresh_timestamp").once().and_return(now)
+ relation.get_query().should_receive("update").once().with_args(
+ {"foo": "bar", "updated_at": now}
+ ).and_return("results")
- self.assertEqual('results', relation.update(foo='bar'))
+ self.assertEqual("results", relation.update(foo="bar"))
def test_relation_is_properly_initialized(self):
relation = self._get_relation()
model = flexmock(Model())
- model.should_receive('set_relation').once().with_args('foo', Collection)
- models = relation.init_relation([model], 'foo')
+ model.should_receive("set_relation").once().with_args("foo", Collection)
+ models = relation.init_relation([model], "foo")
self.assertEqual([model], models)
def test_eager_constraints_are_properly_added(self):
relation = self._get_relation()
- relation.get_query().get_query().should_receive('where_in').once().with_args('table.foreign_key', [1, 2])
+ relation.get_query().get_query().should_receive("where_in").once().with_args(
+ "table.foreign_key", [1, 2]
+ )
model1 = OrmHasOneModelStub()
model1.id = 1
@@ -148,6 +185,21 @@ def test_eager_constraints_are_properly_added(self):
relation.add_eager_constraints([model1, model2])
+ def test_save_many_returns_list_of_models(self):
+ relation = self._get_relation()
+
+ model1 = flexmock()
+ model1.foo = "foo"
+ model1.should_receive("save").once().and_return(True)
+ model1.should_receive("set_attribute").once().with_args("foreign_key", 1)
+
+ model2 = flexmock()
+ model2.foo = "bar"
+ model2.should_receive("save").once().and_return(True)
+ model2.should_receive("set_attribute").once().with_args("foreign_key", 1)
+
+ self.assertEqual([model1, model2], relation.save_many([model1, model2]))
+
def test_models_are_properly_matched_to_parents(self):
relation = self._get_relation()
@@ -165,11 +217,19 @@ def test_models_are_properly_matched_to_parents(self):
model3 = OrmHasOneModelStub()
model3.id = 3
- relation.get_related().should_receive('new_collection').replace_with(lambda l=None: Collection(l))
- relation.get_query().should_receive('where').with_args('table.foreign_key', '=', 2)
- relation.get_query().should_receive('where').with_args('table.foreign_key', '=', 3)
+ relation.get_related().should_receive("new_collection").replace_with(
+ lambda l=None: Collection(l)
+ )
+ relation.get_query().should_receive("where").with_args(
+ "table.foreign_key", "=", 2
+ )
+ relation.get_query().should_receive("where").with_args(
+ "table.foreign_key", "=", 3
+ )
- models = relation.match([model1, model2, model3], Collection([result1, result2, result3]), 'foo')
+ models = relation.match(
+ [model1, model2, model3], Collection([result1, result2, result3]), "foo"
+ )
self.assertEqual(1, models[0].foo[0].foreign_key)
self.assertEqual(1, len(models[0].foo))
@@ -182,14 +242,16 @@ def test_relation_count_query_can_be_built(self):
relation = self._get_relation()
query = flexmock(QueryBuilder(None, QueryGrammar(), None))
builder = Builder(query)
- builder.get_query().should_receive('select').once()
- relation.get_parent().should_receive('get_table').and_return('table')
- builder.should_receive('where').once().with_args('table.foreign_key', '=', QueryExpression)
+ builder.get_query().should_receive("select").once()
+ relation.get_parent().should_receive("get_table").and_return("table")
+ builder.should_receive("where").once().with_args(
+ "table.foreign_key", "=", QueryExpression
+ )
parent_query = flexmock(QueryBuilder(None, None, None))
- relation.get_query().should_receive('get_query').and_return(parent_query)
+ relation.get_query().should_receive("get_query").and_return(parent_query)
grammar = flexmock()
- parent_query.should_receive('get_grammar').once().and_return(grammar)
- grammar.should_receive('wrap').once().with_args('table.id')
+ parent_query.should_receive("get_grammar").once().and_return(grammar)
+ grammar.should_receive("wrap").once().with_args("table.id")
relation.get_relation_count_query(builder, builder)
@@ -197,17 +259,17 @@ def _get_relation(self):
flexmock(Builder)
query = flexmock(QueryBuilder(None, QueryGrammar(), None))
builder = Builder(query)
- builder.should_receive('where').with_args('table.foreign_key', '=', 1)
+ builder.should_receive("where").with_args("table.foreign_key", "=", 1)
related = flexmock(Model())
- related.should_receive('new_query').and_return(builder)
- builder.should_receive('get_model').and_return(related)
+ related.should_receive("new_query").and_return(builder)
+ builder.should_receive("get_model").and_return(related)
parent = flexmock(Model())
- parent.should_receive('get_attribute').with_args('id').and_return(1)
- parent.should_receive('get_created_at_column').and_return('created_at')
- parent.should_receive('get_updated_at_column').and_return('updated_at')
- parent.should_receive('new_query').and_return(builder)
+ parent.should_receive("get_attribute").with_args("id").and_return(1)
+ parent.should_receive("get_created_at_column").and_return("created_at")
+ parent.should_receive("get_updated_at_column").and_return("updated_at")
+ parent.should_receive("new_query").and_return(builder)
- return HasMany(builder, parent, 'table.foreign_key', 'id')
+ return HasMany(builder, parent, "table.foreign_key", "id")
class OrmHasOneModelStub(Model):
diff --git a/tests/orm/relations/test_has_many_through.py b/tests/orm/relations/test_has_many_through.py
index ed2110c9..6a1c8f6b 100644
--- a/tests/orm/relations/test_has_many_through.py
+++ b/tests/orm/relations/test_has_many_through.py
@@ -16,22 +16,25 @@
class OrmHasManyThroughTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
def test_relation_is_properly_initialized(self):
relation = self._get_relation()
model = flexmock(Model())
- relation.get_related().should_receive('new_collection').replace_with(lambda l=None: Collection(l or []))
- model.should_receive('set_relation').once().with_args('foo', Collection)
- models = relation.init_relation([model], 'foo')
+ relation.get_related().should_receive("new_collection").replace_with(
+ lambda l=None: Collection(l or [])
+ )
+ model.should_receive("set_relation").once().with_args("foo", Collection)
+ models = relation.init_relation([model], "foo")
self.assertEqual([model], models)
def test_eager_constraints_are_properly_added(self):
relation = self._get_relation()
- relation.get_query().get_query().should_receive('where_in').once().with_args('users.country_id', [1, 2])
+ relation.get_query().get_query().should_receive("where_in").once().with_args(
+ "users.country_id", [1, 2]
+ )
model1 = OrmHasManyThroughModelStub()
model1.id = 1
model2 = OrmHasManyThroughModelStub()
@@ -55,10 +58,18 @@ def test_models_are_properly_matched_to_parents(self):
model3 = OrmHasManyThroughModelStub()
model3.id = 3
- relation.get_related().should_receive('new_collection').replace_with(lambda l=None: Collection(l or []))
- relation.get_query().should_receive('where').with_args('users.country_id', '=', 2)
- relation.get_query().should_receive('where').with_args('users.country_id', '=', 3)
- models = relation.match([model1, model2, model3], Collection([result1, result2, result3]), 'foo')
+ relation.get_related().should_receive("new_collection").replace_with(
+ lambda l=None: Collection(l or [])
+ )
+ relation.get_query().should_receive("where").with_args(
+ "users.country_id", "=", 2
+ )
+ relation.get_query().should_receive("where").with_args(
+ "users.country_id", "=", 3
+ )
+ models = relation.match(
+ [model1, model2, model3], Collection([result1, result2, result3]), "foo"
+ )
self.assertEqual(1, models[0].foo[0].country_id)
self.assertEqual(1, len(models[0].foo))
@@ -71,8 +82,10 @@ def test_get(self):
relation = self._get_relation()
query = relation.get_query()
- query.get_query().should_receive('add_select').once().with_args('posts.*', 'users.country_id').and_return(query)
- query.should_receive('get_models').and_return([])
+ query.get_query().should_receive("add_select").once().with_args(
+ "posts.*", "users.country_id"
+ ).and_return(query)
+ query.should_receive("get_models").and_return([])
relation.get()
@@ -80,31 +93,33 @@ def _get_relation(self):
flexmock(Builder)
query = flexmock(QueryBuilder(None, QueryGrammar(), None))
builder = Builder(query)
- builder.get_query().should_receive('join').at_least().once().with_args('users', 'users.id', '=', 'posts.user_id')
- builder.should_receive('where').with_args('users.country_id', '=', 1)
+ builder.get_query().should_receive("join").at_least().once().with_args(
+ "users", "users.id", "=", "posts.user_id"
+ )
+ builder.should_receive("where").with_args("users.country_id", "=", 1)
country = flexmock(Model())
- country.should_receive('get_key').and_return(1)
- country.should_receive('get_foreign_key').and_return('country_id')
+ country.should_receive("get_key").and_return(1)
+ country.should_receive("get_foreign_key").and_return("country_id")
user = flexmock(Model())
- user.should_receive('get_table').and_return('users')
- user.should_receive('get_qualified_key_name').and_return('users.id')
+ user.should_receive("get_table").and_return("users")
+ user.should_receive("get_qualified_key_name").and_return("users.id")
post = flexmock(Model())
- post.should_receive('get_table').and_return('posts')
- builder.should_receive('get_model').and_return(post)
+ post.should_receive("get_table").and_return("posts")
+ builder.should_receive("get_model").and_return(post)
- post.should_receive('new_query').and_return(builder)
+ post.should_receive("new_query").and_return(builder)
- user.should_receive('get_key').and_return(1)
- user.should_receive('get_created_at_column').and_return('created_at')
- user.should_receive('get_updated_at_column').and_return('updated_at')
+ user.should_receive("get_key").and_return(1)
+ user.should_receive("get_created_at_column").and_return("created_at")
+ user.should_receive("get_updated_at_column").and_return("updated_at")
parent = flexmock(Model())
- parent.should_receive('get_attribute').with_args('id').and_return(1)
- parent.should_receive('get_created_at_column').and_return('created_at')
- parent.should_receive('get_updated_at_column').and_return('updated_at')
- parent.should_receive('new_query').and_return(builder)
+ parent.should_receive("get_attribute").with_args("id").and_return(1)
+ parent.should_receive("get_created_at_column").and_return("created_at")
+ parent.should_receive("get_updated_at_column").and_return("updated_at")
+ parent.should_receive("new_query").and_return(builder)
- return HasManyThrough(builder, country, user, 'country_id', 'user_id')
+ return HasManyThrough(builder, country, user, "country_id", "user_id")
class OrmHasManyThroughModelStub(Model):
diff --git a/tests/orm/relations/test_has_one.py b/tests/orm/relations/test_has_one.py
index 93e21ed8..20fb48a0 100644
--- a/tests/orm/relations/test_has_one.py
+++ b/tests/orm/relations/test_has_one.py
@@ -15,48 +15,53 @@
class OrmHasOneTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
def test_save_method_set_foreign_key_on_model(self):
relation = self._get_relation()
mock_model = flexmock(Model(), save=lambda: True)
- mock_model.should_receive('save').once().and_return(True)
+ mock_model.should_receive("save").once().and_return(True)
result = relation.save(mock_model)
attributes = result.get_attributes()
- self.assertEqual(1, attributes['foreign_key'])
+ self.assertEqual(1, attributes["foreign_key"])
def test_create_properly_creates_new_model(self):
relation = self._get_relation()
created = flexmock(Model(), save=lambda: True, set_attribute=lambda: None)
- created.should_receive('save').once().and_return(True)
- relation.get_related().should_receive('new_instance').once().with_args({'name': 'john'}).and_return(created)
- created.should_receive('set_attribute').with_args('foreign_key', 1)
+ created.should_receive("save").once().and_return(True)
+ relation.get_related().should_receive("new_instance").once().with_args(
+ {"name": "john"}
+ ).and_return(created)
+ created.should_receive("set_attribute").with_args("foreign_key", 1)
- self.assertEqual(created, relation.create(name='john'))
+ self.assertEqual(created, relation.create(name="john"))
def test_update_updates_models_with_timestamps(self):
relation = self._get_relation()
- relation.get_related().should_receive('uses_timestamps').once().and_return(True)
+ relation.get_related().should_receive("uses_timestamps").once().and_return(True)
now = pendulum.now()
- relation.get_related().should_receive('fresh_timestamp').once().and_return(now)
- relation.get_query().should_receive('update').once().with_args({'foo': 'bar', 'updated_at': now}).and_return('results')
+ relation.get_related().should_receive("fresh_timestamp").once().and_return(now)
+ relation.get_query().should_receive("update").once().with_args(
+ {"foo": "bar", "updated_at": now}
+ ).and_return("results")
- self.assertEqual('results', relation.update(foo='bar'))
+ self.assertEqual("results", relation.update(foo="bar"))
def test_relation_is_properly_initialized(self):
relation = self._get_relation()
model = flexmock(Model())
- model.should_receive('set_relation').once().with_args('foo', None)
- models = relation.init_relation([model], 'foo')
+ model.should_receive("set_relation").once().with_args("foo", None)
+ models = relation.init_relation([model], "foo")
self.assertEqual([model], models)
def test_eager_constraints_are_properly_added(self):
relation = self._get_relation()
- relation.get_query().get_query().should_receive('where_in').once().with_args('table.foreign_key', [1, 2])
+ relation.get_query().get_query().should_receive("where_in").once().with_args(
+ "table.foreign_key", [1, 2]
+ )
model1 = OrmHasOneModelStub()
model1.id = 1
@@ -80,10 +85,16 @@ def test_models_are_properly_matched_to_parents(self):
model3 = OrmHasOneModelStub()
model3.id = 3
- relation.get_query().should_receive('where').with_args('table.foreign_key', '=', 2)
- relation.get_query().should_receive('where').with_args('table.foreign_key', '=', 3)
+ relation.get_query().should_receive("where").with_args(
+ "table.foreign_key", "=", 2
+ )
+ relation.get_query().should_receive("where").with_args(
+ "table.foreign_key", "=", 3
+ )
- models = relation.match([model1, model2, model3], Collection([result1, result2]), 'foo')
+ models = relation.match(
+ [model1, model2, model3], Collection([result1, result2]), "foo"
+ )
self.assertEqual(1, models[0].foo.foreign_key)
self.assertEqual(2, models[1].foo.foreign_key)
@@ -93,14 +104,16 @@ def test_relation_count_query_can_be_built(self):
relation = self._get_relation()
query = flexmock(QueryBuilder(None, QueryGrammar(), None))
builder = Builder(query)
- builder.get_query().should_receive('select').once()
- relation.get_parent().should_receive('get_table').and_return('table')
- builder.should_receive('where').once().with_args('table.foreign_key', '=', QueryExpression)
+ builder.get_query().should_receive("select").once()
+ relation.get_parent().should_receive("get_table").and_return("table")
+ builder.should_receive("where").once().with_args(
+ "table.foreign_key", "=", QueryExpression
+ )
parent_query = flexmock(QueryBuilder(None, None, None))
- relation.get_query().should_receive('get_query').and_return(parent_query)
+ relation.get_query().should_receive("get_query").and_return(parent_query)
grammar = flexmock()
- parent_query.should_receive('get_grammar').once().and_return(grammar)
- grammar.should_receive('wrap').once().with_args('table.id')
+ parent_query.should_receive("get_grammar").once().and_return(grammar)
+ grammar.should_receive("wrap").once().with_args("table.id")
relation.get_relation_count_query(builder, builder)
@@ -108,18 +121,18 @@ def _get_relation(self):
flexmock(Builder)
query = flexmock(QueryBuilder(None, QueryGrammar(), None))
builder = Builder(query)
- builder.should_receive('where').with_args('table.foreign_key', '=', 1)
+ builder.should_receive("where").with_args("table.foreign_key", "=", 1)
related = flexmock(Model())
related_query = QueryBuilder(None, QueryGrammar(), None)
- related.should_receive('new_query').and_return(Builder(related_query))
- builder.should_receive('get_model').and_return(related)
+ related.should_receive("new_query").and_return(Builder(related_query))
+ builder.should_receive("get_model").and_return(related)
parent = flexmock(Model())
- parent.should_receive('get_attribute').with_args('id').and_return(1)
- parent.should_receive('get_created_at_column').and_return('created_at')
- parent.should_receive('get_updated_at_column').and_return('updated_at')
- parent.should_receive('new_query').and_return(builder)
+ parent.should_receive("get_attribute").with_args("id").and_return(1)
+ parent.should_receive("get_created_at_column").and_return("created_at")
+ parent.should_receive("get_updated_at_column").and_return("updated_at")
+ parent.should_receive("new_query").and_return(builder)
- return HasOne(builder, parent, 'table.foreign_key', 'id')
+ return HasOne(builder, parent, "table.foreign_key", "id")
class OrmHasOneModelStub(Model):
diff --git a/tests/orm/relations/test_morph.py b/tests/orm/relations/test_morph.py
index 714a2bb2..29b41c02 100644
--- a/tests/orm/relations/test_morph.py
+++ b/tests/orm/relations/test_morph.py
@@ -14,7 +14,6 @@
class OrmMorphTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
@@ -23,9 +22,12 @@ def test_morph_one_sets_proper_constraints(self):
def test_morph_one_eager_constraints_are_properly_added(self):
relation = self._get_one_relation()
- relation.get_query().get_query().should_receive('where_in').once().with_args('table.morph_id', [1, 2])
- relation.get_query().should_receive('where').once()\
- .with_args('table.morph_type', relation.get_parent().__class__.__name__)
+ relation.get_query().get_query().should_receive("where_in").once().with_args(
+ "table.morph_id", [1, 2]
+ )
+ relation.get_query().should_receive("where").once().with_args(
+ "table.morph_type", relation.get_parent().__class__.__name__
+ )
model1 = Model()
model1.id = 1
@@ -38,9 +40,12 @@ def test_morph_many_sets_proper_constraints(self):
def test_morph_many_eager_constraints_are_properly_added(self):
relation = self._get_many_relation()
- relation.get_query().get_query().should_receive('where_in').once().with_args('table.morph_id', [1, 2])
- relation.get_query().should_receive('where').once()\
- .with_args('table.morph_type', relation.get_parent().__class__.__name__)
+ relation.get_query().get_query().should_receive("where_in").once().with_args(
+ "table.morph_id", [1, 2]
+ )
+ relation.get_query().should_receive("where").once().with_args(
+ "table.morph_type", relation.get_parent().__class__.__name__
+ )
model1 = Model()
model1.id = 1
@@ -51,141 +56,190 @@ def test_morph_many_eager_constraints_are_properly_added(self):
def test_create(self):
relation = self._get_one_relation()
created = flexmock(Model())
- created.should_receive('set_attribute').once().with_args('morph_id', 1)
- created.should_receive('set_attribute').once()\
- .with_args('morph_type', relation.get_parent().__class__.__name__)
- relation.get_related().should_receive('new_instance').once().with_args({'name': 'john'}).and_return(created)
- created.should_receive('save').once().and_return(True)
+ created.should_receive("set_attribute").once().with_args("morph_id", 1)
+ created.should_receive("set_attribute").once().with_args(
+ "morph_type", relation.get_parent().__class__.__name__
+ )
+ relation.get_related().should_receive("new_instance").once().with_args(
+ {"name": "john"}
+ ).and_return(created)
+ created.should_receive("save").once().and_return(True)
- self.assertEqual(created, relation.create(name='john'))
+ self.assertEqual(created, relation.create(name="john"))
def test_find_or_new_finds_model(self):
relation = self._get_one_relation()
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('find').once().with_args('foo', ['*']).and_return(model)
- relation.get_related().should_receive('new_instance').never()
- model.should_receive('set_attribute').never()
- model.should_receive('save').never()
+ model.foo = "bar"
+ relation.get_query().should_receive("find").once().with_args(
+ "foo", ["*"]
+ ).and_return(model)
+ relation.get_related().should_receive("new_instance").never()
+ model.should_receive("set_attribute").never()
+ model.should_receive("save").never()
- self.assertEqual('bar', relation.find_or_new('foo').foo)
+ self.assertEqual("bar", relation.find_or_new("foo").foo)
def test_find_or_new_returns_new_model_with_morph_keys_set(self):
relation = self._get_one_relation()
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('find').once().with_args('foo', ['*']).and_return(None)
- relation.get_related().should_receive('new_instance').once().with_args().and_return(model)
- model.should_receive('set_attribute').once().with_args('morph_id', 1)
- model.should_receive('set_attribute').once().with_args('morph_type', relation.get_parent().__class__.__name__)
- model.should_receive('save').never()
-
- self.assertEqual('bar', relation.find_or_new('foo').foo)
+ model.foo = "bar"
+ relation.get_query().should_receive("find").once().with_args(
+ "foo", ["*"]
+ ).and_return(None)
+ relation.get_related().should_receive(
+ "new_instance"
+ ).once().with_args().and_return(model)
+ model.should_receive("set_attribute").once().with_args("morph_id", 1)
+ model.should_receive("set_attribute").once().with_args(
+ "morph_type", relation.get_parent().__class__.__name__
+ )
+ model.should_receive("save").never()
+
+ self.assertEqual("bar", relation.find_or_new("foo").foo)
def test_first_or_new_returns_first_model(self):
relation = self._get_one_relation()
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query())
- relation.get_query().should_receive('first').once().with_args().and_return(model)
- relation.get_related().should_receive('new_instance').never()
- model.should_receive('set_attribute').never()
- model.should_receive('set_attribute').never()
- model.should_receive('save').never()
-
- self.assertEqual('bar', relation.first_or_new({'foo': 'bar'}).foo)
+ model.foo = "bar"
+ relation.get_query().should_receive("where").once().with_args(
+ {"foo": "bar"}
+ ).and_return(relation.get_query())
+ relation.get_query().should_receive("first").once().with_args().and_return(
+ model
+ )
+ relation.get_related().should_receive("new_instance").never()
+ model.should_receive("set_attribute").never()
+ model.should_receive("set_attribute").never()
+ model.should_receive("save").never()
+
+ self.assertEqual("bar", relation.first_or_new({"foo": "bar"}).foo)
def test_first_or_new_returns_new_model_with_morph_keys_set(self):
relation = self._get_one_relation()
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query())
- relation.get_query().should_receive('first').once().with_args().and_return(None)
- relation.get_related().should_receive('new_instance').once().with_args().and_return(model)
- model.should_receive('set_attribute').once().with_args('morph_id', 1)
- model.should_receive('set_attribute').once().with_args('morph_type', relation.get_parent().__class__.__name__)
- model.should_receive('save').never()
-
- self.assertEqual('bar', relation.first_or_new({'foo': 'bar'}).foo)
+ model.foo = "bar"
+ relation.get_query().should_receive("where").once().with_args(
+ {"foo": "bar"}
+ ).and_return(relation.get_query())
+ relation.get_query().should_receive("first").once().with_args().and_return(None)
+ relation.get_related().should_receive(
+ "new_instance"
+ ).once().with_args().and_return(model)
+ model.should_receive("set_attribute").once().with_args("morph_id", 1)
+ model.should_receive("set_attribute").once().with_args(
+ "morph_type", relation.get_parent().__class__.__name__
+ )
+ model.should_receive("save").never()
+
+ self.assertEqual("bar", relation.first_or_new({"foo": "bar"}).foo)
def test_first_or_create_returns_first_model(self):
relation = self._get_one_relation()
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query())
- relation.get_query().should_receive('first').once().with_args().and_return(model)
- relation.get_related().should_receive('new_instance').never()
- model.should_receive('set_attribute').never()
- model.should_receive('set_attribute').never()
- model.should_receive('save').never()
-
- self.assertEqual('bar', relation.first_or_create({'foo': 'bar'}).foo)
+ model.foo = "bar"
+ relation.get_query().should_receive("where").once().with_args(
+ {"foo": "bar"}
+ ).and_return(relation.get_query())
+ relation.get_query().should_receive("first").once().with_args().and_return(
+ model
+ )
+ relation.get_related().should_receive("new_instance").never()
+ model.should_receive("set_attribute").never()
+ model.should_receive("set_attribute").never()
+ model.should_receive("save").never()
+
+ self.assertEqual("bar", relation.first_or_create({"foo": "bar"}).foo)
def test_first_or_create_creates_new_morph_model(self):
relation = self._get_one_relation()
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query())
- relation.get_query().should_receive('first').once().with_args().and_return(None)
- relation.get_related().should_receive('new_instance').once().with_args({'foo': 'bar'}).and_return(model)
- model.should_receive('set_attribute').once().with_args('morph_id', 1)
- model.should_receive('set_attribute').once().with_args('morph_type', relation.get_parent().__class__.__name__)
- model.should_receive('save').once().and_return()
-
- self.assertEqual('bar', relation.first_or_create({'foo': 'bar'}).foo)
+ model.foo = "bar"
+ relation.get_query().should_receive("where").once().with_args(
+ {"foo": "bar"}
+ ).and_return(relation.get_query())
+ relation.get_query().should_receive("first").once().with_args().and_return(None)
+ relation.get_related().should_receive("new_instance").once().with_args(
+ {"foo": "bar"}
+ ).and_return(model)
+ model.should_receive("set_attribute").once().with_args("morph_id", 1)
+ model.should_receive("set_attribute").once().with_args(
+ "morph_type", relation.get_parent().__class__.__name__
+ )
+ model.should_receive("save").once().and_return()
+
+ self.assertEqual("bar", relation.first_or_create({"foo": "bar"}).foo)
def test_update_or_create_finds_first_model_and_updates(self):
relation = self._get_one_relation()
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query())
- relation.get_query().should_receive('first').once().with_args().and_return(model)
- relation.get_related().should_receive('new_instance').never()
- model.should_receive('set_attribute').never()
- model.should_receive('set_attribute').never()
- model.should_receive('fill').once().with_args({'bar': 'baz'})
- model.should_receive('save').once().and_return(True)
-
- self.assertEqual('bar', relation.update_or_create({'foo': 'bar'}, {'bar': 'baz'}).foo)
+ model.foo = "bar"
+ relation.get_query().should_receive("where").once().with_args(
+ {"foo": "bar"}
+ ).and_return(relation.get_query())
+ relation.get_query().should_receive("first").once().with_args().and_return(
+ model
+ )
+ relation.get_related().should_receive("new_instance").never()
+ model.should_receive("set_attribute").never()
+ model.should_receive("set_attribute").never()
+ model.should_receive("fill").once().with_args({"bar": "baz"})
+ model.should_receive("save").once().and_return(True)
+
+ self.assertEqual(
+ "bar", relation.update_or_create({"foo": "bar"}, {"bar": "baz"}).foo
+ )
def test_update_or_create_finds_creates_new_morph_model(self):
relation = self._get_one_relation()
model = flexmock()
- model.foo = 'bar'
- relation.get_query().should_receive('where').once().with_args({'foo': 'bar'}).and_return(relation.get_query())
- relation.get_query().should_receive('first').once().with_args().and_return(None)
- relation.get_related().should_receive('new_instance').once().with_args().and_return(model)
- model.should_receive('set_attribute').once().with_args('morph_id', 1)
- model.should_receive('set_attribute').once().with_args('morph_type', relation.get_parent().__class__.__name__)
- model.should_receive('fill').once().with_args({'bar': 'baz'})
- model.should_receive('save').once().and_return(True)
-
- self.assertEqual('bar', relation.update_or_create({'foo': 'bar'}, {'bar': 'baz'}).foo)
+ model.foo = "bar"
+ relation.get_query().should_receive("where").once().with_args(
+ {"foo": "bar"}
+ ).and_return(relation.get_query())
+ relation.get_query().should_receive("first").once().with_args().and_return(None)
+ relation.get_related().should_receive(
+ "new_instance"
+ ).once().with_args().and_return(model)
+ model.should_receive("set_attribute").once().with_args("morph_id", 1)
+ model.should_receive("set_attribute").once().with_args(
+ "morph_type", relation.get_parent().__class__.__name__
+ )
+ model.should_receive("fill").once().with_args({"bar": "baz"})
+ model.should_receive("save").once().and_return(True)
+
+ self.assertEqual(
+ "bar", relation.update_or_create({"foo": "bar"}, {"bar": "baz"}).foo
+ )
def _get_many_relation(self):
flexmock(Builder)
query = flexmock(QueryBuilder(None, QueryGrammar(), None))
builder = Builder(query)
- builder.should_receive('where').with_args('table.morph_id', '=', 1)
+ builder.should_receive("where").with_args("table.morph_id", "=", 1)
related = flexmock(Model())
- builder.should_receive('get_model').and_return(related)
+ builder.should_receive("get_model").and_return(related)
parent = flexmock(Model())
- parent.should_receive('get_attribute').with_args('id').and_return(1)
- parent.should_receive('get_morph_name').and_return(parent.__class__.__name__)
- builder.should_receive('where').once().with_args('table.morph_type', parent.__class__.__name__)
+ parent.should_receive("get_attribute").with_args("id").and_return(1)
+ parent.should_receive("get_morph_name").and_return(parent.__class__.__name__)
+ builder.should_receive("where").once().with_args(
+ "table.morph_type", parent.__class__.__name__
+ )
- return MorphMany(builder, parent, 'table.morph_type', 'table.morph_id', 'id')
+ return MorphMany(builder, parent, "table.morph_type", "table.morph_id", "id")
def _get_one_relation(self):
flexmock(Builder)
query = flexmock(QueryBuilder(None, QueryGrammar(), None))
builder = Builder(query)
- builder.should_receive('where').with_args('table.morph_id', '=', 1)
+ builder.should_receive("where").with_args("table.morph_id", "=", 1)
related = flexmock(Model())
- builder.should_receive('get_model').and_return(related)
+ builder.should_receive("get_model").and_return(related)
parent = flexmock(Model())
- parent.should_receive('get_attribute').with_args('id').and_return(1)
- parent.should_receive('get_morph_name').and_return(parent.__class__.__name__)
- builder.should_receive('where').once().with_args('table.morph_type', parent.__class__.__name__)
+ parent.should_receive("get_attribute").with_args("id").and_return(1)
+ parent.should_receive("get_morph_name").and_return(parent.__class__.__name__)
+ builder.should_receive("where").once().with_args(
+ "table.morph_type", parent.__class__.__name__
+ )
- return MorphOne(builder, parent, 'table.morph_type', 'table.morph_id', 'id')
+ return MorphOne(builder, parent, "table.morph_type", "table.morph_id", "id")
diff --git a/tests/orm/relations/test_morph_to.py b/tests/orm/relations/test_morph_to.py
index d761b2ea..cbe12b32 100644
--- a/tests/orm/relations/test_morph_to.py
+++ b/tests/orm/relations/test_morph_to.py
@@ -14,7 +14,6 @@
class OrmMorphToTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
@@ -22,73 +21,79 @@ def test_lookup_dictionary_is_properly_constructed(self):
relation = self._get_relation()
one = flexmock()
- one.morph_type = 'morph_type_1'
- one.foreign_key = 'foreign_key_1'
+ one.morph_type = "morph_type_1"
+ one.foreign_key = "foreign_key_1"
two = flexmock()
- two.morph_type = 'morph_type_1'
- two.foreign_key = 'foreign_key_1'
+ two.morph_type = "morph_type_1"
+ two.foreign_key = "foreign_key_1"
three = flexmock()
- three.morph_type = 'morph_type_2'
- three.foreign_key = 'foreign_key_2'
+ three.morph_type = "morph_type_2"
+ three.foreign_key = "foreign_key_2"
relation.add_eager_constraints([one, two, three])
dictionary = relation.get_dictionary()
- self.assertEqual({
- 'morph_type_1': {
- 'foreign_key_1': [
- one,
- two
- ]
+ self.assertEqual(
+ {
+ "morph_type_1": {"foreign_key_1": [one, two]},
+ "morph_type_2": {"foreign_key_2": [three]},
},
- 'morph_type_2': {
- 'foreign_key_2': [three]
- }
- }, dictionary)
+ dictionary,
+ )
def test_models_are_properly_pulled_and_matched(self):
relation = self._get_relation()
one = flexmock(Model())
- one.morph_type = 'morph_type_1'
- one.foreign_key = 'foreign_key_1'
+ one.morph_type = "morph_type_1"
+ one.foreign_key = "foreign_key_1"
two = flexmock(Model())
- two.morph_type = 'morph_type_1'
- two.foreign_key = 'foreign_key_1'
+ two.morph_type = "morph_type_1"
+ two.foreign_key = "foreign_key_1"
three = flexmock(Model())
- three.morph_type = 'morph_type_2'
- three.foreign_key = 'foreign_key_2'
+ three.morph_type = "morph_type_2"
+ three.foreign_key = "foreign_key_2"
relation.add_eager_constraints([one, two, three])
- first_query = flexmock(Builder(flexmock(QueryBuilder(None, QueryGrammar(), None))))
- second_query = flexmock(Builder(flexmock(QueryBuilder(None, QueryGrammar(), None))))
+ first_query = flexmock(
+ Builder(flexmock(QueryBuilder(None, QueryGrammar(), None)))
+ )
+ second_query = flexmock(
+ Builder(flexmock(QueryBuilder(None, QueryGrammar(), None)))
+ )
first_model = flexmock(Model())
second_model = flexmock(Model())
- relation.should_receive('_create_model_by_type').once().with_args('morph_type_1').and_return(first_model)
- relation.should_receive('_create_model_by_type').once().with_args('morph_type_2').and_return(second_model)
- first_model.should_receive('get_key_name').and_return('id')
- second_model.should_receive('get_key_name').and_return('id')
-
- first_model.should_receive('new_query').once().and_return(first_query)
- second_model.should_receive('new_query').once().and_return(second_query)
-
- first_query.get_query().should_receive('where_in').once()\
- .with_args('id', ['foreign_key_1']).and_return(first_query)
+ relation.should_receive("_create_model_by_type").once().with_args(
+ "morph_type_1"
+ ).and_return(first_model)
+ relation.should_receive("_create_model_by_type").once().with_args(
+ "morph_type_2"
+ ).and_return(second_model)
+ first_model.should_receive("get_key_name").and_return("id")
+ second_model.should_receive("get_key_name").and_return("id")
+
+ first_model.should_receive("new_query").once().and_return(first_query)
+ second_model.should_receive("new_query").once().and_return(second_query)
+
+ first_query.get_query().should_receive("where_in").once().with_args(
+ "id", ["foreign_key_1"]
+ ).and_return(first_query)
result_one = flexmock()
- first_query.should_receive('get').and_return(Collection.make([result_one]))
- result_one.should_receive('get_key').and_return('foreign_key_1')
+ first_query.should_receive("get").and_return(Collection.make([result_one]))
+ result_one.should_receive("get_key").and_return("foreign_key_1")
- second_query.get_query().should_receive('where_in').once()\
- .with_args('id', ['foreign_key_2']).and_return(second_query)
+ second_query.get_query().should_receive("where_in").once().with_args(
+ "id", ["foreign_key_2"]
+ ).and_return(second_query)
result_two = flexmock()
- second_query.should_receive('get').and_return(Collection.make([result_two]))
- result_two.should_receive('get_key').and_return('foreign_key_2')
+ second_query.should_receive("get").and_return(Collection.make([result_two]))
+ result_two.should_receive("get_key").and_return("foreign_key_2")
- one.should_receive('set_relation').once().with_args('relation', result_one)
- two.should_receive('set_relation').once().with_args('relation', result_one)
- three.should_receive('set_relation').once().with_args('relation', result_two)
+ one.should_receive("set_relation").once().with_args("relation", result_one)
+ two.should_receive("set_relation").once().with_args("relation", result_one)
+ three.should_receive("set_relation").once().with_args("relation", result_two)
relation.get_eager()
@@ -96,17 +101,19 @@ def test_models_are_properly_pulled_and_matched(self):
def test_associate_sets_foreign_key_and_type_on_model(self):
parent = flexmock(Model())
- parent.should_receive('get_attribute').once().with_args('foreign_key').and_return('foreign.value')
+ parent.should_receive("get_attribute").once().with_args(
+ "foreign_key"
+ ).and_return("foreign.value")
relation = self._get_relation_associate(parent)
associate = flexmock(Model())
- associate.should_receive('get_key').once().and_return(1)
- associate.should_receive('get_morph_name').once().and_return('Model')
+ associate.should_receive("get_key").once().and_return(1)
+ associate.should_receive("get_morph_name").once().and_return("Model")
- parent.should_receive('set_attribute').once().with_args('foreign_key', 1)
- parent.should_receive('set_attribute').once().with_args('morph_type', 'Model')
- parent.should_receive('set_relation').once().with_args('relation', associate)
+ parent.should_receive("set_attribute").once().with_args("foreign_key", 1)
+ parent.should_receive("set_attribute").once().with_args("morph_type", "Model")
+ parent.should_receive("set_relation").once().with_args("relation", associate)
relation.associate(associate)
@@ -114,30 +121,30 @@ def _get_relation_associate(self, parent):
flexmock(Builder)
query = flexmock(QueryBuilder(None, QueryGrammar(), None))
builder = Builder(query)
- builder.should_receive('where').with_args('relation.id', '=', 'foreign.value')
+ builder.should_receive("where").with_args("relation.id", "=", "foreign.value")
related = flexmock(Model())
- related.should_receive('get_key').and_return(1)
- related.should_receive('get_table').and_return('relation')
- builder.should_receive('get_model').and_return(related)
+ related.should_receive("get_key").and_return(1)
+ related.should_receive("get_table").and_return("relation")
+ builder.should_receive("get_model").and_return(related)
- return MorphTo(builder, parent, 'foreign_key', 'id', 'morph_type', 'relation')
+ return MorphTo(builder, parent, "foreign_key", "id", "morph_type", "relation")
def _get_relation(self, parent=None, builder=None):
flexmock(Builder)
query = flexmock(QueryBuilder(None, QueryGrammar(), None))
builder = builder or Builder(query)
- builder.should_receive('where').with_args('relation.id', '=', 'foreign.value')
+ builder.should_receive("where").with_args("relation.id", "=", "foreign.value")
related = flexmock(Model())
- related.should_receive('get_key').and_return(1)
- related.should_receive('get_table').and_return('relation')
- builder.should_receive('get_model').and_return(related)
+ related.should_receive("get_key").and_return(1)
+ related.should_receive("get_table").and_return("relation")
+ builder.should_receive("get_model").and_return(related)
parent = parent or OrmMorphToModelStub()
flexmock(MorphTo)
- return MorphTo(builder, parent, 'foreign_key', 'id', 'morph_type', 'relation')
+ return MorphTo(builder, parent, "foreign_key", "id", "morph_type", "relation")
class OrmMorphToModelStub(Model):
- foreign_key = 'foreign.value'
+ foreign_key = "foreign.value"
diff --git a/tests/orm/relations/test_morph_to_many.py b/tests/orm/relations/test_morph_to_many.py
index 62e54add..609ce528 100644
--- a/tests/orm/relations/test_morph_to_many.py
+++ b/tests/orm/relations/test_morph_to_many.py
@@ -18,15 +18,17 @@
class OrmMorphToManyTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
def test_eager_constraints_are_properly_added(self):
relation = self._get_relation()
- relation.get_query().get_query().should_receive('where_in').once().with_args('taggables.taggable_id', [1, 2])
- relation.get_query().should_receive('where').once()\
- .with_args('taggables.taggable_type', relation.get_parent().__class__.__name__)
+ relation.get_query().get_query().should_receive("where_in").once().with_args(
+ "taggables.taggable_id", [1, 2]
+ )
+ relation.get_query().should_receive("where").once().with_args(
+ "taggables.taggable_type", relation.get_parent().__class__.__name__
+ )
model1 = OrmMorphToManyModelStub()
model1.id = 1
model2 = OrmMorphToManyModelStub()
@@ -38,37 +40,41 @@ def test_attach_inserts_pivot_table_record(self):
flexmock(MorphToMany, touch_if_touching=lambda: True)
relation = self._get_relation()
query = flexmock()
- query.should_receive('from_').once().with_args('taggables').and_return(query)
- query.should_receive('insert').once()\
- .with_args(
- [{
- 'taggable_id': 1,
- 'taggable_type': relation.get_parent().__class__.__name__,
- 'tag_id': 2,
- 'foo': 'bar',
- }])\
- .and_return(True)
+ query.should_receive("from_").once().with_args("taggables").and_return(query)
+ query.should_receive("insert").once().with_args(
+ [
+ {
+ "taggable_id": 1,
+ "taggable_type": relation.get_parent().__class__.__name__,
+ "tag_id": 2,
+ "foo": "bar",
+ }
+ ]
+ ).and_return(True)
mock_query_builder = flexmock()
- relation.get_query().should_receive('get_query').and_return(mock_query_builder)
- mock_query_builder.should_receive('new_query').once().and_return(query)
- relation.should_receive('touch_if_touching').once()
+ relation.get_query().should_receive("get_query").and_return(mock_query_builder)
+ mock_query_builder.should_receive("new_query").once().and_return(query)
+ relation.should_receive("touch_if_touching").once()
- relation.attach(2, {'foo': 'bar'})
+ relation.attach(2, {"foo": "bar"})
def test_detach_remove_pivot_table_record(self):
flexmock(MorphToMany, touch_if_touching=lambda: True)
relation = self._get_relation()
query = flexmock()
- query.should_receive('from_').once().with_args('taggables').and_return(query)
- query.should_receive('where').once().with_args('taggable_id', 1).and_return(query)
- query.should_receive('where').once()\
- .with_args('taggable_type', relation.get_parent().__class__.__name__).and_return(query)
- query.should_receive('where_in').once().with_args('tag_id', [1, 2, 3])
- query.should_receive('delete').once().and_return(True)
+ query.should_receive("from_").once().with_args("taggables").and_return(query)
+ query.should_receive("where").once().with_args("taggable_id", 1).and_return(
+ query
+ )
+ query.should_receive("where").once().with_args(
+ "taggable_type", relation.get_parent().__class__.__name__
+ ).and_return(query)
+ query.should_receive("where_in").once().with_args("tag_id", [1, 2, 3])
+ query.should_receive("delete").once().and_return(True)
mock_query_builder = flexmock()
- relation.get_query().should_receive('get_query').and_return(mock_query_builder)
- mock_query_builder.should_receive('new_query').once().and_return(query)
- relation.should_receive('touch_if_touching').once()
+ relation.get_query().should_receive("get_query").and_return(mock_query_builder)
+ mock_query_builder.should_receive("new_query").once().and_return(query)
+ relation.should_receive("touch_if_touching").once()
self.assertTrue(relation.detach([1, 2, 3]))
@@ -76,48 +82,72 @@ def test_detach_clears_all_records_when_no_ids(self):
flexmock(MorphToMany, touch_if_touching=lambda: True)
relation = self._get_relation()
query = flexmock()
- query.should_receive('from_').once().with_args('taggables').and_return(query)
- query.should_receive('where').once().with_args('taggable_id', 1).and_return(query)
- query.should_receive('where').once()\
- .with_args('taggable_type', relation.get_parent().__class__.__name__).and_return(query)
- query.should_receive('where_in').never()
- query.should_receive('delete').once().and_return(True)
+ query.should_receive("from_").once().with_args("taggables").and_return(query)
+ query.should_receive("where").once().with_args("taggable_id", 1).and_return(
+ query
+ )
+ query.should_receive("where").once().with_args(
+ "taggable_type", relation.get_parent().__class__.__name__
+ ).and_return(query)
+ query.should_receive("where_in").never()
+ query.should_receive("delete").once().and_return(True)
mock_query_builder = flexmock()
- relation.get_query().should_receive('get_query').and_return(mock_query_builder)
- mock_query_builder.should_receive('new_query').once().and_return(query)
- relation.should_receive('touch_if_touching').once()
+ relation.get_query().should_receive("get_query").and_return(mock_query_builder)
+ mock_query_builder.should_receive("new_query").once().and_return(query)
+ relation.should_receive("touch_if_touching").once()
self.assertTrue(relation.detach())
def _get_relation(self):
builder, parent = self._get_relation_arguments()[:2]
- return MorphToMany(builder, parent, 'taggable', 'taggables', 'taggable_id', 'tag_id')
+ return MorphToMany(
+ builder, parent, "taggable", "taggables", "taggable_id", "tag_id"
+ )
def _get_relation_arguments(self):
parent = flexmock(Model())
- parent.should_receive('get_morph_name').and_return(parent.__class__.__name__)
- parent.should_receive('get_key').and_return(1)
- parent.should_receive('get_created_at_column').and_return('created_at')
- parent.should_receive('get_updated_at_column').and_return('updated_at')
-
- query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()))
+ parent.should_receive("get_morph_name").and_return(parent.__class__.__name__)
+ parent.should_receive("get_key").and_return(1)
+ parent.should_receive("get_created_at_column").and_return("created_at")
+ parent.should_receive("get_updated_at_column").and_return("updated_at")
+
+ query = flexmock(
+ QueryBuilder(
+ MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()
+ )
+ )
flexmock(Builder)
builder = Builder(query)
- builder.should_receive('get_query').and_return(query)
+ builder.should_receive("get_query").and_return(query)
related = flexmock(Model())
builder.set_model(related)
- builder.should_receive('get_model').and_return(related)
-
- related.should_receive('get_key_name').and_return('id')
- related.should_receive('get_table').and_return('tags')
- related.should_receive('get_morph_name').and_return(parent.__class__.__name__)
-
- builder.get_query().should_receive('join').once().with_args('taggables', 'tags.id', '=', 'taggables.tag_id')
- builder.should_receive('where').once().with_args('taggables.taggable_id', '=', 1)
- builder.should_receive('where').once().with_args('taggables.taggable_type', parent.__class__.__name__)
-
- return builder, parent, 'taggable', 'taggables', 'taggable_id', 'tag_id', 'relation_name', False
+ builder.should_receive("get_model").and_return(related)
+
+ related.should_receive("get_key_name").and_return("id")
+ related.should_receive("get_table").and_return("tags")
+ related.should_receive("get_morph_name").and_return(parent.__class__.__name__)
+
+ builder.get_query().should_receive("join").once().with_args(
+ "taggables", "tags.id", "=", "taggables.tag_id"
+ )
+ builder.should_receive("where").once().with_args(
+ "taggables.taggable_id", "=", 1
+ )
+ builder.should_receive("where").once().with_args(
+ "taggables.taggable_type", parent.__class__.__name__
+ )
+
+ return (
+ builder,
+ parent,
+ "taggable",
+ "taggables",
+ "taggable_id",
+ "tag_id",
+ "relation_name",
+ False,
+ )
class OrmMorphToManyModelStub(Model):
diff --git a/tests/orm/relations/test_relation.py b/tests/orm/relations/test_relation.py
index 2778cf12..7188eddf 100644
--- a/tests/orm/relations/test_relation.py
+++ b/tests/orm/relations/test_relation.py
@@ -12,37 +12,37 @@
class OrmRelationTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
def test_set_relation_fail(self):
parent = OrmRelationResetModelStub()
relation = OrmRelationResetModelStub()
- parent.set_relation('test', relation)
- parent.set_relation('foo', 'bar')
- self.assertFalse('foo' in parent.to_dict())
+ parent.set_relation("test", relation)
+ parent.set_relation("foo", "bar")
+ self.assertFalse("foo" in parent.to_dict())
def test_touch_method_updates_related_timestamps(self):
builder = flexmock(Builder, get_model=None, where=None)
parent = Model()
parent = flexmock(parent)
- parent.should_receive('get_attribute').with_args('id').and_return(1)
+ parent.should_receive("get_attribute").with_args("id").and_return(1)
related = Model()
related = flexmock(related)
- builder.should_receive('get_model').and_return(related)
- builder.should_receive('where')
- relation = HasOne(Builder(QueryBuilder(None, None, None)), parent, 'foreign_key', 'id')
- related.should_receive('get_table').and_return('table')
- related.should_receive('get_updated_at_column').and_return('updated_at')
+ builder.should_receive("get_model").and_return(related)
+ builder.should_receive("where")
+ relation = HasOne(
+ Builder(QueryBuilder(None, None, None)), parent, "foreign_key", "id"
+ )
+ related.should_receive("get_table").and_return("table")
+ related.should_receive("get_updated_at_column").and_return("updated_at")
now = pendulum.now()
- related.should_receive('fresh_timestamp').and_return(now)
- builder.should_receive('update').once().with_args({'updated_at': now})
+ related.should_receive("fresh_timestamp").and_return(now)
+ builder.should_receive("update").once().with_args({"updated_at": now})
relation.touch()
class OrmRelationResetModelStub(Model):
-
def get_query(self):
return self.new_query().get_query()
diff --git a/tests/orm/scopes/__init__.py b/tests/orm/scopes/__init__.py
index 633f8661..40a96afc 100644
--- a/tests/orm/scopes/__init__.py
+++ b/tests/orm/scopes/__init__.py
@@ -1,2 +1 @@
# -*- coding: utf-8 -*-
-
diff --git a/tests/orm/scopes/test_soft_deleting.py b/tests/orm/scopes/test_soft_deleting.py
index 12468524..220e9924 100644
--- a/tests/orm/scopes/test_soft_deleting.py
+++ b/tests/orm/scopes/test_soft_deleting.py
@@ -8,7 +8,6 @@
class SoftDeletingScopeTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
@@ -17,8 +16,12 @@ def test_apply_scope_to_a_builder(self):
query = flexmock(QueryBuilder(None, None, None))
builder = Builder(query)
model = flexmock(ModelStub())
- model.should_receive('get_qualified_deleted_at_column').once().and_return('table.deleted_at')
- builder.get_query().should_receive('where_null').once().with_args('table.deleted_at')
+ model.should_receive("get_qualified_deleted_at_column").once().and_return(
+ "table.deleted_at"
+ )
+ builder.get_query().should_receive("where_null").once().with_args(
+ "table.deleted_at"
+ )
scope.apply(builder, model)
@@ -26,10 +29,10 @@ def test_force_delete_extension(self):
scope = SoftDeletingScope()
builder = Builder(None)
scope.extend(builder)
- callback = builder.get_macro('force_delete')
+ callback = builder.get_macro("force_delete")
query = flexmock(QueryBuilder(None, None, None))
given_builder = Builder(query)
- query.should_receive('delete').once()
+ query.should_receive("delete").once()
callback(given_builder)
@@ -37,13 +40,13 @@ def test_restore_extension(self):
scope = SoftDeletingScope()
builder = Builder(None)
scope.extend(builder)
- callback = builder.get_macro('restore')
+ callback = builder.get_macro("restore")
query = flexmock(QueryBuilder(None, None, None))
builder_mock = flexmock(BuilderWithTrashedStub)
given_builder = BuilderWithTrashedStub(query)
- builder_mock.should_receive('with_trashed').once()
- builder_mock.should_receive('get_model').once().and_return(ModelStub())
- builder_mock.should_receive('update').once().with_args({'deleted_at': None})
+ builder_mock.should_receive("with_trashed").once()
+ builder_mock.should_receive("get_model").once().and_return(ModelStub())
+ builder_mock.should_receive("update").once().with_args({"deleted_at": None})
callback(given_builder)
@@ -51,7 +54,7 @@ def test_with_trashed_extension(self):
scope = flexmock(SoftDeletingScope())
builder = Builder(None)
scope.extend(builder)
- callback = builder.get_macro('with_trashed')
+ callback = builder.get_macro("with_trashed")
query = flexmock(QueryBuilder(None, None, None))
given_builder = Builder(query)
model = flexmock(ModelStub())
@@ -65,15 +68,13 @@ def test_with_trashed_extension(self):
class ModelStub(Model):
-
def get_qualified_deleted_at_column(self):
- return 'table.deleted_at'
+ return "table.deleted_at"
def get_deleted_at_column(self):
- return 'deleted_at'
+ return "deleted_at"
class BuilderWithTrashedStub(Builder):
-
def with_trashed(self):
pass
diff --git a/tests/orm/test_builder.py b/tests/orm/test_builder.py
index b8975d7f..6e9c6556 100644
--- a/tests/orm/test_builder.py
+++ b/tests/orm/test_builder.py
@@ -15,7 +15,6 @@
class BuilderTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
@@ -23,30 +22,26 @@ def test_find_method(self):
builder = Builder(self.get_mock_query_builder())
builder.set_model(self.get_mock_model())
builder.get_query().where = mock.MagicMock()
- builder.first = mock.MagicMock(return_value='baz')
+ builder.first = mock.MagicMock(return_value="baz")
- result = builder.find('bar', ['column'])
+ result = builder.find("bar", ["column"])
- builder.get_query().where.assert_called_once_with(
- 'foo_table.foo', '=', 'bar'
- )
- self.assertEqual('baz', result)
+ builder.get_query().where.assert_called_once_with("foo_table.foo", "=", "bar")
+ self.assertEqual("baz", result)
def test_find_or_new_model_found(self):
model = self.get_mock_model()
- model.find_or_new = mock.MagicMock(return_value='baz')
+ model.find_or_new = mock.MagicMock(return_value="baz")
builder = Builder(self.get_mock_query_builder())
builder.set_model(model)
builder.get_query().where = mock.MagicMock()
- builder.first = mock.MagicMock(return_value='baz')
+ builder.first = mock.MagicMock(return_value="baz")
- expected = model.find_or_new('bar', ['column'])
- result = builder.find('bar', ['column'])
+ expected = model.find_or_new("bar", ["column"])
+ result = builder.find("bar", ["column"])
- builder.get_query().where.assert_called_once_with(
- 'foo_table.foo', '=', 'bar'
- )
+ builder.get_query().where.assert_called_once_with("foo_table.foo", "=", "bar")
self.assertEqual(expected, result)
def test_find_or_new_model_not_found(self):
@@ -58,12 +53,10 @@ def test_find_or_new_model_not_found(self):
builder.get_query().where = mock.MagicMock()
builder.first = mock.MagicMock(return_value=None)
- result = model.find_or_new('bar', ['column'])
- find_result = builder.find('bar', ['column'])
+ result = model.find_or_new("bar", ["column"])
+ find_result = builder.find("bar", ["column"])
- builder.get_query().where.assert_called_once_with(
- 'foo_table.foo', '=', 'bar'
- )
+ builder.get_query().where.assert_called_once_with("foo_table.foo", "=", "bar")
self.assertIsNone(find_result)
self.assertIsInstance(result, Model)
@@ -75,20 +68,11 @@ def test_find_or_fail_raises_model_not_found_exception(self):
builder.get_query().where = mock.MagicMock()
builder.first = mock.MagicMock(return_value=None)
- self.assertRaises(
- ModelNotFound,
- builder.find_or_fail,
- 'bar',
- ['column']
- )
+ self.assertRaises(ModelNotFound, builder.find_or_fail, "bar", ["column"])
- builder.get_query().where.assert_called_once_with(
- 'foo_table.foo', '=', 'bar'
- )
+ builder.get_query().where.assert_called_once_with("foo_table.foo", "=", "bar")
- builder.first.assert_called_once_with(
- ['column']
- )
+ builder.first.assert_called_once_with(["column"])
def test_find_or_fail_with_many_raises_model_not_found_exception(self):
model = self.get_mock_model()
@@ -98,20 +82,11 @@ def test_find_or_fail_with_many_raises_model_not_found_exception(self):
builder.get_query().where_in = mock.MagicMock()
builder.get = mock.MagicMock(return_value=Collection([1]))
- self.assertRaises(
- ModelNotFound,
- builder.find_or_fail,
- [1, 2],
- ['column']
- )
+ self.assertRaises(ModelNotFound, builder.find_or_fail, [1, 2], ["column"])
- builder.get_query().where_in.assert_called_once_with(
- 'foo_table.foo', [1, 2]
- )
+ builder.get_query().where_in.assert_called_once_with("foo_table.foo", [1, 2])
- builder.get.assert_called_once_with(
- ['column']
- )
+ builder.get.assert_called_once_with(["column"])
def test_first_or_fail_raises_model_not_found_exception(self):
model = self.get_mock_model()
@@ -120,15 +95,9 @@ def test_first_or_fail_raises_model_not_found_exception(self):
builder.set_model(model)
builder.first = mock.MagicMock(return_value=None)
- self.assertRaises(
- ModelNotFound,
- builder.first_or_fail,
- ['column']
- )
+ self.assertRaises(ModelNotFound, builder.first_or_fail, ["column"])
- builder.first.assert_called_once_with(
- ['column']
- )
+ builder.first.assert_called_once_with(["column"])
def test_find_with_many(self):
model = self.get_mock_model()
@@ -136,18 +105,14 @@ def test_find_with_many(self):
builder = Builder(self.get_mock_query_builder())
builder.set_model(model)
builder.get_query().where_in = mock.MagicMock()
- builder.get = mock.MagicMock(return_value='baz')
+ builder.get = mock.MagicMock(return_value="baz")
- result = builder.find([1, 2], ['column'])
- self.assertEqual('baz', result)
+ result = builder.find([1, 2], ["column"])
+ self.assertEqual("baz", result)
- builder.get_query().where_in.assert_called_once_with(
- 'foo_table.foo', [1, 2]
- )
+ builder.get_query().where_in.assert_called_once_with("foo_table.foo", [1, 2])
- builder.get.assert_called_once_with(
- ['column']
- )
+ builder.get.assert_called_once_with(["column"])
def test_first(self):
model = self.get_mock_model()
@@ -155,41 +120,41 @@ def test_first(self):
builder = Builder(self.get_mock_query_builder())
builder.set_model(model)
builder.take = mock.MagicMock(return_value=builder)
- builder.get = mock.MagicMock(return_value=Collection(['bar']))
+ builder.get = mock.MagicMock(return_value=Collection(["bar"]))
result = builder.first()
- self.assertEqual('bar', result)
+ self.assertEqual("bar", result)
- builder.take.assert_called_once_with(
- 1
- )
+ builder.take.assert_called_once_with(1)
- builder.get.assert_called_once_with(
- ['*']
- )
+ builder.get.assert_called_once_with(["*"])
def test_get_loads_models_and_hydrates_eager_relations(self):
flexmock(Builder)
builder = Builder(self.get_mock_query_builder())
- builder.should_receive('get_models').with_args(['foo']).and_return(['bar'])
- builder.should_receive('eager_load_relations').with_args(['bar']).and_return(['bar', 'baz'])
+ builder.should_receive("get_models").with_args(["foo"]).and_return(["bar"])
+ builder.should_receive("eager_load_relations").with_args(["bar"]).and_return(
+ ["bar", "baz"]
+ )
builder.set_model(self.get_mock_model())
- builder.get_model().new_collection = mock.MagicMock(return_value=Collection(['bar', 'baz']))
+ builder.get_model().new_collection = mock.MagicMock(
+ return_value=Collection(["bar", "baz"])
+ )
- results = builder.get(['foo'])
- self.assertEqual(['bar', 'baz'], results.all())
+ results = builder.get(["foo"])
+ self.assertEqual(["bar", "baz"], results.all())
- builder.get_model().new_collection.assert_called_with(['bar', 'baz'])
+ builder.get_model().new_collection.assert_called_with(["bar", "baz"])
def test_get_does_not_eager_relations_when_no_results_are_returned(self):
flexmock(Builder)
builder = Builder(self.get_mock_query_builder())
- builder.should_receive('get_models').with_args(['foo']).and_return(['bar'])
- builder.should_receive('eager_load_relations').with_args(['bar']).and_return([])
+ builder.should_receive("get_models").with_args(["foo"]).and_return(["bar"])
+ builder.should_receive("eager_load_relations").with_args(["bar"]).and_return([])
builder.set_model(self.get_mock_model())
builder.get_model().new_collection = mock.MagicMock(return_value=Collection([]))
- results = builder.get(['foo'])
+ results = builder.get(["foo"])
self.assertEqual([], results.all())
builder.get_model().new_collection.assert_called_with([])
@@ -197,27 +162,34 @@ def test_get_does_not_eager_relations_when_no_results_are_returned(self):
def test_pluck_with_model_found(self):
builder = Builder(self.get_mock_query_builder())
- model = {'name': 'foo'}
+ model = {"name": "foo"}
builder.first = mock.MagicMock(return_value=model)
- self.assertEqual('foo', builder.pluck('name'))
+ self.assertEqual("foo", builder.pluck("name"))
- builder.first.assert_called_once_with(
- ['name']
- )
+ builder.first.assert_called_once_with(["name"])
def test_pluck_with_model_not_found(self):
builder = Builder(self.get_mock_query_builder())
builder.first = mock.MagicMock(return_value=None)
- self.assertIsNone(builder.pluck('name'))
+ self.assertIsNone(builder.pluck("name"))
def test_chunk(self):
- builder = Builder(self.get_mock_query_builder())
- results = [Collection(['foo1', 'foo2']), Collection(['foo3']), Collection([])]
- builder.for_page = mock.MagicMock(return_value=builder)
- builder.get = mock.MagicMock(side_effect=results)
+ query_builder = self.get_mock_query_builder()
+ query_results = [["foo1", "foo2"], ["foo3"]]
+ query_builder.chunk = mock.MagicMock(return_value=query_results)
+
+ builder = Builder(query_builder)
+ model = self.get_mock_model()
+ builder.set_model(model)
+
+ results = [Collection(["foo1", "foo2"]), Collection(["foo3"])]
+
+ model.hydrate = mock.MagicMock(return_value=[])
+ model.new_collection = mock.MagicMock(side_effect=results)
+ model.get_connection_name = mock.MagicMock(return_value="foo")
i = 0
for result in builder.chunk(2):
@@ -225,61 +197,59 @@ def test_chunk(self):
i += 1
- builder.for_page.assert_has_calls([
- mock.call(1, 2),
- mock.call(2, 2),
- mock.call(3, 2)
- ])
+ self.assertEqual(i, 2)
+
+ query_builder.chunk.assert_has_calls([mock.call(2)])
+ model.hydrate.assert_has_calls(
+ [mock.call(["foo1", "foo2"], "foo"), mock.call(["foo3"], "foo")]
+ )
+ model.new_collection.assert_has_calls([mock.call([]), mock.call([])])
# TODO: lists with get mutators
def test_lists_without_model_getters(self):
builder = self.get_builder()
- builder.get_query().get = mock.MagicMock(return_value=[{'name': 'bar'}, {'name': 'baz'}])
+ builder.get_query().get = mock.MagicMock(
+ return_value=[{"name": "bar"}, {"name": "baz"}]
+ )
builder.set_model(self.get_mock_model())
builder.get_model().has_get_mutator = mock.MagicMock(return_value=False)
- result = builder.lists('name')
- self.assertEqual(['bar', 'baz'], result)
+ result = builder.lists("name")
+ self.assertEqual(["bar", "baz"], result)
- builder.get_query().get.assert_called_once_with(['name'])
+ builder.get_query().get.assert_called_once_with(["name"])
def test_get_models_hydrates_models(self):
builder = Builder(self.get_mock_query_builder())
- records = Collection([{
- 'name': 'john', 'age': 26
- }, {
- 'name': 'jane', 'age': 28
- }])
+ records = Collection([{"name": "john", "age": 26}, {"name": "jane", "age": 28}])
builder.get_query().get = mock.MagicMock(return_value=records)
model = self.get_mock_model()
builder.set_model(model)
- model.get_connection_name = mock.MagicMock(return_value='foo_connection')
- model.hydrate = mock.MagicMock(return_value=Collection(['hydrated']))
- models = builder.get_models(['foo'])
+ model.get_connection_name = mock.MagicMock(return_value="foo_connection")
+ model.hydrate = mock.MagicMock(return_value=Collection(["hydrated"]))
+ models = builder.get_models(["foo"])
- self.assertEqual(models.all(), ['hydrated'])
+ self.assertEqual(models.all(), ["hydrated"])
model.get_table.assert_called_once_with()
model.get_connection_name.assert_called_once_with()
- model.hydrate.assert_called_once_with(
- records, 'foo_connection'
- )
+ model.hydrate.assert_called_once_with(records, "foo_connection")
def test_macros_are_called_on_builder(self):
- builder = Builder(QueryBuilder(
- flexmock(Connection),
- flexmock(QueryGrammar),
- flexmock(QueryProcessor)
- ))
+ builder = Builder(
+ QueryBuilder(
+ flexmock(Connection), flexmock(QueryGrammar), flexmock(QueryProcessor)
+ )
+ )
def foo_bar(builder):
builder.foobar = True
return builder
- builder.macro('foo_bar', foo_bar)
+ builder.macro("foo_bar", foo_bar)
result = builder.foo_bar()
self.assertEqual(result, builder)
@@ -290,29 +260,36 @@ def test_eager_load_relations_load_top_level_relationships(self):
builder = Builder(flexmock(QueryBuilder(None, None, None)))
nop1 = lambda: None
nop2 = lambda: None
- builder.set_eager_loads({'foo': nop1, 'foo.bar': nop2})
- builder.should_receive('_load_relation').with_args(['models'], 'foo', nop1).and_return(['foo'])
+ builder.set_eager_loads({"foo": nop1, "foo.bar": nop2})
+ builder.should_receive("_load_relation").with_args(
+ ["models"], "foo", nop1
+ ).and_return(["foo"])
- results = builder.eager_load_relations(['models'])
- self.assertEqual(['foo'], results)
+ results = builder.eager_load_relations(["models"])
+ self.assertEqual(["foo"], results)
def test_eager_load_accept_queries(self):
model = OrmBuilderTestModelCloseRelated()
flexmock(Builder)
builder = Builder(flexmock(QueryBuilder(None, None, None)))
- nop1 = OrmBuilderTestModelFarRelatedStub.where('id', 5)
- builder.set_eager_loads({'foo': nop1})
+ nop1 = OrmBuilderTestModelFarRelatedStub.where("id", 5)
+ builder.set_eager_loads({"foo": nop1})
relation = flexmock()
- relation.should_receive('add_eager_constraints').once().with_args(['models'])
- relation.should_receive('init_relation').once().with_args(['models'], 'foo').and_return(['models'])
- relation.should_receive('get_eager').once().and_return(['results'])
- relation.should_receive('match').once()\
- .with_args(['models'], ['results'], 'foo').and_return(['foo'])
- builder.should_receive('get_relation').once().with_args('foo').and_return(relation)
- relation.should_receive('merge_query').with_args(nop1).and_return(relation)
+ relation.should_receive("add_eager_constraints").once().with_args(["models"])
+ relation.should_receive("init_relation").once().with_args(
+ ["models"], "foo"
+ ).and_return(["models"])
+ relation.should_receive("get_eager").once().and_return(["results"])
+ relation.should_receive("match").once().with_args(
+ ["models"], ["results"], "foo"
+ ).and_return(["foo"])
+ builder.should_receive("get_relation").once().with_args("foo").and_return(
+ relation
+ )
+ relation.should_receive("merge_query").with_args(nop1).and_return(relation)
- results = builder.eager_load_relations(['models'])
- self.assertEqual(['foo'], results)
+ results = builder.eager_load_relations(["models"])
+ self.assertEqual(["foo"], results)
def test_relationship_eager_load_process(self):
proof = flexmock()
@@ -322,18 +299,23 @@ def test_relationship_eager_load_process(self):
def callback(q):
proof.foo = q
- builder.set_eager_loads({'orders': callback})
+ builder.set_eager_loads({"orders": callback})
relation = flexmock()
- relation.should_receive('add_eager_constraints').once().with_args(['models'])
- relation.should_receive('init_relation').once().with_args(['models'], 'orders').and_return(['models'])
- relation.should_receive('get_eager').once().and_return(['results'])
- relation.should_receive('get_query').once().and_return(relation)
- relation.should_receive('match').once()\
- .with_args(['models'], ['results'], 'orders').and_return(['models.matched'])
- builder.should_receive('get_relation').once().with_args('orders').and_return(relation)
- results = builder.eager_load_relations(['models'])
-
- self.assertEqual(['models.matched'], results)
+ relation.should_receive("add_eager_constraints").once().with_args(["models"])
+ relation.should_receive("init_relation").once().with_args(
+ ["models"], "orders"
+ ).and_return(["models"])
+ relation.should_receive("get_eager").once().and_return(["results"])
+ relation.should_receive("get_query").once().and_return(relation)
+ relation.should_receive("match").once().with_args(
+ ["models"], ["results"], "orders"
+ ).and_return(["models.matched"])
+ builder.should_receive("get_relation").once().with_args("orders").and_return(
+ relation
+ )
+ results = builder.eager_load_relations(["models"])
+
+ self.assertEqual(["models.matched"], results)
self.assertEqual(relation, proof.foo)
def test_get_relation_properly_sets_nested_relationships(self):
@@ -341,32 +323,32 @@ def test_get_relation_properly_sets_nested_relationships(self):
builder = Builder(flexmock(QueryBuilder(None, None, None)))
model = flexmock(Model())
relation = flexmock()
- model.set_relation('orders', relation)
+ model.set_relation("orders", relation)
builder.set_model(model)
relation_query = flexmock()
- relation.should_receive('get_query').and_return(relation_query)
- relation_query.should_receive('with_').once().with_args({'lines': None, 'lines.details': None})
- builder.set_eager_loads({
- 'orders': None,
- 'orders.lines': None,
- 'orders.lines.details': None
- })
+ relation.should_receive("get_query").and_return(relation_query)
+ relation_query.should_receive("with_").once().with_args(
+ {"lines": None, "lines.details": None}
+ )
+ builder.set_eager_loads(
+ {"orders": None, "orders.lines": None, "orders.lines.details": None}
+ )
- relation = builder.get_relation('orders')
+ relation = builder.get_relation("orders")
def test_query_passthru(self):
builder = self.get_builder()
- builder.get_query().foobar = mock.MagicMock(return_value='foo')
+ builder.get_query().foobar = mock.MagicMock(return_value="foo")
self.assertIsInstance(builder.foobar(), Builder)
self.assertEqual(builder.foobar(), builder)
builder = self.get_builder()
- builder.get_query().insert = mock.MagicMock(return_value='foo')
+ builder.get_query().insert = mock.MagicMock(return_value="foo")
- self.assertEqual('foo', builder.insert(['bar']))
+ self.assertEqual("foo", builder.insert(["bar"]))
- builder.get_query().insert.assert_called_once_with(['bar'])
+ builder.get_query().insert.assert_called_once_with(["bar"])
def test_query_scopes(self):
builder = self.get_builder()
@@ -381,11 +363,11 @@ def test_query_scopes(self):
def test_simple_where(self):
builder = self.get_builder()
builder.get_query().where = mock.MagicMock()
- result = builder.where('foo', '=', 'bar')
+ result = builder.where("foo", "=", "bar")
self.assertEqual(builder, result)
- builder.get_query().where.assert_called_once_with('foo', '=', 'bar', 'and')
+ builder.get_query().where.assert_called_once_with("foo", "=", "bar", "and")
def test_nested_where(self):
nested_query = self.get_builder()
@@ -399,42 +381,48 @@ def test_nested_where(self):
result = builder.where(nested_query)
self.assertEqual(builder, result)
- builder.get_query().add_nested_where_query.assert_called_once_with(nested_raw_query, 'and')
+ builder.get_query().add_nested_where_query.assert_called_once_with(
+ nested_raw_query, "and"
+ )
# TODO: nested query with scopes
def test_delete_override(self):
builder = self.get_builder()
- builder.on_delete(lambda builder_: {'foo': builder_})
+ builder.on_delete(lambda builder_: {"foo": builder_})
- self.assertEqual({'foo': builder}, builder.delete())
+ self.assertEqual({"foo": builder}, builder.delete())
def test_has_nested(self):
- builder = OrmBuilderTestModelParentStub.where_has('foo', lambda q: q.has('bar'))
+ builder = OrmBuilderTestModelParentStub.where_has("foo", lambda q: q.has("bar"))
- result = OrmBuilderTestModelParentStub.has('foo.bar').to_sql()
+ result = OrmBuilderTestModelParentStub.has("foo.bar").to_sql()
self.assertEqual(builder.to_sql(), result)
def test_has_nested_with_constraints(self):
model = OrmBuilderTestModelParentStub
- builder = model.where_has('foo', lambda q: q.where_has('bar', lambda q: q.where('baz', 'bam'))).to_sql()
+ builder = model.where_has(
+ "foo", lambda q: q.where_has("bar", lambda q: q.where("baz", "bam"))
+ ).to_sql()
- result = model.where_has('foo.bar', lambda q: q.where('baz', 'bam')).to_sql()
+ result = model.where_has("foo.bar", lambda q: q.where("baz", "bam")).to_sql()
self.assertEqual(builder, result)
def test_where_exists_accepts_builder_instance(self):
model = OrmBuilderTestModelCloseRelated
- builder = model.where_exists(OrmBuilderTestModelFarRelatedStub.where('foo', 'bar')).to_sql()
+ builder = model.where_exists(
+ OrmBuilderTestModelFarRelatedStub.where("foo", "bar")
+ ).to_sql()
self.assertEqual(
'SELECT * FROM "orm_builder_test_model_close_relateds" '
'WHERE EXISTS (SELECT * FROM "orm_builder_test_model_far_related_stubs" WHERE "foo" = ?)',
- builder
+ builder,
)
def get_builder(self):
@@ -449,17 +437,12 @@ def get_mock_query_builder(self):
connection = MockConnection().prepare_mock()
processor = MockProcessor().prepare_mock()
- builder = MockQueryBuilder(
- connection,
- QueryGrammar(),
- processor
- ).prepare_mock()
+ builder = MockQueryBuilder(connection, QueryGrammar(), processor).prepare_mock()
return builder
class OratorTestModel(Model):
-
@classmethod
def _boot_columns(cls):
return []
@@ -475,21 +458,18 @@ class OrmBuilderTestModelFarRelatedStub(OratorTestModel):
class OrmBuilderTestModelScopeStub(OratorTestModel):
-
@scope
def approved(self, query):
- query.where('foo', 'bar')
+ query.where("foo", "bar")
class OrmBuilderTestModelCloseRelated(OratorTestModel):
-
@has_many
def bar(self):
return OrmBuilderTestModelFarRelatedStub
class OrmBuilderTestModelParentStub(OratorTestModel):
-
@belongs_to
def foo(self):
return OrmBuilderTestModelCloseRelated
diff --git a/tests/orm/test_factory.py b/tests/orm/test_factory.py
index 1b751b78..6bb7f3b5 100644
--- a/tests/orm/test_factory.py
+++ b/tests/orm/test_factory.py
@@ -8,7 +8,6 @@
class FactoryTestCase(OratorTestCase):
-
@classmethod
def setUpClass(cls):
Model.set_connection_resolver(DatabaseConnectionResolver())
@@ -24,50 +23,43 @@ def schema(self):
return self.connection().get_schema_builder()
def setUp(self):
- with self.schema().create('users') as table:
- table.increments('id')
- table.string('name').unique()
- table.string('email').unique()
- table.boolean('admin').default(True)
+ with self.schema().create("users") as table:
+ table.increments("id")
+ table.string("name").unique()
+ table.string("email").unique()
+ table.boolean("admin").default(True)
table.timestamps()
- with self.schema().create('posts') as table:
- table.increments('id')
- table.integer('user_id')
- table.string('title').unique()
- table.text('content').unique()
+ with self.schema().create("posts") as table:
+ table.increments("id")
+ table.integer("user_id")
+ table.string("title").unique()
+ table.text("content").unique()
table.timestamps()
- table.foreign('user_id').references('id').on('users')
+ table.foreign("user_id").references("id").on("users")
self.factory = Factory()
@self.factory.define(User)
def users_factory(faker):
- return {
- 'name': faker.name(),
- 'email': faker.email(),
- 'admin': False
- }
+ return {"name": faker.name(), "email": faker.email(), "admin": False}
- @self.factory.define(User, 'admin')
+ @self.factory.define(User, "admin")
def users_factory(faker):
attributes = self.factory.raw(User)
- attributes.update({'admin': True})
+ attributes.update({"admin": True})
return attributes
@self.factory.define(Post)
def posts_factory(faker):
- return {
- 'title': faker.sentence(),
- 'content': faker.text()
- }
+ return {"title": faker.sentence(), "content": faker.text()}
def tearDown(self):
- self.schema().drop('posts')
- self.schema().drop('users')
+ self.schema().drop("posts")
+ self.schema().drop("users")
def test_factory_make(self):
user = self.factory.make(User)
@@ -75,7 +67,7 @@ def test_factory_make(self):
self.assertIsInstance(user, User)
self.assertIsNotNone(user.name)
self.assertIsNotNone(user.email)
- self.assertIsNone(User.where('name', user.name).first())
+ self.assertIsNone(User.where("name", user.name).first())
def test_factory_create(self):
user = self.factory.create(User)
@@ -83,15 +75,15 @@ def test_factory_create(self):
self.assertIsInstance(user, User)
self.assertIsNotNone(user.name)
self.assertIsNotNone(user.email)
- self.assertIsNotNone(User.where('name', user.name).first())
+ self.assertIsNotNone(User.where("name", user.name).first())
def test_factory_create_with_attributes(self):
- user = self.factory.create(User, name='foo', email='foo@bar.com')
+ user = self.factory.create(User, name="foo", email="foo@bar.com")
self.assertIsInstance(user, User)
- self.assertEqual('foo', user.name)
- self.assertEqual('foo@bar.com', user.email)
- self.assertIsNotNone(User.where('name', user.name).first())
+ self.assertEqual("foo", user.name)
+ self.assertEqual("foo@bar.com", user.email)
+ self.assertIsNotNone(User.where("name", user.name).first())
def test_factory_create_with_relations(self):
users = self.factory.build(User, 3)
@@ -111,19 +103,19 @@ def test_factory_call(self):
self.assertEqual(3, len(users))
self.assertFalse(users[0].admin)
- admin = self.factory(User, 'admin').create()
+ admin = self.factory(User, "admin").create()
self.assertTrue(admin.admin)
- admins = self.factory(User, 'admin', 3).create()
+ admins = self.factory(User, "admin", 3).create()
self.assertEqual(3, len(admins))
self.assertTrue(admins[0].admin)
class User(Model):
- __guarded__ = ['id']
+ __guarded__ = ["id"]
- @has_many('user_id')
+ @has_many("user_id")
def posts(self):
return Post
@@ -132,7 +124,7 @@ class Post(Model):
__guarded__ = []
- @belongs_to('user_id')
+ @belongs_to("user_id")
def user(self):
return User
@@ -145,12 +137,14 @@ def connection(self, name=None):
if self._connection:
return self._connection
- self._connection = SQLiteConnection(SQLiteConnector().connect({'database': ':memory:'}))
+ self._connection = SQLiteConnection(
+ SQLiteConnector().connect({"database": ":memory:"})
+ )
return self._connection
def get_default_connection(self):
- return 'default'
+ return "default"
def set_default_connection(self, name):
pass
diff --git a/tests/orm/test_model.py b/tests/orm/test_model.py
index 4bf7480a..d81abaab 100644
--- a/tests/orm/test_model.py
+++ b/tests/orm/test_model.py
@@ -24,56 +24,54 @@
class OrmModelTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
def test_attributes_manipulation(self):
model = OrmModelStub()
- model.name = 'foo'
- self.assertEqual('foo', model.name)
+ model.name = "foo"
+ self.assertEqual("foo", model.name)
del model.name
- self.assertFalse(hasattr(model, 'name'))
+ self.assertFalse(hasattr(model, "name"))
- model.list_items = {'name': 'john'}
- self.assertEqual({'name': 'john'}, model.list_items)
+ model.list_items = {"name": "john"}
+ self.assertEqual({"name": "john"}, model.list_items)
attributes = model.get_attributes()
- self.assertEqual(json.dumps({'name': 'john'}), attributes['list_items'])
+ self.assertEqual(json.dumps({"name": "john"}), attributes["list_items"])
def test_dirty_attributes(self):
- model = OrmModelStub(foo='1', bar=2, baz=3)
+ model = OrmModelStub(foo="1", bar=2, baz=3)
model.foo = 1
model.bar = 20
model.baz = 30
self.assertTrue(model.is_dirty())
- self.assertTrue(model.is_dirty('foo'))
- self.assertTrue(model.is_dirty('bar'))
- self.assertTrue(model.is_dirty('baz'))
- self.assertTrue(model.is_dirty('foo', 'bar', 'baz'))
+ self.assertTrue(model.is_dirty("foo"))
+ self.assertTrue(model.is_dirty("bar"))
+ self.assertTrue(model.is_dirty("baz"))
+ self.assertTrue(model.is_dirty("foo", "bar", "baz"))
def test_calculated_attributes(self):
model = OrmModelStub()
- model.password = 'secret'
+ model.password = "secret"
attributes = model.get_attributes()
- self.assertFalse('password' in attributes)
- self.assertEqual('******', model.password)
- self.assertEqual('5ebe2294ecd0e0f08eab7690d2a6ee69', attributes['password_hash'])
- self.assertEqual('5ebe2294ecd0e0f08eab7690d2a6ee69', model.password_hash)
+ self.assertFalse("password" in attributes)
+ self.assertEqual("******", model.password)
+ self.assertEqual(
+ "5ebe2294ecd0e0f08eab7690d2a6ee69", attributes["password_hash"]
+ )
+ self.assertEqual("5ebe2294ecd0e0f08eab7690d2a6ee69", model.password_hash)
def test_new_instance_returns_instance_wit_attributes_set(self):
model = OrmModelStub()
- instance = model.new_instance({'name': 'john'})
+ instance = model.new_instance({"name": "john"})
self.assertIsInstance(instance, OrmModelStub)
- self.assertEqual('john', instance.name)
+ self.assertEqual("john", instance.name)
def test_hydrate_creates_collection_of_models(self):
- data = [
- {'name': 'john'},
- {'name': 'jane'}
- ]
- collection = OrmModelStub.hydrate(data, 'foo_connection')
+ data = [{"name": "john"}, {"name": "jane"}]
+ collection = OrmModelStub.hydrate(data, "foo_connection")
self.assertIsInstance(collection, Collection)
self.assertEqual(2, len(collection))
@@ -81,10 +79,10 @@ def test_hydrate_creates_collection_of_models(self):
self.assertIsInstance(collection[1], OrmModelStub)
self.assertEqual(collection[0].get_attributes(), collection[0].get_original())
self.assertEqual(collection[1].get_attributes(), collection[1].get_original())
- self.assertEqual('john', collection[0].name)
- self.assertEqual('jane', collection[1].name)
- self.assertEqual('foo_connection', collection[0].get_connection_name())
- self.assertEqual('foo_connection', collection[1].get_connection_name())
+ self.assertEqual("john", collection[0].name)
+ self.assertEqual("jane", collection[1].name)
+ self.assertEqual("foo_connection", collection[0].get_connection_name())
+ self.assertEqual("foo_connection", collection[1].get_connection_name())
def test_hydrate_raw_makes_raw_query(self):
model = OrmModelHydrateRawStub()
@@ -97,22 +95,22 @@ def _set_connection(name):
return model
- OrmModelHydrateRawStub.set_connection = mock.MagicMock(side_effect=_set_connection)
- collection = OrmModelHydrateRawStub.hydrate_raw('SELECT ?', ['foo'])
- self.assertEqual('hydrated', collection)
- connection.select.assert_called_once_with(
- 'SELECT ?', ['foo']
+ OrmModelHydrateRawStub.set_connection = mock.MagicMock(
+ side_effect=_set_connection
)
+ collection = OrmModelHydrateRawStub.hydrate_raw("SELECT ?", ["foo"])
+ self.assertEqual("hydrated", collection)
+ connection.select.assert_called_once_with("SELECT ?", ["foo"])
def test_create_saves_new_model(self):
- model = OrmModelSaveStub.create(name='john')
+ model = OrmModelSaveStub.create(name="john")
self.assertTrue(model.get_saved())
- self.assertEqual('john', model.name)
+ self.assertEqual("john", model.name)
def test_find_method_calls_query_builder_correctly(self):
result = OrmModelFindStub.find(1)
- self.assertEqual('foo', result)
+ self.assertEqual("foo", result)
def test_find_use_write_connection(self):
OrmModelFindWithWriteConnectionStub.on_write_connection().find(1)
@@ -120,42 +118,44 @@ def test_find_use_write_connection(self):
def test_find_with_list_calls_query_builder_correctly(self):
result = OrmModelFindManyStub.find([1, 2])
- self.assertEqual('foo', result)
+ self.assertEqual("foo", result)
def test_destroy_method_calls_query_builder_correctly(self):
OrmModelDestroyStub.destroy(1, 2, 3)
def test_with_calls_query_builder_correctly(self):
- result = OrmModelWithStub.with_('foo', 'bar')
- self.assertEqual('foo', result)
+ result = OrmModelWithStub.with_("foo", "bar")
+ self.assertEqual("foo", result)
def test_update_process(self):
query = flexmock(Builder)
- query.should_receive('where').once().with_args('id', 1)
- query.should_receive('update').once().with_args({'name': 'john'})
+ query.should_receive("where").once().with_args("id", 1)
+ query.should_receive("update").once().with_args({"name": "john"})
model = OrmModelStub()
- model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None)))
+ model.new_query = mock.MagicMock(
+ return_value=Builder(QueryBuilder(None, None, None))
+ )
model._update_timestamps = mock.MagicMock()
events = flexmock(Event())
model.__dispatcher__ = events
- events.should_receive('fire').once()\
- .with_args('saving: %s' % model.__class__.__name__, model)\
- .and_return(True)
- events.should_receive('fire').once()\
- .with_args('updating: %s' % model.__class__.__name__, model)\
- .and_return(True)
- events.should_receive('fire').once()\
- .with_args('updated: %s' % model.__class__.__name__, model)\
- .and_return(True)
- events.should_receive('fire').once()\
- .with_args('saved: %s' % model.__class__.__name__, model)\
- .and_return(True)
+ events.should_receive("fire").once().with_args(
+ "saving: %s" % model.__class__.__name__, model
+ ).and_return(True)
+ events.should_receive("fire").once().with_args(
+ "updating: %s" % model.__class__.__name__, model
+ ).and_return(True)
+ events.should_receive("fire").once().with_args(
+ "updated: %s" % model.__class__.__name__, model
+ ).and_return(True)
+ events.should_receive("fire").once().with_args(
+ "saved: %s" % model.__class__.__name__, model
+ ).and_return(True)
model.id = 1
- model.foo = 'bar'
+ model.foo = "bar"
model.sync_original()
- model.name = 'john'
+ model.name = "john"
model.set_exists(True)
self.assertTrue(model.save())
@@ -164,32 +164,36 @@ def test_update_process(self):
def test_update_process_does_not_override_timestamps(self):
query = flexmock(Builder)
- query.should_receive('where').once().with_args('id', 1)
- query.should_receive('update').once().with_args({'created_at': 'foo', 'updated_at': 'bar'})
+ query.should_receive("where").once().with_args("id", 1)
+ query.should_receive("update").once().with_args(
+ {"created_at": "foo", "updated_at": "bar"}
+ )
model = OrmModelStub()
- model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None)))
+ model.new_query = mock.MagicMock(
+ return_value=Builder(QueryBuilder(None, None, None))
+ )
model._update_timestamps = mock.MagicMock()
events = flexmock(Event())
model.__dispatcher__ = events
- events.should_receive('fire').once()\
- .with_args('saving: %s' % model.__class__.__name__, model)\
- .and_return(True)
- events.should_receive('fire').once()\
- .with_args('updating: %s' % model.__class__.__name__, model)\
- .and_return(True)
- events.should_receive('fire').once()\
- .with_args('updated: %s' % model.__class__.__name__, model)\
- .and_return(True)
- events.should_receive('fire').once()\
- .with_args('saved: %s' % model.__class__.__name__, model)\
- .and_return(True)
+ events.should_receive("fire").once().with_args(
+ "saving: %s" % model.__class__.__name__, model
+ ).and_return(True)
+ events.should_receive("fire").once().with_args(
+ "updating: %s" % model.__class__.__name__, model
+ ).and_return(True)
+ events.should_receive("fire").once().with_args(
+ "updated: %s" % model.__class__.__name__, model
+ ).and_return(True)
+ events.should_receive("fire").once().with_args(
+ "saved: %s" % model.__class__.__name__, model
+ ).and_return(True)
model.id = 1
model.sync_original()
- model.created_at = 'foo'
- model.updated_at = 'bar'
+ model.created_at = "foo"
+ model.updated_at = "bar"
model.set_exists(True)
self.assertTrue(model.save())
@@ -198,90 +202,104 @@ def test_update_process_does_not_override_timestamps(self):
def test_creating_with_only_created_at_column(self):
query_builder = flexmock(QueryBuilder)
- query_builder.should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1)
+ query_builder.should_receive("insert_get_id").once().with_args(
+ {"name": "john"}, "id"
+ ).and_return(1)
model = flexmock(OrmModelCreatedAt())
- model.should_receive('new_query').and_return(Builder(QueryBuilder(None, None, None)))
- model.should_receive('set_created_at').once()
- model.should_receive('set_updated_at').never()
- model.name = 'john'
+ model.should_receive("new_query").and_return(
+ Builder(QueryBuilder(None, None, None))
+ )
+ model.should_receive("set_created_at").once()
+ model.should_receive("set_updated_at").never()
+ model.name = "john"
model.save()
def test_creating_with_only_updated_at_column(self):
query_builder = flexmock(QueryBuilder)
- query_builder.should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1)
+ query_builder.should_receive("insert_get_id").once().with_args(
+ {"name": "john"}, "id"
+ ).and_return(1)
model = flexmock(OrmModelUpdatedAt())
- model.should_receive('new_query').and_return(Builder(QueryBuilder(None, None, None)))
- model.should_receive('set_created_at').never()
- model.should_receive('set_updated_at').once()
- model.name = 'john'
+ model.should_receive("new_query").and_return(
+ Builder(QueryBuilder(None, None, None))
+ )
+ model.should_receive("set_created_at").never()
+ model.should_receive("set_updated_at").once()
+ model.name = "john"
model.save()
def test_updating_with_only_created_at_column(self):
query = flexmock(Builder)
- query.should_receive('where').once().with_args('id', 1)
- query.should_receive('update').once().with_args({'name': 'john'})
+ query.should_receive("where").once().with_args("id", 1)
+ query.should_receive("update").once().with_args({"name": "john"})
model = flexmock(OrmModelCreatedAt())
model.id = 1
model.sync_original()
model.set_exists(True)
- model.should_receive('new_query').and_return(Builder(QueryBuilder(None, None, None)))
- model.should_receive('set_created_at').never()
- model.should_receive('set_updated_at').never()
- model.name = 'john'
+ model.should_receive("new_query").and_return(
+ Builder(QueryBuilder(None, None, None))
+ )
+ model.should_receive("set_created_at").never()
+ model.should_receive("set_updated_at").never()
+ model.name = "john"
model.save()
def test_updating_with_only_updated_at_column(self):
query = flexmock(Builder)
- query.should_receive('where').once().with_args('id', 1)
- query.should_receive('update').once().with_args({'name': 'john'})
+ query.should_receive("where").once().with_args("id", 1)
+ query.should_receive("update").once().with_args({"name": "john"})
model = flexmock(OrmModelUpdatedAt())
model.id = 1
model.sync_original()
model.set_exists(True)
- model.should_receive('new_query').and_return(Builder(QueryBuilder(None, None, None)))
- model.should_receive('set_created_at').never()
- model.should_receive('set_updated_at').once()
- model.name = 'john'
+ model.should_receive("new_query").and_return(
+ Builder(QueryBuilder(None, None, None))
+ )
+ model.should_receive("set_created_at").never()
+ model.should_receive("set_updated_at").once()
+ model.name = "john"
model.save()
def test_update_is_cancelled_if_updating_event_returns_false(self):
model = flexmock(OrmModelStub())
query = flexmock(Builder(flexmock(QueryBuilder(None, None, None))))
- model.should_receive('new_query_without_scopes').once().and_return(query)
+ model.should_receive("new_query_without_scopes").once().and_return(query)
events = flexmock(Event())
model.__dispatcher__ = events
- events.should_receive('fire').once()\
- .with_args('saving: %s' % model.__class__.__name__, model)\
- .and_return(True)
- events.should_receive('fire').once()\
- .with_args('updating: %s' % model.__class__.__name__, model)\
- .and_return(False)
+ events.should_receive("fire").once().with_args(
+ "saving: %s" % model.__class__.__name__, model
+ ).and_return(True)
+ events.should_receive("fire").once().with_args(
+ "updating: %s" % model.__class__.__name__, model
+ ).and_return(False)
model.set_exists(True)
- model.foo = 'bar'
+ model.foo = "bar"
self.assertFalse(model.save())
def test_update_process_without_timestamps(self):
query = flexmock(Builder)
- query.should_receive('where').once().with_args('id', 1)
- query.should_receive('update').once().with_args({'name': 'john'})
+ query.should_receive("where").once().with_args("id", 1)
+ query.should_receive("update").once().with_args({"name": "john"})
model = flexmock(OrmModelStub())
model.__timestamps__ = False
- model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None)))
+ model.new_query = mock.MagicMock(
+ return_value=Builder(QueryBuilder(None, None, None))
+ )
model._update_timestamps = mock.MagicMock()
events = flexmock(Event())
model.__dispatcher__ = events
- model.should_receive('_fire_model_event').and_return(True)
+ model.should_receive("_fire_model_event").and_return(True)
model.id = 1
model.sync_original()
- model.name = 'john'
+ model.name = "john"
model.set_exists(True)
self.assertTrue(model.save())
@@ -290,32 +308,34 @@ def test_update_process_without_timestamps(self):
def test_update_process_uses_old_primary_key(self):
query = flexmock(Builder)
- query.should_receive('where').once().with_args('id', 1)
- query.should_receive('update').once().with_args({'id': 2, 'name': 'john'})
+ query.should_receive("where").once().with_args("id", 1)
+ query.should_receive("update").once().with_args({"id": 2, "name": "john"})
model = OrmModelStub()
- model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None)))
+ model.new_query = mock.MagicMock(
+ return_value=Builder(QueryBuilder(None, None, None))
+ )
model._update_timestamps = mock.MagicMock()
events = flexmock(Event())
model.__dispatcher__ = events
- events.should_receive('fire').once()\
- .with_args('saving: %s' % model.__class__.__name__, model)\
- .and_return(True)
- events.should_receive('fire').once()\
- .with_args('updating: %s' % model.__class__.__name__, model)\
- .and_return(True)
- events.should_receive('fire').once()\
- .with_args('updated: %s' % model.__class__.__name__, model)\
- .and_return(True)
- events.should_receive('fire').once()\
- .with_args('saved: %s' % model.__class__.__name__, model)\
- .and_return(True)
+ events.should_receive("fire").once().with_args(
+ "saving: %s" % model.__class__.__name__, model
+ ).and_return(True)
+ events.should_receive("fire").once().with_args(
+ "updating: %s" % model.__class__.__name__, model
+ ).and_return(True)
+ events.should_receive("fire").once().with_args(
+ "updated: %s" % model.__class__.__name__, model
+ ).and_return(True)
+ events.should_receive("fire").once().with_args(
+ "saved: %s" % model.__class__.__name__, model
+ ).and_return(True)
model.id = 1
model.sync_original()
model.id = 2
- model.name = 'john'
+ model.name = "john"
model.set_exists(True)
self.assertTrue(model.save())
@@ -324,20 +344,18 @@ def test_update_process_uses_old_primary_key(self):
def test_timestamps_are_returned_as_objects(self):
model = Model()
- model.set_raw_attributes({
- 'created_at': '2015-03-24',
- 'updated_at': '2015-03-24'
- })
+ model.set_raw_attributes(
+ {"created_at": "2015-03-24", "updated_at": "2015-03-24"}
+ )
self.assertIsInstance(model.created_at, Pendulum)
self.assertIsInstance(model.updated_at, Pendulum)
def test_timestamps_are_returned_as_objects_from_timestamps_and_datetime(self):
model = Model()
- model.set_raw_attributes({
- 'created_at': datetime.datetime.utcnow(),
- 'updated_at': time.time()
- })
+ model.set_raw_attributes(
+ {"created_at": datetime.datetime.utcnow(), "updated_at": time.time()}
+ )
self.assertIsInstance(model.created_at, Pendulum)
self.assertIsInstance(model.updated_at, Pendulum)
@@ -347,8 +365,8 @@ def test_timestamps_are_returned_as_objects_on_create(self):
model.unguard()
timestamps = {
- 'created_at': datetime.datetime.now(),
- 'updated_at': datetime.datetime.now()
+ "created_at": datetime.datetime.now(),
+ "updated_at": datetime.datetime.now(),
}
instance = model.new_instance(timestamps)
@@ -363,8 +381,8 @@ def test_timestamps_return_none_if_set_to_none(self):
model.unguard()
timestamps = {
- 'created_at': datetime.datetime.now(),
- 'updated_at': datetime.datetime.now()
+ "created_at": datetime.datetime.now(),
+ "updated_at": datetime.datetime.now(),
}
instance = model.new_instance(timestamps)
@@ -379,26 +397,30 @@ def test_insert_process(self):
model = OrmModelStub()
query_builder = flexmock(QueryBuilder)
- query_builder.should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1)
- model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None)))
+ query_builder.should_receive("insert_get_id").once().with_args(
+ {"name": "john"}, "id"
+ ).and_return(1)
+ model.new_query = mock.MagicMock(
+ return_value=Builder(QueryBuilder(None, None, None))
+ )
model._update_timestamps = mock.MagicMock()
events = flexmock(Event())
model.__dispatcher__ = events
- events.should_receive('fire').once()\
- .with_args('saving: %s' % model.__class__.__name__, model)\
- .and_return(True)
- events.should_receive('fire').once()\
- .with_args('creating: %s' % model.__class__.__name__, model)\
- .and_return(True)
- events.should_receive('fire').once()\
- .with_args('created: %s' % model.__class__.__name__, model)\
- .and_return(True)
- events.should_receive('fire').once()\
- .with_args('saved: %s' % model.__class__.__name__, model)\
- .and_return(True)
-
- model.name = 'john'
+ events.should_receive("fire").once().with_args(
+ "saving: %s" % model.__class__.__name__, model
+ ).and_return(True)
+ events.should_receive("fire").once().with_args(
+ "creating: %s" % model.__class__.__name__, model
+ ).and_return(True)
+ events.should_receive("fire").once().with_args(
+ "created: %s" % model.__class__.__name__, model
+ ).and_return(True)
+ events.should_receive("fire").once().with_args(
+ "saved: %s" % model.__class__.__name__, model
+ ).and_return(True)
+
+ model.name = "john"
model.set_exists(False)
self.assertTrue(model.save())
self.assertEqual(1, model.id)
@@ -406,45 +428,47 @@ def test_insert_process(self):
self.assertTrue(model._update_timestamps.called)
model = OrmModelStub()
- query_builder.should_receive('insert').once().with_args({'name': 'john'})
- model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None)))
+ query_builder.should_receive("insert").once().with_args({"name": "john"})
+ model.new_query = mock.MagicMock(
+ return_value=Builder(QueryBuilder(None, None, None))
+ )
model._update_timestamps = mock.MagicMock()
model.set_incrementing(False)
events = flexmock(Event())
model.__dispatcher__ = events
- events.should_receive('fire').once()\
- .with_args('saving: %s' % model.__class__.__name__, model)\
- .and_return(True)
- events.should_receive('fire').once()\
- .with_args('creating: %s' % model.__class__.__name__, model)\
- .and_return(True)
- events.should_receive('fire').once()\
- .with_args('created: %s' % model.__class__.__name__, model)\
- .and_return(True)
- events.should_receive('fire').once()\
- .with_args('saved: %s' % model.__class__.__name__, model)\
- .and_return(True)
-
- model.name = 'john'
+ events.should_receive("fire").once().with_args(
+ "saving: %s" % model.__class__.__name__, model
+ ).and_return(True)
+ events.should_receive("fire").once().with_args(
+ "creating: %s" % model.__class__.__name__, model
+ ).and_return(True)
+ events.should_receive("fire").once().with_args(
+ "created: %s" % model.__class__.__name__, model
+ ).and_return(True)
+ events.should_receive("fire").once().with_args(
+ "saved: %s" % model.__class__.__name__, model
+ ).and_return(True)
+
+ model.name = "john"
model.set_exists(False)
self.assertTrue(model.save())
- self.assertFalse(hasattr(model, 'id'))
+ self.assertFalse(hasattr(model, "id"))
self.assertTrue(model.exists)
self.assertTrue(model._update_timestamps.called)
def test_insert_is_cancelled_if_creating_event_returns_false(self):
model = flexmock(OrmModelStub())
query = flexmock(Builder(flexmock(QueryBuilder(None, None, None))))
- model.should_receive('new_query_without_scopes').once().and_return(query)
+ model.should_receive("new_query_without_scopes").once().and_return(query)
events = flexmock(Event())
model.__dispatcher__ = events
- events.should_receive('fire').once()\
- .with_args('saving: %s' % model.__class__.__name__, model)\
- .and_return(True)
- events.should_receive('fire').once()\
- .with_args('creating: %s' % model.__class__.__name__, model)\
- .and_return(False)
+ events.should_receive("fire").once().with_args(
+ "saving: %s" % model.__class__.__name__, model
+ ).and_return(True)
+ events.should_receive("fire").once().with_args(
+ "creating: %s" % model.__class__.__name__, model
+ ).and_return(False)
self.assertFalse(model.save())
self.assertFalse(model.exists)
@@ -452,8 +476,8 @@ def test_insert_is_cancelled_if_creating_event_returns_false(self):
def test_delete_properly_deletes_model(self):
model = OrmModelStub()
builder = flexmock(Builder(QueryBuilder(None, None, None)))
- builder.should_receive('where').once().with_args('id', 1).and_return(builder)
- builder.should_receive('delete').once()
+ builder.should_receive("where").once().with_args("id", 1).and_return(builder)
+ builder.should_receive("delete").once()
model.new_query = mock.MagicMock(return_value=builder)
model.touch_owners = mock.MagicMock()
@@ -465,13 +489,19 @@ def test_delete_properly_deletes_model(self):
def test_push_no_relations(self):
model = flexmock(Model())
- query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()))
+ query = flexmock(
+ QueryBuilder(
+ MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()
+ )
+ )
builder = Builder(query)
- builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1)
- model.should_receive('new_query').once().and_return(builder)
- model.should_receive('_update_timestamps').once()
+ builder.get_query().should_receive("insert_get_id").once().with_args(
+ {"name": "john"}, "id"
+ ).and_return(1)
+ model.should_receive("new_query").once().and_return(builder)
+ model.should_receive("_update_timestamps").once()
- model.name = 'john'
+ model.name = "john"
model.set_exists(False)
self.assertTrue(model.push())
@@ -480,15 +510,21 @@ def test_push_no_relations(self):
def test_push_empty_one_relation(self):
model = flexmock(Model())
- query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()))
+ query = flexmock(
+ QueryBuilder(
+ MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()
+ )
+ )
builder = Builder(query)
- builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1)
- model.should_receive('new_query').once().and_return(builder)
- model.should_receive('_update_timestamps').once()
+ builder.get_query().should_receive("insert_get_id").once().with_args(
+ {"name": "john"}, "id"
+ ).and_return(1)
+ model.should_receive("new_query").once().and_return(builder)
+ model.should_receive("_update_timestamps").once()
- model.name = 'john'
+ model.name = "john"
model.set_exists(False)
- model.set_relation('relation_one', None)
+ model.set_relation("relation_one", None)
self.assertTrue(model.push())
self.assertEqual(1, model.id)
@@ -497,26 +533,40 @@ def test_push_empty_one_relation(self):
def test_push_one_relation(self):
related1 = flexmock(Model())
- query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()))
+ query = flexmock(
+ QueryBuilder(
+ MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()
+ )
+ )
builder = Builder(query)
- builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'related1'}, 'id').and_return(2)
- related1.should_receive('new_query').once().and_return(builder)
- related1.should_receive('_update_timestamps').once()
+ builder.get_query().should_receive("insert_get_id").once().with_args(
+ {"name": "related1"}, "id"
+ ).and_return(2)
+ related1.should_receive("new_query").once().and_return(builder)
+ related1.should_receive("_update_timestamps").once()
- related1.name = 'related1'
+ related1.name = "related1"
related1.set_exists(False)
model = flexmock(Model())
- model.should_receive('resolve_connection').and_return(MockConnection().prepare_mock())
- query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()))
+ model.should_receive("resolve_connection").and_return(
+ MockConnection().prepare_mock()
+ )
+ query = flexmock(
+ QueryBuilder(
+ MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()
+ )
+ )
builder = Builder(query)
- builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1)
- model.should_receive('new_query').once().and_return(builder)
- model.should_receive('_update_timestamps').once()
+ builder.get_query().should_receive("insert_get_id").once().with_args(
+ {"name": "john"}, "id"
+ ).and_return(1)
+ model.should_receive("new_query").once().and_return(builder)
+ model.should_receive("_update_timestamps").once()
- model.name = 'john'
+ model.name = "john"
model.set_exists(False)
- model.set_relation('relation_one', related1)
+ model.set_relation("relation_one", related1)
self.assertTrue(model.push())
self.assertEqual(1, model.id)
@@ -528,15 +578,21 @@ def test_push_one_relation(self):
def test_push_empty_many_relation(self):
model = flexmock(Model())
- query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()))
+ query = flexmock(
+ QueryBuilder(
+ MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()
+ )
+ )
builder = Builder(query)
- builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1)
- model.should_receive('new_query').once().and_return(builder)
- model.should_receive('_update_timestamps').once()
+ builder.get_query().should_receive("insert_get_id").once().with_args(
+ {"name": "john"}, "id"
+ ).and_return(1)
+ model.should_receive("new_query").once().and_return(builder)
+ model.should_receive("_update_timestamps").once()
- model.name = 'john'
+ model.name = "john"
model.set_exists(False)
- model.set_relation('relation_many', Collection([]))
+ model.set_relation("relation_many", Collection([]))
self.assertTrue(model.push())
self.assertEqual(1, model.id)
@@ -545,51 +601,71 @@ def test_push_empty_many_relation(self):
def test_push_many_relation(self):
related1 = flexmock(Model())
- query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()))
+ query = flexmock(
+ QueryBuilder(
+ MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()
+ )
+ )
builder = Builder(query)
- builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'related1'}, 'id').and_return(2)
- related1.should_receive('new_query').once().and_return(builder)
- related1.should_receive('_update_timestamps').once()
+ builder.get_query().should_receive("insert_get_id").once().with_args(
+ {"name": "related1"}, "id"
+ ).and_return(2)
+ related1.should_receive("new_query").once().and_return(builder)
+ related1.should_receive("_update_timestamps").once()
- related1.name = 'related1'
+ related1.name = "related1"
related1.set_exists(False)
related2 = flexmock(Model())
- query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()))
+ query = flexmock(
+ QueryBuilder(
+ MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()
+ )
+ )
builder = Builder(query)
- builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'related2'}, 'id').and_return(3)
- related2.should_receive('new_query').once().and_return(builder)
- related2.should_receive('_update_timestamps').once()
+ builder.get_query().should_receive("insert_get_id").once().with_args(
+ {"name": "related2"}, "id"
+ ).and_return(3)
+ related2.should_receive("new_query").once().and_return(builder)
+ related2.should_receive("_update_timestamps").once()
- related2.name = 'related2'
+ related2.name = "related2"
related2.set_exists(False)
model = flexmock(Model())
- model.should_receive('resolve_connection').and_return(MockConnection().prepare_mock())
- query = flexmock(QueryBuilder(MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()))
+ model.should_receive("resolve_connection").and_return(
+ MockConnection().prepare_mock()
+ )
+ query = flexmock(
+ QueryBuilder(
+ MockConnection().prepare_mock(), QueryGrammar(), QueryProcessor()
+ )
+ )
builder = Builder(query)
- builder.get_query().should_receive('insert_get_id').once().with_args({'name': 'john'}, 'id').and_return(1)
- model.should_receive('new_query').once().and_return(builder)
- model.should_receive('_update_timestamps').once()
+ builder.get_query().should_receive("insert_get_id").once().with_args(
+ {"name": "john"}, "id"
+ ).and_return(1)
+ model.should_receive("new_query").once().and_return(builder)
+ model.should_receive("_update_timestamps").once()
- model.name = 'john'
+ model.name = "john"
model.set_exists(False)
- model.set_relation('relation_many', Collection([related1, related2]))
+ model.set_relation("relation_many", Collection([related1, related2]))
self.assertTrue(model.push())
self.assertEqual(1, model.id)
self.assertTrue(model.exists)
self.assertEqual(2, len(model.relation_many))
- self.assertEqual([2, 3], model.relation_many.lists('id'))
+ self.assertEqual([2, 3], model.relation_many.lists("id"))
def test_new_query_returns_orator_query_builder(self):
conn = flexmock(Connection)
grammar = flexmock(QueryGrammar)
processor = flexmock(QueryProcessor)
- conn.should_receive('get_query_grammar').and_return(grammar)
- conn.should_receive('get_post_processor').and_return(processor)
+ conn.should_receive("get_query_grammar").and_return(grammar)
+ conn.should_receive("get_post_processor").and_return(processor)
resolver = flexmock(DatabaseManager)
- resolver.should_receive('connection').and_return(Connection(None))
+ resolver.should_receive("connection").and_return(Connection(None))
OrmModelStub.set_connection_resolver(DatabaseManager({}))
model = OrmModelStub()
@@ -598,175 +674,174 @@ def test_new_query_returns_orator_query_builder(self):
def test_get_and_set_table(self):
model = OrmModelStub()
- self.assertEqual('stub', model.get_table())
- model.set_table('foo')
- self.assertEqual('foo', model.get_table())
+ self.assertEqual("stub", model.get_table())
+ model.set_table("foo")
+ self.assertEqual("foo", model.get_table())
def test_get_key_returns_primary_key_value(self):
model = OrmModelStub()
model.id = 1
self.assertEqual(1, model.get_key())
- self.assertEqual('id', model.get_key_name())
+ self.assertEqual("id", model.get_key_name())
def test_connection_management(self):
resolver = flexmock(DatabaseManager)
- resolver.should_receive('connection').once().with_args('foo').and_return('bar')
+ resolver.should_receive("connection").once().with_args("foo").and_return("bar")
OrmModelStub.set_connection_resolver(DatabaseManager({}))
model = OrmModelStub()
- model.set_connection('foo')
+ model.set_connection("foo")
- self.assertEqual('bar', model.get_connection())
+ self.assertEqual("bar", model.get_connection())
def test_serialize(self):
model = OrmModelStub()
- model.name = 'foo'
+ model.name = "foo"
model.age = None
- model.password = 'password1'
- model.set_hidden(['password'])
- model.set_relation('names', Collection([OrmModelStub(bar='baz'), OrmModelStub(bam='boom')]))
- model.set_relation('partner', OrmModelStub(name='jane'))
- model.set_relation('group', None)
- model.set_relation('multi', Collection())
+ model.password = "password1"
+ model.set_hidden(["password"])
+ model.set_relation(
+ "names", Collection([OrmModelStub(bar="baz"), OrmModelStub(bam="boom")])
+ )
+ model.set_relation("partner", OrmModelStub(name="jane"))
+ model.set_relation("group", None)
+ model.set_relation("multi", Collection())
d = model.serialize()
self.assertIsInstance(d, dict)
- self.assertEqual('foo', d['name'])
- self.assertEqual('baz', d['names'][0]['bar'])
- self.assertEqual('boom', d['names'][1]['bam'])
- self.assertEqual('jane', d['partner']['name'])
- self.assertIsNone(d['group'])
- self.assertEqual([], d['multi'])
- self.assertIsNone(d['age'])
- self.assertNotIn('password', d)
-
- model.set_appends(['appendable'])
+ self.assertEqual("foo", d["name"])
+ self.assertEqual("baz", d["names"][0]["bar"])
+ self.assertEqual("boom", d["names"][1]["bam"])
+ self.assertEqual("jane", d["partner"]["name"])
+ self.assertIsNone(d["group"])
+ self.assertEqual([], d["multi"])
+ self.assertIsNone(d["age"])
+ self.assertNotIn("password", d)
+
+ model.set_appends(["appendable"])
d = model.to_dict()
- self.assertEqual('appended', d['appendable'])
+ self.assertEqual("appended", d["appendable"])
def test_to_dict_includes_default_formatted_timestamps(self):
model = Model()
- model.set_raw_attributes({
- 'created_at': '2015-03-24',
- 'updated_at': '2015-03-25'
- })
+ model.set_raw_attributes(
+ {"created_at": "2015-03-24", "updated_at": "2015-03-25"}
+ )
d = model.to_dict()
- self.assertEqual('2015-03-24T00:00:00+00:00', d['created_at'])
- self.assertEqual('2015-03-25T00:00:00+00:00', d['updated_at'])
+ self.assertEqual("2015-03-24T00:00:00+00:00", d["created_at"])
+ self.assertEqual("2015-03-25T00:00:00+00:00", d["updated_at"])
def test_to_dict_includes_custom_formatted_timestamps(self):
class Stub(Model):
-
def get_date_format(self):
- return '%d-%m-%-y'
+ return "%d-%m-%-y"
- flexmock(Stub).should_receive('_boot_columns').and_return(['created_at', 'updated_at'])
+ flexmock(Stub).should_receive("_boot_columns").and_return(
+ ["created_at", "updated_at"]
+ )
model = Stub()
- model.set_raw_attributes({
- 'created_at': '2015-03-24',
- 'updated_at': '2015-03-25'
- })
+ model.set_raw_attributes(
+ {"created_at": "2015-03-24", "updated_at": "2015-03-25"}
+ )
d = model.to_dict()
- self.assertEqual('24-03-15', d['created_at'])
- self.assertEqual('25-03-15', d['updated_at'])
+ self.assertEqual("24-03-15", d["created_at"])
+ self.assertEqual("25-03-15", d["updated_at"])
def test_visible_creates_dict_whitelist(self):
model = OrmModelStub()
- model.set_visible(['name'])
- model.name = 'John'
+ model.set_visible(["name"])
+ model.name = "John"
model.age = 28
d = model.to_dict()
- self.assertEqual({'name': 'John'}, d)
+ self.assertEqual({"name": "John"}, d)
def test_hidden_can_also_exclude_relationships(self):
model = OrmModelStub()
- model.name = 'John'
- model.set_relation('foo', ['bar'])
- model.set_hidden(['foo', 'list_items', 'password'])
+ model.name = "John"
+ model.set_relation("foo", ["bar"])
+ model.set_hidden(["foo", "list_items", "password"])
d = model.to_dict()
- self.assertEqual({'name': 'John'}, d)
+ self.assertEqual({"name": "John"}, d)
def test_to_dict_uses_mutators(self):
model = OrmModelStub()
model.list_items = [1, 2, 3]
d = model.to_dict()
- self.assertEqual([1, 2, 3], d['list_items'])
+ self.assertEqual([1, 2, 3], d["list_items"])
model = OrmModelStub(list_items=[1, 2, 3])
d = model.to_dict()
- self.assertEqual([1, 2, 3], d['list_items'])
+ self.assertEqual([1, 2, 3], d["list_items"])
def test_hidden_are_ignored_when_visible(self):
- model = OrmModelStub(name='john', age=28, id='foo')
- model.set_visible(['name', 'id'])
- model.set_hidden(['name', 'age'])
+ model = OrmModelStub(name="john", age=28, id="foo")
+ model.set_visible(["name", "id"])
+ model.set_hidden(["name", "age"])
d = model.to_dict()
- self.assertIn('name', d)
- self.assertIn('id', d)
- self.assertNotIn('age', d)
+ self.assertIn("name", d)
+ self.assertIn("id", d)
+ self.assertNotIn("age", d)
def test_fillable(self):
model = OrmModelStub()
- model.fillable(['name', 'age'])
- model.fill(name='foo', age=28)
- self.assertEqual('foo', model.name)
+ model.fillable(["name", "age"])
+ model.fill(name="foo", age=28)
+ self.assertEqual("foo", model.name)
self.assertEqual(28, model.age)
def test_fill_with_dict(self):
model = OrmModelStub()
- model.fill({'name': 'foo', 'age': 28})
- self.assertEqual('foo', model.name)
+ model.fill({"name": "foo", "age": 28})
+ self.assertEqual("foo", model.name)
self.assertEqual(28, model.age)
def test_unguard_allows_anything(self):
model = OrmModelStub()
model.unguard()
- model.guard(['*'])
- model.fill(name='foo', age=28)
- self.assertEqual('foo', model.name)
+ model.guard(["*"])
+ model.fill(name="foo", age=28)
+ self.assertEqual("foo", model.name)
self.assertEqual(28, model.age)
model.reguard()
def test_underscore_properties_are_not_filled(self):
model = OrmModelStub()
- model.fill(_foo='bar')
+ model.fill(_foo="bar")
self.assertEqual({}, model.get_attributes())
def test_guarded(self):
model = OrmModelStub()
- model.guard(['name', 'age'])
- model.fill(name='foo', age='bar', foo='bar')
- self.assertFalse(hasattr(model, 'name'))
- self.assertFalse(hasattr(model, 'age'))
- self.assertEqual('bar', model.foo)
+ model.guard(["name", "age"])
+ model.fill(name="foo", age="bar", foo="bar")
+ self.assertFalse(hasattr(model, "name"))
+ self.assertFalse(hasattr(model, "age"))
+ self.assertEqual("bar", model.foo)
def test_fillable_overrides_guarded(self):
model = OrmModelStub()
- model.guard(['name', 'age'])
- model.fillable(['age', 'foo'])
- model.fill(name='foo', age='bar', foo='bar')
- self.assertFalse(hasattr(model, 'name'))
- self.assertEqual('bar', model.age)
- self.assertEqual('bar', model.foo)
+ model.guard(["name", "age"])
+ model.fillable(["age", "foo"])
+ model.fill(name="foo", age="bar", foo="bar")
+ self.assertFalse(hasattr(model, "name"))
+ self.assertEqual("bar", model.age)
+ self.assertEqual("bar", model.foo)
def test_global_guarded(self):
model = OrmModelStub()
- model.guard(['*'])
+ model.guard(["*"])
self.assertRaises(
- MassAssignmentError,
- model.fill,
- name='foo', age='bar', foo='bar'
+ MassAssignmentError, model.fill, name="foo", age="bar", foo="bar"
)
# TODO: test relations
@@ -774,18 +849,16 @@ def test_global_guarded(self):
def test_models_assumes_their_name(self):
model = OrmModelNoTableStub()
- self.assertEqual('orm_model_no_table_stubs', model.get_table())
+ self.assertEqual("orm_model_no_table_stubs", model.get_table())
def test_mutator_cache_is_populated(self):
model = OrmModelStub()
- expected_attributes = sorted([
- 'list_items',
- 'password',
- 'appendable'
- ])
+ expected_attributes = sorted(["list_items", "password", "appendable"])
- self.assertEqual(expected_attributes, sorted(list(model._get_mutated_attributes().keys())))
+ self.assertEqual(
+ expected_attributes, sorted(list(model._get_mutated_attributes().keys()))
+ )
def test_fresh_method(self):
model = flexmock(OrmModelStub())
@@ -794,14 +867,14 @@ def test_fresh_method(self):
flexmock(Builder)
q = flexmock(QueryBuilder(None, None, None))
query = flexmock(Builder(q))
- query.should_receive('where').and_return(query)
- query.get_query().should_receive('take').and_return(query)
- query.should_receive('get').and_return(Collection([]))
- model.should_receive('with_').once().with_args('foo', 'bar').and_return(query)
+ query.should_receive("where").and_return(query)
+ query.get_query().should_receive("take").and_return(query)
+ query.should_receive("get").and_return(Collection([]))
+ model.should_receive("with_").once().with_args("foo", "bar").and_return(query)
- model.fresh(['foo', 'bar'])
+ model.fresh(["foo", "bar"])
- model.should_receive('with_').once().with_args().and_return(query)
+ model.should_receive("with_").once().with_args().and_return(query)
model.fresh()
@@ -809,33 +882,33 @@ def test_clone_model_makes_a_fresh_copy(self):
model = OrmModelStub()
model.id = 1
model.set_exists(True)
- model.first = 'john'
- model.last = 'doe'
+ model.first = "john"
+ model.last = "doe"
model.created_at = model.fresh_timestamp()
model.updated_at = model.fresh_timestamp()
# TODO: relation
clone = model.replicate()
- self.assertFalse(hasattr(clone, 'id'))
+ self.assertFalse(hasattr(clone, "id"))
self.assertFalse(clone.exists)
- self.assertEqual('john', clone.first)
- self.assertEqual('doe', clone.last)
- self.assertFalse(hasattr(clone, 'created_at'))
- self.assertFalse(hasattr(clone, 'updated_at'))
+ self.assertEqual("john", clone.first)
+ self.assertEqual("doe", clone.last)
+ self.assertFalse(hasattr(clone, "created_at"))
+ self.assertFalse(hasattr(clone, "updated_at"))
# TODO: relation
- clone.first = 'jane'
+ clone.first = "jane"
- self.assertEqual('john', model.first)
- self.assertEqual('jane', clone.first)
+ self.assertEqual("john", model.first)
+ self.assertEqual("jane", clone.first)
def test_get_attribute_raise_attribute_error(self):
model = OrmModelStub()
try:
relation = model.incorrect_relation
- self.fail('AttributeError not raised')
+ self.fail("AttributeError not raised")
except AttributeError:
pass
@@ -845,14 +918,14 @@ def test_increment(self):
model = OrmModelStub()
model.set_exists(True)
model.id = 1
- model.sync_original_attribute('id')
+ model.sync_original_attribute("id")
model.foo = 2
- model_mock.should_receive('new_query').and_return(query)
- query.should_receive('where').and_return(query)
- query.should_receive('increment')
+ model_mock.should_receive("new_query").and_return(query)
+ query.should_receive("where").and_return(query)
+ query.should_receive("increment")
- model.public_increment('foo')
+ model.public_increment("foo")
self.assertEqual(3, model.foo)
self.assertFalse(model.is_dirty())
@@ -863,31 +936,33 @@ def test_increment(self):
def test_timestamps_are_not_update_with_timestamps_false_save_option(self):
query = flexmock(Builder)
- query.should_receive('where').once().with_args('id', 1)
- query.should_receive('update').once().with_args({'name': 'john'})
+ query.should_receive("where").once().with_args("id", 1)
+ query.should_receive("update").once().with_args({"name": "john"})
model = OrmModelStub()
- model.new_query = mock.MagicMock(return_value=Builder(QueryBuilder(None, None, None)))
+ model.new_query = mock.MagicMock(
+ return_value=Builder(QueryBuilder(None, None, None))
+ )
model.id = 1
model.sync_original()
- model.name = 'john'
+ model.name = "john"
model.set_exists(True)
- self.assertTrue(model.save({'timestamps': False}))
- self.assertFalse(hasattr(model, 'updated_at'))
+ self.assertTrue(model.save({"timestamps": False}))
+ self.assertFalse(hasattr(model, "updated_at"))
model.new_query.assert_called_once_with()
def test_casts(self):
model = OrmModelCastingStub()
- model.first = '3'
- model.second = '4.0'
+ model.first = "3"
+ model.second = "4.0"
model.third = 2.5
model.fourth = 1
model.fifth = 0
- model.sixth = {'foo': 'bar'}
- model.seventh = ['foo', 'bar']
- model.eighth = {'foo': 'bar'}
+ model.sixth = {"foo": "bar"}
+ model.seventh = ["foo", "bar"]
+ model.eighth = {"foo": "bar"}
self.assertIsInstance(model.first, int)
self.assertIsInstance(model.second, float)
@@ -899,25 +974,25 @@ def test_casts(self):
self.assertIsInstance(model.eighth, dict)
self.assertTrue(model.fourth)
self.assertFalse(model.fifth)
- self.assertEqual({'foo': 'bar'}, model.sixth)
- self.assertEqual({'foo': 'bar'}, model.eighth)
- self.assertEqual(['foo', 'bar'], model.seventh)
+ self.assertEqual({"foo": "bar"}, model.sixth)
+ self.assertEqual({"foo": "bar"}, model.eighth)
+ self.assertEqual(["foo", "bar"], model.seventh)
d = model.to_dict()
- self.assertIsInstance(d['first'], int)
- self.assertIsInstance(d['second'], float)
- self.assertIsInstance(d['third'], basestring)
- self.assertIsInstance(d['fourth'], bool)
- self.assertIsInstance(d['fifth'], bool)
- self.assertIsInstance(d['sixth'], dict)
- self.assertIsInstance(d['seventh'], list)
- self.assertIsInstance(d['eighth'], dict)
- self.assertTrue(d['fourth'])
- self.assertFalse(d['fifth'])
- self.assertEqual({'foo': 'bar'}, d['sixth'])
- self.assertEqual({'foo': 'bar'}, d['eighth'])
- self.assertEqual(['foo', 'bar'], d['seventh'])
+ self.assertIsInstance(d["first"], int)
+ self.assertIsInstance(d["second"], float)
+ self.assertIsInstance(d["third"], basestring)
+ self.assertIsInstance(d["fourth"], bool)
+ self.assertIsInstance(d["fifth"], bool)
+ self.assertIsInstance(d["sixth"], dict)
+ self.assertIsInstance(d["seventh"], list)
+ self.assertIsInstance(d["eighth"], dict)
+ self.assertTrue(d["fourth"])
+ self.assertFalse(d["fifth"])
+ self.assertEqual({"foo": "bar"}, d["sixth"])
+ self.assertEqual({"foo": "bar"}, d["eighth"])
+ self.assertEqual(["foo", "bar"], d["seventh"])
def test_casts_preserve_null(self):
model = OrmModelCastingStub()
@@ -941,57 +1016,57 @@ def test_casts_preserve_null(self):
d = model.to_dict()
- self.assertIsNone(d['first'])
- self.assertIsNone(d['second'])
- self.assertIsNone(d['third'])
- self.assertIsNone(d['fourth'])
- self.assertIsNone(d['fifth'])
- self.assertIsNone(d['sixth'])
- self.assertIsNone(d['seventh'])
- self.assertIsNone(d['eighth'])
+ self.assertIsNone(d["first"])
+ self.assertIsNone(d["second"])
+ self.assertIsNone(d["third"])
+ self.assertIsNone(d["fourth"])
+ self.assertIsNone(d["fifth"])
+ self.assertIsNone(d["sixth"])
+ self.assertIsNone(d["seventh"])
+ self.assertIsNone(d["eighth"])
def test_get_foreign_key(self):
model = OrmModelStub()
- model.set_table('stub')
+ model.set_table("stub")
- self.assertEqual('stub_id', model.get_foreign_key())
+ self.assertEqual("stub_id", model.get_foreign_key())
def test_default_values(self):
model = OrmModelDefaultAttributes()
- self.assertEqual('bar', model.foo)
+ self.assertEqual("bar", model.foo)
def test_get_morph_name(self):
model = OrmModelStub()
- self.assertEqual('stub', model.get_morph_name())
+ self.assertEqual("stub", model.get_morph_name())
class OrmModelStub(Model):
- __table__ = 'stub'
+ __table__ = "stub"
__guarded__ = []
@accessor
def list_items(self):
- return json.loads(self.get_raw_attribute('list_items'))
+ return json.loads(self.get_raw_attribute("list_items"))
@list_items.mutator
def set_list_items(self, value):
- self.set_raw_attribute('list_items', json.dumps(value))
+ self.set_raw_attribute("list_items", json.dumps(value))
@mutator
def password(self, value):
- self.set_raw_attribute('password_hash', hashlib.md5(value.encode()).hexdigest())
+ self.set_raw_attribute("password_hash", hashlib.md5(value.encode()).hexdigest())
@password.accessor
def get_password(self):
- return '******'
+ return "******"
@accessor
def appendable(self):
- return 'appended'
+ return "appended"
def public_increment(self, column, amount=1):
return self._increment(column, amount)
@@ -1001,24 +1076,22 @@ def get_dates(self):
class OrmModelHydrateRawStub(Model):
-
@classmethod
def hydrate(cls, items, connection=None):
- return 'hydrated'
+ return "hydrated"
class OrmModelWithStub(Model):
-
def new_query(self):
mock = flexmock(Builder(None))
- mock.should_receive('with_').once().with_args('foo', 'bar').and_return('foo')
+ mock.should_receive("with_").once().with_args("foo", "bar").and_return("foo")
return mock
class OrmModelSaveStub(Model):
- __table__ = 'save_stub'
+ __table__ = "save_stub"
__guarded__ = []
@@ -1033,42 +1106,42 @@ def get_saved(self):
class OrmModelFindStub(Model):
-
def new_query(self):
- flexmock(Builder).should_receive('find').once().with_args(1, ['*']).and_return('foo')
+ flexmock(Builder).should_receive("find").once().with_args(1, ["*"]).and_return(
+ "foo"
+ )
return Builder(None)
class OrmModelFindWithWriteConnectionStub(Model):
-
def new_query(self):
mock = flexmock(Builder)
mock_query = flexmock(QueryBuilder)
- mock_query.should_receive('use_write_connection').once().and_return(flexmock)
- mock.should_receive('find').once().with_args(1).and_return('foo')
+ mock_query.should_receive("use_write_connection").once().and_return(flexmock)
+ mock.should_receive("find").once().with_args(1).and_return("foo")
return Builder(QueryBuilder(None, None, None))
class OrmModelFindManyStub(Model):
-
def new_query(self):
mock = flexmock(Builder)
- mock.should_receive('find').once().with_args([1, 2], ['*']).and_return('foo')
+ mock.should_receive("find").once().with_args([1, 2], ["*"]).and_return("foo")
return Builder(QueryBuilder(None, None, None))
class OrmModelDestroyStub(Model):
-
def new_query(self):
mock = flexmock(Builder)
model = flexmock()
mock_query = flexmock(QueryBuilder)
- mock_query.should_receive('where_in').once().with_args('id', [1, 2, 3]).and_return(flexmock)
- mock.should_receive('get').once().and_return([model])
- model.should_receive('delete').once()
+ mock_query.should_receive("where_in").once().with_args(
+ "id", [1, 2, 3]
+ ).and_return(flexmock)
+ mock.should_receive("get").once().and_return([model])
+ model.should_receive("delete").once()
return Builder(QueryBuilder(None, None, None))
@@ -1081,28 +1154,27 @@ class OrmModelNoTableStub(Model):
class OrmModelCastingStub(Model):
__casts__ = {
- 'first': 'int',
- 'second': 'float',
- 'third': 'str',
- 'fourth': 'bool',
- 'fifth': 'boolean',
- 'sixth': 'dict',
- 'seventh': 'list',
- 'eighth': 'json'
+ "first": "int",
+ "second": "float",
+ "third": "str",
+ "fourth": "bool",
+ "fifth": "boolean",
+ "sixth": "dict",
+ "seventh": "list",
+ "eighth": "json",
}
+
class OrmModelCreatedAt(Model):
- __timestamps__ = ['created_at']
+ __timestamps__ = ["created_at"]
class OrmModelUpdatedAt(Model):
- __timestamps__ = ['updated_at']
+ __timestamps__ = ["updated_at"]
class OrmModelDefaultAttributes(Model):
- __attributes__ = {
- 'foo': 'bar'
- }
+ __attributes__ = {"foo": "bar"}
diff --git a/tests/orm/test_model_global_scopes.py b/tests/orm/test_model_global_scopes.py
index 41955706..30fd7f64 100644
--- a/tests/orm/test_model_global_scopes.py
+++ b/tests/orm/test_model_global_scopes.py
@@ -8,7 +8,6 @@
class ModelGlobalScopesTestCase(OratorTestCase):
-
@classmethod
def setUpClass(cls):
Model.set_connection_resolver(DatabaseConnectionResolver())
@@ -21,10 +20,7 @@ def test_global_scope_is_applied(self):
model = GlobalScopesModel()
query = model.new_query()
- self.assertEqual(
- 'SELECT * FROM "table" WHERE "active" = ?',
- query.to_sql()
- )
+ self.assertEqual('SELECT * FROM "table" WHERE "active" = ?', query.to_sql())
self.assertEqual([1], query.get_bindings())
@@ -32,10 +28,7 @@ def test_global_scope_can_be_removed(self):
model = GlobalScopesModel()
query = model.new_query().without_global_scope(ActiveScope)
- self.assertEqual(
- 'SELECT * FROM "table"',
- query.to_sql()
- )
+ self.assertEqual('SELECT * FROM "table"', query.to_sql())
self.assertEqual([], query.get_bindings())
@@ -45,19 +38,16 @@ def test_callable_global_scope_is_applied(self):
self.assertEqual(
'SELECT * FROM "table" WHERE "active" = ? ORDER BY "name" ASC',
- query.to_sql()
+ query.to_sql(),
)
self.assertEqual([1], query.get_bindings())
def test_callable_global_scope_can_be_removed(self):
model = CallableGlobalScopesModel()
- query = model.new_query().without_global_scope('active_scope')
+ query = model.new_query().without_global_scope("active_scope")
- self.assertEqual(
- 'SELECT * FROM "table" ORDER BY "name" ASC',
- query.to_sql()
- )
+ self.assertEqual('SELECT * FROM "table" ORDER BY "name" ASC', query.to_sql())
self.assertEqual([], query.get_bindings())
@@ -67,80 +57,76 @@ def test_global_scope_can_be_removed_after_query_is_executed(self):
self.assertEqual(
'SELECT * FROM "table" WHERE "active" = ? ORDER BY "name" ASC',
- query.to_sql()
+ query.to_sql(),
)
self.assertEqual([1], query.get_bindings())
- query.without_global_scope('active_scope')
+ query.without_global_scope("active_scope")
- self.assertEqual(
- 'SELECT * FROM "table" ORDER BY "name" ASC',
- query.to_sql()
- )
+ self.assertEqual('SELECT * FROM "table" ORDER BY "name" ASC', query.to_sql())
self.assertEqual([], query.get_bindings())
def test_all_global_scopes_can_be_removed(self):
model = CallableGlobalScopesModel()
query = model.new_query().without_global_scopes()
- self.assertEqual(
- 'SELECT * FROM "table"',
- query.to_sql()
- )
+ self.assertEqual('SELECT * FROM "table"', query.to_sql())
self.assertEqual([], query.get_bindings())
query = CallableGlobalScopesModel.without_global_scopes()
- self.assertEqual(
- 'SELECT * FROM "table"',
- query.to_sql()
- )
+ self.assertEqual('SELECT * FROM "table"', query.to_sql())
self.assertEqual([], query.get_bindings())
def test_global_scopes_with_or_where_conditions_are_nested(self):
model = CallableGlobalScopesModelWithOr()
- query = model.new_query().where('col1', 'val1').or_where('col2', 'val2')
+ query = model.new_query().where("col1", "val1").or_where("col2", "val2")
self.assertEqual(
'SELECT "email", "password" FROM "table" '
'WHERE ("col1" = ? OR "col2" = ?) AND ("email" = ? OR "email" = ?) '
'AND ("active" = ?) ORDER BY "name" ASC',
- query.to_sql()
+ query.to_sql(),
)
self.assertEqual(
- ['val1', 'val2', 'john@doe.com', 'someone@else.com', True],
- query.get_bindings()
+ ["val1", "val2", "john@doe.com", "someone@else.com", True],
+ query.get_bindings(),
)
class CallableGlobalScopesModel(Model):
- __table__ = 'table'
+ __table__ = "table"
@classmethod
def _boot(cls):
- cls.add_global_scope('active_scope', lambda query: query.where('active', 1))
+ cls.add_global_scope("active_scope", lambda query: query.where("active", 1))
- cls.add_global_scope(lambda query: query.order_by('name'))
+ cls.add_global_scope(lambda query: query.order_by("name"))
super(CallableGlobalScopesModel, cls)._boot()
class CallableGlobalScopesModelWithOr(CallableGlobalScopesModel):
- __table__ = 'table'
+ __table__ = "table"
@classmethod
def _boot(cls):
- cls.add_global_scope('or_scope', lambda q: q.where('email', 'john@doe.com').or_where('email', 'someone@else.com'))
+ cls.add_global_scope(
+ "or_scope",
+ lambda q: q.where("email", "john@doe.com").or_where(
+ "email", "someone@else.com"
+ ),
+ )
- cls.add_global_scope(lambda query: query.select('email', 'password'))
+ cls.add_global_scope(lambda query: query.select("email", "password"))
super(CallableGlobalScopesModelWithOr, cls)._boot()
class GlobalScopesModel(Model):
- __table__ = 'table'
+ __table__ = "table"
@classmethod
def _boot(cls):
@@ -150,9 +136,8 @@ def _boot(cls):
class ActiveScope(Scope):
-
def apply(self, builder, model):
- return builder.where('active', 1)
+ return builder.where("active", 1)
class DatabaseConnectionResolver(object):
@@ -163,12 +148,14 @@ def connection(self, name=None):
if self._connection:
return self._connection
- self._connection = SQLiteConnection(SQLiteConnector().connect({'database': ':memory:'}))
+ self._connection = SQLiteConnection(
+ SQLiteConnector().connect({"database": ":memory:"})
+ )
return self._connection
def get_default_connection(self):
- return 'default'
+ return "default"
def set_default_connection(self, name):
pass
diff --git a/tests/pagination/__init__.py b/tests/pagination/__init__.py
index 633f8661..40a96afc 100644
--- a/tests/pagination/__init__.py
+++ b/tests/pagination/__init__.py
@@ -1,2 +1 @@
# -*- coding: utf-8 -*-
-
diff --git a/tests/pagination/test_length_ware_paginator.py b/tests/pagination/test_length_ware_paginator.py
index bf961e4a..d06a45d5 100644
--- a/tests/pagination/test_length_ware_paginator.py
+++ b/tests/pagination/test_length_ware_paginator.py
@@ -5,26 +5,25 @@
class LengthAwarePaginatorTestCase(OratorTestCase):
-
def test_returns_relevant_context(self):
- p = LengthAwarePaginator(['item3', 'item4'], 4, 2, 2)
+ p = LengthAwarePaginator(["item3", "item4"], 4, 2, 2)
self.assertEqual(2, p.current_page)
self.assertEqual(2, p.last_page)
self.assertEqual(4, p.total)
self.assertTrue(p.has_pages())
self.assertFalse(p.has_more_pages())
- self.assertEqual(['item3', 'item4'], p.items)
+ self.assertEqual(["item3", "item4"], p.items)
self.assertEqual(2, p.per_page)
self.assertIsNone(p.next_page)
self.assertEqual(1, p.previous_page)
self.assertEqual(3, p.first_item)
self.assertEqual(4, p.last_item)
- self.assertEqual('item4', p[1])
+ self.assertEqual("item4", p[1])
def test_integer_division_for_last_page(self):
- p = LengthAwarePaginator(['item3', 'item4'], 5, 2, 2)
+ p = LengthAwarePaginator(["item3", "item4"], 5, 2, 2)
self.assertEqual(2, p.current_page)
self.assertEqual(3, p.last_page)
diff --git a/tests/pagination/test_paginator.py b/tests/pagination/test_paginator.py
index 86662fbd..a38a6a90 100644
--- a/tests/pagination/test_paginator.py
+++ b/tests/pagination/test_paginator.py
@@ -5,21 +5,20 @@
class PaginatorTestCase(OratorTestCase):
-
def test_returns_relevant_context(self):
- p = Paginator(['item3', 'item4', 'item5'], 2, 2)
+ p = Paginator(["item3", "item4", "item5"], 2, 2)
self.assertEqual(2, p.current_page)
self.assertTrue(p.has_pages())
self.assertTrue(p.has_more_pages())
- self.assertEqual(['item3', 'item4'], p.items)
+ self.assertEqual(["item3", "item4"], p.items)
self.assertEqual(2, p.per_page)
self.assertEqual(3, p.next_page)
self.assertEqual(1, p.previous_page)
self.assertEqual(3, p.first_item)
self.assertEqual(4, p.last_item)
- self.assertEqual('item4', p[1])
+ self.assertEqual("item4", p[1])
def test_current_page_resolver(self):
def current_page_resolver():
diff --git a/tests/query/__init__.py b/tests/query/__init__.py
index 633f8661..40a96afc 100644
--- a/tests/query/__init__.py
+++ b/tests/query/__init__.py
@@ -1,2 +1 @@
# -*- coding: utf-8 -*-
-
diff --git a/tests/query/test_query_builder.py b/tests/query/test_query_builder.py
index 6b07be2c..26c1c5e3 100644
--- a/tests/query/test_query_builder.py
+++ b/tests/query/test_query_builder.py
@@ -12,7 +12,7 @@
QueryGrammar,
PostgresQueryGrammar,
SQLiteQueryGrammar,
- MySQLQueryGrammar
+ MySQLQueryGrammar,
)
from orator.query.builder import QueryBuilder
from orator.query.expression import QueryExpression
@@ -21,1208 +21,1236 @@
class QueryBuilderTestCase(OratorTestCase):
-
def test_basic_select(self):
builder = self.get_builder()
- builder.select('*').from_('users')
+ builder.select("*").from_("users")
self.assertEqual('SELECT * FROM "users"', builder.to_sql())
def test_basic_select_use_write_connection(self):
builder = self.get_builder()
- builder.use_write_connection().select('*').from_('users').get()
+ builder.use_write_connection().select("*").from_("users").get()
builder.get_connection().select.assert_called_once_with(
- 'SELECT * FROM "users"',
- [],
- False
+ 'SELECT * FROM "users"', [], False
)
builder = self.get_builder()
- builder.select('*').from_('users').get()
+ builder.select("*").from_("users").get()
builder.get_connection().select.assert_called_once_with(
- 'SELECT * FROM "users"',
- [],
- True
+ 'SELECT * FROM "users"', [], True
)
def test_alias_wrapping_as_whole_constant(self):
builder = self.get_builder()
- builder.select('x.y as foo.bar').from_('baz')
+ builder.select("x.y as foo.bar").from_("baz")
- self.assertEqual(
- 'SELECT "x"."y" AS "foo.bar" FROM "baz"',
- builder.to_sql()
- )
+ self.assertEqual('SELECT "x"."y" AS "foo.bar" FROM "baz"', builder.to_sql())
def test_adding_selects(self):
builder = self.get_builder()
- builder.select('foo').add_select('bar').add_select('baz', 'boom').from_('users')
+ builder.select("foo").add_select("bar").add_select("baz", "boom").from_("users")
self.assertEqual(
- 'SELECT "foo", "bar", "baz", "boom" FROM "users"',
- builder.to_sql()
+ 'SELECT "foo", "bar", "baz", "boom" FROM "users"', builder.to_sql()
)
def test_basic_select_with_prefix(self):
builder = self.get_builder()
- builder.get_grammar().set_table_prefix('prefix_')
- builder.select('*').from_('users')
+ builder.get_grammar().set_table_prefix("prefix_")
+ builder.select("*").from_("users")
- self.assertEqual(
- 'SELECT * FROM "prefix_users"',
- builder.to_sql()
- )
+ self.assertEqual('SELECT * FROM "prefix_users"', builder.to_sql())
def test_basic_select_distinct(self):
builder = self.get_builder()
- builder.distinct().select('foo', 'bar').from_('users')
+ builder.distinct().select("foo", "bar").from_("users")
- self.assertEqual(
- 'SELECT DISTINCT "foo", "bar" FROM "users"',
- builder.to_sql()
- )
+ self.assertEqual('SELECT DISTINCT "foo", "bar" FROM "users"', builder.to_sql())
def test_basic_alias(self):
builder = self.get_builder()
- builder.select('foo as bar').from_('users')
+ builder.select("foo as bar").from_("users")
- self.assertEqual(
- 'SELECT "foo" AS "bar" FROM "users"',
- builder.to_sql()
- )
+ self.assertEqual('SELECT "foo" AS "bar" FROM "users"', builder.to_sql())
def test_mysql_wrapping_protects_qutotation_marks(self):
builder = self.get_mysql_builder()
- builder.select('*').from_('some`table')
+ builder.select("*").from_("some`table")
- self.assertEqual(
- 'SELECT * FROM `some``table`',
- builder.to_sql()
- )
+ self.assertEqual("SELECT * FROM `some``table`", builder.to_sql())
def test_where_day_mysql(self):
builder = self.get_mysql_builder()
- builder.select('*').from_('users').where_day('created_at', '=', 1)
+ builder.select("*").from_("users").where_day("created_at", "=", 1)
self.assertEqual(
- 'SELECT * FROM `users` WHERE DAY(`created_at`) = %s',
- builder.to_sql()
+ "SELECT * FROM `users` WHERE DAY(`created_at`) = %s", builder.to_sql()
)
self.assertEqual([1], builder.get_bindings())
def test_where_month_mysql(self):
builder = self.get_mysql_builder()
- builder.select('*').from_('users').where_month('created_at', '=', 5)
+ builder.select("*").from_("users").where_month("created_at", "=", 5)
self.assertEqual(
- 'SELECT * FROM `users` WHERE MONTH(`created_at`) = %s',
- builder.to_sql()
+ "SELECT * FROM `users` WHERE MONTH(`created_at`) = %s", builder.to_sql()
)
self.assertEqual([5], builder.get_bindings())
def test_where_year_mysql(self):
builder = self.get_mysql_builder()
- builder.select('*').from_('users').where_year('created_at', '=', 2014)
+ builder.select("*").from_("users").where_year("created_at", "=", 2014)
self.assertEqual(
- 'SELECT * FROM `users` WHERE YEAR(`created_at`) = %s',
- builder.to_sql()
+ "SELECT * FROM `users` WHERE YEAR(`created_at`) = %s", builder.to_sql()
)
self.assertEqual([2014], builder.get_bindings())
def test_where_day_postgres(self):
builder = self.get_postgres_builder()
- builder.select('*').from_('users').where_day('created_at', '=', 1)
+ builder.select("*").from_("users").where_day("created_at", "=", 1)
self.assertEqual(
- 'SELECT * FROM "users" WHERE DAY("created_at") = %s',
- builder.to_sql()
+ 'SELECT * FROM "users" WHERE DAY("created_at") = %s', builder.to_sql()
)
self.assertEqual([1], builder.get_bindings())
def test_where_month_postgres(self):
builder = self.get_postgres_builder()
- builder.select('*').from_('users').where_month('created_at', '=', 5)
+ builder.select("*").from_("users").where_month("created_at", "=", 5)
self.assertEqual(
- 'SELECT * FROM "users" WHERE MONTH("created_at") = %s',
- builder.to_sql()
+ 'SELECT * FROM "users" WHERE MONTH("created_at") = %s', builder.to_sql()
)
self.assertEqual([5], builder.get_bindings())
def test_where_year_postgres(self):
builder = self.get_postgres_builder()
- builder.select('*').from_('users').where_year('created_at', '=', 2014)
+ builder.select("*").from_("users").where_year("created_at", "=", 2014)
self.assertEqual(
- 'SELECT * FROM "users" WHERE YEAR("created_at") = %s',
- builder.to_sql()
+ 'SELECT * FROM "users" WHERE YEAR("created_at") = %s', builder.to_sql()
)
self.assertEqual([2014], builder.get_bindings())
def test_where_day_sqlite(self):
builder = self.get_sqlite_builder()
- builder.select('*').from_('users').where_day('created_at', '=', 1)
+ builder.select("*").from_("users").where_day("created_at", "=", 1)
self.assertEqual(
'SELECT * FROM "users" WHERE strftime(\'%d\', "created_at") = ?',
- builder.to_sql()
+ builder.to_sql(),
)
self.assertEqual([1], builder.get_bindings())
def test_where_month_sqlite(self):
builder = self.get_sqlite_builder()
- builder.select('*').from_('users').where_month('created_at', '=', 5)
+ builder.select("*").from_("users").where_month("created_at", "=", 5)
self.assertEqual(
'SELECT * FROM "users" WHERE strftime(\'%m\', "created_at") = ?',
- builder.to_sql()
+ builder.to_sql(),
)
self.assertEqual([5], builder.get_bindings())
def test_where_year_sqlite(self):
builder = self.get_sqlite_builder()
- builder.select('*').from_('users').where_year('created_at', '=', 2014)
+ builder.select("*").from_("users").where_year("created_at", "=", 2014)
self.assertEqual(
'SELECT * FROM "users" WHERE strftime(\'%Y\', "created_at") = ?',
- builder.to_sql()
+ builder.to_sql(),
)
self.assertEqual([2014], builder.get_bindings())
def test_where_between(self):
builder = self.get_builder()
- builder.select('*').from_('users').where_between('id', [1, 2])
+ builder.select("*").from_("users").where_between("id", [1, 2])
self.assertEqual(
- 'SELECT * FROM "users" WHERE "id" BETWEEN ? AND ?',
- builder.to_sql()
+ 'SELECT * FROM "users" WHERE "id" BETWEEN ? AND ?', builder.to_sql()
)
self.assertEqual([1, 2], builder.get_bindings())
builder = self.get_builder()
- builder.select('*').from_('users').where_not_between('id', [1, 2])
+ builder.select("*").from_("users").where_not_between("id", [1, 2])
self.assertEqual(
- 'SELECT * FROM "users" WHERE "id" NOT BETWEEN ? AND ?',
- builder.to_sql()
+ 'SELECT * FROM "users" WHERE "id" NOT BETWEEN ? AND ?', builder.to_sql()
)
self.assertEqual([1, 2], builder.get_bindings())
def test_basic_or_where(self):
builder = self.get_builder()
- builder.select('*').from_('users').where('id', '=', 1).or_where('email', '=', 'foo')
+ builder.select("*").from_("users").where("id", "=", 1).or_where(
+ "email", "=", "foo"
+ )
self.assertEqual(
- 'SELECT * FROM "users" WHERE "id" = ? OR "email" = ?',
- builder.to_sql()
+ 'SELECT * FROM "users" WHERE "id" = ? OR "email" = ?', builder.to_sql()
)
- self.assertEqual([1, 'foo'], builder.get_bindings())
+ self.assertEqual([1, "foo"], builder.get_bindings())
def test_raw_wheres(self):
builder = self.get_builder()
- builder.select('*').from_('users').where_raw('id = ? or email = ?', [1, 'foo'])
+ builder.select("*").from_("users").where_raw("id = ? or email = ?", [1, "foo"])
self.assertEqual(
- 'SELECT * FROM "users" WHERE id = ? OR email = ?',
- builder.to_sql()
+ 'SELECT * FROM "users" WHERE id = ? OR email = ?', builder.to_sql()
)
- self.assertEqual([1, 'foo'], builder.get_bindings())
+ self.assertEqual([1, "foo"], builder.get_bindings())
builder = self.get_builder()
- builder.select('*').from_('users').where_raw('id = ?', [1])
- self.assertEqual(
- 'SELECT * FROM "users" WHERE id = ?',
- builder.to_sql()
- )
+ builder.select("*").from_("users").where_raw("id = ?", [1])
+ self.assertEqual('SELECT * FROM "users" WHERE id = ?', builder.to_sql())
self.assertEqual([1], builder.get_bindings())
def test_raw_or_wheres(self):
builder = self.get_builder()
- builder.select('*').from_('users').where('id', '=', 1).or_where_raw('email = ?', ['foo'])
+ builder.select("*").from_("users").where("id", "=", 1).or_where_raw(
+ "email = ?", ["foo"]
+ )
self.assertEqual(
- 'SELECT * FROM "users" WHERE "id" = ? OR email = ?',
- builder.to_sql()
+ 'SELECT * FROM "users" WHERE "id" = ? OR email = ?', builder.to_sql()
)
- self.assertEqual([1, 'foo'], builder.get_bindings())
+ self.assertEqual([1, "foo"], builder.get_bindings())
def test_basic_where_ins(self):
builder = self.get_builder()
- builder.select('*').from_('users').where_in('id', [1, 2, 3])
+ builder.select("*").from_("users").where_in("id", [1, 2, 3])
self.assertEqual(
- 'SELECT * FROM "users" WHERE "id" IN (?, ?, ?)',
- builder.to_sql()
+ 'SELECT * FROM "users" WHERE "id" IN (?, ?, ?)', builder.to_sql()
)
self.assertEqual([1, 2, 3], builder.get_bindings())
builder = self.get_builder()
- builder.select('*').from_('users').where('id', '=', 1).or_where_in('id', [1, 2, 3])
+ builder.select("*").from_("users").where("id", "=", 1).or_where_in(
+ "id", [1, 2, 3]
+ )
self.assertEqual(
'SELECT * FROM "users" WHERE "id" = ? OR "id" IN (?, ?, ?)',
- builder.to_sql()
+ builder.to_sql(),
)
self.assertEqual([1, 1, 2, 3], builder.get_bindings())
def test_basic_where_not_ins(self):
builder = self.get_builder()
- builder.select('*').from_('users').where_not_in('id', [1, 2, 3])
+ builder.select("*").from_("users").where_not_in("id", [1, 2, 3])
self.assertEqual(
- 'SELECT * FROM "users" WHERE "id" NOT IN (?, ?, ?)',
- builder.to_sql()
+ 'SELECT * FROM "users" WHERE "id" NOT IN (?, ?, ?)', builder.to_sql()
)
self.assertEqual([1, 2, 3], builder.get_bindings())
builder = self.get_builder()
- builder.select('*').from_('users').where('id', '=', 1).or_where_not_in('id', [1, 2, 3])
+ builder.select("*").from_("users").where("id", "=", 1).or_where_not_in(
+ "id", [1, 2, 3]
+ )
self.assertEqual(
'SELECT * FROM "users" WHERE "id" = ? OR "id" NOT IN (?, ?, ?)',
- builder.to_sql()
+ builder.to_sql(),
)
self.assertEqual([1, 1, 2, 3], builder.get_bindings())
def test_empty_where_ins(self):
builder = self.get_builder()
- builder.select('*').from_('users').where_in('id', [])
- self.assertEqual(
- 'SELECT * FROM "users" WHERE 0 = 1',
- builder.to_sql()
- )
+ builder.select("*").from_("users").where_in("id", [])
+ self.assertEqual('SELECT * FROM "users" WHERE 0 = 1', builder.to_sql())
self.assertEqual([], builder.get_bindings())
builder = self.get_builder()
- builder.select('*').from_('users').where('id', '=', 1).or_where_in('id', [])
+ builder.select("*").from_("users").where("id", "=", 1).or_where_in("id", [])
self.assertEqual(
- 'SELECT * FROM "users" WHERE "id" = ? OR 0 = 1',
- builder.to_sql()
+ 'SELECT * FROM "users" WHERE "id" = ? OR 0 = 1', builder.to_sql()
)
self.assertEqual([1], builder.get_bindings())
def test_empty_where_not_ins(self):
builder = self.get_builder()
- builder.select('*').from_('users').where_not_in('id', [])
- self.assertEqual(
- 'SELECT * FROM "users" WHERE 1 = 1',
- builder.to_sql()
- )
+ builder.select("*").from_("users").where_not_in("id", [])
+ self.assertEqual('SELECT * FROM "users" WHERE 1 = 1', builder.to_sql())
self.assertEqual([], builder.get_bindings())
builder = self.get_builder()
- builder.select('*').from_('users').where('id', '=', 1).or_where_not_in('id', [])
+ builder.select("*").from_("users").where("id", "=", 1).or_where_not_in("id", [])
self.assertEqual(
- 'SELECT * FROM "users" WHERE "id" = ? OR 1 = 1',
- builder.to_sql()
+ 'SELECT * FROM "users" WHERE "id" = ? OR 1 = 1', builder.to_sql()
)
self.assertEqual([1], builder.get_bindings())
def test_where_in_accepts_collections(self):
builder = self.get_builder()
- builder.select('*').from_('users').where_in('id', Collection([1, 2, 3]))
+ builder.select("*").from_("users").where_in("id", Collection([1, 2, 3]))
self.assertEqual(
- 'SELECT * FROM "users" WHERE "id" IN (?, ?, ?)',
- builder.to_sql()
+ 'SELECT * FROM "users" WHERE "id" IN (?, ?, ?)', builder.to_sql()
)
self.assertEqual([1, 2, 3], builder.get_bindings())
def test_unions(self):
builder = self.get_builder()
- builder.select('*').from_('users').where('id', '=', 1)
- builder.union(self.get_builder().select('*').from_('users').where('id', '=', 2))
+ builder.select("*").from_("users").where("id", "=", 1)
+ builder.union(self.get_builder().select("*").from_("users").where("id", "=", 2))
self.assertEqual(
'SELECT * FROM "users" WHERE "id" = ? UNION SELECT * FROM "users" WHERE "id" = ?',
- builder.to_sql()
+ builder.to_sql(),
)
self.assertEqual([1, 2], builder.get_bindings())
builder = self.get_mysql_builder()
- builder.select('*').from_('users').where('id', '=', 1)
- builder.union(self.get_mysql_builder().select('*').from_('users').where('id', '=', 2))
+ builder.select("*").from_("users").where("id", "=", 1)
+ builder.union(
+ self.get_mysql_builder().select("*").from_("users").where("id", "=", 2)
+ )
self.assertEqual(
- '(SELECT * FROM `users` WHERE `id` = %s) UNION (SELECT * FROM `users` WHERE `id` = %s)',
- builder.to_sql()
+ "(SELECT * FROM `users` WHERE `id` = %s) UNION (SELECT * FROM `users` WHERE `id` = %s)",
+ builder.to_sql(),
)
self.assertEqual([1, 2], builder.get_bindings())
def test_union_alls(self):
builder = self.get_builder()
- builder.select('*').from_('users').where('id', '=', 1)
- builder.union_all(self.get_builder().select('*').from_('users').where('id', '=', 2))
+ builder.select("*").from_("users").where("id", "=", 1)
+ builder.union_all(
+ self.get_builder().select("*").from_("users").where("id", "=", 2)
+ )
self.assertEqual(
'SELECT * FROM "users" WHERE "id" = ? UNION ALL SELECT * FROM "users" WHERE "id" = ?',
- builder.to_sql()
+ builder.to_sql(),
)
self.assertEqual([1, 2], builder.get_bindings())
def test_multiple_unions(self):
builder = self.get_builder()
- builder.select('*').from_('users').where('id', '=', 1)
- builder.union(self.get_builder().select('*').from_('users').where('id', '=', 2))
- builder.union(self.get_builder().select('*').from_('users').where('id', '=', 3))
+ builder.select("*").from_("users").where("id", "=", 1)
+ builder.union(self.get_builder().select("*").from_("users").where("id", "=", 2))
+ builder.union(self.get_builder().select("*").from_("users").where("id", "=", 3))
self.assertEqual(
'SELECT * FROM "users" WHERE "id" = ? '
'UNION SELECT * FROM "users" WHERE "id" = ? '
'UNION SELECT * FROM "users" WHERE "id" = ?',
- builder.to_sql()
+ builder.to_sql(),
)
self.assertEqual([1, 2, 3], builder.get_bindings())
def test_multiple_union_alls(self):
builder = self.get_builder()
- builder.select('*').from_('users').where('id', '=', 1)
- builder.union_all(self.get_builder().select('*').from_('users').where('id', '=', 2))
- builder.union_all(self.get_builder().select('*').from_('users').where('id', '=', 3))
+ builder.select("*").from_("users").where("id", "=", 1)
+ builder.union_all(
+ self.get_builder().select("*").from_("users").where("id", "=", 2)
+ )
+ builder.union_all(
+ self.get_builder().select("*").from_("users").where("id", "=", 3)
+ )
self.assertEqual(
'SELECT * FROM "users" WHERE "id" = ? '
'UNION ALL SELECT * FROM "users" WHERE "id" = ? '
'UNION ALL SELECT * FROM "users" WHERE "id" = ?',
- builder.to_sql()
+ builder.to_sql(),
)
self.assertEqual([1, 2, 3], builder.get_bindings())
def test_union_order_bys(self):
builder = self.get_builder()
- builder.select('*').from_('users').where('id', '=', 1)
- builder.union(self.get_builder().select('*').from_('users').where('id', '=', 2))
- builder.order_by('id', 'desc')
+ builder.select("*").from_("users").where("id", "=", 1)
+ builder.union(self.get_builder().select("*").from_("users").where("id", "=", 2))
+ builder.order_by("id", "desc")
self.assertEqual(
'SELECT * FROM "users" WHERE "id" = ? '
'UNION SELECT * FROM "users" WHERE "id" = ? '
'ORDER BY "id" DESC',
- builder.to_sql()
+ builder.to_sql(),
)
self.assertEqual([1, 2], builder.get_bindings())
def test_union_limits_and_offsets(self):
builder = self.get_builder()
- builder.select('*').from_('users').where('id', '=', 1)
- builder.union(self.get_builder().select('*').from_('users').where('id', '=', 2))
+ builder.select("*").from_("users").where("id", "=", 1)
+ builder.union(self.get_builder().select("*").from_("users").where("id", "=", 2))
builder.skip(5).take(10)
self.assertEqual(
'SELECT * FROM "users" WHERE "id" = ? '
'UNION SELECT * FROM "users" WHERE "id" = ? '
- 'LIMIT 10 OFFSET 5',
- builder.to_sql()
+ "LIMIT 10 OFFSET 5",
+ builder.to_sql(),
)
self.assertEqual([1, 2], builder.get_bindings())
def test_mysql_union_order_bys(self):
builder = self.get_mysql_builder()
- builder.select('*').from_('users').where('id', '=', 1)
- builder.union(self.get_mysql_builder().select('*').from_('users').where('id', '=', 2))
- builder.order_by('id', 'desc')
+ builder.select("*").from_("users").where("id", "=", 1)
+ builder.union(
+ self.get_mysql_builder().select("*").from_("users").where("id", "=", 2)
+ )
+ builder.order_by("id", "desc")
self.assertEqual(
- '(SELECT * FROM `users` WHERE `id` = %s) '
- 'UNION (SELECT * FROM `users` WHERE `id` = %s) '
- 'ORDER BY `id` DESC',
- builder.to_sql()
+ "(SELECT * FROM `users` WHERE `id` = %s) "
+ "UNION (SELECT * FROM `users` WHERE `id` = %s) "
+ "ORDER BY `id` DESC",
+ builder.to_sql(),
)
self.assertEqual([1, 2], builder.get_bindings())
def test_mysql_union_limits_and_offsets(self):
builder = self.get_mysql_builder()
- builder.select('*').from_('users').where('id', '=', 1)
- builder.union(self.get_mysql_builder().select('*').from_('users').where('id', '=', 2))
+ builder.select("*").from_("users").where("id", "=", 1)
+ builder.union(
+ self.get_mysql_builder().select("*").from_("users").where("id", "=", 2)
+ )
builder.skip(5).take(10)
self.assertEqual(
- '(SELECT * FROM `users` WHERE `id` = %s) '
- 'UNION (SELECT * FROM `users` WHERE `id` = %s) '
- 'LIMIT 10 OFFSET 5',
- builder.to_sql()
+ "(SELECT * FROM `users` WHERE `id` = %s) "
+ "UNION (SELECT * FROM `users` WHERE `id` = %s) "
+ "LIMIT 10 OFFSET 5",
+ builder.to_sql(),
)
self.assertEqual([1, 2], builder.get_bindings())
def test_sub_select_where_in(self):
builder = self.get_builder()
- builder.select('*').from_('users').where_in(
- 'id',
- self.get_builder().select('id').from_('users').where('age', '>', 25).take(3)
+ builder.select("*").from_("users").where_in(
+ "id",
+ self.get_builder()
+ .select("id")
+ .from_("users")
+ .where("age", ">", 25)
+ .take(3),
)
self.assertEqual(
'SELECT * FROM "users" WHERE "id" IN (SELECT "id" FROM "users" WHERE "age" > ? LIMIT 3)',
- builder.to_sql()
+ builder.to_sql(),
)
self.assertEqual([25], builder.get_bindings())
builder = self.get_builder()
- builder.select('*').from_('users').where_not_in(
- 'id',
- self.get_builder().select('id').from_('users').where('age', '>', 25).take(3)
+ builder.select("*").from_("users").where_not_in(
+ "id",
+ self.get_builder()
+ .select("id")
+ .from_("users")
+ .where("age", ">", 25)
+ .take(3),
)
self.assertEqual(
'SELECT * FROM "users" WHERE "id" NOT IN (SELECT "id" FROM "users" WHERE "age" > ? LIMIT 3)',
- builder.to_sql()
+ builder.to_sql(),
)
self.assertEqual([25], builder.get_bindings())
def test_basic_where_null(self):
builder = self.get_builder()
- builder.select('*').from_('users').where_null('id')
- self.assertEqual(
- 'SELECT * FROM "users" WHERE "id" IS NULL',
- builder.to_sql()
- )
+ builder.select("*").from_("users").where_null("id")
+ self.assertEqual('SELECT * FROM "users" WHERE "id" IS NULL', builder.to_sql())
self.assertEqual([], builder.get_bindings())
builder = self.get_builder()
- builder.select('*').from_('users').where('id', '=', 1).or_where_null('id')
+ builder.select("*").from_("users").where("id", "=", 1).or_where_null("id")
self.assertEqual(
- 'SELECT * FROM "users" WHERE "id" = ? OR "id" IS NULL',
- builder.to_sql()
+ 'SELECT * FROM "users" WHERE "id" = ? OR "id" IS NULL', builder.to_sql()
)
self.assertEqual([1], builder.get_bindings())
def test_basic_where_not_null(self):
builder = self.get_builder()
- builder.select('*').from_('users').where_not_null('id')
+ builder.select("*").from_("users").where_not_null("id")
self.assertEqual(
- 'SELECT * FROM "users" WHERE "id" IS NOT NULL',
- builder.to_sql()
+ 'SELECT * FROM "users" WHERE "id" IS NOT NULL', builder.to_sql()
)
self.assertEqual([], builder.get_bindings())
builder = self.get_builder()
- builder.select('*').from_('users').where('id', '=', 1).or_where_not_null('id')
+ builder.select("*").from_("users").where("id", "=", 1).or_where_not_null("id")
self.assertEqual(
- 'SELECT * FROM "users" WHERE "id" = ? OR "id" IS NOT NULL',
- builder.to_sql()
+ 'SELECT * FROM "users" WHERE "id" = ? OR "id" IS NOT NULL', builder.to_sql()
)
self.assertEqual([1], builder.get_bindings())
def test_group_bys(self):
builder = self.get_builder()
- builder.select('*').from_('users').group_by('id', 'email')
+ builder.select("*").from_("users").group_by("id", "email")
self.assertEqual(
- 'SELECT * FROM "users" GROUP BY "id", "email"',
- builder.to_sql()
+ 'SELECT * FROM "users" GROUP BY "id", "email"', builder.to_sql()
)
def test_order_bys(self):
builder = self.get_builder()
- builder.select('*').from_('users').order_by('email').order_by('age', 'desc')
+ builder.select("*").from_("users").order_by("email").order_by("age", "desc")
self.assertEqual(
- 'SELECT * FROM "users" ORDER BY "email" ASC, "age" DESC',
- builder.to_sql()
+ 'SELECT * FROM "users" ORDER BY "email" ASC, "age" DESC', builder.to_sql()
)
builder = self.get_builder()
- builder.select('*').from_('users').order_by('email').order_by_raw('"age" ? desc', ['foo'])
+ builder.select("*").from_("users").order_by("email").order_by_raw(
+ '"age" ? desc', ["foo"]
+ )
self.assertEqual(
- 'SELECT * FROM "users" ORDER BY "email" ASC, "age" ? DESC',
- builder.to_sql()
+ 'SELECT * FROM "users" ORDER BY "email" ASC, "age" ? DESC', builder.to_sql()
)
def test_havings(self):
builder = self.get_builder()
- builder.select('*').from_('users').having('email', '>', 1)
- self.assertEqual(
- 'SELECT * FROM "users" HAVING "email" > ?',
- builder.to_sql()
- )
+ builder.select("*").from_("users").having("email", ">", 1)
+ self.assertEqual('SELECT * FROM "users" HAVING "email" > ?', builder.to_sql())
self.assertEqual([1], builder.get_bindings())
builder = self.get_builder()
- builder.select('*').from_('users')\
- .or_having('email', '=', 'foo@bar.com')\
- .or_having('email', '=', 'foo2@bar.com')
+ builder.select("*").from_("users").or_having(
+ "email", "=", "foo@bar.com"
+ ).or_having("email", "=", "foo2@bar.com")
self.assertEqual(
- 'SELECT * FROM "users" HAVING "email" = ? OR "email" = ?',
- builder.to_sql()
+ 'SELECT * FROM "users" HAVING "email" = ? OR "email" = ?', builder.to_sql()
)
- self.assertEqual(['foo@bar.com', 'foo2@bar.com'], builder.get_bindings())
+ self.assertEqual(["foo@bar.com", "foo2@bar.com"], builder.get_bindings())
builder = self.get_builder()
- builder.select('*').from_('users')\
- .group_by('email')\
- .having('email', '>', 1)
+ builder.select("*").from_("users").group_by("email").having("email", ">", 1)
self.assertEqual(
'SELECT * FROM "users" GROUP BY "email" HAVING "email" > ?',
- builder.to_sql()
+ builder.to_sql(),
)
self.assertEqual([1], builder.get_bindings())
builder = self.get_builder()
- builder.select('email as foo_mail').from_('users')\
- .having('foo_mail', '>', 1)
+ builder.select("email as foo_mail").from_("users").having("foo_mail", ">", 1)
self.assertEqual(
'SELECT "email" AS "foo_mail" FROM "users" HAVING "foo_mail" > ?',
- builder.to_sql()
+ builder.to_sql(),
)
self.assertEqual([1], builder.get_bindings())
builder = self.get_builder()
- builder.select('category', QueryExpression('count(*) as "total"'))\
- .from_('item')\
- .where('department', '=', 'popular')\
- .group_by('category')\
- .having('total', '>', QueryExpression('3'))
+ builder.select("category", QueryExpression('count(*) as "total"')).from_(
+ "item"
+ ).where("department", "=", "popular").group_by("category").having(
+ "total", ">", QueryExpression("3")
+ )
self.assertEqual(
'SELECT "category", count(*) as "total" '
'FROM "item" '
'WHERE "department" = ? '
'GROUP BY "category" '
'HAVING "total" > 3',
- builder.to_sql()
+ builder.to_sql(),
)
- self.assertEqual(['popular'], builder.get_bindings())
+ self.assertEqual(["popular"], builder.get_bindings())
builder = self.get_builder()
- builder.select('category', QueryExpression('count(*) as "total"'))\
- .from_('item')\
- .where('department', '=', 'popular')\
- .group_by('category')\
- .having('total', '>', 3)
+ builder.select("category", QueryExpression('count(*) as "total"')).from_(
+ "item"
+ ).where("department", "=", "popular").group_by("category").having(
+ "total", ">", 3
+ )
self.assertEqual(
'SELECT "category", count(*) as "total" '
'FROM "item" '
'WHERE "department" = ? '
'GROUP BY "category" '
'HAVING "total" > ?',
- builder.to_sql()
+ builder.to_sql(),
)
- self.assertEqual(['popular', 3], builder.get_bindings())
+ self.assertEqual(["popular", 3], builder.get_bindings())
def test_having_followed_by_select_get(self):
builder = self.get_builder()
- query = 'SELECT "category", count(*) as "total" ' \
- 'FROM "item" ' \
- 'WHERE "department" = ? ' \
- 'GROUP BY "category" ' \
- 'HAVING "total" > ?'
- results = [{
- 'category': 'rock',
- 'total': 5
- }]
+ query = (
+ 'SELECT "category", count(*) as "total" '
+ 'FROM "item" '
+ 'WHERE "department" = ? '
+ 'GROUP BY "category" '
+ 'HAVING "total" > ?'
+ )
+ results = [{"category": "rock", "total": 5}]
builder.get_connection().select.return_value = results
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results: results)
- result = builder.select('category', QueryExpression('count(*) as "total"'))\
- .from_('item')\
- .where('department', '=', 'popular')\
- .group_by('category')\
- .having('total', '>', 3)\
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results: results
+ )
+ result = (
+ builder.select("category", QueryExpression('count(*) as "total"'))
+ .from_("item")
+ .where("department", "=", "popular")
+ .group_by("category")
+ .having("total", ">", 3)
.get()
+ )
builder.get_connection().select.assert_called_once_with(
- query,
- ['popular', 3],
- True
- )
- builder.get_processor().process_select.assert_called_once_with(
- builder,
- results
+ query, ["popular", 3], True
)
+ builder.get_processor().process_select.assert_called_once_with(builder, results)
self.assertEqual(results, result)
- self.assertEqual(['popular', 3], builder.get_bindings())
+ self.assertEqual(["popular", 3], builder.get_bindings())
# Using raw value
builder = self.get_builder()
- query = 'SELECT "category", count(*) as "total" ' \
- 'FROM "item" ' \
- 'WHERE "department" = ? ' \
- 'GROUP BY "category" ' \
- 'HAVING "total" > 3'
+ query = (
+ 'SELECT "category", count(*) as "total" '
+ 'FROM "item" '
+ 'WHERE "department" = ? '
+ 'GROUP BY "category" '
+ 'HAVING "total" > 3'
+ )
builder.get_connection().select.return_value = results
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results: results)
- result = builder.select('category', QueryExpression('count(*) as "total"'))\
- .from_('item')\
- .where('department', '=', 'popular')\
- .group_by('category')\
- .having('total', '>', QueryExpression('3'))\
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results: results
+ )
+ result = (
+ builder.select("category", QueryExpression('count(*) as "total"'))
+ .from_("item")
+ .where("department", "=", "popular")
+ .group_by("category")
+ .having("total", ">", QueryExpression("3"))
.get()
+ )
builder.get_connection().select.assert_called_once_with(
- query,
- ['popular'],
- True
- )
- builder.get_processor().process_select.assert_called_once_with(
- builder,
- results
+ query, ["popular"], True
)
+ builder.get_processor().process_select.assert_called_once_with(builder, results)
self.assertEqual(results, result)
- self.assertEqual(['popular'], builder.get_bindings())
+ self.assertEqual(["popular"], builder.get_bindings())
def test_raw_havings(self):
builder = self.get_builder()
- builder.select('*').from_('users').having_raw('user_foo < user_bar')
+ builder.select("*").from_("users").having_raw("user_foo < user_bar")
self.assertEqual(
- 'SELECT * FROM "users" HAVING user_foo < user_bar',
- builder.to_sql()
+ 'SELECT * FROM "users" HAVING user_foo < user_bar', builder.to_sql()
)
self.assertEqual([], builder.get_bindings())
builder = self.get_builder()
- builder.select('*').from_('users').having('foo', '=', 1).or_having_raw('user_foo < user_bar')
+ builder.select("*").from_("users").having("foo", "=", 1).or_having_raw(
+ "user_foo < user_bar"
+ )
self.assertEqual(
'SELECT * FROM "users" HAVING "foo" = ? OR user_foo < user_bar',
- builder.to_sql()
+ builder.to_sql(),
)
self.assertEqual([1], builder.get_bindings())
def test_limits_and_offsets(self):
builder = self.get_builder()
- builder.select('*').from_('users').offset(5).limit(10)
+ builder.select("*").from_("users").offset(5).limit(10)
self.assertEqual('SELECT * FROM "users" LIMIT 10 OFFSET 5', builder.to_sql())
builder = self.get_builder()
- builder.select('*').from_('users').skip(5).take(10)
+ builder.select("*").from_("users").skip(5).take(10)
self.assertEqual('SELECT * FROM "users" LIMIT 10 OFFSET 5', builder.to_sql())
builder = self.get_builder()
- builder.select('*').from_('users').skip(-5).take(10)
+ builder.select("*").from_("users").skip(-5).take(10)
self.assertEqual('SELECT * FROM "users" LIMIT 10 OFFSET 0', builder.to_sql())
builder = self.get_builder()
- builder.select('*').from_('users').for_page(2, 15)
+ builder.select("*").from_("users").for_page(2, 15)
self.assertEqual('SELECT * FROM "users" LIMIT 15 OFFSET 15', builder.to_sql())
builder = self.get_builder()
- builder.select('*').from_('users').for_page(-2, 15)
+ builder.select("*").from_("users").for_page(-2, 15)
self.assertEqual('SELECT * FROM "users" LIMIT 15 OFFSET 0', builder.to_sql())
def test_where_shortcut(self):
builder = self.get_builder()
- builder.select('*').from_('users').where('id', 1).or_where('name', 'foo')
+ builder.select("*").from_("users").where("id", 1).or_where("name", "foo")
self.assertEqual(
- 'SELECT * FROM "users" WHERE "id" = ? OR "name" = ?',
- builder.to_sql()
+ 'SELECT * FROM "users" WHERE "id" = ? OR "name" = ?', builder.to_sql()
)
- self.assertEqual([1, 'foo'], builder.get_bindings())
+ self.assertEqual([1, "foo"], builder.get_bindings())
def test_multiple_wheres_in_list(self):
builder = self.get_builder()
- builder.select('*').from_('users').where([['name', '=', 'bar'], ['age', '=', 25]])
+ builder.select("*").from_("users").where(
+ [["name", "=", "bar"], ["age", "=", 25]]
+ )
self.assertEqual(
- 'SELECT * FROM "users" WHERE ("name" = ? AND "age" = ?)',
- builder.to_sql()
+ 'SELECT * FROM "users" WHERE ("name" = ? AND "age" = ?)', builder.to_sql()
)
- self.assertEqual(['bar', 25], builder.get_bindings())
+ self.assertEqual(["bar", 25], builder.get_bindings())
def test_multiple_wheres_in_list_with_exception(self):
builder = self.get_builder()
try:
- builder.select('*').from_('users').where([['name', 'bar'], ['age', '=', 25]])
- self.fail('Builder has not raised Argument Error for invalid no. of values in where list')
+ builder.select("*").from_("users").where(
+ [["name", "bar"], ["age", "=", 25]]
+ )
+ self.fail(
+ "Builder has not raised Argument Error for invalid no. of values in where list"
+ )
except ArgumentError:
self.assertTrue(True)
try:
- builder.select('*').from_('users').where(['name', 'bar'])
- self.fail('Builder has not raised Argument Error for invalid datatype in where list')
+ builder.select("*").from_("users").where(["name", "bar"])
+ self.fail(
+ "Builder has not raised Argument Error for invalid datatype in where list"
+ )
except ArgumentError:
self.assertTrue(True)
-
def test_nested_wheres(self):
builder = self.get_builder()
- builder.select('*').from_('users').where('email', '=', 'foo').or_where(
- builder.new_query().where('name', '=', 'bar').where('age', '=', 25)
+ builder.select("*").from_("users").where("email", "=", "foo").or_where(
+ builder.new_query().where("name", "=", "bar").where("age", "=", 25)
)
self.assertEqual(
'SELECT * FROM "users" WHERE "email" = ? OR ("name" = ? AND "age" = ?)',
- builder.to_sql()
+ builder.to_sql(),
)
- self.assertEqual(['foo', 'bar', 25], builder.get_bindings())
+ self.assertEqual(["foo", "bar", 25], builder.get_bindings())
def test_full_sub_selects(self):
builder = self.get_builder()
- builder.select('*').from_('users').where('email', '=', 'foo').or_where(
- 'id', '=', builder.new_query().select(QueryExpression('max(id)')).from_('users').where('email', '=', 'bar')
+ builder.select("*").from_("users").where("email", "=", "foo").or_where(
+ "id",
+ "=",
+ builder.new_query()
+ .select(QueryExpression("max(id)"))
+ .from_("users")
+ .where("email", "=", "bar"),
)
self.assertEqual(
'SELECT * FROM "users" WHERE "email" = ? OR "id" = (SELECT max(id) FROM "users" WHERE "email" = ?)',
- builder.to_sql()
+ builder.to_sql(),
)
- self.assertEqual(['foo', 'bar'], builder.get_bindings())
+ self.assertEqual(["foo", "bar"], builder.get_bindings())
def test_where_exists(self):
builder = self.get_builder()
- builder.select('*').from_('orders').where_exists(
- self.get_builder().select('*').from_('products').where('products.id', '=', QueryExpression('"orders"."id"'))
+ builder.select("*").from_("orders").where_exists(
+ self.get_builder()
+ .select("*")
+ .from_("products")
+ .where("products.id", "=", QueryExpression('"orders"."id"'))
)
self.assertEqual(
'SELECT * FROM "orders" '
'WHERE EXISTS (SELECT * FROM "products" WHERE "products"."id" = "orders"."id")',
- builder.to_sql()
+ builder.to_sql(),
)
self.assertEqual([], builder.get_bindings())
builder = self.get_builder()
- builder.select('*').from_('orders').where_not_exists(
- self.get_builder().select('*').from_('products').where('products.id', '=', QueryExpression('"orders"."id"'))
+ builder.select("*").from_("orders").where_not_exists(
+ self.get_builder()
+ .select("*")
+ .from_("products")
+ .where("products.id", "=", QueryExpression('"orders"."id"'))
)
self.assertEqual(
'SELECT * FROM "orders" '
'WHERE NOT EXISTS (SELECT * FROM "products" WHERE "products"."id" = "orders"."id")',
- builder.to_sql()
+ builder.to_sql(),
)
self.assertEqual([], builder.get_bindings())
builder = self.get_builder()
- builder.select('*').from_('orders').where('id', '=', 1).or_where_exists(
- self.get_builder().select('*').from_('products').where('products.id', '=', QueryExpression('"orders"."id"'))
+ builder.select("*").from_("orders").where("id", "=", 1).or_where_exists(
+ self.get_builder()
+ .select("*")
+ .from_("products")
+ .where("products.id", "=", QueryExpression('"orders"."id"'))
)
self.assertEqual(
'SELECT * FROM "orders" WHERE "id" = ? '
'OR EXISTS (SELECT * FROM "products" WHERE "products"."id" = "orders"."id")',
- builder.to_sql()
+ builder.to_sql(),
)
self.assertEqual([1], builder.get_bindings())
builder = self.get_builder()
- builder.select('*').from_('orders').where('id', '=', 1).or_where_not_exists(
- self.get_builder().select('*').from_('products').where('products.id', '=', QueryExpression('"orders"."id"'))
+ builder.select("*").from_("orders").where("id", "=", 1).or_where_not_exists(
+ self.get_builder()
+ .select("*")
+ .from_("products")
+ .where("products.id", "=", QueryExpression('"orders"."id"'))
)
self.assertEqual(
'SELECT * FROM "orders" WHERE "id" = ? '
'OR NOT EXISTS (SELECT * FROM "products" WHERE "products"."id" = "orders"."id")',
- builder.to_sql()
+ builder.to_sql(),
)
self.assertEqual([1], builder.get_bindings())
def test_basic_joins(self):
builder = self.get_builder()
- builder.select('*').from_('users')\
- .join('contacts', 'users.id', '=', 'contacts.id')\
- .left_join('photos', 'users.id', '=', 'photos.user_id')
+ builder.select("*").from_("users").join(
+ "contacts", "users.id", "=", "contacts.id"
+ ).left_join("photos", "users.id", "=", "photos.user_id")
self.assertEqual(
'SELECT * FROM "users" '
'INNER JOIN "contacts" ON "users"."id" = "contacts"."id" '
'LEFT JOIN "photos" ON "users"."id" = "photos"."user_id"',
- builder.to_sql()
+ builder.to_sql(),
)
builder = self.get_builder()
- builder.select('*').from_('users')\
- .left_join_where('photos', 'users.id', '=', 3)\
- .join_where('photos', 'users.id', '=', 'foo')
+ builder.select("*").from_("users").left_join_where(
+ "photos", "users.id", "=", 3
+ ).join_where("photos", "users.id", "=", "foo")
self.assertEqual(
'SELECT * FROM "users" '
'LEFT JOIN "photos" ON "users"."id" = ? '
'INNER JOIN "photos" ON "users"."id" = ?',
- builder.to_sql()
+ builder.to_sql(),
)
- self.assertEqual([3, 'foo'], builder.get_bindings())
+ self.assertEqual([3, "foo"], builder.get_bindings())
def test_complex_joins(self):
builder = self.get_builder()
- builder.select('*').from_('users').join(
- JoinClause('contacts')
- .on('users.id', '=', 'contacts.id')
- .or_on('users.name', '=', 'contacts.name')
+ builder.select("*").from_("users").join(
+ JoinClause("contacts")
+ .on("users.id", "=", "contacts.id")
+ .or_on("users.name", "=", "contacts.name")
)
self.assertEqual(
'SELECT * FROM "users" '
'INNER JOIN "contacts" ON "users"."id" = "contacts"."id" '
'OR "users"."name" = "contacts"."name"',
- builder.to_sql()
+ builder.to_sql(),
)
builder = self.get_builder()
- builder.select('*').from_('users').join(
- JoinClause('contacts')
- .where('users.id', '=', 'foo')
- .or_where('users.name', '=', 'bar')
+ builder.select("*").from_("users").join(
+ JoinClause("contacts")
+ .where("users.id", "=", "foo")
+ .or_where("users.name", "=", "bar")
)
self.assertEqual(
'SELECT * FROM "users" '
'INNER JOIN "contacts" ON "users"."id" = ? '
'OR "users"."name" = ?',
- builder.to_sql()
+ builder.to_sql(),
)
- self.assertEqual(['foo', 'bar'], builder.get_bindings())
+ self.assertEqual(["foo", "bar"], builder.get_bindings())
self.assertEqual(
'SELECT * FROM "users" '
'INNER JOIN "contacts" ON "users"."id" = ? '
'OR "users"."name" = ?',
- builder.to_sql()
+ builder.to_sql(),
)
- self.assertEqual(['foo', 'bar'], builder.get_bindings())
+ self.assertEqual(["foo", "bar"], builder.get_bindings())
builder = self.get_builder()
- builder.select('*').from_('users').left_join(
- JoinClause('contacts')
- .where('users.id', '=', 'foo')
- .or_where('users.name', '=', 'bar')
+ builder.select("*").from_("users").left_join(
+ JoinClause("contacts")
+ .where("users.id", "=", "foo")
+ .or_where("users.name", "=", "bar")
)
self.assertEqual(
'SELECT * FROM "users" '
'LEFT JOIN "contacts" ON "users"."id" = ? '
'OR "users"."name" = ?',
- builder.to_sql()
+ builder.to_sql(),
)
- self.assertEqual(['foo', 'bar'], builder.get_bindings())
+ self.assertEqual(["foo", "bar"], builder.get_bindings())
self.assertEqual(
'SELECT * FROM "users" '
'LEFT JOIN "contacts" ON "users"."id" = ? '
'OR "users"."name" = ?',
- builder.to_sql()
+ builder.to_sql(),
)
- self.assertEqual(['foo', 'bar'], builder.get_bindings())
+ self.assertEqual(["foo", "bar"], builder.get_bindings())
def test_join_where_null(self):
builder = self.get_builder()
- builder.select('*').from_('users').join(
- JoinClause('contacts')
- .on('users.id', '=', 'contacts.id')
- .where_null('contacts.deleted_at')
+ builder.select("*").from_("users").join(
+ JoinClause("contacts")
+ .on("users.id", "=", "contacts.id")
+ .where_null("contacts.deleted_at")
)
self.assertEqual(
'SELECT * FROM "users" '
'INNER JOIN "contacts" '
'ON "users"."id" = "contacts"."id" '
'AND "contacts"."deleted_at" IS NULL',
- builder.to_sql()
+ builder.to_sql(),
)
builder = self.get_builder()
- builder.select('*').from_('users').join(
- JoinClause('contacts')
- .on('users.id', '=', 'contacts.id')
- .or_where_null('contacts.deleted_at')
+ builder.select("*").from_("users").join(
+ JoinClause("contacts")
+ .on("users.id", "=", "contacts.id")
+ .or_where_null("contacts.deleted_at")
)
self.assertEqual(
'SELECT * FROM "users" '
'INNER JOIN "contacts" '
'ON "users"."id" = "contacts"."id" '
'OR "contacts"."deleted_at" IS NULL',
- builder.to_sql()
+ builder.to_sql(),
)
def test_join_where_not_null(self):
builder = self.get_builder()
- builder.select('*').from_('users').join(
- JoinClause('contacts')
- .on('users.id', '=', 'contacts.id')
- .where_not_null('contacts.deleted_at')
+ builder.select("*").from_("users").join(
+ JoinClause("contacts")
+ .on("users.id", "=", "contacts.id")
+ .where_not_null("contacts.deleted_at")
)
self.assertEqual(
'SELECT * FROM "users" '
'INNER JOIN "contacts" '
'ON "users"."id" = "contacts"."id" '
'AND "contacts"."deleted_at" IS NOT NULL',
- builder.to_sql()
+ builder.to_sql(),
)
builder = self.get_builder()
- builder.select('*').from_('users').join(
- JoinClause('contacts')
- .on('users.id', '=', 'contacts.id')
- .or_where_not_null('contacts.deleted_at')
+ builder.select("*").from_("users").join(
+ JoinClause("contacts")
+ .on("users.id", "=", "contacts.id")
+ .or_where_not_null("contacts.deleted_at")
)
self.assertEqual(
'SELECT * FROM "users" '
'INNER JOIN "contacts" '
'ON "users"."id" = "contacts"."id" '
'OR "contacts"."deleted_at" IS NOT NULL',
- builder.to_sql()
+ builder.to_sql(),
)
def test_raw_expression_in_select(self):
builder = self.get_builder()
- builder.select(QueryExpression('substr(foo, 6)')).from_('users')
+ builder.select(QueryExpression("substr(foo, 6)")).from_("users")
self.assertEqual('SELECT substr(foo, 6) FROM "users"', builder.to_sql())
def test_find_return_first_result_by_id(self):
builder = self.get_builder()
query = 'SELECT * FROM "users" WHERE "id" = ? LIMIT 1'
- results = [{
- 'foo': 'bar'
- }]
+ results = [{"foo": "bar"}]
builder.get_connection().select.return_value = results
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results)
- result = builder.from_('users').find(1)
- builder.get_connection().select.assert_called_once_with(
- query, [1], True
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results
)
+ result = builder.from_("users").find(1)
+ builder.get_connection().select.assert_called_once_with(query, [1], True)
builder.get_processor().process_select.assert_called_once_with(builder, results)
self.assertEqual(result, results[0])
def test_first_return_first_result(self):
builder = self.get_builder()
query = 'SELECT * FROM "users" WHERE "id" = ? LIMIT 1'
- results = [{
- 'foo': 'bar'
- }]
+ results = [{"foo": "bar"}]
builder.get_connection().select.return_value = results
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results)
- result = builder.from_('users').where('id', '=', 1).first()
- builder.get_connection().select.assert_called_once_with(
- query, [1], True
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results
)
+ result = builder.from_("users").where("id", "=", 1).first()
+ builder.get_connection().select.assert_called_once_with(query, [1], True)
builder.get_processor().process_select.assert_called_once_with(builder, results)
self.assertEqual(result, results[0])
def test_list_methods_gets_list_of_colmun_values(self):
builder = self.get_builder()
- results = [
- {'foo': 'bar'}, {'foo': 'baz'}
- ]
+ results = [{"foo": "bar"}, {"foo": "baz"}]
builder.get_connection().select.return_value = results
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results)
- result = builder.from_('users').where('id', '=', 1).lists('foo')
- self.assertEqual(['bar', 'baz'], result)
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results
+ )
+ result = builder.from_("users").where("id", "=", 1).lists("foo")
+ self.assertEqual(["bar", "baz"], result)
builder = self.get_builder()
- results = [
- {'id': 1, 'foo': 'bar'}, {'id': 10, 'foo': 'baz'}
- ]
+ results = [{"id": 1, "foo": "bar"}, {"id": 10, "foo": "baz"}]
builder.get_connection().select.return_value = results
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results)
- result = builder.from_('users').where('id', '=', 1).lists('foo', 'id')
- self.assertEqual({1: 'bar', 10: 'baz'}, result)
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results
+ )
+ result = builder.from_("users").where("id", "=", 1).lists("foo", "id")
+ self.assertEqual({1: "bar", 10: "baz"}, result)
def test_implode(self):
builder = self.get_builder()
- results = [
- {'foo': 'bar'}, {'foo': 'baz'}
- ]
+ results = [{"foo": "bar"}, {"foo": "baz"}]
builder.get_connection().select.return_value = results
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results)
- result = builder.from_('users').where('id', '=', 1).implode('foo')
- self.assertEqual('barbaz', result)
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results
+ )
+ result = builder.from_("users").where("id", "=", 1).implode("foo")
+ self.assertEqual("barbaz", result)
builder = self.get_builder()
- results = [
- {'foo': 'bar'}, {'foo': 'baz'}
- ]
+ results = [{"foo": "bar"}, {"foo": "baz"}]
builder.get_connection().select.return_value = results
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results)
- result = builder.from_('users').where('id', '=', 1).implode('foo', ',')
- self.assertEqual('bar,baz', result)
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results
+ )
+ result = builder.from_("users").where("id", "=", 1).implode("foo", ",")
+ self.assertEqual("bar,baz", result)
def test_pluck_return_single_column(self):
builder = self.get_builder()
query = 'SELECT "foo" FROM "users" WHERE "id" = ? LIMIT 1'
- results = [{'foo': 'bar'}]
+ results = [{"foo": "bar"}]
builder.get_connection().select.return_value = results
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results)
- result = builder.from_('users').where('id', '=', 1).pluck('foo')
- builder.get_connection().select.assert_called_once_with(
- query, [1], True
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results
)
+ result = builder.from_("users").where("id", "=", 1).pluck("foo")
+ builder.get_connection().select.assert_called_once_with(query, [1], True)
builder.get_processor().process_select.assert_called_once_with(builder, results)
- self.assertEqual('bar', result)
+ self.assertEqual("bar", result)
def test_aggegate_functions(self):
builder = self.get_builder()
query = 'SELECT COUNT(*) AS aggregate FROM "users"'
- results = [{'aggregate': 1}]
+ results = [{"aggregate": 1}]
builder.get_connection().select.return_value = results
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results)
- result = builder.from_('users').count()
- builder.get_connection().select.assert_called_once_with(
- query, [], True
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results
)
+ result = builder.from_("users").count()
+ builder.get_connection().select.assert_called_once_with(query, [], True)
builder.get_processor().process_select.assert_called_once_with(builder, results)
self.assertEqual(1, result)
builder = self.get_builder()
query = 'SELECT COUNT(*) AS aggregate FROM "users" LIMIT 1'
- results = [{'aggregate': 1}]
+ results = [{"aggregate": 1}]
builder.get_connection().select.return_value = results
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_)
- result = builder.from_('users').exists()
- builder.get_connection().select.assert_called_once_with(
- query, [], True
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results_
)
+ result = builder.from_("users").exists()
+ builder.get_connection().select.assert_called_once_with(query, [], True)
builder.get_processor().process_select.assert_called_once_with(builder, results)
self.assertTrue(result)
builder = self.get_builder()
query = 'SELECT MAX("id") AS aggregate FROM "users"'
- results = [{'aggregate': 1}]
+ results = [{"aggregate": 1}]
builder.get_connection().select.return_value = results
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_)
- result = builder.from_('users').max('id')
- builder.get_connection().select.assert_called_once_with(
- query, [], True
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results_
)
+ result = builder.from_("users").max("id")
+ builder.get_connection().select.assert_called_once_with(query, [], True)
builder.get_processor().process_select.assert_called_once_with(builder, results)
self.assertEqual(1, result)
builder = self.get_builder()
query = 'SELECT MIN("id") AS aggregate FROM "users"'
- results = [{'aggregate': 1}]
+ results = [{"aggregate": 1}]
builder.get_connection().select.return_value = results
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_)
- result = builder.from_('users').min('id')
- builder.get_connection().select.assert_called_once_with(
- query, [], True
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results_
)
+ result = builder.from_("users").min("id")
+ builder.get_connection().select.assert_called_once_with(query, [], True)
builder.get_processor().process_select.assert_called_once_with(builder, results)
self.assertEqual(1, result)
builder = self.get_builder()
query = 'SELECT SUM("id") AS aggregate FROM "users"'
- results = [{'aggregate': 1}]
+ results = [{"aggregate": 1}]
builder.get_connection().select.return_value = results
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_)
- result = builder.from_('users').sum('id')
- builder.get_connection().select.assert_called_once_with(
- query, [], True
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results_
)
+ result = builder.from_("users").sum("id")
+ builder.get_connection().select.assert_called_once_with(query, [], True)
builder.get_processor().process_select.assert_called_once_with(builder, results)
self.assertEqual(1, result)
builder = self.get_builder()
query = 'SELECT AVG("id") AS aggregate FROM "users"'
- results = [{'aggregate': 1}]
+ results = [{"aggregate": 1}]
builder.get_connection().select.return_value = results
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_)
- result = builder.from_('users').avg('id')
- builder.get_connection().select.assert_called_once_with(
- query, [], True
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results_
+ )
+ result = builder.from_("users").avg("id")
+ builder.get_connection().select.assert_called_once_with(query, [], True)
+ builder.get_processor().process_select.assert_called_once_with(builder, results)
+ self.assertEqual(1, result)
+
+ def test_distinct_count_with_column(self):
+ builder = self.get_builder()
+ query = 'SELECT COUNT(DISTINCT "id") AS aggregate FROM "users"'
+ results = [{"aggregate": 1}]
+ builder.get_connection().select.return_value = results
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results
)
+ result = builder.from_("users").distinct().count("id")
+ builder.get_connection().select.assert_called_once_with(query, [], True)
+ builder.get_processor().process_select.assert_called_once_with(builder, results)
+ self.assertEqual(1, result)
+
+ def test_distinct_count_with_select(self):
+ builder = self.get_builder()
+ query = 'SELECT COUNT(DISTINCT "id") AS aggregate FROM "users"'
+ results = [{"aggregate": 1}]
+ builder.get_connection().select.return_value = results
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results
+ )
+ result = builder.from_("users").distinct().select("id").count()
+ builder.get_connection().select.assert_called_once_with(query, [], True)
builder.get_processor().process_select.assert_called_once_with(builder, results)
self.assertEqual(1, result)
def test_aggregate_reset_followed_by_get(self):
builder = self.get_builder()
query = 'SELECT COUNT(*) AS aggregate FROM "users"'
- builder.get_connection().select.return_value = [{'aggregate': 1}]
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_)
- builder.from_('users').select('column1', 'column2')
+ builder.get_connection().select.return_value = [{"aggregate": 1}]
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results_
+ )
+ builder.from_("users").select("column1", "column2")
count = builder.count()
- builder.get_connection().select.assert_called_once_with(
- query, [], True
+ builder.get_connection().select.assert_called_once_with(query, [], True)
+ builder.get_processor().process_select.assert_called_once_with(
+ builder, [{"aggregate": 1}]
)
- builder.get_processor().process_select.assert_called_once_with(builder, [{'aggregate': 1}])
self.assertEqual(1, count)
builder.get_connection().select.reset_mock()
builder.get_processor().process_select.reset_mock()
query = 'SELECT SUM("id") AS aggregate FROM "users"'
- builder.get_connection().select.return_value = [{'aggregate': 2}]
- sum_ = builder.sum('id')
- builder.get_connection().select.assert_called_once_with(
- query, [], True
+ builder.get_connection().select.return_value = [{"aggregate": 2}]
+ sum_ = builder.sum("id")
+ builder.get_connection().select.assert_called_once_with(query, [], True)
+ builder.get_processor().process_select.assert_called_once_with(
+ builder, [{"aggregate": 2}]
)
- builder.get_processor().process_select.assert_called_once_with(builder, [{'aggregate': 2}])
self.assertEqual(2, sum_)
builder.get_connection().select.reset_mock()
builder.get_processor().process_select.reset_mock()
query = 'SELECT "column1", "column2" FROM "users"'
- builder.get_connection().select.return_value = [{'column1': 'foo', 'column2': 'bar'}]
+ builder.get_connection().select.return_value = [
+ {"column1": "foo", "column2": "bar"}
+ ]
result = builder.get()
- builder.get_connection().select.assert_called_once_with(
- query, [], True
+ builder.get_connection().select.assert_called_once_with(query, [], True)
+ builder.get_processor().process_select.assert_called_once_with(
+ builder, [{"column1": "foo", "column2": "bar"}]
)
- builder.get_processor().process_select.assert_called_once_with(builder, [{'column1': 'foo', 'column2': 'bar'}])
- self.assertEqual([{'column1': 'foo', 'column2': 'bar'}], result)
+ self.assertEqual([{"column1": "foo", "column2": "bar"}], result)
def test_aggregate_reset_followed_by_select_get(self):
builder = self.get_builder()
query = 'SELECT COUNT("column1") AS aggregate FROM "users"'
- builder.get_connection().select.return_value = [{'aggregate': 1}]
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_)
- builder.from_('users')
- count = builder.count('column1')
- builder.get_connection().select.assert_called_once_with(
- query, [], True
+ builder.get_connection().select.return_value = [{"aggregate": 1}]
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results_
+ )
+ builder.from_("users")
+ count = builder.count("column1")
+ builder.get_connection().select.assert_called_once_with(query, [], True)
+ builder.get_processor().process_select.assert_called_once_with(
+ builder, [{"aggregate": 1}]
)
- builder.get_processor().process_select.assert_called_once_with(builder, [{'aggregate': 1}])
self.assertEqual(1, count)
builder.get_connection().select.reset_mock()
builder.get_processor().process_select.reset_mock()
query = 'SELECT "column2", "column3" FROM "users"'
- builder.get_connection().select.return_value = [{'column2': 'foo', 'column3': 'bar'}]
- result = builder.select('column2', 'column3').get()
- builder.get_connection().select.assert_called_once_with(
- query, [], True
+ builder.get_connection().select.return_value = [
+ {"column2": "foo", "column3": "bar"}
+ ]
+ result = builder.select("column2", "column3").get()
+ builder.get_connection().select.assert_called_once_with(query, [], True)
+ builder.get_processor().process_select.assert_called_once_with(
+ builder, [{"column2": "foo", "column3": "bar"}]
)
- builder.get_processor().process_select.assert_called_once_with(builder, [{'column2': 'foo', 'column3': 'bar'}])
- self.assertEqual([{'column2': 'foo', 'column3': 'bar'}], result)
+ self.assertEqual([{"column2": "foo", "column3": "bar"}], result)
def test_aggregate_reset_followed_by_get_with_columns(self):
builder = self.get_builder()
query = 'SELECT COUNT("column1") AS aggregate FROM "users"'
- builder.get_connection().select.return_value = [{'aggregate': 1}]
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_)
- builder.from_('users')
- count = builder.count('column1')
- builder.get_connection().select.assert_called_once_with(
- query, [], True
+ builder.get_connection().select.return_value = [{"aggregate": 1}]
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results_
+ )
+ builder.from_("users")
+ count = builder.count("column1")
+ builder.get_connection().select.assert_called_once_with(query, [], True)
+ builder.get_processor().process_select.assert_called_once_with(
+ builder, [{"aggregate": 1}]
)
- builder.get_processor().process_select.assert_called_once_with(builder, [{'aggregate': 1}])
self.assertEqual(1, count)
builder.get_connection().select.reset_mock()
builder.get_processor().process_select.reset_mock()
query = 'SELECT "column2", "column3" FROM "users"'
- builder.get_connection().select.return_value = [{'column2': 'foo', 'column3': 'bar'}]
- result = builder.get(['column2', 'column3'])
- builder.get_connection().select.assert_called_once_with(
- query, [], True
+ builder.get_connection().select.return_value = [
+ {"column2": "foo", "column3": "bar"}
+ ]
+ result = builder.get(["column2", "column3"])
+ builder.get_connection().select.assert_called_once_with(query, [], True)
+ builder.get_processor().process_select.assert_called_once_with(
+ builder, [{"column2": "foo", "column3": "bar"}]
)
- builder.get_processor().process_select.assert_called_once_with(builder, [{'column2': 'foo', 'column3': 'bar'}])
- self.assertEqual([{'column2': 'foo', 'column3': 'bar'}], result)
+ self.assertEqual([{"column2": "foo", "column3": "bar"}], result)
def test_insert_method(self):
builder = self.get_builder()
query = 'INSERT INTO "users" ("email") VALUES (?)'
builder.get_connection().insert.return_value = True
- result = builder.from_('users').insert({'email': 'foo'})
- builder.get_connection().insert.assert_called_once_with(
- query, ['foo']
- )
+ result = builder.from_("users").insert({"email": "foo"})
+ builder.get_connection().insert.assert_called_once_with(query, ["foo"])
self.assertTrue(result)
def test_insert_method_with_keyword_arguments(self):
builder = self.get_builder()
query = 'INSERT INTO "users" ("email") VALUES (?)'
builder.get_connection().insert.return_value = True
- result = builder.from_('users').insert({'email': 'foo'})
- builder.get_connection().insert.assert_called_once_with(
- query, ['foo']
- )
+ result = builder.from_("users").insert({"email": "foo"})
+ builder.get_connection().insert.assert_called_once_with(query, ["foo"])
self.assertTrue(result)
def test_sqlite_multiple_insert(self):
builder = self.get_sqlite_builder()
- query = 'INSERT INTO "users" ("email", "name") ' \
- 'SELECT ? AS "email", ? AS "name" UNION ALL SELECT ? AS "email", ? AS "name"'
+ query = (
+ 'INSERT INTO "users" ("email", "name") '
+ 'SELECT ? AS "email", ? AS "name" UNION ALL SELECT ? AS "email", ? AS "name"'
+ )
builder.get_connection().insert.return_value = True
- result = builder.from_('users').insert([
- {'email': 'foo', 'name': 'john'},
- {'email': 'bar', 'name': 'jane'}
- ])
+ result = builder.from_("users").insert(
+ [{"email": "foo", "name": "john"}, {"email": "bar", "name": "jane"}]
+ )
builder.get_connection().insert.assert_called_once_with(
- query, ['foo', 'john', 'bar', 'jane']
+ query, ["foo", "john", "bar", "jane"]
)
self.assertTrue(result)
def test_insert_get_id_method(self):
builder = self.get_builder()
builder.get_processor().process_insert_get_id.return_value = 1
- result = builder.from_('users').insert_get_id({
- 'email': 'foo',
- 'bar': QueryExpression('bar')
- })
+ result = builder.from_("users").insert_get_id(
+ {"email": "foo", "bar": QueryExpression("bar")}
+ )
builder.get_processor().process_insert_get_id.assert_called_once_with(
- builder, 'INSERT INTO "users" ("bar", "email") VALUES (bar, ?)', ['foo'], None
+ builder,
+ 'INSERT INTO "users" ("bar", "email") VALUES (bar, ?)',
+ ["foo"],
+ None,
)
self.assertEqual(1, result)
def test_insert_get_id_with_sequence(self):
builder = self.get_builder()
builder.get_processor().process_insert_get_id.return_value = 1
- result = builder.from_('users').insert_get_id({
- 'email': 'foo',
- 'bar': QueryExpression('bar')
- }, 'id')
+ result = builder.from_("users").insert_get_id(
+ {"email": "foo", "bar": QueryExpression("bar")}, "id"
+ )
builder.get_processor().process_insert_get_id.assert_called_once_with(
- builder, 'INSERT INTO "users" ("bar", "email") VALUES (bar, ?)', ['foo'], 'id'
+ builder,
+ 'INSERT INTO "users" ("bar", "email") VALUES (bar, ?)',
+ ["foo"],
+ "id",
)
self.assertEqual(1, result)
def test_insert_get_id_respects_raw_bindings(self):
builder = self.get_builder()
builder.get_processor().process_insert_get_id.return_value = 1
- result = builder.from_('users').insert_get_id({
- 'email': QueryExpression('CURRENT_TIMESTAMP'),
- })
+ result = builder.from_("users").insert_get_id(
+ {"email": QueryExpression("CURRENT_TIMESTAMP")}
+ )
builder.get_processor().process_insert_get_id.assert_called_once_with(
- builder, 'INSERT INTO "users" ("email") VALUES (CURRENT_TIMESTAMP)', [], None
+ builder,
+ 'INSERT INTO "users" ("email") VALUES (CURRENT_TIMESTAMP)',
+ [],
+ None,
)
self.assertEqual(1, result)
@@ -1230,225 +1258,258 @@ def test_update(self):
builder = self.get_builder()
query = 'UPDATE "users" SET "email" = ?, "name" = ? WHERE "id" = ?'
builder.get_connection().update.return_value = 1
- result = builder.from_('users').where('id', '=', 1).update(email='foo', name='bar')
- builder.get_connection().update.assert_called_with(
- query, ['foo', 'bar', 1]
+ result = (
+ builder.from_("users").where("id", "=", 1).update(email="foo", name="bar")
)
+ builder.get_connection().update.assert_called_with(query, ["foo", "bar", 1])
self.assertEqual(1, result)
builder = self.get_mysql_builder()
marker = builder.get_grammar().get_marker()
- query = 'UPDATE `users` SET `email` = %s, `name` = %s WHERE `id` = %s' % (marker, marker, marker)
+ query = "UPDATE `users` SET `email` = %s, `name` = %s WHERE `id` = %s" % (
+ marker,
+ marker,
+ marker,
+ )
builder.get_connection().update.return_value = 1
- result = builder.from_('users').where('id', '=', 1).update(email='foo', name='bar')
- builder.get_connection().update.assert_called_with(
- query, ['foo', 'bar', 1]
+ result = (
+ builder.from_("users").where("id", "=", 1).update(email="foo", name="bar")
)
+ builder.get_connection().update.assert_called_with(query, ["foo", "bar", 1])
self.assertEqual(1, result)
def test_update_with_dictionaries(self):
builder = self.get_builder()
query = 'UPDATE "users" SET "email" = ?, "name" = ? WHERE "id" = ?'
builder.get_connection().update.return_value = 1
- result = builder.from_('users').where('id', '=', 1).update({'email': 'foo', 'name': 'bar'})
- builder.get_connection().update.assert_called_with(
- query, ['foo', 'bar', 1]
+ result = (
+ builder.from_("users")
+ .where("id", "=", 1)
+ .update({"email": "foo", "name": "bar"})
)
+ builder.get_connection().update.assert_called_with(query, ["foo", "bar", 1])
self.assertEqual(1, result)
builder = self.get_builder()
query = 'UPDATE "users" SET "email" = ?, "name" = ? WHERE "id" = ?'
builder.get_connection().update.return_value = 1
- result = builder.from_('users').where('id', '=', 1).update({'email': 'foo'}, name='bar')
- builder.get_connection().update.assert_called_with(
- query, ['foo', 'bar', 1]
+ result = (
+ builder.from_("users")
+ .where("id", "=", 1)
+ .update({"email": "foo"}, name="bar")
)
+ builder.get_connection().update.assert_called_with(query, ["foo", "bar", 1])
self.assertEqual(1, result)
+ def test_where_date(self):
+ builder = self.get_sqlite_builder()
+ builder.where_date("date", "=", "10-20-2018")
+
+ self.assertEqual(
+ builder.to_sql(),
+ 'SELECT * FROM "" WHERE strftime(\'%Y-%m-%d\', "date") = ?',
+ )
+
def test_update_with_joins(self):
builder = self.get_builder()
- query = 'UPDATE "users" ' \
- 'INNER JOIN "orders" ON "users"."id" = "orders"."user_id" ' \
- 'SET "email" = ?, "name" = ? WHERE "id" = ?'
+ query = (
+ 'UPDATE "users" '
+ 'INNER JOIN "orders" ON "users"."id" = "orders"."user_id" '
+ 'SET "email" = ?, "name" = ? WHERE "id" = ?'
+ )
builder.get_connection().update.return_value = 1
- result = builder.from_('users')\
- .join('orders', 'users.id', '=', 'orders.user_id')\
- .where('id', '=', 1)\
- .update(email='foo', name='bar')
- builder.get_connection().update.assert_called_with(
- query, ['foo', 'bar', 1]
+ result = (
+ builder.from_("users")
+ .join("orders", "users.id", "=", "orders.user_id")
+ .where("id", "=", 1)
+ .update(email="foo", name="bar")
)
+ builder.get_connection().update.assert_called_with(query, ["foo", "bar", 1])
self.assertEqual(1, result)
def test_update_on_postgres(self):
builder = self.get_postgres_builder()
marker = builder.get_grammar().get_marker()
- query = 'UPDATE "users" SET "email" = %s, "name" = %s WHERE "id" = %s' % (marker, marker, marker)
+ query = 'UPDATE "users" SET "email" = %s, "name" = %s WHERE "id" = %s' % (
+ marker,
+ marker,
+ marker,
+ )
builder.get_connection().update.return_value = 1
- result = builder.from_('users').where('id', '=', 1).update(email='foo', name='bar')
- builder.get_connection().update.assert_called_with(
- query, ['foo', 'bar', 1]
+ result = (
+ builder.from_("users").where("id", "=", 1).update(email="foo", name="bar")
)
+ builder.get_connection().update.assert_called_with(query, ["foo", "bar", 1])
self.assertEqual(1, result)
def test_update_with_joins_on_postgres(self):
builder = self.get_postgres_builder()
marker = builder.get_grammar().get_marker()
- query = 'UPDATE "users" ' \
- 'SET "email" = %s, "name" = %s ' \
- 'FROM "orders" WHERE "id" = %s AND "users"."id" = "orders"."user_id"'\
- % (marker, marker, marker)
+ query = (
+ 'UPDATE "users" '
+ 'SET "email" = %s, "name" = %s '
+ 'FROM "orders" WHERE "id" = %s AND "users"."id" = "orders"."user_id"'
+ % (marker, marker, marker)
+ )
builder.get_connection().update.return_value = 1
- result = builder.from_('users')\
- .join('orders', 'users.id', '=', 'orders.user_id')\
- .where('id', '=', 1)\
- .update(email='foo', name='bar')
+ result = (
+ builder.from_("users")
+ .join("orders", "users.id", "=", "orders.user_id")
+ .where("id", "=", 1)
+ .update(email="foo", name="bar")
+ )
builder.get_connection().update.assert_called_once_with(
- query, ['foo', 'bar', 1]
+ query, ["foo", "bar", 1]
)
self.assertEqual(1, result)
def test_update_respects_raw(self):
builder = self.get_builder()
- marker = '?'
- query = 'UPDATE "users" SET "email" = foo, "name" = %s WHERE "id" = %s' % (marker, marker)
+ marker = "?"
+ query = 'UPDATE "users" SET "email" = foo, "name" = %s WHERE "id" = %s' % (
+ marker,
+ marker,
+ )
builder.get_connection().update.return_value = 1
- result = builder.from_('users').where('id', '=', 1).update(email=QueryExpression('foo'), name='bar')
- builder.get_connection().update.assert_called_once_with(
- query, ['bar', 1]
+ result = (
+ builder.from_("users")
+ .where("id", "=", 1)
+ .update(email=QueryExpression("foo"), name="bar")
)
+ builder.get_connection().update.assert_called_once_with(query, ["bar", 1])
self.assertEqual(1, result)
def test_delete(self):
builder = self.get_builder()
query = 'DELETE FROM "users" WHERE "email" = ?'
builder.get_connection().delete.return_value = 1
- result = builder.from_('users').where('email', '=', 'foo').delete()
- builder.get_connection().delete.assert_called_once_with(
- query, ['foo']
- )
+ result = builder.from_("users").where("email", "=", "foo").delete()
+ builder.get_connection().delete.assert_called_once_with(query, ["foo"])
self.assertEqual(1, result)
builder = self.get_builder()
query = 'DELETE FROM "users" WHERE "id" = ?'
builder.get_connection().delete.return_value = 1
- result = builder.from_('users').delete(1)
- builder.get_connection().delete.assert_called_once_with(
- query, [1]
- )
+ result = builder.from_("users").delete(1)
+ builder.get_connection().delete.assert_called_once_with(query, [1])
self.assertEqual(1, result)
def test_delete_with_join(self):
builder = self.get_mysql_builder()
marker = builder.get_grammar().get_marker()
- query = 'DELETE `users` FROM `users` ' \
- 'INNER JOIN `contacts` ON `users`.`id` = `contacts`.`id` WHERE `email` = %s' % marker
+ query = (
+ "DELETE `users` FROM `users` "
+ "INNER JOIN `contacts` ON `users`.`id` = `contacts`.`id` WHERE `email` = %s"
+ % marker
+ )
builder.get_connection().delete.return_value = 1
- result = builder.from_('users')\
- .join('contacts', 'users.id', '=', 'contacts.id')\
- .where('email', '=', 'foo')\
+ result = (
+ builder.from_("users")
+ .join("contacts", "users.id", "=", "contacts.id")
+ .where("email", "=", "foo")
.delete()
- builder.get_connection().delete.assert_called_once_with(
- query, ['foo']
)
+ builder.get_connection().delete.assert_called_once_with(query, ["foo"])
self.assertEqual(1, result)
builder = self.get_mysql_builder()
marker = builder.get_grammar().get_marker()
- query = 'DELETE `users` FROM `users` ' \
- 'INNER JOIN `contacts` ON `users`.`id` = `contacts`.`id` WHERE `id` = %s' % marker
+ query = (
+ "DELETE `users` FROM `users` "
+ "INNER JOIN `contacts` ON `users`.`id` = `contacts`.`id` WHERE `id` = %s"
+ % marker
+ )
builder.get_connection().delete.return_value = 1
- result = builder.from_('users')\
- .join('contacts', 'users.id', '=', 'contacts.id')\
+ result = (
+ builder.from_("users")
+ .join("contacts", "users.id", "=", "contacts.id")
.delete(1)
- builder.get_connection().delete.assert_called_once_with(
- query, [1]
)
+ builder.get_connection().delete.assert_called_once_with(query, [1])
self.assertEqual(1, result)
def test_truncate(self):
builder = self.get_builder()
query = 'TRUNCATE "users"'
- builder.from_('users').truncate()
- builder.get_connection().statement.assert_called_once_with(
- query, []
- )
+ builder.from_("users").truncate()
+ builder.get_connection().statement.assert_called_once_with(query, [])
builder = self.get_sqlite_builder()
- builder.from_('users')
- self.assertEqual({
- 'DELETE FROM sqlite_sequence WHERE name = ?': ['users'],
- 'DELETE FROM "users"': []
- }, builder.get_grammar().compile_truncate(builder))
+ builder.from_("users")
+ self.assertEqual(
+ {
+ "DELETE FROM sqlite_sequence WHERE name = ?": ["users"],
+ 'DELETE FROM "users"': [],
+ },
+ builder.get_grammar().compile_truncate(builder),
+ )
def test_postgres_insert_get_id(self):
builder = self.get_postgres_builder()
marker = builder.get_grammar().get_marker()
query = 'INSERT INTO "users" ("email") VALUES (%s) RETURNING "id"' % marker
builder.get_processor().process_insert_get_id.return_value = 1
- result = builder.from_('users').insert_get_id({'email': 'foo'}, 'id')
+ result = builder.from_("users").insert_get_id({"email": "foo"}, "id")
builder.get_processor().process_insert_get_id.assert_called_once_with(
- builder, query, ['foo'], 'id'
+ builder, query, ["foo"], "id"
)
self.assertEqual(1, result)
def test_mysql_wrapping(self):
builder = self.get_mysql_builder()
- builder.select('*').from_('users')
- self.assertEqual(
- 'SELECT * FROM `users`',
- builder.to_sql()
- )
+ builder.select("*").from_("users")
+ self.assertEqual("SELECT * FROM `users`", builder.to_sql())
def test_merge_wheres_can_merge_wheres_and_bindings(self):
builder = self.get_builder()
- builder.wheres = ['foo']
- builder.merge_wheres(['wheres'], ['foo', 'bar'])
- self.assertEqual(['foo', 'wheres'], builder.wheres)
- self.assertEqual(['foo', 'bar'], builder.get_bindings())
+ builder.wheres = ["foo"]
+ builder.merge_wheres(["wheres"], ["foo", "bar"])
+ self.assertEqual(["foo", "wheres"], builder.wheres)
+ self.assertEqual(["foo", "bar"], builder.get_bindings())
def test_where_with_null_second_parameter(self):
builder = self.get_builder()
- builder.select('*').from_('users').where('foo', None)
- self.assertEqual(
- 'SELECT * FROM "users" WHERE "foo" IS NULL',
- builder.to_sql()
- )
+ builder.select("*").from_("users").where("foo", None)
+ self.assertEqual('SELECT * FROM "users" WHERE "foo" IS NULL', builder.to_sql())
def test_dynamic_where(self):
- method = 'where_foo_bar_and_baz_or_boom'
- parameters = ['john', 'jane', 'bam']
+ method = "where_foo_bar_and_baz_or_boom"
+ parameters = ["john", "jane", "bam"]
builder = self.get_builder()
builder.where = mock.MagicMock(return_value=builder)
getattr(builder, method)(*parameters)
- builder.where.assert_has_calls([
- mock.call('foo_bar', '=', parameters[0], 'and'),
- mock.call('baz', '=', parameters[1], 'and'),
- mock.call('boom', '=', parameters[2], 'or')
- ])
+ builder.where.assert_has_calls(
+ [
+ mock.call("foo_bar", "=", parameters[0], "and"),
+ mock.call("baz", "=", parameters[1], "and"),
+ mock.call("boom", "=", parameters[2], "or"),
+ ]
+ )
def test_dynamic_where_is_not_greedy(self):
- method = 'where_ios_version_and_android_version_or_orientation'
- parameters = ['6.1', '4.2', 'Vertical']
+ method = "where_ios_version_and_android_version_or_orientation"
+ parameters = ["6.1", "4.2", "Vertical"]
builder = self.get_builder()
builder.where = mock.MagicMock(return_value=builder)
getattr(builder, method)(*parameters)
- builder.where.assert_has_calls([
- mock.call('ios_version', '=', parameters[0], 'and'),
- mock.call('android_version', '=', parameters[1], 'and'),
- mock.call('orientation', '=', parameters[2], 'or')
- ])
+ builder.where.assert_has_calls(
+ [
+ mock.call("ios_version", "=", parameters[0], "and"),
+ mock.call("android_version", "=", parameters[1], "and"),
+ mock.call("orientation", "=", parameters[2], "or"),
+ ]
+ )
def test_call_triggers_dynamic_where(self):
builder = self.get_builder()
- self.assertEqual(builder, builder.where_foo_and_bar('baz', 'boom'))
+ self.assertEqual(builder, builder.where_foo_and_bar("baz", "boom"))
self.assertEqual(2, len(builder.wheres))
def test_builder_raises_exception_with_undefined_method(self):
@@ -1456,171 +1517,164 @@ def test_builder_raises_exception_with_undefined_method(self):
try:
builder.do_not_exist()
- self.fail('Builder did not raise and AttributeError exception')
+ self.fail("Builder did not raise and AttributeError exception")
except AttributeError:
self.assertTrue(True)
def test_mysql_lock(self):
builder = self.get_mysql_builder()
marker = builder.get_grammar().get_marker()
- builder.select('*').from_('foo').where('bar', '=', 'baz').lock()
+ builder.select("*").from_("foo").where("bar", "=", "baz").lock()
self.assertEqual(
- 'SELECT * FROM `foo` WHERE `bar` = %s FOR UPDATE' % marker,
- builder.to_sql()
+ "SELECT * FROM `foo` WHERE `bar` = %s FOR UPDATE" % marker, builder.to_sql()
)
- self.assertEqual(['baz'], builder.get_bindings())
+ self.assertEqual(["baz"], builder.get_bindings())
builder = self.get_mysql_builder()
marker = builder.get_grammar().get_marker()
- builder.select('*').from_('foo').where('bar', '=', 'baz').lock(False)
+ builder.select("*").from_("foo").where("bar", "=", "baz").lock(False)
self.assertEqual(
- 'SELECT * FROM `foo` WHERE `bar` = %s LOCK IN SHARE MODE' % marker,
- builder.to_sql()
+ "SELECT * FROM `foo` WHERE `bar` = %s LOCK IN SHARE MODE" % marker,
+ builder.to_sql(),
)
- self.assertEqual(['baz'], builder.get_bindings())
+ self.assertEqual(["baz"], builder.get_bindings())
def test_postgres_lock(self):
builder = self.get_postgres_builder()
marker = builder.get_grammar().get_marker()
- builder.select('*').from_('foo').where('bar', '=', 'baz').lock()
+ builder.select("*").from_("foo").where("bar", "=", "baz").lock()
self.assertEqual(
- 'SELECT * FROM "foo" WHERE "bar" = %s FOR UPDATE' % marker,
- builder.to_sql()
+ 'SELECT * FROM "foo" WHERE "bar" = %s FOR UPDATE' % marker, builder.to_sql()
)
- self.assertEqual(['baz'], builder.get_bindings())
+ self.assertEqual(["baz"], builder.get_bindings())
builder = self.get_postgres_builder()
marker = builder.get_grammar().get_marker()
- builder.select('*').from_('foo').where('bar', '=', 'baz').lock(False)
+ builder.select("*").from_("foo").where("bar", "=", "baz").lock(False)
self.assertEqual(
- 'SELECT * FROM "foo" WHERE "bar" = %s FOR SHARE' % marker,
- builder.to_sql()
+ 'SELECT * FROM "foo" WHERE "bar" = %s FOR SHARE' % marker, builder.to_sql()
)
- self.assertEqual(['baz'], builder.get_bindings())
+ self.assertEqual(["baz"], builder.get_bindings())
def test_binding_order(self):
- expected_sql = 'SELECT * FROM "users" ' \
- 'INNER JOIN "othertable" ON "bar" = ? ' \
- 'WHERE "registered" = ? ' \
- 'GROUP BY "city" ' \
- 'HAVING "population" > ? ' \
- 'ORDER BY match ("foo") against(?)'
- expected_bindings = ['foo', True, 3, 'bar']
-
- builder = self.get_builder()
- builder.select('*').from_('users')\
- .order_by_raw('match ("foo") against(?)', ['bar'])\
- .where('registered', True)\
- .group_by('city')\
- .having('population', '>', 3)\
- .join(JoinClause('othertable').where('bar', '=', 'foo'))
+ expected_sql = (
+ 'SELECT * FROM "users" '
+ 'INNER JOIN "othertable" ON "bar" = ? '
+ 'WHERE "registered" = ? '
+ 'GROUP BY "city" '
+ 'HAVING "population" > ? '
+ 'ORDER BY match ("foo") against(?)'
+ )
+ expected_bindings = ["foo", True, 3, "bar"]
+
+ builder = self.get_builder()
+ builder.select("*").from_("users").order_by_raw(
+ 'match ("foo") against(?)', ["bar"]
+ ).where("registered", True).group_by("city").having("population", ">", 3).join(
+ JoinClause("othertable").where("bar", "=", "foo")
+ )
self.assertEqual(expected_sql, builder.to_sql())
self.assertEqual(expected_bindings, builder.get_bindings())
def test_add_binding_with_list_merges_bindings(self):
builder = self.get_builder()
- builder.add_binding(['foo', 'bar'])
- builder.add_binding(['baz'])
- self.assertEqual(['foo', 'bar', 'baz'], builder.get_bindings())
+ builder.add_binding(["foo", "bar"])
+ builder.add_binding(["baz"])
+ self.assertEqual(["foo", "bar", "baz"], builder.get_bindings())
def test_add_binding_with_list_merges_bindings_in_correct_order(self):
builder = self.get_builder()
- builder.add_binding(['bar', 'baz'], 'having')
- builder.add_binding(['foo'], 'where')
- self.assertEqual(['foo', 'bar', 'baz'], builder.get_bindings())
+ builder.add_binding(["bar", "baz"], "having")
+ builder.add_binding(["foo"], "where")
+ self.assertEqual(["foo", "bar", "baz"], builder.get_bindings())
def test_merge_builders(self):
builder = self.get_builder()
- builder.add_binding('foo', 'where')
- builder.add_binding('baz', 'having')
+ builder.add_binding("foo", "where")
+ builder.add_binding("baz", "having")
other_builder = self.get_builder()
- other_builder.add_binding('bar', 'where')
+ other_builder.add_binding("bar", "where")
builder.merge_bindings(other_builder)
- self.assertEqual(['foo', 'bar', 'baz'], builder.get_bindings())
+ self.assertEqual(["foo", "bar", "baz"], builder.get_bindings())
def test_sub_select(self):
builder = self.get_builder()
marker = builder.get_grammar().get_marker()
- expected_sql = 'SELECT "foo", "bar", (SELECT "baz" FROM "two" WHERE "subkey" = %s) AS "sub" ' \
- 'FROM "one" WHERE "key" = %s' % (marker, marker)
- expected_bindings = ['subval', 'val']
+ expected_sql = (
+ 'SELECT "foo", "bar", (SELECT "baz" FROM "two" WHERE "subkey" = %s) AS "sub" '
+ 'FROM "one" WHERE "key" = %s' % (marker, marker)
+ )
+ expected_bindings = ["subval", "val"]
- builder.from_('one').select('foo', 'bar').where('key', '=', 'val')
- builder.select_sub(builder.new_query().from_('two').select('baz').where('subkey', '=', 'subval'), 'sub')
+ builder.from_("one").select("foo", "bar").where("key", "=", "val")
+ builder.select_sub(
+ builder.new_query()
+ .from_("two")
+ .select("baz")
+ .where("subkey", "=", "subval"),
+ "sub",
+ )
self.assertEqual(expected_sql, builder.to_sql())
self.assertEqual(expected_bindings, builder.get_bindings())
def test_chunk(self):
builder = self.get_builder()
- results = [
- {'foo': 'bar'},
- {'foo': 'baz'},
- {'foo': 'bam'},
- {'foo': 'boom'}
- ]
+ results = [{"foo": "bar"}, {"foo": "baz"}, {"foo": "bam"}, {"foo": "boom"}]
def select(query, bindings, _):
- index = int(re.search('OFFSET (\d+)', query).group(1))
- limit = int(re.search('LIMIT (\d+)', query).group(1))
+ index = int(re.search("OFFSET (\d+)", query).group(1))
+ limit = int(re.search("LIMIT (\d+)", query).group(1))
if index >= len(results):
return []
- return results[index:index + limit]
+ return results[index : index + limit]
builder.get_connection().select.side_effect = select
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_)
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results_
+ )
i = 0
- for users in builder.from_('users').chunk(1):
+ for users in builder.from_("users").chunk(1):
self.assertEqual(users[0], results[i])
i += 1
builder = self.get_builder()
- results = [
- {'foo': 'bar'},
- {'foo': 'baz'},
- {'foo': 'bam'},
- {'foo': 'boom'}
- ]
+ results = [{"foo": "bar"}, {"foo": "baz"}, {"foo": "bam"}, {"foo": "boom"}]
builder.get_connection().select.side_effect = select
- builder.get_processor().process_select = mock.MagicMock(side_effect=lambda builder_, results_: results_)
+ builder.get_processor().process_select = mock.MagicMock(
+ side_effect=lambda builder_, results_: results_
+ )
- for users in builder.from_('users').chunk(2):
+ for users in builder.from_("users").chunk(2):
self.assertEqual(2, len(users))
def test_not_specifying_columns_sects_all(self):
builder = self.get_builder()
- builder.from_('users')
+ builder.from_("users")
- self.assertEqual(
- 'SELECT * FROM "users"',
- builder.to_sql()
- )
+ self.assertEqual('SELECT * FROM "users"', builder.to_sql())
def test_merge(self):
b1 = self.get_builder()
- b1.from_('test').select('foo', 'bar').where('baz', 'boom')
+ b1.from_("test").select("foo", "bar").where("baz", "boom")
b2 = self.get_builder()
- b2.where('foo', 'bar')
+ b2.where("foo", "bar")
b1.merge(b2)
self.assertEqual(
- 'SELECT "foo", "bar" FROM "test" WHERE "baz" = ? AND "foo" = ?',
- b1.to_sql()
+ 'SELECT "foo", "bar" FROM "test" WHERE "baz" = ? AND "foo" = ?', b1.to_sql()
)
- self.assertEqual(
- ['boom', 'bar'],
- b1.get_bindings()
- )
+ self.assertEqual(["boom", "bar"], b1.get_bindings())
def get_mysql_builder(self):
grammar = MySQLQueryGrammar()
diff --git a/tests/schema/grammars/__init__.py b/tests/schema/grammars/__init__.py
index 633f8661..40a96afc 100644
--- a/tests/schema/grammars/__init__.py
+++ b/tests/schema/grammars/__init__.py
@@ -1,2 +1 @@
# -*- coding: utf-8 -*-
-
diff --git a/tests/schema/grammars/test_mysql_grammar.py b/tests/schema/grammars/test_mysql_grammar.py
index 80d50c19..b233d0e3 100644
--- a/tests/schema/grammars/test_mysql_grammar.py
+++ b/tests/schema/grammars/test_mysql_grammar.py
@@ -9,56 +9,57 @@
class MySQLSchemaGrammarTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
def test_basic_create(self):
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.create()
- blueprint.increments('id')
- blueprint.string('email')
+ blueprint.increments("id")
+ blueprint.string("email")
conn = self.get_connection()
- conn.should_receive('get_config').once().with_args('charset').and_return('utf8')
- conn.should_receive('get_config').once().with_args('collation').and_return('utf8_unicode_ci')
+ conn.should_receive("get_config").once().with_args("charset").and_return("utf8")
+ conn.should_receive("get_config").once().with_args("collation").and_return(
+ "utf8_unicode_ci"
+ )
statements = blueprint.to_sql(conn, self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'CREATE TABLE `users` ('
- '`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, '
- '`email` VARCHAR(255) NOT NULL) '
- 'DEFAULT CHARACTER SET utf8 COLLATE utf8_unicode_ci',
- statements[0]
+ "CREATE TABLE `users` ("
+ "`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, "
+ "`email` VARCHAR(255) NOT NULL) "
+ "DEFAULT CHARACTER SET utf8 COLLATE utf8_unicode_ci",
+ statements[0],
)
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.create()
- blueprint.increments('id')
- blueprint.string('email')
+ blueprint.increments("id")
+ blueprint.string("email")
conn = self.get_connection()
- conn.should_receive('get_config').and_return(None)
+ conn.should_receive("get_config").and_return(None)
statements = blueprint.to_sql(conn, self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'CREATE TABLE `users` ('
- '`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, '
- '`email` VARCHAR(255) NOT NULL)',
- statements[0]
+ "CREATE TABLE `users` ("
+ "`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, "
+ "`email` VARCHAR(255) NOT NULL)",
+ statements[0],
)
def test_charset_collation_create(self):
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.create()
- blueprint.increments('id')
- blueprint.string('email')
- blueprint.charset = 'utf8mb4'
- blueprint.collation = 'utf8mb4_unicode_ci'
+ blueprint.increments("id")
+ blueprint.string("email")
+ blueprint.charset = "utf8mb4"
+ blueprint.collation = "utf8mb4_unicode_ci"
conn = self.get_connection()
@@ -66,540 +67,557 @@ def test_charset_collation_create(self):
self.assertEqual(1, len(statements))
self.assertEqual(
- 'CREATE TABLE `users` ('
- '`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, '
- '`email` VARCHAR(255) NOT NULL) '
- 'DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci',
- statements[0]
+ "CREATE TABLE `users` ("
+ "`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, "
+ "`email` VARCHAR(255) NOT NULL) "
+ "DEFAULT CHARACTER SET utf8mb4 COLLATE utf8mb4_unicode_ci",
+ statements[0],
)
def test_basic_create_with_prefix(self):
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.create()
- blueprint.increments('id')
- blueprint.string('email')
+ blueprint.increments("id")
+ blueprint.string("email")
grammar = self.get_grammar()
- grammar.set_table_prefix('prefix_')
+ grammar.set_table_prefix("prefix_")
conn = self.get_connection()
- conn.should_receive('get_config').and_return(None)
+ conn.should_receive("get_config").and_return(None)
statements = blueprint.to_sql(conn, grammar)
self.assertEqual(1, len(statements))
self.assertEqual(
- 'CREATE TABLE `prefix_users` ('
- '`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, '
- '`email` VARCHAR(255) NOT NULL)',
- statements[0]
+ "CREATE TABLE `prefix_users` ("
+ "`id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY, "
+ "`email` VARCHAR(255) NOT NULL)",
+ statements[0],
)
def test_drop_table(self):
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.drop()
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual('DROP TABLE `users`', statements[0])
+ self.assertEqual("DROP TABLE `users`", statements[0])
def test_drop_table_if_exists(self):
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.drop_if_exists()
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual('DROP TABLE IF EXISTS `users`', statements[0])
+ self.assertEqual("DROP TABLE IF EXISTS `users`", statements[0])
def test_drop_column(self):
- blueprint = Blueprint('users')
- blueprint.drop_column('foo')
+ blueprint = Blueprint("users")
+ blueprint.drop_column("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual('ALTER TABLE `users` DROP `foo`', statements[0])
+ self.assertEqual("ALTER TABLE `users` DROP `foo`", statements[0])
- blueprint = Blueprint('users')
- blueprint.drop_column('foo', 'bar')
+ blueprint = Blueprint("users")
+ blueprint.drop_column("foo", "bar")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual('ALTER TABLE `users` DROP `foo`, DROP `bar`', statements[0])
+ self.assertEqual("ALTER TABLE `users` DROP `foo`, DROP `bar`", statements[0])
def test_drop_primary(self):
- blueprint = Blueprint('users')
- blueprint.drop_primary('foo')
+ blueprint = Blueprint("users")
+ blueprint.drop_primary("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual('ALTER TABLE `users` DROP PRIMARY KEY', statements[0])
+ self.assertEqual("ALTER TABLE `users` DROP PRIMARY KEY", statements[0])
def test_drop_unique(self):
- blueprint = Blueprint('users')
- blueprint.drop_unique('foo')
+ blueprint = Blueprint("users")
+ blueprint.drop_unique("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual('ALTER TABLE `users` DROP INDEX foo', statements[0])
+ self.assertEqual("ALTER TABLE `users` DROP INDEX foo", statements[0])
def test_drop_index(self):
- blueprint = Blueprint('users')
- blueprint.drop_index('foo')
+ blueprint = Blueprint("users")
+ blueprint.drop_index("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual('ALTER TABLE `users` DROP INDEX foo', statements[0])
+ self.assertEqual("ALTER TABLE `users` DROP INDEX foo", statements[0])
def test_drop_foreign(self):
- blueprint = Blueprint('users')
- blueprint.drop_foreign('foo')
+ blueprint = Blueprint("users")
+ blueprint.drop_foreign("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual('ALTER TABLE `users` DROP FOREIGN KEY foo', statements[0])
+ self.assertEqual("ALTER TABLE `users` DROP FOREIGN KEY foo", statements[0])
def test_drop_timestamps(self):
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.drop_timestamps()
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual('ALTER TABLE `users` DROP `created_at`, DROP `updated_at`', statements[0])
+ self.assertEqual(
+ "ALTER TABLE `users` DROP `created_at`, DROP `updated_at`", statements[0]
+ )
def test_rename_table(self):
- blueprint = Blueprint('users')
- blueprint.rename('foo')
+ blueprint = Blueprint("users")
+ blueprint.rename("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual('RENAME TABLE `users` TO `foo`', statements[0])
+ self.assertEqual("RENAME TABLE `users` TO `foo`", statements[0])
def test_adding_primary_key(self):
- blueprint = Blueprint('users')
- blueprint.primary('foo', 'bar')
+ blueprint = Blueprint("users")
+ blueprint.primary("foo", "bar")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual('ALTER TABLE `users` ADD PRIMARY KEY bar(`foo`)', statements[0])
+ self.assertEqual(
+ "ALTER TABLE `users` ADD PRIMARY KEY bar(`foo`)", statements[0]
+ )
def test_adding_foreign_key(self):
- blueprint = Blueprint('users')
- blueprint.foreign('order_id').references('id').on('orders')
+ blueprint = Blueprint("users")
+ blueprint.foreign("order_id").references("id").on("orders")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
expected = [
- 'ALTER TABLE `users` ADD CONSTRAINT users_order_id_foreign '
- 'FOREIGN KEY (`order_id`) REFERENCES `orders` (`id`)'
+ "ALTER TABLE `users` ADD CONSTRAINT users_order_id_foreign "
+ "FOREIGN KEY (`order_id`) REFERENCES `orders` (`id`)"
]
self.assertEqual(expected, statements)
def test_adding_unique_key(self):
- blueprint = Blueprint('users')
- blueprint.unique('foo', 'bar')
+ blueprint = Blueprint("users")
+ blueprint.unique("foo", "bar")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual(
- 'ALTER TABLE `users` ADD UNIQUE bar(`foo`)',
- statements[0]
- )
+ self.assertEqual("ALTER TABLE `users` ADD UNIQUE bar(`foo`)", statements[0])
def test_adding_index(self):
- blueprint = Blueprint('users')
- blueprint.index(['foo', 'bar'], 'baz')
+ blueprint = Blueprint("users")
+ blueprint.index(["foo", "bar"], "baz")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD INDEX baz(`foo`, `bar`)',
- statements[0]
+ "ALTER TABLE `users` ADD INDEX baz(`foo`, `bar`)", statements[0]
)
def test_adding_incrementing_id(self):
- blueprint = Blueprint('users')
- blueprint.increments('id')
+ blueprint = Blueprint("users")
+ blueprint.increments("id")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY',
- statements[0]
+ "ALTER TABLE `users` ADD `id` INT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY",
+ statements[0],
)
def test_adding_big_incrementing_id(self):
- blueprint = Blueprint('users')
- blueprint.big_increments('id')
+ blueprint = Blueprint("users")
+ blueprint.big_increments("id")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY',
- statements[0]
+ "ALTER TABLE `users` ADD `id` BIGINT UNSIGNED NOT NULL AUTO_INCREMENT PRIMARY KEY",
+ statements[0],
)
-
+
def test_adding_column_after_another(self):
- blueprint = Blueprint('users')
- blueprint.string('name').after('foo')
-
+ blueprint = Blueprint("users")
+ blueprint.string("name").after("foo")
+
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL AFTER `foo`',
- statements[0]
+ "ALTER TABLE `users` ADD `name` VARCHAR(255) NOT NULL AFTER `foo`",
+ statements[0],
)
def test_adding_string(self):
- blueprint = Blueprint('users')
- blueprint.string('foo')
+ blueprint = Blueprint("users")
+ blueprint.string("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` VARCHAR(255) NOT NULL',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` VARCHAR(255) NOT NULL", statements[0]
)
- blueprint = Blueprint('users')
- blueprint.string('foo', 100)
+ blueprint = Blueprint("users")
+ blueprint.string("foo", 100)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` VARCHAR(100) NOT NULL',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` VARCHAR(100) NOT NULL", statements[0]
)
- blueprint = Blueprint('users')
- blueprint.string('foo', 100).nullable().default('bar')
+ blueprint = Blueprint("users")
+ blueprint.string("foo", 100).nullable().default("bar")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` VARCHAR(100) NULL DEFAULT \'bar\'',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` VARCHAR(100) NULL DEFAULT 'bar'",
+ statements[0],
)
def test_adding_text(self):
- blueprint = Blueprint('users')
- blueprint.text('foo')
+ blueprint = Blueprint("users")
+ blueprint.text("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` TEXT NOT NULL',
- statements[0]
- )
+ self.assertEqual("ALTER TABLE `users` ADD `foo` TEXT NOT NULL", statements[0])
def test_adding_big_integer(self):
- blueprint = Blueprint('users')
- blueprint.big_integer('foo')
+ blueprint = Blueprint("users")
+ blueprint.big_integer("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` BIGINT NOT NULL',
- statements[0]
- )
+ self.assertEqual("ALTER TABLE `users` ADD `foo` BIGINT NOT NULL", statements[0])
- blueprint = Blueprint('users')
- blueprint.big_integer('foo', True)
+ blueprint = Blueprint("users")
+ blueprint.big_integer("foo", True)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` BIGINT NOT NULL AUTO_INCREMENT PRIMARY KEY",
+ statements[0],
)
def test_adding_integer(self):
- blueprint = Blueprint('users')
- blueprint.integer('foo')
+ blueprint = Blueprint("users")
+ blueprint.integer("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` INT NOT NULL',
- statements[0]
- )
+ self.assertEqual("ALTER TABLE `users` ADD `foo` INT NOT NULL", statements[0])
- blueprint = Blueprint('users')
- blueprint.integer('foo', True)
+ blueprint = Blueprint("users")
+ blueprint.integer("foo", True)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` INT NOT NULL AUTO_INCREMENT PRIMARY KEY',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` INT NOT NULL AUTO_INCREMENT PRIMARY KEY",
+ statements[0],
)
def test_adding_medium_integer(self):
- blueprint = Blueprint('users')
- blueprint.medium_integer('foo')
+ blueprint = Blueprint("users")
+ blueprint.medium_integer("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` MEDIUMINT NOT NULL',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` MEDIUMINT NOT NULL", statements[0]
)
- blueprint = Blueprint('users')
- blueprint.medium_integer('foo', True)
+ blueprint = Blueprint("users")
+ blueprint.medium_integer("foo", True)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` MEDIUMINT NOT NULL AUTO_INCREMENT PRIMARY KEY',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` MEDIUMINT NOT NULL AUTO_INCREMENT PRIMARY KEY",
+ statements[0],
)
def test_adding_tiny_integer(self):
- blueprint = Blueprint('users')
- blueprint.tiny_integer('foo')
+ blueprint = Blueprint("users")
+ blueprint.tiny_integer("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` TINYINT NOT NULL',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` TINYINT NOT NULL", statements[0]
)
- blueprint = Blueprint('users')
- blueprint.tiny_integer('foo', True)
+ blueprint = Blueprint("users")
+ blueprint.tiny_integer("foo", True)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` TINYINT NOT NULL AUTO_INCREMENT PRIMARY KEY',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` TINYINT NOT NULL AUTO_INCREMENT PRIMARY KEY",
+ statements[0],
)
def test_adding_small_integer(self):
- blueprint = Blueprint('users')
- blueprint.small_integer('foo')
+ blueprint = Blueprint("users")
+ blueprint.small_integer("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` SMALLINT NOT NULL',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` SMALLINT NOT NULL", statements[0]
)
- blueprint = Blueprint('users')
- blueprint.small_integer('foo', True)
+ blueprint = Blueprint("users")
+ blueprint.small_integer("foo", True)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` SMALLINT NOT NULL AUTO_INCREMENT PRIMARY KEY',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` SMALLINT NOT NULL AUTO_INCREMENT PRIMARY KEY",
+ statements[0],
)
def test_adding_float(self):
- blueprint = Blueprint('users')
- blueprint.float('foo', 5, 2)
+ blueprint = Blueprint("users")
+ blueprint.float("foo", 5, 2)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` DOUBLE(5, 2) NOT NULL',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` DOUBLE(5, 2) NOT NULL", statements[0]
)
def test_adding_double(self):
- blueprint = Blueprint('users')
- blueprint.double('foo')
+ blueprint = Blueprint("users")
+ blueprint.double("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` DOUBLE NOT NULL',
- statements[0]
- )
+ self.assertEqual("ALTER TABLE `users` ADD `foo` DOUBLE NOT NULL", statements[0])
def test_adding_double_with_precision(self):
- blueprint = Blueprint('users')
- blueprint.double('foo', 15, 8)
+ blueprint = Blueprint("users")
+ blueprint.double("foo", 15, 8)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` DOUBLE(15, 8) NOT NULL',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` DOUBLE(15, 8) NOT NULL", statements[0]
)
def test_adding_decimal(self):
- blueprint = Blueprint('users')
- blueprint.decimal('foo', 5, 2)
+ blueprint = Blueprint("users")
+ blueprint.decimal("foo", 5, 2)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` DECIMAL(5, 2) NOT NULL',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` DECIMAL(5, 2) NOT NULL", statements[0]
)
def test_adding_boolean(self):
- blueprint = Blueprint('users')
- blueprint.boolean('foo')
+ blueprint = Blueprint("users")
+ blueprint.boolean("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` TINYINT(1) NOT NULL',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` TINYINT(1) NOT NULL", statements[0]
)
def test_adding_enum(self):
- blueprint = Blueprint('users')
- blueprint.enum('foo', ['bar', 'baz'])
+ blueprint = Blueprint("users")
+ blueprint.enum("foo", ["bar", "baz"])
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` ENUM(\'bar\', \'baz\') NOT NULL',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` ENUM('bar', 'baz') NOT NULL", statements[0]
)
def test_adding_date(self):
- blueprint = Blueprint('users')
- blueprint.date('foo')
+ blueprint = Blueprint("users")
+ blueprint.date("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` DATE NOT NULL',
- statements[0]
- )
+ self.assertEqual("ALTER TABLE `users` ADD `foo` DATE NOT NULL", statements[0])
def test_adding_datetime(self):
- blueprint = Blueprint('users')
- blueprint.datetime('foo')
+ blueprint = Blueprint("users")
+ blueprint.datetime("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` DATETIME NOT NULL',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` DATETIME NOT NULL", statements[0]
)
def test_adding_time(self):
- blueprint = Blueprint('users')
- blueprint.time('foo')
+ blueprint = Blueprint("users")
+ blueprint.time("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
+ self.assertEqual(1, len(statements))
+ self.assertEqual("ALTER TABLE `users` ADD `foo` TIME NOT NULL", statements[0])
+
+ def test_adding_timestamp_mysql_lt_564(self):
+ blueprint = Blueprint("users")
+ blueprint.timestamp("foo")
+ statements = blueprint.to_sql(
+ self.get_connection(), self.get_grammar((5, 6, 0, ""))
+ )
+
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` TIME NOT NULL',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` TIMESTAMP NOT NULL", statements[0]
)
- def test_adding_timestamp(self):
- blueprint = Blueprint('users')
- blueprint.timestamp('foo')
- statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
+ def test_adding_timestamp_mysql_gte_564(self):
+ blueprint = Blueprint("users")
+ blueprint.timestamp("foo")
+ statements = blueprint.to_sql(
+ self.get_connection(), self.get_grammar((5, 6, 4, ""))
+ )
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` TIMESTAMP NOT NULL',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` TIMESTAMP(6) NOT NULL", statements[0]
)
- def test_adding_timestamp_with_current(self):
- blueprint = Blueprint('users')
- blueprint.timestamp('foo').use_current()
- statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
+ def test_adding_timestamp_with_current_mysql_lt_564(self):
+ blueprint = Blueprint("users")
+ blueprint.timestamp("foo").use_current()
+ statements = blueprint.to_sql(
+ self.get_connection(), self.get_grammar((5, 6, 0, ""))
+ )
+
+ self.assertEqual(1, len(statements))
+ self.assertEqual(
+ "ALTER TABLE `users` ADD `foo` TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL",
+ statements[0],
+ )
+
+ def test_adding_timestamp_with_current_mysql_gte_564(self):
+ blueprint = Blueprint("users")
+ blueprint.timestamp("foo").use_current()
+ statements = blueprint.to_sql(
+ self.get_connection(), self.get_grammar((5, 6, 4, ""))
+ )
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL',
- statements[0]
+ "ALTER TABLE `users` ADD `foo` TIMESTAMP(6) DEFAULT CURRENT_TIMESTAMP(6) NOT NULL",
+ statements[0],
)
- def test_adding_timestamps(self):
- blueprint = Blueprint('users')
+ def test_adding_timestamps_mysql_lt_564(self):
+ blueprint = Blueprint("users")
blueprint.timestamps()
- statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
+ statements = blueprint.to_sql(
+ self.get_connection(), self.get_grammar((5, 6, 0, ""))
+ )
self.assertEqual(1, len(statements))
expected = [
- 'ALTER TABLE `users` ADD `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, '
- 'ADD `updated_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL'
+ "ALTER TABLE `users` ADD `created_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL, "
+ "ADD `updated_at` TIMESTAMP DEFAULT CURRENT_TIMESTAMP NOT NULL"
]
- self.assertEqual(
- expected[0],
- statements[0]
+ self.assertEqual(expected[0], statements[0])
+
+ def test_adding_timestamps_mysql_gte_564(self):
+ blueprint = Blueprint("users")
+ blueprint.timestamps()
+ statements = blueprint.to_sql(
+ self.get_connection(), self.get_grammar((5, 6, 4, ""))
)
- def test_adding_timestamps_not_current(self):
- blueprint = Blueprint('users')
+ self.assertEqual(1, len(statements))
+ expected = [
+ "ALTER TABLE `users` ADD `created_at` TIMESTAMP(6) DEFAULT CURRENT_TIMESTAMP(6) NOT NULL, "
+ "ADD `updated_at` TIMESTAMP(6) DEFAULT CURRENT_TIMESTAMP(6) NOT NULL"
+ ]
+ self.assertEqual(expected[0], statements[0])
+
+ def test_adding_timestamps_not_current_mysql_lt_564(self):
+ blueprint = Blueprint("users")
blueprint.timestamps(use_current=False)
- statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
+ statements = blueprint.to_sql(
+ self.get_connection(), self.get_grammar((5, 6, 0, ""))
+ )
self.assertEqual(1, len(statements))
expected = [
- 'ALTER TABLE `users` ADD `created_at` TIMESTAMP NOT NULL, '
- 'ADD `updated_at` TIMESTAMP NOT NULL'
+ "ALTER TABLE `users` ADD `created_at` TIMESTAMP NOT NULL, "
+ "ADD `updated_at` TIMESTAMP NOT NULL"
]
- self.assertEqual(
- expected[0],
- statements[0]
+ self.assertEqual(expected[0], statements[0])
+
+ def test_adding_timestamps_not_current_mysql_gte_564(self):
+ blueprint = Blueprint("users")
+ blueprint.timestamps(use_current=False)
+ statements = blueprint.to_sql(
+ self.get_connection(), self.get_grammar((5, 6, 4, ""))
)
+ self.assertEqual(1, len(statements))
+ expected = [
+ "ALTER TABLE `users` ADD `created_at` TIMESTAMP(6) NOT NULL, "
+ "ADD `updated_at` TIMESTAMP(6) NOT NULL"
+ ]
+ self.assertEqual(expected[0], statements[0])
+
def test_adding_binary(self):
- blueprint = Blueprint('users')
- blueprint.binary('foo')
+ blueprint = Blueprint("users")
+ blueprint.binary("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` BLOB NOT NULL',
- statements[0]
- )
+ self.assertEqual("ALTER TABLE `users` ADD `foo` BLOB NOT NULL", statements[0])
def test_adding_json(self):
- blueprint = Blueprint('users')
- blueprint.json('foo')
+ blueprint = Blueprint("users")
+ blueprint.json("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` JSON NOT NULL',
- statements[0]
- )
+ self.assertEqual("ALTER TABLE `users` ADD `foo` JSON NOT NULL", statements[0])
def test_adding_json_mysql_56(self):
- blueprint = Blueprint('users')
- blueprint.json('foo')
+ blueprint = Blueprint("users")
+ blueprint.json("foo")
- statements = blueprint.to_sql(self.get_connection(), self.get_grammar((5, 6, 0, '')))
+ statements = blueprint.to_sql(
+ self.get_connection(), self.get_grammar((5, 6, 0, ""))
+ )
self.assertEqual(1, len(statements))
- self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` TEXT NOT NULL',
- statements[0]
- )
+ self.assertEqual("ALTER TABLE `users` ADD `foo` TEXT NOT NULL", statements[0])
def test_adding_json_mariadb(self):
- blueprint = Blueprint('users')
- blueprint.json('foo')
+ blueprint = Blueprint("users")
+ blueprint.json("foo")
- statements = blueprint.to_sql(self.get_connection(), self.get_grammar((10, 6, 0, 'mariadb')))
+ statements = blueprint.to_sql(
+ self.get_connection(), self.get_grammar((10, 6, 0, "mariadb"))
+ )
self.assertEqual(1, len(statements))
- self.assertEqual(
- 'ALTER TABLE `users` ADD `foo` TEXT NOT NULL',
- statements[0]
- )
+ self.assertEqual("ALTER TABLE `users` ADD `foo` TEXT NOT NULL", statements[0])
def get_connection(self, version=None):
if version is None:
- version = (5, 7, 11, '')
+ version = (5, 7, 11, "")
connector = flexmock(MySQLConnector())
- connector.should_receive('get_server_version').and_return(version)
+ connector.should_receive("get_server_version").and_return(version)
conn = flexmock(Connection(connector))
return conn
diff --git a/tests/schema/grammars/test_postgres_grammar.py b/tests/schema/grammars/test_postgres_grammar.py
index f7778d7a..e1bf70b1 100644
--- a/tests/schema/grammars/test_postgres_grammar.py
+++ b/tests/schema/grammars/test_postgres_grammar.py
@@ -8,26 +8,25 @@
class PostgresSchemaGrammarTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
def test_basic_create(self):
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.create()
- blueprint.increments('id')
- blueprint.string('email')
+ blueprint.increments("id")
+ blueprint.string("email")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
'CREATE TABLE "users" ("id" SERIAL PRIMARY KEY NOT NULL, "email" VARCHAR(255) NOT NULL)',
- statements[0]
+ statements[0],
)
- blueprint = Blueprint('users')
- blueprint.increments('id')
- blueprint.string('email')
+ blueprint = Blueprint("users")
+ blueprint.increments("id")
+ blueprint.string("email")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
@@ -38,7 +37,7 @@ def test_basic_create(self):
self.assertEqual(expected[0], statements[0])
def test_drop_table(self):
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.drop()
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
@@ -46,7 +45,7 @@ def test_drop_table(self):
self.assertEqual('DROP TABLE "users"', statements[0])
def test_drop_table_if_exists(self):
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.drop_if_exists()
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
@@ -54,82 +53,89 @@ def test_drop_table_if_exists(self):
self.assertEqual('DROP TABLE IF EXISTS "users"', statements[0])
def test_drop_column(self):
- blueprint = Blueprint('users')
- blueprint.drop_column('foo')
+ blueprint = Blueprint("users")
+ blueprint.drop_column("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual('ALTER TABLE "users" DROP COLUMN "foo"', statements[0])
- blueprint = Blueprint('users')
- blueprint.drop_column('foo', 'bar')
+ blueprint = Blueprint("users")
+ blueprint.drop_column("foo", "bar")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual('ALTER TABLE "users" DROP COLUMN "foo", DROP COLUMN "bar"', statements[0])
+ self.assertEqual(
+ 'ALTER TABLE "users" DROP COLUMN "foo", DROP COLUMN "bar"', statements[0]
+ )
def test_drop_primary(self):
- blueprint = Blueprint('users')
- blueprint.drop_primary('foo')
+ blueprint = Blueprint("users")
+ blueprint.drop_primary("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual('ALTER TABLE "users" DROP CONSTRAINT users_pkey', statements[0])
+ self.assertEqual(
+ 'ALTER TABLE "users" DROP CONSTRAINT users_pkey', statements[0]
+ )
def test_drop_unique(self):
- blueprint = Blueprint('users')
- blueprint.drop_unique('foo')
+ blueprint = Blueprint("users")
+ blueprint.drop_unique("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual('ALTER TABLE "users" DROP CONSTRAINT foo', statements[0])
def test_drop_index(self):
- blueprint = Blueprint('users')
- blueprint.drop_index('foo')
+ blueprint = Blueprint("users")
+ blueprint.drop_index("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual('DROP INDEX foo', statements[0])
+ self.assertEqual("DROP INDEX foo", statements[0])
def test_drop_foreign(self):
- blueprint = Blueprint('users')
- blueprint.drop_unique('foo')
+ blueprint = Blueprint("users")
+ blueprint.drop_unique("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual('ALTER TABLE "users" DROP CONSTRAINT foo', statements[0])
def test_drop_timestamps(self):
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.drop_timestamps()
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual('ALTER TABLE "users" DROP COLUMN "created_at", DROP COLUMN "updated_at"', statements[0])
+ self.assertEqual(
+ 'ALTER TABLE "users" DROP COLUMN "created_at", DROP COLUMN "updated_at"',
+ statements[0],
+ )
def test_rename_table(self):
- blueprint = Blueprint('users')
- blueprint.rename('foo')
+ blueprint = Blueprint("users")
+ blueprint.rename("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual('ALTER TABLE "users" RENAME TO "foo"', statements[0])
def test_adding_primary_key(self):
- blueprint = Blueprint('users')
- blueprint.primary('foo')
+ blueprint = Blueprint("users")
+ blueprint.primary("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual('ALTER TABLE "users" ADD PRIMARY KEY ("foo")', statements[0])
def test_adding_foreign_key(self):
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.create()
- blueprint.string('foo').primary()
- blueprint.string('order_id')
- blueprint.foreign('order_id').references('id').on('orders')
+ blueprint.string("foo").primary()
+ blueprint.string("order_id")
+ blueprint.foreign("order_id").references("id").on("orders")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(3, len(statements))
@@ -137,362 +143,341 @@ def test_adding_foreign_key(self):
'CREATE TABLE "users" ("foo" VARCHAR(255) NOT NULL, "order_id" VARCHAR(255) NOT NULL)',
'ALTER TABLE "users" ADD CONSTRAINT users_order_id_foreign'
' FOREIGN KEY ("order_id") REFERENCES "orders" ("id")',
- 'ALTER TABLE "users" ADD PRIMARY KEY ("foo")'
+ 'ALTER TABLE "users" ADD PRIMARY KEY ("foo")',
]
self.assertEqual(expected, statements)
def test_adding_unique_key(self):
- blueprint = Blueprint('users')
- blueprint.unique('foo', 'bar')
+ blueprint = Blueprint("users")
+ blueprint.unique("foo", "bar")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD CONSTRAINT bar UNIQUE ("foo")',
- statements[0]
+ 'ALTER TABLE "users" ADD CONSTRAINT bar UNIQUE ("foo")', statements[0]
)
def test_adding_index(self):
- blueprint = Blueprint('users')
- blueprint.index(['foo', 'bar'], 'baz')
+ blueprint = Blueprint("users")
+ blueprint.index(["foo", "bar"], "baz")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual(
- 'CREATE INDEX baz ON "users" ("foo", "bar")',
- statements[0]
- )
+ self.assertEqual('CREATE INDEX baz ON "users" ("foo", "bar")', statements[0])
def test_adding_incrementing_id(self):
- blueprint = Blueprint('users')
- blueprint.increments('id')
+ blueprint = Blueprint("users")
+ blueprint.increments("id")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
'ALTER TABLE "users" ADD COLUMN "id" SERIAL PRIMARY KEY NOT NULL',
- statements[0]
+ statements[0],
)
def test_adding_big_incrementing_id(self):
- blueprint = Blueprint('users')
- blueprint.big_increments('id')
+ blueprint = Blueprint("users")
+ blueprint.big_increments("id")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
'ALTER TABLE "users" ADD COLUMN "id" BIGSERIAL PRIMARY KEY NOT NULL',
- statements[0]
+ statements[0],
)
def test_adding_string(self):
- blueprint = Blueprint('users')
- blueprint.string('foo')
+ blueprint = Blueprint("users")
+ blueprint.string("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR(255) NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR(255) NOT NULL', statements[0]
)
- blueprint = Blueprint('users')
- blueprint.string('foo', 100)
+ blueprint = Blueprint("users")
+ blueprint.string("foo", 100)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR(100) NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR(100) NOT NULL', statements[0]
)
- blueprint = Blueprint('users')
- blueprint.string('foo', 100).nullable().default('bar')
+ blueprint = Blueprint("users")
+ blueprint.string("foo", 100).nullable().default("bar")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR(100) NULL DEFAULT \'bar\'',
- statements[0]
+ statements[0],
)
def test_adding_text(self):
- blueprint = Blueprint('users')
- blueprint.text('foo')
+ blueprint = Blueprint("users")
+ blueprint.text("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" TEXT NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" TEXT NOT NULL', statements[0]
)
def test_adding_big_integer(self):
- blueprint = Blueprint('users')
- blueprint.big_integer('foo')
+ blueprint = Blueprint("users")
+ blueprint.big_integer("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" BIGINT NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" BIGINT NOT NULL', statements[0]
)
- blueprint = Blueprint('users')
- blueprint.big_integer('foo', True)
+ blueprint = Blueprint("users")
+ blueprint.big_integer("foo", True)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
'ALTER TABLE "users" ADD COLUMN "foo" BIGSERIAL PRIMARY KEY NOT NULL',
- statements[0]
+ statements[0],
)
def test_adding_integer(self):
- blueprint = Blueprint('users')
- blueprint.integer('foo')
+ blueprint = Blueprint("users")
+ blueprint.integer("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', statements[0]
)
- blueprint = Blueprint('users')
- blueprint.integer('foo', True)
+ blueprint = Blueprint("users")
+ blueprint.integer("foo", True)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
'ALTER TABLE "users" ADD COLUMN "foo" SERIAL PRIMARY KEY NOT NULL',
- statements[0]
+ statements[0],
)
def test_adding_medium_integer(self):
- blueprint = Blueprint('users')
- blueprint.medium_integer('foo')
+ blueprint = Blueprint("users")
+ blueprint.medium_integer("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', statements[0]
)
- blueprint = Blueprint('users')
- blueprint.medium_integer('foo', True)
+ blueprint = Blueprint("users")
+ blueprint.medium_integer("foo", True)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
'ALTER TABLE "users" ADD COLUMN "foo" SERIAL PRIMARY KEY NOT NULL',
- statements[0]
+ statements[0],
)
def test_adding_tiny_integer(self):
- blueprint = Blueprint('users')
- blueprint.tiny_integer('foo')
+ blueprint = Blueprint("users")
+ blueprint.tiny_integer("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" SMALLINT NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" SMALLINT NOT NULL', statements[0]
)
- blueprint = Blueprint('users')
- blueprint.tiny_integer('foo', True)
+ blueprint = Blueprint("users")
+ blueprint.tiny_integer("foo", True)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
'ALTER TABLE "users" ADD COLUMN "foo" SMALLSERIAL PRIMARY KEY NOT NULL',
- statements[0]
+ statements[0],
)
def test_adding_small_integer(self):
- blueprint = Blueprint('users')
- blueprint.small_integer('foo')
+ blueprint = Blueprint("users")
+ blueprint.small_integer("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" SMALLINT NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" SMALLINT NOT NULL', statements[0]
)
- blueprint = Blueprint('users')
- blueprint.small_integer('foo', True)
+ blueprint = Blueprint("users")
+ blueprint.small_integer("foo", True)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
'ALTER TABLE "users" ADD COLUMN "foo" SMALLSERIAL PRIMARY KEY NOT NULL',
- statements[0]
+ statements[0],
)
def test_adding_float(self):
- blueprint = Blueprint('users')
- blueprint.float('foo', 5, 2)
+ blueprint = Blueprint("users")
+ blueprint.float("foo", 5, 2)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
'ALTER TABLE "users" ADD COLUMN "foo" DOUBLE PRECISION NOT NULL',
- statements[0]
+ statements[0],
)
def test_adding_double(self):
- blueprint = Blueprint('users')
- blueprint.double('foo', 15, 8)
+ blueprint = Blueprint("users")
+ blueprint.double("foo", 15, 8)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
'ALTER TABLE "users" ADD COLUMN "foo" DOUBLE PRECISION NOT NULL',
- statements[0]
+ statements[0],
)
def test_adding_decimal(self):
- blueprint = Blueprint('users')
- blueprint.decimal('foo', 5, 2)
+ blueprint = Blueprint("users")
+ blueprint.decimal("foo", 5, 2)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" DECIMAL(5, 2) NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" DECIMAL(5, 2) NOT NULL', statements[0]
)
def test_adding_boolean(self):
- blueprint = Blueprint('users')
- blueprint.boolean('foo')
+ blueprint = Blueprint("users")
+ blueprint.boolean("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" BOOLEAN NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" BOOLEAN NOT NULL', statements[0]
)
def test_adding_enum(self):
- blueprint = Blueprint('users')
- blueprint.enum('foo', ['bar', 'baz'])
+ blueprint = Blueprint("users")
+ blueprint.enum("foo", ["bar", "baz"])
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR(255) CHECK ("foo" IN (\'bar\', \'baz\')) NOT NULL',
- statements[0]
+ statements[0],
)
def test_adding_date(self):
- blueprint = Blueprint('users')
- blueprint.date('foo')
+ blueprint = Blueprint("users")
+ blueprint.date("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" DATE NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" DATE NOT NULL', statements[0]
)
def test_adding_datetime(self):
- blueprint = Blueprint('users')
- blueprint.datetime('foo')
+ blueprint = Blueprint("users")
+ blueprint.datetime("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" TIMESTAMP(0) WITHOUT TIME ZONE NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" TIMESTAMP(6) WITHOUT TIME ZONE NOT NULL',
+ statements[0],
)
def test_adding_time(self):
- blueprint = Blueprint('users')
- blueprint.time('foo')
+ blueprint = Blueprint("users")
+ blueprint.time("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" TIME(0) WITHOUT TIME ZONE NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" TIME(6) WITHOUT TIME ZONE NOT NULL',
+ statements[0],
)
def test_adding_timestamp(self):
- blueprint = Blueprint('users')
- blueprint.timestamp('foo')
+ blueprint = Blueprint("users")
+ blueprint.timestamp("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" TIMESTAMP(0) WITHOUT TIME ZONE NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" TIMESTAMP(6) WITHOUT TIME ZONE NOT NULL',
+ statements[0],
)
def test_adding_timestamp_with_current(self):
- blueprint = Blueprint('users')
- blueprint.timestamp('foo').use_current()
+ blueprint = Blueprint("users")
+ blueprint.timestamp("foo").use_current()
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" TIMESTAMP(0) WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP(0) NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" TIMESTAMP(6) WITHOUT TIME ZONE '
+ "DEFAULT CURRENT_TIMESTAMP(6) NOT NULL",
+ statements[0],
)
def test_adding_timestamps(self):
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.timestamps()
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
expected = [
- 'ALTER TABLE "users" ADD COLUMN "created_at" TIMESTAMP(0) WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP(0) NOT NULL, '
- 'ADD COLUMN "updated_at" TIMESTAMP(0) WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP(0) NOT NULL'
+ 'ALTER TABLE "users" ADD COLUMN "created_at" TIMESTAMP(6) WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP(6) NOT NULL, '
+ 'ADD COLUMN "updated_at" TIMESTAMP(6) WITHOUT TIME ZONE DEFAULT CURRENT_TIMESTAMP(6) NOT NULL'
]
- self.assertEqual(
- expected[0],
- statements[0]
- )
+
+ self.assertEqual(expected[0], statements[0])
def test_adding_timestamps_not_current(self):
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.timestamps(use_current=False)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
expected = [
- 'ALTER TABLE "users" ADD COLUMN "created_at" TIMESTAMP(0) WITHOUT TIME ZONE NOT NULL, '
- 'ADD COLUMN "updated_at" TIMESTAMP(0) WITHOUT TIME ZONE NOT NULL'
+ 'ALTER TABLE "users" ADD COLUMN "created_at" TIMESTAMP(6) WITHOUT TIME ZONE NOT NULL, '
+ 'ADD COLUMN "updated_at" TIMESTAMP(6) WITHOUT TIME ZONE NOT NULL'
]
- self.assertEqual(
- expected[0],
- statements[0]
- )
+ self.assertEqual(expected[0], statements[0])
def test_adding_binary(self):
- blueprint = Blueprint('users')
- blueprint.binary('foo')
+ blueprint = Blueprint("users")
+ blueprint.binary("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" BYTEA NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" BYTEA NOT NULL', statements[0]
)
def test_adding_json(self):
- blueprint = Blueprint('users')
- blueprint.json('foo')
+ blueprint = Blueprint("users")
+ blueprint.json("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" JSON NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" JSON NOT NULL', statements[0]
)
def get_connection(self):
diff --git a/tests/schema/grammars/test_sqlite_grammar.py b/tests/schema/grammars/test_sqlite_grammar.py
index b646f55c..a5ed7cf0 100644
--- a/tests/schema/grammars/test_sqlite_grammar.py
+++ b/tests/schema/grammars/test_sqlite_grammar.py
@@ -8,37 +8,36 @@
class SqliteSchemaGrammarTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
def test_basic_create(self):
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.create()
- blueprint.increments('id')
- blueprint.string('email')
+ blueprint.increments("id")
+ blueprint.string("email")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
'CREATE TABLE "users" ("id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT, "email" VARCHAR NOT NULL)',
- statements[0]
+ statements[0],
)
- blueprint = Blueprint('users')
- blueprint.increments('id')
- blueprint.string('email')
+ blueprint = Blueprint("users")
+ blueprint.increments("id")
+ blueprint.string("email")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(2, len(statements))
expected = [
'ALTER TABLE "users" ADD COLUMN "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT',
- 'ALTER TABLE "users" ADD COLUMN "email" VARCHAR NOT NULL'
+ 'ALTER TABLE "users" ADD COLUMN "email" VARCHAR NOT NULL',
]
self.assertEqual(expected, statements)
def test_drop_table(self):
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.drop()
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
@@ -46,7 +45,7 @@ def test_drop_table(self):
self.assertEqual('DROP TABLE "users"', statements[0])
def test_drop_table_if_exists(self):
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.drop_if_exists()
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
@@ -54,364 +53,335 @@ def test_drop_table_if_exists(self):
self.assertEqual('DROP TABLE IF EXISTS "users"', statements[0])
def test_drop_unique(self):
- blueprint = Blueprint('users')
- blueprint.drop_unique('foo')
+ blueprint = Blueprint("users")
+ blueprint.drop_unique("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual('DROP INDEX foo', statements[0])
+ self.assertEqual("DROP INDEX foo", statements[0])
def test_drop_index(self):
- blueprint = Blueprint('users')
- blueprint.drop_index('foo')
+ blueprint = Blueprint("users")
+ blueprint.drop_index("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual('DROP INDEX foo', statements[0])
+ self.assertEqual("DROP INDEX foo", statements[0])
def test_rename_table(self):
- blueprint = Blueprint('users')
- blueprint.rename('foo')
+ blueprint = Blueprint("users")
+ blueprint.rename("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual('ALTER TABLE "users" RENAME TO "foo"', statements[0])
def test_adding_foreign_key(self):
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.create()
- blueprint.string('foo').primary()
- blueprint.string('order_id')
- blueprint.foreign('order_id').references('id').on('orders')
+ blueprint.string("foo").primary()
+ blueprint.string("order_id")
+ blueprint.foreign("order_id").references("id").on("orders")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- expected = 'CREATE TABLE "users" ("foo" VARCHAR NOT NULL, "order_id" VARCHAR NOT NULL, ' \
- 'FOREIGN KEY("order_id") REFERENCES "orders"("id"), PRIMARY KEY ("foo"))'
+ expected = (
+ 'CREATE TABLE "users" ("foo" VARCHAR NOT NULL, "order_id" VARCHAR NOT NULL, '
+ 'FOREIGN KEY("order_id") REFERENCES "orders"("id"), PRIMARY KEY ("foo"))'
+ )
self.assertEqual(expected, statements[0])
def test_adding_unique_key(self):
- blueprint = Blueprint('users')
- blueprint.unique('foo', 'bar')
+ blueprint = Blueprint("users")
+ blueprint.unique("foo", "bar")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual(
- 'CREATE UNIQUE INDEX bar ON "users" ("foo")',
- statements[0]
- )
+ self.assertEqual('CREATE UNIQUE INDEX bar ON "users" ("foo")', statements[0])
def test_adding_index(self):
- blueprint = Blueprint('users')
- blueprint.index('foo', 'bar')
+ blueprint = Blueprint("users")
+ blueprint.index("foo", "bar")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
- self.assertEqual(
- 'CREATE INDEX bar ON "users" ("foo")',
- statements[0]
- )
+ self.assertEqual('CREATE INDEX bar ON "users" ("foo")', statements[0])
def test_adding_incrementing_id(self):
- blueprint = Blueprint('users')
- blueprint.increments('id')
+ blueprint = Blueprint("users")
+ blueprint.increments("id")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
'ALTER TABLE "users" ADD COLUMN "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT',
- statements[0]
+ statements[0],
)
def test_adding_big_incrementing_id(self):
- blueprint = Blueprint('users')
- blueprint.big_increments('id')
+ blueprint = Blueprint("users")
+ blueprint.big_increments("id")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
'ALTER TABLE "users" ADD COLUMN "id" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT',
- statements[0]
+ statements[0],
)
def test_adding_string(self):
- blueprint = Blueprint('users')
- blueprint.string('foo')
+ blueprint = Blueprint("users")
+ blueprint.string("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR NOT NULL', statements[0]
)
- blueprint = Blueprint('users')
- blueprint.string('foo', 100)
+ blueprint = Blueprint("users")
+ blueprint.string("foo", 100)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR NOT NULL', statements[0]
)
- blueprint = Blueprint('users')
- blueprint.string('foo', 100).nullable().default('bar')
+ blueprint = Blueprint("users")
+ blueprint.string("foo", 100).nullable().default("bar")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR NULL DEFAULT \'bar\'',
- statements[0]
+ statements[0],
)
def test_adding_text(self):
- blueprint = Blueprint('users')
- blueprint.text('foo')
+ blueprint = Blueprint("users")
+ blueprint.text("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" TEXT NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" TEXT NOT NULL', statements[0]
)
def test_adding_big_integer(self):
- blueprint = Blueprint('users')
- blueprint.big_integer('foo')
+ blueprint = Blueprint("users")
+ blueprint.big_integer("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', statements[0]
)
- blueprint = Blueprint('users')
- blueprint.big_integer('foo', True)
+ blueprint = Blueprint("users")
+ blueprint.big_integer("foo", True)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT',
- statements[0]
+ statements[0],
)
def test_adding_integer(self):
- blueprint = Blueprint('users')
- blueprint.integer('foo')
+ blueprint = Blueprint("users")
+ blueprint.integer("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', statements[0]
)
- blueprint = Blueprint('users')
- blueprint.integer('foo', True)
+ blueprint = Blueprint("users")
+ blueprint.integer("foo", True)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL PRIMARY KEY AUTOINCREMENT',
- statements[0]
+ statements[0],
)
def test_adding_medium_integer(self):
- blueprint = Blueprint('users')
- blueprint.integer('foo')
+ blueprint = Blueprint("users")
+ blueprint.integer("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', statements[0]
)
def test_adding_tiny_integer(self):
- blueprint = Blueprint('users')
- blueprint.integer('foo')
+ blueprint = Blueprint("users")
+ blueprint.integer("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', statements[0]
)
def test_adding_small_integer(self):
- blueprint = Blueprint('users')
- blueprint.integer('foo')
+ blueprint = Blueprint("users")
+ blueprint.integer("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" INTEGER NOT NULL', statements[0]
)
def test_adding_float(self):
- blueprint = Blueprint('users')
- blueprint.float('foo', 5, 2)
+ blueprint = Blueprint("users")
+ blueprint.float("foo", 5, 2)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" FLOAT NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" FLOAT NOT NULL', statements[0]
)
def test_adding_double(self):
- blueprint = Blueprint('users')
- blueprint.double('foo', 15, 8)
+ blueprint = Blueprint("users")
+ blueprint.double("foo", 15, 8)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" FLOAT NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" FLOAT NOT NULL', statements[0]
)
def test_adding_decimal(self):
- blueprint = Blueprint('users')
- blueprint.decimal('foo', 5, 2)
+ blueprint = Blueprint("users")
+ blueprint.decimal("foo", 5, 2)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" NUMERIC NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" NUMERIC NOT NULL', statements[0]
)
def test_adding_boolean(self):
- blueprint = Blueprint('users')
- blueprint.boolean('foo')
+ blueprint = Blueprint("users")
+ blueprint.boolean("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" TINYINT NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" TINYINT NOT NULL', statements[0]
)
def test_adding_enum(self):
- blueprint = Blueprint('users')
- blueprint.enum('foo', ['bar', 'baz'])
+ blueprint = Blueprint("users")
+ blueprint.enum("foo", ["bar", "baz"])
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" VARCHAR NOT NULL', statements[0]
)
def test_adding_date(self):
- blueprint = Blueprint('users')
- blueprint.date('foo')
+ blueprint = Blueprint("users")
+ blueprint.date("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" DATE NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" DATE NOT NULL', statements[0]
)
def test_adding_datetime(self):
- blueprint = Blueprint('users')
- blueprint.datetime('foo')
+ blueprint = Blueprint("users")
+ blueprint.datetime("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" DATETIME NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" DATETIME NOT NULL', statements[0]
)
def test_adding_time(self):
- blueprint = Blueprint('users')
- blueprint.time('foo')
+ blueprint = Blueprint("users")
+ blueprint.time("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" TIME NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" TIME NOT NULL', statements[0]
)
def test_adding_timestamp(self):
- blueprint = Blueprint('users')
- blueprint.timestamp('foo')
+ blueprint = Blueprint("users")
+ blueprint.timestamp("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" DATETIME NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" DATETIME NOT NULL', statements[0]
)
def test_adding_timestamp_with_current(self):
- blueprint = Blueprint('users')
- blueprint.timestamp('foo').use_current()
+ blueprint = Blueprint("users")
+ blueprint.timestamp("foo").use_current()
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
'ALTER TABLE "users" ADD COLUMN "foo" DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL',
- statements[0]
+ statements[0],
)
def test_adding_timestamps(self):
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.timestamps()
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(2, len(statements))
expected = [
'ALTER TABLE "users" ADD COLUMN "created_at" DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL',
- 'ALTER TABLE "users" ADD COLUMN "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL'
+ 'ALTER TABLE "users" ADD COLUMN "updated_at" DATETIME DEFAULT CURRENT_TIMESTAMP NOT NULL',
]
- self.assertEqual(
- expected,
- statements
- )
+ self.assertEqual(expected, statements)
def test_adding_timestamps_not_current(self):
- blueprint = Blueprint('users')
+ blueprint = Blueprint("users")
blueprint.timestamps(use_current=False)
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(2, len(statements))
expected = [
'ALTER TABLE "users" ADD COLUMN "created_at" DATETIME NOT NULL',
- 'ALTER TABLE "users" ADD COLUMN "updated_at" DATETIME NOT NULL'
+ 'ALTER TABLE "users" ADD COLUMN "updated_at" DATETIME NOT NULL',
]
- self.assertEqual(
- expected,
- statements
- )
+ self.assertEqual(expected, statements)
def test_adding_binary(self):
- blueprint = Blueprint('users')
- blueprint.binary('foo')
+ blueprint = Blueprint("users")
+ blueprint.binary("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" BLOB NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" BLOB NOT NULL', statements[0]
)
def test_adding_json(self):
- blueprint = Blueprint('users')
- blueprint.json('foo')
+ blueprint = Blueprint("users")
+ blueprint.json("foo")
statements = blueprint.to_sql(self.get_connection(), self.get_grammar())
self.assertEqual(1, len(statements))
self.assertEqual(
- 'ALTER TABLE "users" ADD COLUMN "foo" TEXT NOT NULL',
- statements[0]
+ 'ALTER TABLE "users" ADD COLUMN "foo" TEXT NOT NULL', statements[0]
)
def get_connection(self):
diff --git a/tests/schema/integrations/__init__.py b/tests/schema/integrations/__init__.py
index 3823997c..b8724bb2 100644
--- a/tests/schema/integrations/__init__.py
+++ b/tests/schema/integrations/__init__.py
@@ -1,13 +1,19 @@
# -*- coding: utf-8 -*-
from orator import Model
-from orator.orm import has_one, has_many, belongs_to, belongs_to_many, morph_to, morph_many
+from orator.orm import (
+ has_one,
+ has_many,
+ belongs_to,
+ belongs_to_many,
+ morph_to,
+ morph_many,
+)
from orator import QueryExpression
from orator.dbal.exceptions import ColumnDoesNotExist
class IntegrationTestCase(object):
-
@classmethod
def setUpClass(cls):
Model.set_connection_resolver(cls.get_connection_resolver())
@@ -22,173 +28,205 @@ def tearDownClass(cls):
def setUp(self):
with self.connection().transaction():
- self.schema().drop_if_exists('photos')
- self.schema().drop_if_exists('posts')
- self.schema().drop_if_exists('friends')
- self.schema().drop_if_exists('users')
-
- with self.schema().create('users') as table:
- table.increments('id')
- table.string('email').unique()
- table.integer('votes').default(0)
+ self.schema().drop_if_exists("photos")
+ self.schema().drop_if_exists("posts")
+ self.schema().drop_if_exists("friends")
+ self.schema().drop_if_exists("users")
+
+ with self.schema().create("users") as table:
+ table.increments("id")
+ table.string("email").unique()
+ table.integer("votes").default(0)
table.timestamps(use_current=True)
- with self.schema().create('friends') as table:
- table.integer('user_id', unsigned=True)
- table.integer('friend_id', unsigned=True)
-
- table.foreign('user_id').references('id').on('users').on_delete('cascade')
- table.foreign('friend_id').references('id').on('users').on_delete('cascade')
-
- with self.schema().create('posts') as table:
- table.increments('id')
- table.integer('user_id', unsigned=True)
- table.string('name').unique()
- table.enum('status', ['draft', 'published']).default('draft').nullable()
- table.string('default').default(0)
- table.string('tag').nullable().default('tag')
+ with self.schema().create("friends") as table:
+ table.integer("user_id", unsigned=True)
+ table.integer("friend_id", unsigned=True)
+
+ table.foreign("user_id").references("id").on("users").on_delete(
+ "cascade"
+ )
+ table.foreign("friend_id").references("id").on("users").on_delete(
+ "cascade"
+ )
+
+ with self.schema().create("posts") as table:
+ table.increments("id")
+ table.integer("user_id", unsigned=True)
+ table.string("name").unique()
+ table.enum("status", ["draft", "published"]).default("draft").nullable()
+ table.string("default").default(0)
+ table.string("tag").nullable().default("tag")
table.timestamps(use_current=True)
- table.foreign('user_id', 'users_foreign_key').references('id').on('users')
+ table.foreign("user_id", "users_foreign_key").references("id").on(
+ "users"
+ )
- with self.schema().create('photos') as table:
- table.increments('id')
- table.morphs('imageable')
- table.string('name')
+ with self.schema().create("photos") as table:
+ table.increments("id")
+ table.morphs("imageable")
+ table.string("name")
table.timestamps(use_current=True)
for i in range(10):
- user = User.create(email='user%d@foo.com' % (i + 1))
+ user = User.create(email="user%d@foo.com" % (i + 1))
for j in range(10):
- post = Post(name='User %d Post %d' % (user.id, j + 1))
+ post = Post(name="User %d Post %d" % (user.id, j + 1))
user.posts().save(post)
def tearDown(self):
- self.schema().drop('photos')
- self.schema().drop('posts')
- self.schema().drop('friends')
- self.schema().drop('users')
+ self.schema().drop("photos")
+ self.schema().drop("posts")
+ self.schema().drop("friends")
+ self.schema().drop("users")
def test_foreign_keys_creation(self):
- posts_foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts')
- friends_foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('friends')
+ posts_foreign_keys = (
+ self.connection().get_schema_manager().list_table_foreign_keys("posts")
+ )
+ friends_foreign_keys = (
+ self.connection().get_schema_manager().list_table_foreign_keys("friends")
+ )
- self.assertEqual('users_foreign_key', posts_foreign_keys[0].get_name())
+ self.assertEqual("users_foreign_key", posts_foreign_keys[0].get_name())
self.assertEqual(
- ['friends_friend_id_foreign', 'friends_user_id_foreign'],
- sorted([f.get_name() for f in friends_foreign_keys])
+ ["friends_friend_id_foreign", "friends_user_id_foreign"],
+ sorted([f.get_name() for f in friends_foreign_keys]),
)
def test_add_columns(self):
- with self.schema().table('posts') as table:
- table.text('content').nullable()
- table.integer('votes').default(QueryExpression(0))
+ with self.schema().table("posts") as table:
+ table.text("content").nullable()
+ table.integer("votes").default(QueryExpression(0))
user = User.find(1)
- post = user.posts().order_by('id', 'asc').first()
+ post = user.posts().order_by("id", "asc").first()
- self.assertEqual('User 1 Post 1', post.name)
+ self.assertEqual("User 1 Post 1", post.name)
self.assertEqual(0, post.votes)
def test_remove_columns(self):
- with self.schema().table('posts') as table:
- table.drop_column('name')
+ with self.schema().table("posts") as table:
+ table.drop_column("name")
- self.assertRaises(ColumnDoesNotExist, self.connection().get_column, 'posts', 'name')
+ self.assertRaises(
+ ColumnDoesNotExist, self.connection().get_column, "posts", "name"
+ )
user = User.find(1)
- post = user.posts().order_by('id', 'asc').first()
+ post = user.posts().order_by("id", "asc").first()
- self.assertFalse(hasattr(post, 'name'))
+ self.assertFalse(hasattr(post, "name"))
def test_rename_columns(self):
- old_foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts')
+ old_foreign_keys = (
+ self.connection().get_schema_manager().list_table_foreign_keys("posts")
+ )
- with self.schema().table('posts') as table:
- table.rename_column('name', 'title')
+ with self.schema().table("posts") as table:
+ table.rename_column("name", "title")
- self.assertRaises(ColumnDoesNotExist, self.connection().get_column, 'posts', 'name')
- self.assertIsNotNone(self.connection().get_column('posts', 'title'))
+ self.assertRaises(
+ ColumnDoesNotExist, self.connection().get_column, "posts", "name"
+ )
+ self.assertIsNotNone(self.connection().get_column("posts", "title"))
- foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts')
+ foreign_keys = (
+ self.connection().get_schema_manager().list_table_foreign_keys("posts")
+ )
self.assertEqual(len(foreign_keys), len(old_foreign_keys))
user = User.find(1)
- post = user.posts().order_by('id', 'asc').first()
+ post = user.posts().order_by("id", "asc").first()
- self.assertEqual('User 1 Post 1', post.title)
+ self.assertEqual("User 1 Post 1", post.title)
def test_rename_columns_with_index(self):
- indexes = self.connection().get_schema_manager().list_table_indexes('users')
- old_foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts')
+ indexes = self.connection().get_schema_manager().list_table_indexes("users")
+ old_foreign_keys = (
+ self.connection().get_schema_manager().list_table_foreign_keys("posts")
+ )
- index = indexes['users_email_unique']
- self.assertEqual(['email'], index.get_columns())
+ index = indexes["users_email_unique"]
+ self.assertEqual(["email"], index.get_columns())
self.assertTrue(index.is_unique())
- with self.schema().table('users') as table:
- table.rename_column('email', 'email_address')
+ with self.schema().table("users") as table:
+ table.rename_column("email", "email_address")
- self.assertRaises(ColumnDoesNotExist, self.connection().get_column, 'users', 'email')
- self.assertIsNotNone(self.connection().get_column('users', 'email_address'))
- foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts')
+ self.assertRaises(
+ ColumnDoesNotExist, self.connection().get_column, "users", "email"
+ )
+ self.assertIsNotNone(self.connection().get_column("users", "email_address"))
+ foreign_keys = (
+ self.connection().get_schema_manager().list_table_foreign_keys("posts")
+ )
self.assertEqual(len(foreign_keys), len(old_foreign_keys))
- indexes = self.connection().get_schema_manager().list_table_indexes('users')
+ indexes = self.connection().get_schema_manager().list_table_indexes("users")
- index = indexes['users_email_unique']
- self.assertEqual('users_email_unique', index.get_name())
- self.assertEqual(['email_address'], index.get_columns())
+ index = indexes["users_email_unique"]
+ self.assertEqual("users_email_unique", index.get_name())
+ self.assertEqual(["email_address"], index.get_columns())
self.assertTrue(index.is_unique())
def test_rename_columns_with_foreign_keys(self):
- old_foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts')
+ old_foreign_keys = (
+ self.connection().get_schema_manager().list_table_foreign_keys("posts")
+ )
old_foreign_key = old_foreign_keys[0]
- self.assertEqual(['user_id'], old_foreign_key.get_local_columns())
- self.assertEqual(['id'], old_foreign_key.get_foreign_columns())
- self.assertEqual('users', old_foreign_key.get_foreign_table_name())
+ self.assertEqual(["user_id"], old_foreign_key.get_local_columns())
+ self.assertEqual(["id"], old_foreign_key.get_foreign_columns())
+ self.assertEqual("users", old_foreign_key.get_foreign_table_name())
- with self.schema().table('posts') as table:
- table.rename_column('user_id', 'my_user_id')
+ with self.schema().table("posts") as table:
+ table.rename_column("user_id", "my_user_id")
- self.assertRaises(ColumnDoesNotExist, self.connection().get_column, 'posts', 'user_id')
- self.assertIsNotNone(self.connection().get_column('posts', 'my_user_id'))
- foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts')
+ self.assertRaises(
+ ColumnDoesNotExist, self.connection().get_column, "posts", "user_id"
+ )
+ self.assertIsNotNone(self.connection().get_column("posts", "my_user_id"))
+ foreign_keys = (
+ self.connection().get_schema_manager().list_table_foreign_keys("posts")
+ )
self.assertEqual(len(foreign_keys), len(old_foreign_keys))
foreign_key = foreign_keys[0]
- self.assertEqual(['my_user_id'], foreign_key.get_local_columns())
- self.assertEqual(['id'], foreign_key.get_foreign_columns())
- self.assertEqual('users', foreign_key.get_foreign_table_name())
+ self.assertEqual(["my_user_id"], foreign_key.get_local_columns())
+ self.assertEqual(["id"], foreign_key.get_foreign_columns())
+ self.assertEqual("users", foreign_key.get_foreign_table_name())
def test_change_columns(self):
- with self.schema().table('posts') as table:
- table.integer('votes').default(0)
+ with self.schema().table("posts") as table:
+ table.integer("votes").default(0)
- indexes = self.connection().get_schema_manager().list_table_indexes('posts')
- old_foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts')
+ indexes = self.connection().get_schema_manager().list_table_indexes("posts")
+ old_foreign_keys = (
+ self.connection().get_schema_manager().list_table_foreign_keys("posts")
+ )
- self.assertIn('posts_name_unique', indexes)
- self.assertEqual(['name'], indexes['posts_name_unique'].get_columns())
- self.assertTrue(indexes['posts_name_unique'].is_unique())
+ self.assertIn("posts_name_unique", indexes)
+ self.assertEqual(["name"], indexes["posts_name_unique"].get_columns())
+ self.assertTrue(indexes["posts_name_unique"].is_unique())
post = Post.find(1)
self.assertEqual(0, post.votes)
- with self.schema().table('posts') as table:
- table.string('name').nullable().change()
- table.string('votes').default('0').change()
- table.string('tag').default('new').change()
+ with self.schema().table("posts") as table:
+ table.string("name").nullable().change()
+ table.string("votes").default("0").change()
+ table.string("tag").default("new").change()
- name_column = self.connection().get_column('posts', 'name')
- votes_column = self.connection().get_column('posts', 'votes')
- status_column = self.connection().get_column('posts', 'status')
- tag_column = self.connection().get_column('posts', 'tag')
+ name_column = self.connection().get_column("posts", "name")
+ votes_column = self.connection().get_column("posts", "votes")
+ status_column = self.connection().get_column("posts", "status")
+ tag_column = self.connection().get_column("posts", "tag")
self.assertFalse(name_column.get_notnull())
self.assertTrue(votes_column.get_notnull())
self.assertEqual("0", votes_column.get_default())
@@ -199,44 +237,50 @@ def test_change_columns(self):
self.assertFalse(tag_column.get_notnull())
self.assertEqual("new", tag_column.get_default())
- foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('posts')
+ foreign_keys = (
+ self.connection().get_schema_manager().list_table_foreign_keys("posts")
+ )
self.assertEqual(len(foreign_keys), len(old_foreign_keys))
- indexes = self.connection().get_schema_manager().list_table_indexes('posts')
+ indexes = self.connection().get_schema_manager().list_table_indexes("posts")
- self.assertIn('posts_name_unique', indexes)
- self.assertEqual(['name'], indexes['posts_name_unique'].get_columns())
- self.assertTrue(indexes['posts_name_unique'].is_unique())
+ self.assertIn("posts_name_unique", indexes)
+ self.assertEqual(["name"], indexes["posts_name_unique"].get_columns())
+ self.assertTrue(indexes["posts_name_unique"].is_unique())
post = Post.find(1)
- self.assertEqual('0', post.votes)
+ self.assertEqual("0", post.votes)
- with self.schema().table('users') as table:
- table.big_integer('votes').change()
+ with self.schema().table("users") as table:
+ table.big_integer("votes").change()
def test_cascading(self):
- user = User.create(email='john@doe.com')
- friend = User.create(email='jane@doe.com')
- another_friend = User.create(email='another@doe.com')
+ user = User.create(email="john@doe.com")
+ friend = User.create(email="jane@doe.com")
+ another_friend = User.create(email="another@doe.com")
user.friends().attach(friend)
user.friends().attach(another_friend)
user.delete()
- self.assertEqual(0, user.get_connection_resolver().connection().table('friends').count())
+ self.assertEqual(
+ 0, user.get_connection_resolver().connection().table("friends").count()
+ )
# Altering users table
- with self.schema().table('users') as table:
- table.string('email', 50).change()
+ with self.schema().table("users") as table:
+ table.string("email", 50).change()
- user = User.create(email='john@doe.com')
+ user = User.create(email="john@doe.com")
user.friends().attach(friend)
user.friends().attach(another_friend)
user.delete()
- self.assertEqual(0, user.get_connection_resolver().connection().table('friends').count())
+ self.assertEqual(
+ 0, user.get_connection_resolver().connection().table("friends").count()
+ )
def grammar(self):
return self.connection().get_default_query_grammar()
@@ -252,19 +296,19 @@ class User(Model):
__guarded__ = []
- @belongs_to_many('friends', 'user_id', 'friend_id')
+ @belongs_to_many("friends", "user_id", "friend_id")
def friends(self):
return User
- @has_many('user_id')
+ @has_many("user_id")
def posts(self):
return Post
- @has_one('user_id')
+ @has_one("user_id")
def post(self):
return Post
- @morph_many('imageable')
+ @morph_many("imageable")
def photos(self):
return Photo
@@ -273,11 +317,11 @@ class Post(Model):
__guarded__ = []
- @belongs_to('user_id')
+ @belongs_to("user_id")
def user(self):
return User
- @morph_many('imageable')
+ @morph_many("imageable")
def photos(self):
return Photo
diff --git a/tests/schema/integrations/test_mysql.py b/tests/schema/integrations/test_mysql.py
index 1e1c93e1..d1e1b7f8 100644
--- a/tests/schema/integrations/test_mysql.py
+++ b/tests/schema/integrations/test_mysql.py
@@ -8,7 +8,6 @@
class SchemaBuilderMySQLIntegrationTestCase(IntegrationTestCase, OratorTestCase):
-
@classmethod
def get_connection_resolver(cls):
return DatabaseIntegrationConnectionResolver()
@@ -22,28 +21,26 @@ def connection(self, name=None):
if self._connection:
return self._connection
- ci = os.environ.get('CI', False)
+ ci = os.environ.get("CI", False)
if ci:
- database = 'orator_test'
- user = 'root'
- password = ''
+ database = "orator_test"
+ user = "root"
+ password = ""
else:
- database = 'orator_test'
- user = 'orator'
- password = 'orator'
+ database = "orator_test"
+ user = "orator"
+ password = "orator"
self._connection = MySQLConnection(
- MySQLConnector().connect({
- 'database': database,
- 'user': user,
- 'password': password
- })
+ MySQLConnector().connect(
+ {"database": database, "user": user, "password": password}
+ )
)
return self._connection
def get_default_connection(self):
- return 'default'
+ return "default"
def set_default_connection(self, name):
pass
diff --git a/tests/schema/integrations/test_postgres.py b/tests/schema/integrations/test_postgres.py
index a54688dc..fe141e35 100644
--- a/tests/schema/integrations/test_postgres.py
+++ b/tests/schema/integrations/test_postgres.py
@@ -8,7 +8,6 @@
class SchemaBuilderPostgresIntegrationTestCase(IntegrationTestCase, OratorTestCase):
-
@classmethod
def get_connection_resolver(cls):
return DatabaseIntegrationConnectionResolver()
@@ -22,28 +21,26 @@ def connection(self, name=None):
if self._connection:
return self._connection
- ci = os.environ.get('CI', False)
+ ci = os.environ.get("CI", False)
if ci:
- database = 'orator_test'
- user = 'postgres'
+ database = "orator_test"
+ user = "postgres"
password = None
else:
- database = 'orator_test'
- user = 'orator'
- password = 'orator'
+ database = "orator_test"
+ user = "orator"
+ password = "orator"
self._connection = PostgresConnection(
- PostgresConnector().connect({
- 'database': database,
- 'user': user,
- 'password': password
- })
+ PostgresConnector().connect(
+ {"database": database, "user": user, "password": password}
+ )
)
return self._connection
def get_default_connection(self):
- return 'default'
+ return "default"
def set_default_connection(self, name):
pass
diff --git a/tests/schema/integrations/test_sqlite.py b/tests/schema/integrations/test_sqlite.py
index 17ffab80..5bb2e6c8 100644
--- a/tests/schema/integrations/test_sqlite.py
+++ b/tests/schema/integrations/test_sqlite.py
@@ -8,7 +8,6 @@
class SchemaBuilderSQLiteIntegrationTestCase(IntegrationTestCase, OratorTestCase):
-
@classmethod
def get_connection_resolver(cls):
return DatabaseIntegrationConnectionResolver()
@@ -17,78 +16,87 @@ def test_foreign_keys_creation(self):
pass
def test_rename_columns_with_foreign_keys(self):
- super(SchemaBuilderSQLiteIntegrationTestCase, self).test_rename_columns_with_foreign_keys()
+ super(
+ SchemaBuilderSQLiteIntegrationTestCase, self
+ ).test_rename_columns_with_foreign_keys()
- old_foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('friends')
+ old_foreign_keys = (
+ self.connection().get_schema_manager().list_table_foreign_keys("friends")
+ )
- with self.schema().table('friends') as table:
- table.rename_column('user_id', 'my_user_id')
+ with self.schema().table("friends") as table:
+ table.rename_column("user_id", "my_user_id")
- foreign_keys = self.connection().get_schema_manager().list_table_foreign_keys('friends')
+ foreign_keys = (
+ self.connection().get_schema_manager().list_table_foreign_keys("friends")
+ )
self.assertEqual(len(old_foreign_keys), len(foreign_keys))
class SchemaBuilderSQLiteIntegrationCascadingTestCase(OratorTestCase):
-
@classmethod
def setUpClass(cls):
- Model.set_connection_resolver(DatabaseIntegrationConnectionWithoutForeignKeysResolver())
+ Model.set_connection_resolver(
+ DatabaseIntegrationConnectionWithoutForeignKeysResolver()
+ )
@classmethod
def tearDownClass(cls):
Model.unset_connection_resolver()
def setUp(self):
- with self.schema().create('users') as table:
- table.increments('id')
- table.string('email').unique()
+ with self.schema().create("users") as table:
+ table.increments("id")
+ table.string("email").unique()
table.timestamps()
- with self.schema().create('friends') as table:
- table.integer('user_id')
- table.integer('friend_id')
+ with self.schema().create("friends") as table:
+ table.integer("user_id")
+ table.integer("friend_id")
- table.foreign('user_id').references('id').on('users').on_delete('cascade')
- table.foreign('friend_id').references('id').on('users').on_delete('cascade')
+ table.foreign("user_id").references("id").on("users").on_delete("cascade")
+ table.foreign("friend_id").references("id").on("users").on_delete("cascade")
- with self.schema().create('posts') as table:
- table.increments('id')
- table.integer('user_id')
- table.string('name').unique()
+ with self.schema().create("posts") as table:
+ table.increments("id")
+ table.integer("user_id")
+ table.string("name").unique()
table.timestamps()
- table.foreign('user_id').references('id').on('users')
+ table.foreign("user_id").references("id").on("users")
- with self.schema().create('photos') as table:
- table.increments('id')
- table.morphs('imageable')
- table.string('name')
+ with self.schema().create("photos") as table:
+ table.increments("id")
+ table.morphs("imageable")
+ table.string("name")
table.timestamps()
for i in range(10):
- user = User.create(email='user%d@foo.com' % (i + 1))
+ user = User.create(email="user%d@foo.com" % (i + 1))
for j in range(10):
- post = Post(name='User %d Post %d' % (user.id, j + 1))
+ post = Post(name="User %d Post %d" % (user.id, j + 1))
user.posts().save(post)
def tearDown(self):
- self.schema().drop('photos')
- self.schema().drop('posts')
- self.schema().drop('friends')
- self.schema().drop('users')
+ self.schema().drop("photos")
+ self.schema().drop("posts")
+ self.schema().drop("friends")
+ self.schema().drop("users")
def test_cascading(self):
- user = User.create(email='john@doe.com')
- friend = User.create(email='jane@doe.com')
- another_friend = User.create(email='another@doe.com')
+ user = User.create(email="john@doe.com")
+ friend = User.create(email="jane@doe.com")
+ another_friend = User.create(email="another@doe.com")
user.friends().attach(friend)
user.friends().attach(another_friend)
user.delete()
- self.assertEqual(2, user.get_connection_resolver().connection().table('friends').count())
+ self.assertEqual(
+ 2, user.get_connection_resolver().connection().table("friends").count()
+ )
def connection(self):
return Model.get_connection_resolver().connection()
@@ -105,18 +113,22 @@ def connection(self, name=None):
if self._connection:
return self._connection
- self._connection = SQLiteConnection(SQLiteConnector().connect({'database': ':memory:'}))
+ self._connection = SQLiteConnection(
+ SQLiteConnector().connect({"database": ":memory:"})
+ )
return self._connection
def get_default_connection(self):
- return 'default'
+ return "default"
def set_default_connection(self, name):
pass
-class DatabaseIntegrationConnectionWithoutForeignKeysResolver(DatabaseIntegrationConnectionResolver):
+class DatabaseIntegrationConnectionWithoutForeignKeysResolver(
+ DatabaseIntegrationConnectionResolver
+):
_connection = None
@@ -125,10 +137,7 @@ def connection(self, name=None):
return self._connection
self._connection = SQLiteConnection(
- SQLiteConnector().connect(
- {'database': ':memory:',
- 'foreign_keys': False}
- )
+ SQLiteConnector().connect({"database": ":memory:", "foreign_keys": False})
)
return self._connection
diff --git a/tests/schema/test_blueprint.py b/tests/schema/test_blueprint.py
index 3eb6c737..94af46c0 100644
--- a/tests/schema/test_blueprint.py
+++ b/tests/schema/test_blueprint.py
@@ -8,38 +8,39 @@
class SchemaBuilderTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
def test_to_sql_runs_commands_from_blueprint(self):
conn = flexmock(Connection(None))
- conn.should_receive('statement').once().with_args('foo')
- conn.should_receive('statement').once().with_args('bar')
+ conn.should_receive("statement").once().with_args("foo")
+ conn.should_receive("statement").once().with_args("bar")
grammar = flexmock(SchemaGrammar(conn))
- blueprint = flexmock(Blueprint('table'))
- blueprint.should_receive('to_sql').once().with_args(conn, grammar).and_return(['foo', 'bar'])
+ blueprint = flexmock(Blueprint("table"))
+ blueprint.should_receive("to_sql").once().with_args(conn, grammar).and_return(
+ ["foo", "bar"]
+ )
blueprint.build(conn, grammar)
def test_index_default_names(self):
- blueprint = Blueprint('users')
- blueprint.unique(['foo', 'bar'])
+ blueprint = Blueprint("users")
+ blueprint.unique(["foo", "bar"])
commands = blueprint.get_commands()
- self.assertEqual('users_foo_bar_unique', commands[0].index)
+ self.assertEqual("users_foo_bar_unique", commands[0].index)
- blueprint = Blueprint('users')
- blueprint.index('foo')
+ blueprint = Blueprint("users")
+ blueprint.index("foo")
commands = blueprint.get_commands()
- self.assertEqual('users_foo_index', commands[0].index)
+ self.assertEqual("users_foo_index", commands[0].index)
def test_drop_index_default_names(self):
- blueprint = Blueprint('users')
- blueprint.drop_unique(['foo', 'bar'])
+ blueprint = Blueprint("users")
+ blueprint.drop_unique(["foo", "bar"])
commands = blueprint.get_commands()
- self.assertEqual('users_foo_bar_unique', commands[0].index)
+ self.assertEqual("users_foo_bar_unique", commands[0].index)
- blueprint = Blueprint('users')
- blueprint.drop_index(['foo'])
+ blueprint = Blueprint("users")
+ blueprint.drop_index(["foo"])
commands = blueprint.get_commands()
- self.assertEqual('users_foo_index', commands[0].index)
+ self.assertEqual("users_foo_index", commands[0].index)
diff --git a/tests/schema/test_builder.py b/tests/schema/test_builder.py
index bf88f2cb..cd3e7c62 100644
--- a/tests/schema/test_builder.py
+++ b/tests/schema/test_builder.py
@@ -7,17 +7,18 @@
class SchemaBuilderTestCase(OratorTestCase):
-
def tearDown(self):
flexmock_teardown()
def test_has_table_correctly_calls_grammar(self):
connection = flexmock(Connection(None))
grammar = flexmock()
- connection.should_receive('get_schema_grammar').and_return(grammar)
+ connection.should_receive("get_schema_grammar").and_return(grammar)
builder = SchemaBuilder(connection)
- grammar.should_receive('compile_table_exists').once().and_return('sql')
- connection.should_receive('get_table_prefix').once().and_return('prefix_')
- connection.should_receive('select').once().with_args('sql', ['prefix_table']).and_return(['prefix_table'])
+ grammar.should_receive("compile_table_exists").once().and_return("sql")
+ connection.should_receive("get_table_prefix").once().and_return("prefix_")
+ connection.should_receive("select").once().with_args(
+ "sql", ["prefix_table"]
+ ).and_return(["prefix_table"])
- self.assertTrue(builder.has_table('table'))
+ self.assertTrue(builder.has_table("table"))
diff --git a/tests/seeds/__init__.py b/tests/seeds/__init__.py
index 633f8661..40a96afc 100644
--- a/tests/seeds/__init__.py
+++ b/tests/seeds/__init__.py
@@ -1,2 +1 @@
# -*- coding: utf-8 -*-
-
diff --git a/tests/seeds/test_seeder.py b/tests/seeds/test_seeder.py
index 1a9a6bbd..fdc7ba67 100644
--- a/tests/seeds/test_seeder.py
+++ b/tests/seeds/test_seeder.py
@@ -10,33 +10,32 @@
class SeederTestCase(OratorTestCase):
-
def tearDown(self):
super(SeederTestCase, self).tearDown()
flexmock_teardown()
def test_call_resolve_class_and_calls_run(self):
resolver_mock = flexmock(DatabaseManager)
- resolver_mock.should_receive('connection').and_return({})
+ resolver_mock.should_receive("connection").and_return({})
resolver = flexmock(DatabaseManager({}))
connection = flexmock(Connection(None))
- resolver.should_receive('connection').with_args(None).and_return(connection)
+ resolver.should_receive("connection").with_args(None).and_return(connection)
seeder = Seeder(resolver)
- command = flexmock(Command('foo'))
- command.should_receive('line').once()
+ command = flexmock(Command("foo"))
+ command.should_receive("line").once()
seeder.set_command(command)
child = flexmock()
- child.__name__ = 'foo'
- child.should_receive('set_command').once().with_args(command)
- child.should_receive('set_connection_resolver').once().with_args(resolver)
- child.should_receive('run').once()
+ child.__name__ = "foo"
+ child.should_receive("set_command").once().with_args(command)
+ child.should_receive("set_connection_resolver").once().with_args(resolver)
+ child.should_receive("run").once()
seeder.call(child)
class Command(BaseCommand):
- resolver = 'bar'
+ resolver = "bar"
def get_output(self):
- return 'foo'
+ return "foo"
diff --git a/tests/support/test_collection.py b/tests/support/test_collection.py
index 171c2ac7..376cda08 100644
--- a/tests/support/test_collection.py
+++ b/tests/support/test_collection.py
@@ -5,33 +5,32 @@
class CollectionTestCase(OratorTestCase):
-
def test_first_returns_first_item_in_collection(self):
- c = Collection(['foo', 'bar'])
+ c = Collection(["foo", "bar"])
- self.assertEqual('foo', c.first())
+ self.assertEqual("foo", c.first())
def test_last_returns_last_item_in_collection(self):
- c = Collection(['foo', 'bar'])
+ c = Collection(["foo", "bar"])
- self.assertEqual('bar', c.last())
+ self.assertEqual("bar", c.last())
def test_pop_removes_and_returns_last_item_or_specified_index(self):
- c = Collection(['foo', 'bar'])
+ c = Collection(["foo", "bar"])
- self.assertEqual('bar', c.pop())
- self.assertEqual('foo', c.last())
+ self.assertEqual("bar", c.pop())
+ self.assertEqual("foo", c.last())
- c = Collection(['foo', 'bar'])
+ c = Collection(["foo", "bar"])
- self.assertEqual('foo', c.pop(0))
- self.assertEqual('bar', c.first())
+ self.assertEqual("foo", c.pop(0))
+ self.assertEqual("bar", c.first())
def test_shift_removes_and_returns_first_item(self):
- c = Collection(['foo', 'bar'])
+ c = Collection(["foo", "bar"])
- self.assertEqual('foo', c.shift())
- self.assertEqual('bar', c.first())
+ self.assertEqual("foo", c.shift())
+ self.assertEqual("bar", c.first())
def test_empty_collection_is_empty(self):
c = Collection()
@@ -41,8 +40,8 @@ def test_empty_collection_is_empty(self):
self.assertTrue(c2.is_empty())
def test_collection_is_constructed(self):
- c = Collection('foo')
- self.assertEqual(['foo'], c.all())
+ c = Collection("foo")
+ self.assertEqual(["foo"], c.all())
c = Collection(2)
self.assertEqual([2], c.all())
@@ -57,25 +56,25 @@ def test_collection_is_constructed(self):
self.assertEqual([], c.all())
def test_offset_access(self):
- c = Collection(['foo', 'bar'])
- self.assertEqual('bar', c[1])
+ c = Collection(["foo", "bar"])
+ self.assertEqual("bar", c[1])
- c[1] = 'baz'
- self.assertEqual('baz', c[1])
+ c[1] = "baz"
+ self.assertEqual("baz", c[1])
del c[0]
- self.assertEqual('baz', c[0])
+ self.assertEqual("baz", c[0])
def test_forget(self):
- c = Collection(['foo', 'bar', 'boom'])
+ c = Collection(["foo", "bar", "boom"])
c.forget(0)
- self.assertEqual('bar', c[0])
+ self.assertEqual("bar", c[0])
c.forget(0, 1)
self.assertTrue(c.is_empty())
def test_get_avg_items_from_collection(self):
- c = Collection([{'foo': 10}, {'foo': 20}])
- self.assertEqual(15, c.avg('foo'))
+ c = Collection([{"foo": 10}, {"foo": 20}])
+ self.assertEqual(15, c.avg("foo"))
c = Collection([1, 2, 3, 4, 5])
self.assertEqual(3, c.avg())
@@ -103,31 +102,31 @@ def test_contains(self):
self.assertFalse(c.contains(lambda x: x > 5))
self.assertIn(3, c)
- c = Collection([{'v': 1}, {'v': 3}, {'v': 5}])
- self.assertTrue(c.contains('v', 1))
- self.assertFalse(c.contains('v', 2))
+ c = Collection([{"v": 1}, {"v": 3}, {"v": 5}])
+ self.assertTrue(c.contains("v", 1))
+ self.assertFalse(c.contains("v", 2))
- obj1 = type('lamdbaobject', (object,), {})()
+ obj1 = type("lamdbaobject", (object,), {})()
obj1.v = 1
- obj2 = type('lamdbaobject', (object,), {})()
+ obj2 = type("lamdbaobject", (object,), {})()
obj2.v = 3
- obj3 = type('lamdbaobject', (object,), {})()
+ obj3 = type("lamdbaobject", (object,), {})()
obj3.v = 5
- c = Collection([{'v': 1}, {'v': 3}, {'v': 5}])
- self.assertTrue(c.contains('v', 1))
- self.assertFalse(c.contains('v', 2))
+ c = Collection([{"v": 1}, {"v": 3}, {"v": 5}])
+ self.assertTrue(c.contains("v", 1))
+ self.assertFalse(c.contains("v", 2))
def test_countable(self):
- c = Collection(['foo', 'bar'])
+ c = Collection(["foo", "bar"])
self.assertEqual(2, c.count())
self.assertEqual(2, len(c))
def test_diff(self):
- c = Collection(['foo', 'bar'])
- self.assertEqual(['foo'], c.diff(Collection(['bar', 'baz'])).all())
+ c = Collection(["foo", "bar"])
+ self.assertEqual(["foo"], c.diff(Collection(["bar", "baz"])).all())
def test_each(self):
- original = ['foo', 'bar', 'baz']
+ original = ["foo", "bar", "baz"]
c = Collection(original)
result = []
@@ -141,35 +140,39 @@ def test_every(self):
self.assertEqual([2, 4, 6], c.every(2, 1).all())
def test_filter(self):
- c = Collection([{'id': 1, 'name': 'hello'}, {'id': 2, 'name': 'world'}])
- self.assertEqual([{'id': 2, 'name': 'world'}], c.filter(lambda item: item['id'] == 2).all())
+ c = Collection([{"id": 1, "name": "hello"}, {"id": 2, "name": "world"}])
+ self.assertEqual(
+ [{"id": 2, "name": "world"}], c.filter(lambda item: item["id"] == 2).all()
+ )
- c = Collection(['', 'hello', '', 'world'])
- self.assertEqual(['hello', 'world'], c.filter().all())
+ c = Collection(["", "hello", "", "world"])
+ self.assertEqual(["hello", "world"], c.filter().all())
def test_where(self):
- c = Collection([{'v': 1}, {'v': 3}, {'v': 2}, {'v': 3}, {'v': 4}])
- self.assertEqual([{'v': 3}, {'v': 3}], c.where('v', 3).all())
+ c = Collection([{"v": 1}, {"v": 3}, {"v": 2}, {"v": 3}, {"v": 4}])
+ self.assertEqual([{"v": 3}, {"v": 3}], c.where("v", 3).all())
def test_implode(self):
- obj1 = type('lamdbaobject', (object,), {})()
- obj1.name = 'john'
- obj1.email = 'foo'
- c = Collection([{'name': 'john', 'email': 'foo'}, {'name': 'jane', 'email': 'bar'}])
- self.assertEqual('foobar', c.implode('email'))
- self.assertEqual('foo,bar', c.implode('email', ','))
+ obj1 = type("lamdbaobject", (object,), {})()
+ obj1.name = "john"
+ obj1.email = "foo"
+ c = Collection(
+ [{"name": "john", "email": "foo"}, {"name": "jane", "email": "bar"}]
+ )
+ self.assertEqual("foobar", c.implode("email"))
+ self.assertEqual("foo,bar", c.implode("email", ","))
- c = Collection(['foo', 'bar'])
- self.assertEqual('foobar', c.implode(''))
- self.assertEqual('foo,bar', c.implode(','))
+ c = Collection(["foo", "bar"])
+ self.assertEqual("foobar", c.implode(""))
+ self.assertEqual("foo,bar", c.implode(","))
def test_lists(self):
- obj1 = type('lamdbaobject', (object,), {})()
- obj1.name = 'john'
- obj1.email = 'foo'
- c = Collection([obj1, {'name': 'jane', 'email': 'bar'}])
- self.assertEqual({'john': 'foo', 'jane': 'bar'}, c.lists('email', 'name'))
- self.assertEqual(['foo', 'bar'], c.pluck('email').all())
+ obj1 = type("lamdbaobject", (object,), {})()
+ obj1.name = "john"
+ obj1.email = "foo"
+ c = Collection([obj1, {"name": "jane", "email": "bar"}])
+ self.assertEqual({"john": "foo", "jane": "bar"}, c.lists("email", "name"))
+ self.assertEqual(["foo", "bar"], c.pluck("email").all())
def test_map(self):
c = Collection([1, 2, 3, 4, 5])
@@ -247,12 +250,9 @@ def test_without(self):
self.assertEqual([1, 2, 3, 4, 5], c.all())
def test_flatten(self):
- c = Collection({'foo': [5, 6], 'bar': 7, 'baz': {'boom': [1, 2, 3, 4]}})
+ c = Collection({"foo": [5, 6], "bar": 7, "baz": {"boom": [1, 2, 3, 4]}})
- self.assertEqual(
- [1, 2, 3, 4, 5, 6, 7],
- c.flatten().sort().all()
- )
+ self.assertEqual([1, 2, 3, 4, 5, 6, 7], c.flatten().sort().all())
c = Collection([1, [2, 3], 4])
self.assertEqual([1, 2, 3, 4], c.flatten().all())
diff --git a/tests/support/test_fluent.py b/tests/support/test_fluent.py
index cf111cc6..4ae07f4d 100644
--- a/tests/support/test_fluent.py
+++ b/tests/support/test_fluent.py
@@ -5,33 +5,34 @@
class FluentTestCase(OratorTestCase):
-
def test_get_method_return_attributes(self):
- fluent = Fluent(name='john')
+ fluent = Fluent(name="john")
- self.assertEqual('john', fluent.get('name'))
- self.assertEqual('default', fluent.get('foo', 'default'))
- self.assertEqual('john', fluent.name)
+ self.assertEqual("john", fluent.get("name"))
+ self.assertEqual("default", fluent.get("foo", "default"))
+ self.assertEqual("john", fluent.name)
self.assertEqual(None, fluent.foo)
def test_set_attributes(self):
fluent = Fluent()
- fluent.name = 'john'
+ fluent.name = "john"
fluent.developer()
fluent.age(25)
- self.assertEqual('john', fluent.name)
+ self.assertEqual("john", fluent.name)
self.assertTrue(fluent.developer)
self.assertEqual(25, fluent.age)
- self.assertEqual({'name': 'john', 'developer': True, 'age': 25}, fluent.get_attributes())
+ self.assertEqual(
+ {"name": "john", "developer": True, "age": 25}, fluent.get_attributes()
+ )
def test_chained_attributes(self):
fluent = Fluent()
fluent.unsigned = False
- fluent.integer('status').unsigned()
+ fluent.integer("status").unsigned()
- self.assertEqual('status', fluent.integer)
+ self.assertEqual("status", fluent.integer)
self.assertTrue(fluent.unsigned)
diff --git a/tests/test_database_manager.py b/tests/test_database_manager.py
index b4bffc6a..1d783033 100644
--- a/tests/test_database_manager.py
+++ b/tests/test_database_manager.py
@@ -8,17 +8,14 @@
class ConnectionTestCase(OratorTestCase):
-
def test_connection_method_create_a_new_connection_if_needed(self):
manager = self._get_manager()
- manager.table('users')
+ manager.table("users")
- manager._make_connection.assert_called_once_with(
- 'sqlite'
- )
+ manager._make_connection.assert_called_once_with("sqlite")
manager._make_connection.reset_mock()
- manager.table('users')
+ manager.table("users")
self.assertFalse(manager._make_connection.called)
def test_manager_uses_factory_to_create_connections(self):
@@ -28,37 +25,30 @@ def test_manager_uses_factory_to_create_connections(self):
manager.connection()
manager._factory.make.assert_called_with(
- {
- 'name': 'sqlite',
- 'driver': 'sqlite',
- 'database': ':memory:'
- }, 'sqlite'
+ {"name": "sqlite", "driver": "sqlite", "database": ":memory:"}, "sqlite"
)
manager._factory.make = original_make
def test_connection_can_select_connections(self):
manager = self._get_manager()
- self.assertEqual(manager.connection(), manager.connection('sqlite'))
- self.assertNotEqual(manager.connection('sqlite'), manager.connection('sqlite2'))
+ self.assertEqual(manager.connection(), manager.connection("sqlite"))
+ self.assertNotEqual(manager.connection("sqlite"), manager.connection("sqlite2"))
def test_dynamic_attribute_gets_connection_attribute(self):
manager = self._get_manager()
- manager.statement('CREATE TABLE users')
+ manager.statement("CREATE TABLE users")
- manager.get_connections()['sqlite'].statement.assert_called_once_with(
- 'CREATE TABLE users'
+ manager.get_connections()["sqlite"].statement.assert_called_once_with(
+ "CREATE TABLE users"
)
def test_default_database_with_one_database(self):
- manager = MockManager({
- 'sqlite': {
- 'driver': 'sqlite',
- 'database': ':memory:'
- }
- }).prepare_mock()
+ manager = MockManager(
+ {"sqlite": {"driver": "sqlite", "database": ":memory:"}}
+ ).prepare_mock()
- self.assertEqual('sqlite', manager.get_default_connection())
+ self.assertEqual("sqlite", manager.get_default_connection())
def test_reconnect(self):
manager = self._get_real_manager()
@@ -73,31 +63,26 @@ def test_set_default_connection_with_none_should_not_overwrite(self):
manager.set_default_connection(None)
- self.assertEqual('sqlite', manager.get_default_connection())
+ self.assertEqual("sqlite", manager.get_default_connection())
def _get_manager(self):
- manager = MockManager({
- 'default': 'sqlite',
- 'sqlite': {
- 'driver': 'sqlite',
- 'database': ':memory:'
- },
- 'sqlite2': {
- 'driver': 'sqlite',
- 'database': ':memory:'
+ manager = MockManager(
+ {
+ "default": "sqlite",
+ "sqlite": {"driver": "sqlite", "database": ":memory:"},
+ "sqlite2": {"driver": "sqlite", "database": ":memory:"},
}
- }).prepare_mock()
+ ).prepare_mock()
return manager
def _get_real_manager(self):
- manager = DatabaseManager({
- 'default': 'sqlite',
- 'sqlite': {
- 'driver': 'sqlite',
- 'database': ':memory:'
+ manager = DatabaseManager(
+ {
+ "default": "sqlite",
+ "sqlite": {"driver": "sqlite", "database": ":memory:"},
}
- })
+ )
return manager
diff --git a/tests/utils.py b/tests/utils.py
index 8f3e09c8..91b0c400 100644
--- a/tests/utils.py
+++ b/tests/utils.py
@@ -10,7 +10,6 @@
class MockConnection(ConnectionInterface):
-
def __init__(self, name=None):
if name:
self.get_name = lambda: name
@@ -25,12 +24,12 @@ def prepare_mock(self, name=None):
self.update = mock.MagicMock()
self.delete = mock.MagicMock()
self.statement = mock.MagicMock()
+ self.select_many = mock.MagicMock()
return self
class MockProcessor(QueryProcessor):
-
def prepare_mock(self):
self.process_select = mock.MagicMock()
self.process_insert_get_id = mock.MagicMock()
@@ -39,7 +38,6 @@ def prepare_mock(self):
class MockManager(DatabaseManager):
-
def prepare_mock(self):
self._make_connection = mock.MagicMock(
side_effect=lambda name: MockConnection(name).prepare_mock()
@@ -49,7 +47,6 @@ def prepare_mock(self):
class MockFactory(ConnectionFactory):
-
def prepare_mock(self):
self.make = mock.MagicMock(return_value=MockConnection().prepare_mock())
@@ -57,18 +54,16 @@ def prepare_mock(self):
class MockQueryBuilder(QueryBuilder):
-
def prepare_mock(self):
- self.from__ = 'foo_table'
+ self.from__ = "foo_table"
return self
class MockModel(Model):
-
def prepare_mock(self):
- self.get_key_name = mock.MagicMock(return_value='foo')
- self.get_table = mock.MagicMock(return_value='foo_table')
- self.get_qualified_key_name = mock.MagicMock(return_value='foo_table.foo')
+ self.get_key_name = mock.MagicMock(return_value="foo")
+ self.get_table = mock.MagicMock(return_value="foo_table")
+ self.get_qualified_key_name = mock.MagicMock(return_value="foo_table.foo")
return self
diff --git a/tox.ini b/tox.ini
index 16805b78..61907978 100644
--- a/tox.ini
+++ b/tox.ini
@@ -1,15 +1,42 @@
[tox]
-envlist = py{27,35}-{pymysql,mysqlclient}
+isolated_build = true
+envlist = py{27,35,36,37}-{pymysql,mysqlclient}
[testenv]
-deps =
- -rtests-requirements.txt
- pymysql: pymysql
- mysqlclient: mysqlclient
-commands = py.test tests/ -sq
-
-[testenv:flake8]
-basepython=python
-deps=flake8
-commands=
- flake8 orator
+whitelist_externals = poetry
+
+[testenv:py27-pymysql]
+commands:
+ poetry run pip install -U pip
+ poetry install -E mysql-python -E pgsql -v
+ poetry run pytest tests/ -sq
+
+[testenv:py27-mysqlclient]
+commands:
+ poetry run pip install -U pip
+ poetry install -E mysql -E pgsql -v
+ poetry run pytest tests/ -sq
+
+[testenv:py35-pymysql]
+commands:
+ poetry run pip install -U pip
+ poetry install -E mysql-python -E pgsql -v
+ poetry run pytest tests/ -sq
+
+[testenv:py35-mysqlclient]
+commands:
+ poetry run pip install -U pip
+ poetry install -E mysql -E pgsql -v
+ poetry run pytest tests/ -sq
+
+[testenv:py36-pymysql]
+commands:
+ poetry run pip install -U pip
+ poetry install -E mysql-python -E pgsql -v
+ poetry run pytest tests/ -sq
+
+[testenv:py36-mysqlclient]
+commands:
+ poetry run pip install -U pip
+ poetry install -E mysql -E pgsql -v
+ poetry run pytest tests/ -sq